blob: aed74fc0e2fe94284224927dbe2f05029871a60e [file] [log] [blame]
// Copyright 2017 The Goma Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
package auth
import (
"context"
"errors"
"fmt"
"reflect"
"strings"
"testing"
"time"
"golang.org/x/oauth2"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/timestamppb"
authpb "go.chromium.org/goma/server/proto/auth"
)
func TestServiceExpire(t *testing.T) {
ctx := context.Background()
token := &oauth2.Token{
AccessToken: "token-value",
TokenType: "Bearer",
}
ch := make(chan chan bool)
now := time.Now()
expiresAt := func() time.Time {
return now.Add(1 * time.Hour)
}
var deadline time.Time
var fetchCount int
s := &Service{
CheckToken: func(ctx context.Context, token *oauth2.Token, tokenInfo *TokenInfo) (string, *oauth2.Token, error) {
return "", token, nil
},
fetchInfo: func(ctx context.Context, t *oauth2.Token) (*TokenInfo, error) {
fetchCount++
if !reflect.DeepEqual(t, token) {
return nil, errors.New("wrong token")
}
return &TokenInfo{
Email: "foo@example.com",
ExpiresAt: expiresAt(),
}, nil
},
runAt: func(t time.Time, f func()) {
deadline = t
rch := <-ch
f()
close(rch)
},
}
t.Logf("initial fetch")
resp, err := s.Auth(ctx, &authpb.AuthReq{
Authorization: "Bearer token-value",
})
if err != nil {
t.Fatalf("Auth failed: %v", err)
}
expires := expiresAt()
expiresProto := timestamppb.New(expires)
want := &authpb.AuthResp{
Email: "foo@example.com",
ExpiresAt: expiresProto,
Quota: -1,
Token: &authpb.Token{
AccessToken: "token-value",
TokenType: "Bearer",
},
}
if !reflect.DeepEqual(resp, want) {
t.Errorf("s.Auth(ctx, req)=%v; want=%v", resp, want)
}
key := tokenKey(token)
s.mu.Lock()
if _, ok := s.tokenCache[key]; !ok {
t.Errorf("%q must exist in tokenCache", key)
}
s.mu.Unlock()
if fetchCount != 1 {
t.Errorf("fetch count=%d; want=1", fetchCount)
}
t.Logf("30 minutes later")
now = now.Add(30 * time.Minute)
resp, err = s.Auth(ctx, &authpb.AuthReq{
Authorization: "Bearer token-value",
})
if err != nil {
t.Fatalf("Auth failed: %v", err)
}
if !reflect.DeepEqual(resp, want) {
t.Errorf("s.Auth(ctx, req)=%v; want=%v", resp, want)
}
if fetchCount != 1 {
t.Errorf("fetch count=%d; want=1", fetchCount)
}
t.Logf("1 hours later")
now = now.Add(30 * time.Minute)
rch := make(chan bool)
// fire runAt
ch <- rch
if !deadline.Equal(expiryTime(expires)) {
t.Errorf("deadline %s != %s (expiresAt %s)", deadline, expiryTime(expires), expires)
}
// wait runAt finish.
<-rch
s.mu.Lock()
if _, ok := s.tokenCache[key]; ok {
t.Errorf("%q must be removed from tokenCache", key)
}
s.mu.Unlock()
expires2 := expiresAt()
if otime, ntime := expires, expires2; otime.Equal(ntime) {
t.Fatalf("expiresAt: %s == %s", otime, ntime)
}
want.ExpiresAt = timestamppb.New(expires2)
resp, err = s.Auth(ctx, &authpb.AuthReq{
Authorization: "Bearer token-value",
})
if err != nil {
t.Fatalf("Auth failed: %v", err)
}
if !reflect.DeepEqual(resp, want) {
t.Errorf("s.Auth(ctx, req)=%v; want=%v", resp, want)
}
s.mu.Lock()
if _, ok := s.tokenCache[key]; !ok {
t.Errorf("%q must exist in tokenCache", key)
}
s.mu.Unlock()
if fetchCount != 2 {
t.Errorf("fetch count=%d; want=2", fetchCount)
}
}
// TODO: revise this when we implement ACL
func TestAuth(t *testing.T) {
t.Logf("0. access succeeds by cache")
email := "example@google.com"
expiresAt := time.Now().Add(10 * time.Second)
ptypeExpiresAt := timestamppb.New(expiresAt)
ti := &TokenInfo{
Email: email,
Audience: "test-audience.apps.googleusercontent.com",
ExpiresAt: expiresAt,
}
checkToken := func(ctx context.Context, token *oauth2.Token, tokenInfo *TokenInfo) (string, *oauth2.Token, error) {
if !reflect.DeepEqual(tokenInfo, ti) {
return "", nil, fmt.Errorf("wrong token info: got=%v; want=%v", tokenInfo, ti)
}
return "", token, nil
}
want := &authpb.AuthResp{
Email: email,
ExpiresAt: ptypeExpiresAt,
Quota: -1,
Token: &authpb.Token{
AccessToken: "test",
TokenType: "Bearer",
},
}
token := &oauth2.Token{
AccessToken: "test",
TokenType: "Bearer",
}
key := tokenKey(token)
s := &Service{
CheckToken: checkToken,
tokenCache: map[string]*tokenCacheEntry{
key: {
TokenInfo: ti,
Token: token,
},
},
}
req := &authpb.AuthReq{
Authorization: "Bearer test",
}
resp, err := s.Auth(context.Background(), req)
if err != nil {
t.Errorf("Auth(%q) error %v; want nil error", req, err)
}
if !reflect.DeepEqual(resp, want) {
t.Errorf("Auth(%q)=%q; want %q", req, resp, want)
}
t.Logf("1. access succeeds by fetching token info.")
s1 := &Service{
CheckToken: checkToken,
fetchInfo: func(ctx context.Context, token *oauth2.Token) (*TokenInfo, error) {
return ti, nil
},
}
resp, err = s1.Auth(context.Background(), req)
if err != nil {
t.Errorf("Auth(%q) error %v; want nil error", req, err)
}
if !reflect.DeepEqual(resp, want) {
t.Errorf("Auth(%q)=%q; want %q", req, resp, want)
}
// and confirm it is stored in cache.
if entry, ok := s.tokenCache[key]; !ok || !reflect.DeepEqual(entry.TokenInfo, ti) {
t.Errorf(`tokenCache[%q].TokenInfo=%q; want %q`, key, entry.TokenInfo, ti)
}
t.Logf("2. failed to fetch token info.")
s2 := &Service{
CheckToken: checkToken,
fetchInfo: func(ctx context.Context, token *oauth2.Token) (*TokenInfo, error) {
return nil, errors.New("fetch failed")
},
}
resp, err = s2.Auth(context.Background(), req)
if err != nil {
t.Errorf("Auth(%q) error %v; want nil error", req, err)
}
if resp.Quota != 0 {
t.Errorf("Auth(%q).Quota=%d; want 0", req, resp.Quota)
}
if resp.ErrorDescription == "" {
t.Errorf("Auth(%q).ErrorDescription=%q; want non empty", req, resp.ErrorDescription)
}
t.Logf("3. non-allowed user access")
s3 := &Service{
CheckToken: func(ctx context.Context, token *oauth2.Token, tokenInfo *TokenInfo) (string, *oauth2.Token, error) {
if !strings.HasSuffix(tokenInfo.Email, "@google.com") {
return "", token, status.Errorf(codes.PermissionDenied, "not allowed")
}
return "", token, nil
},
fetchInfo: func(ctx context.Context, token *oauth2.Token) (*TokenInfo, error) {
return &TokenInfo{
Email: "not_allowed@example.com",
ExpiresAt: expiresAt,
}, nil
},
}
resp, err = s3.Auth(context.Background(), req)
if err != nil {
t.Errorf("Auth(%q) error %v; want nil error", req, err)
}
if resp == nil {
t.Errorf("Auth(%q)=nil, want non-nil resp", req)
} else {
if resp.Quota != 0 {
t.Errorf("Auth(%q).Quota=%d; want 0", req, resp.Quota)
}
if resp.ErrorDescription == "" {
t.Errorf("Auth(%q).ErrorDescription=%q; want non empty", req, resp.ErrorDescription)
}
}
t.Logf("4. failed membership check")
s4 := &Service{
CheckToken: func(ctx context.Context, token *oauth2.Token, tokenInfo *TokenInfo) (string, *oauth2.Token, error) {
if tokenInfo.Email == "misfortune@example.com" {
return "", token, status.Errorf(codes.DeadlineExceeded, "deadline exceeded")
}
return "", token, nil
},
fetchInfo: func(ctx context.Context, token *oauth2.Token) (*TokenInfo, error) {
return &TokenInfo{
Email: "misfortune@example.com",
ExpiresAt: expiresAt,
}, nil
},
}
resp, err = s4.Auth(context.Background(), req)
if err == nil {
t.Errorf("Auth(%q)=%v, %v; want error", req, resp, err)
}
}