blob: 8b0b8dced310b36d72b1d970db481baa69ea9ce7 [file] [log] [blame]
package endpoints
import (
"bytes"
"fmt"
"io/ioutil"
"math/big"
"net/http"
"reflect"
"strings"
"testing"
"time"
"golang.org/x/net/context"
"google.golang.org/appengine"
"google.golang.org/appengine/memcache"
"appengine/aetest"
)
func TestParseToken(t *testing.T) {
tts := []struct {
header, value, want string
}{
{"Authorization", "Bearer token", "token"},
{"Authorization", "bearer foo", "foo"},
{"authorization", "bearer bar", "bar"},
{"Authorization", "OAuth baz", "baz"},
{"authorization", "oauth xxx", "xxx"},
{"Authorization", "Bearer", ""},
{"authorization", "Bearer ", ""},
{"Authorization", "", ""},
{"X-Other-Header", "Bearer token", ""},
{"x-header", "Bearer token", ""},
{"", "", ""},
}
for i, tt := range tts {
h := make(http.Header)
if tt.header != "" {
h.Set(tt.header, tt.value)
}
r := &http.Request{Header: h}
out := parseToken(r)
if out != tt.want {
t.Errorf("%d: parseToken(%v) = %q; want %q", i, h, out, tt.want)
}
}
}
func TestMaxAge(t *testing.T) {
verifyPairs(t,
maxAge("max-age=86400"), 86400,
maxAge("max-age = 7200, must-revalidate"), 7200,
maxAge("public, max-age= 3600"), 3600,
maxAge("max-age=-100"), 0,
maxAge("max-age = 0a1b, must-revalidate"), 0,
maxAge("public, max-age= short"), 0,
maxAge("s-maxage=86400"), 0,
maxAge("max=86400"), 0,
maxAge("public"), 0,
maxAge(""), 0,
)
}
func TestCertExpirationTime(t *testing.T) {
tts := []struct {
cacheControl, age string
want time.Duration
}{
{"max-age=3600", "600", 3000 * time.Second},
{"max-age=600", "", 0},
{"max-age=300", "301", 0},
{"max-age=0", "0", 0},
{"", "600", 0},
{"", "", 0},
}
for i, tt := range tts {
h := make(http.Header)
h.Set("cache-control", tt.cacheControl)
h.Set("age", tt.age)
if out := certExpirationTime(h); out != tt.want {
t.Errorf("%d: certExpirationTime(%v) = %d; want %d", i, h, out, tt.want)
}
}
}
func TestAddBase64Pad(t *testing.T) {
verifyPairs(t,
addBase64Pad("12"), "12==",
addBase64Pad("123"), "123=",
addBase64Pad("1234"), "1234",
addBase64Pad("12345"), "12345",
addBase64Pad(""), "")
}
func TestBase64ToBig(t *testing.T) {
tts := []struct {
in string
want *big.Int
error bool
}{
{"MTI=", new(big.Int).SetBytes([]byte("12")), false},
{"MTI", new(big.Int).SetBytes([]byte("12")), false},
{"", big.NewInt(0), false},
{" ", nil, true},
}
for i, tt := range tts {
out, err := base64ToBig(tt.in)
switch {
case err == nil && !tt.error && tt.want.Cmp(out) != 0:
t.Errorf("%d: base64ToBig(%q) = %v; want %v", i, tt.in, out, tt.want)
case err != nil && !tt.error:
t.Errorf("%d: base64ToBig(%q) = %v; want %v", i, tt.in, err, tt.want)
case err == nil && tt.error:
t.Errorf("%d: base64ToBig(%q) = %v; want error", i, tt.in, out)
}
}
}
func TestZeroPad(t *testing.T) {
in := []byte{1, 2, 3}
padded := zeroPad(in, 5)
want := []byte{0, 0, 1, 2, 3}
if !bytes.Equal(padded, want) {
t.Errorf("zeroPad(%#v, 5) = %#v; want %#v", in, padded, want)
}
}
func TestContains(t *testing.T) {
tts := []struct {
list []string
val string
want bool
}{
{[]string{"test"}, "test", true},
{[]string{"one", "test", "two"}, "test", true},
{[]string{"test"}, "xxx", false},
{[]string{"xxx"}, "test", false},
{[]string{}, "", false},
}
for i, tt := range tts {
res := contains(tt.list, tt.val)
if res != tt.want {
t.Errorf("%d: contains(%#v, %q) = %v; want %v",
i, tt.list, tt.val, res, tt.want)
}
}
}
func TestCachedCertsCacheHit(t *testing.T) {
origTransport := httpTransportFactory
defer func() { httpTransportFactory = origTransport }()
httpTransportFactory = func(c context.Context) http.RoundTripper {
return newTestRoundTripper()
}
req, _, closer := newTestRequest(t, "GET", "/", nil)
defer closer()
nc, err := appengine.Namespace(appengine.NewContext(req), certNamespace)
if err != nil {
t.Fatal(err)
}
tts := []struct {
cacheValue string
want *certsList
}{
{"", nil},
{"{}", &certsList{}},
{`{"keyvalues": [{}]}`, &certsList{[]*certInfo{{}}}},
{`{"keyvalues": [
{"algorithm": "RS256",
"exponent": "123",
"keyid": "some-id",
"modulus": "123"} ]}`,
&certsList{[]*certInfo{{"RS256", "123", "some-id", "123"}}}},
}
ec := NewContext(req)
for i, tt := range tts {
item := &memcache.Item{Key: DefaultCertURI, Value: []byte(tt.cacheValue)}
if err := memcache.Set(nc, item); err != nil {
t.Fatal(err)
}
out, err := cachedCerts(ec)
switch {
case err != nil && tt.want != nil:
t.Errorf("%d: cachedCerts() error %v", i, err)
case err == nil && tt.want == nil:
t.Errorf("%d: cachedCerts() = %#v; want error", i, out)
case err == nil && tt.want != nil && !reflect.DeepEqual(out, tt.want):
t.Errorf("cachedCerts() = %#+v (%T); want %#+v (%T)",
out, out, tt.want, tt.want)
}
}
}
func TestCachedCertsCacheMiss(t *testing.T) {
rt := newTestRoundTripper()
origTransport := httpTransportFactory
defer func() { httpTransportFactory = origTransport }()
httpTransportFactory = func(c context.Context) http.RoundTripper {
return rt
}
req, _, closer := newTestRequest(t, "GET", "/", nil)
defer closer()
nc, err := appengine.Namespace(appengine.NewContext(req), certNamespace)
if err != nil {
t.Fatal(err)
}
ec := NewContext(req)
tts := []*struct {
respStatus int
respContent, cacheControl, age string
want *certsList
shouldCache bool
}{
{200, `{"keyvalues":null}`, "max-age=3600", "600", &certsList{}, true},
{-1, "", "", "", nil, false},
{400, "", "", "", nil, false},
{200, `{"keyvalues":null}`, "", "", &certsList{}, false},
}
for i, tt := range tts {
if tt.respStatus > 0 {
resp := &http.Response{
Status: fmt.Sprintf("%d", tt.respStatus),
StatusCode: tt.respStatus,
Body: ioutil.NopCloser(strings.NewReader(tt.respContent)),
Header: make(http.Header),
}
resp.Header.Set("cache-control", tt.cacheControl)
resp.Header.Set("age", tt.age)
rt.Add(resp)
}
memcache.Delete(nc, DefaultCertURI)
out, err := cachedCerts(ec)
switch {
case err != nil && tt.want != nil:
t.Errorf("%d: cachedCerts() = %v", i, err)
case err == nil && tt.want == nil:
t.Errorf("%d: cachedCerts() = %#v; want error", i, out)
default:
if !reflect.DeepEqual(out, tt.want) {
t.Errorf("%d: cachedCerts() = %#v; want %#v", i, out, tt.want)
}
if !tt.shouldCache {
continue
}
item, err := memcache.Get(nc, DefaultCertURI)
if err != nil {
t.Errorf("%d: memcache.Get(%q) = %v", i, DefaultCertURI, err)
continue
}
cert := string(item.Value)
if tt.respContent != cert {
t.Errorf("%d: memcache.Get(%q) = %q; want %q",
i, DefaultCertURI, cert, tt.respContent)
}
}
}
}
func TestCurrentBearerTokenUser(t *testing.T) {
var empty = []string{}
const (
// Default values from user_service_stub.py of dev_appserver2.
validScope = "valid.scope"
validClientID = "123456789.apps.googleusercontent.com"
email = "example@example.com"
userID = "0"
authDomain = "gmail.com"
isAdmin = false
)
inst, err := aetest.NewInstance(nil)
if err != nil {
t.Fatalf("failed to create instance: %v", err)
}
defer inst.Close()
tts := []*struct {
scopes []string
clientIDs []string
success bool
}{
{empty, empty, false},
{empty, []string{validClientID}, false},
{[]string{validScope}, empty, false},
{[]string{validScope}, []string{validClientID}, true},
{[]string{"a", validScope, "b"}, []string{"c", validClientID, "d"}, true},
}
for i, tt := range tts {
r, err := inst.NewRequest("GET", "/", nil)
if err != nil {
t.Fatalf("failed to create req: %v", err)
}
c := newContext(r, cachingAuthenticatorFactory)
user, err := CurrentBearerTokenUser(c, tt.scopes, tt.clientIDs)
switch {
case tt.success && (err != nil || user == nil):
t.Errorf("%d: CurrentBearerTokenUser(%v, %v): err=%v, user=%+v; want ok",
i, tt.scopes, tt.clientIDs, err, user)
case !tt.success && err == nil:
t.Errorf("%d: CurrentBearerTokenUser(%v, %v) = %+v; want error",
i, tt.scopes, tt.clientIDs, user)
}
}
r, err := inst.NewRequest("GET", "/", nil)
if err != nil {
t.Fatalf("failed to create req: %v", err)
}
c := newContext(r, cachingAuthenticatorFactory)
scopes := []string{validScope}
clientIDs := []string{validClientID}
user, err := CurrentBearerTokenUser(c, scopes, clientIDs)
if err != nil {
t.Fatalf("CurrentBearerTokenUser(%v, %v) = %v", scopes, clientIDs, err)
}
if user.ID != userID {
t.Fatalf("CurrentBearerTokenUser(%v, %v) = %v; want ID=%v",
scopes, clientIDs, user, userID)
}
if user.Email != email {
t.Fatalf("CurrentBearerTokenUser(%v, %v) = %v; want email=%v",
scopes, clientIDs, user, email)
}
if user.AuthDomain != authDomain {
t.Fatalf("CurrentBearerTokenUser(%v, %v) = %v; want authDomain=%v",
scopes, clientIDs, user, authDomain)
}
if user.Admin != isAdmin {
t.Fatalf("CurrentBearerTokenUser(%v, %v) = %v; want isAdmin=%v",
scopes, clientIDs, user, isAdmin)
}
}
func TestCurrentUser(t *testing.T) {
const (
// Default values from user_service_stub.py of dev_appserver2.
clientID = "123456789.apps.googleusercontent.com"
bearerEmail = "example@example.com"
validScope = "valid.scope"
)
inst, err := aetest.NewInstance(nil)
if err != nil {
t.Fatalf("failed to create instance: %v", err)
}
defer inst.Close()
req, err := inst.NewRequest("GET", "/", nil)
nc, err := appengine.Namespace(appengine.NewContext(req), certNamespace)
if err != nil {
t.Fatal(err)
}
// googCerts are provided in jwt_test.go
item := &memcache.Item{Key: DefaultCertURI, Value: []byte(googCerts)}
if err := memcache.Set(nc, item); err != nil {
t.Fatal(err)
}
origCurrentUTC := currentUTC
defer func() { currentUTC = origCurrentUTC }()
currentUTC = func() time.Time {
return jwtValidTokenTime
}
jwtStr, jwt := jwtValidTokenString, jwtValidTokenObject
tts := []struct {
token string
scopes, audiences, clientIDs []string
wantEmail string
}{
// success
{jwtStr, []string{EmailScope}, []string{jwt.Audience}, []string{jwt.ClientID}, jwt.Email},
{"ya29.token", []string{EmailScope}, []string{clientID}, []string{clientID}, bearerEmail},
{"ya29.token", []string{EmailScope, validScope}, []string{clientID}, []string{clientID}, bearerEmail},
{"1/token", []string{validScope}, []string{clientID}, []string{clientID}, bearerEmail},
// failure
{jwtStr, []string{EmailScope}, []string{"other-client"}, []string{"other-client"}, ""},
{"some.invalid.jwt", []string{EmailScope}, []string{jwt.Audience}, []string{jwt.ClientID}, ""},
{"", []string{validScope}, []string{clientID}, []string{clientID}, ""},
// The following test is commented for now because default implementation
// of UserServiceStub in dev_appserver2 allows any scope.
// TODO: figure out how to test this.
//{"ya29.invalid", []string{"invalid.scope"}, []string{clientID}, []string{clientID}, ""},
{"doesn't matter", nil, []string{clientID}, []string{clientID}, ""},
{"doesn't matter", []string{EmailScope}, nil, []string{clientID}, ""},
{"doesn't matter", []string{EmailScope}, []string{clientID}, nil, ""},
}
for i, tt := range tts {
r, err := inst.NewRequest("GET", "/", nil)
c := newContext(r, cachingAuthenticatorFactory)
if tt.token != "" {
r.Header.Set("authorization", "oauth "+tt.token)
}
user, err := CurrentUser(c, tt.scopes, tt.audiences, tt.clientIDs)
switch {
case tt.wantEmail == "" && err == nil:
t.Errorf("%d: CurrentUser(%v, %v, %v) = %v; want error",
i, tt.scopes, tt.audiences, tt.clientIDs, user)
case tt.wantEmail != "" && user == nil:
t.Errorf("%d: CurrentUser(%v, %v, %v) = %v; want email = %q",
i, tt.scopes, tt.audiences, tt.clientIDs, err, tt.wantEmail)
case tt.wantEmail != "" && tt.wantEmail != user.Email:
t.Errorf("%d: CurrentUser(%v, %v, %v) = %v; want email = %q",
i, tt.scopes, tt.audiences, tt.clientIDs, user, tt.wantEmail)
}
}
}
func newContext(r *http.Request, factory func() Authenticator) context.Context {
defer func(old func() Authenticator) {
AuthenticatorFactory = old
}(AuthenticatorFactory)
AuthenticatorFactory = factory
return NewContext(r)
}