blob: de012405022eda4e11c004daf1fe2b1eac84b2d2 [file] [log] [blame]
package endpoints
import (
"bytes"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io/ioutil"
"math/big"
"net/http"
"regexp"
"strconv"
"strings"
"time"
"golang.org/x/net/context"
"google.golang.org/appengine"
"google.golang.org/appengine/log"
"google.golang.org/appengine/memcache"
"google.golang.org/appengine/user"
)
const (
// DefaultCertURI is Google's public URL which points to JWT certs.
DefaultCertURI = ("https://www.googleapis.com/service_accounts/" +
"v1/metadata/raw/federated-signon@system.gserviceaccount.com")
// EmailScope is Google's OAuth 2.0 email scope
EmailScope = "https://www.googleapis.com/auth/userinfo.email"
// TokeninfoURL is Google's OAuth 2.0 access token verification URL
TokeninfoURL = "https://www.googleapis.com/oauth2/v1/tokeninfo"
// APIExplorerClientID is the client ID of API explorer.
APIExplorerClientID = "292824132082.apps.googleusercontent.com"
)
var (
allowedAuthSchemesUpper = [2]string{"OAUTH", "BEARER"}
certNamespace = "__verify_jwt"
clockSkewSecs = int64(300) // 5 minutes in seconds
maxTokenLifetimeSecs = int64(86400) // 1 day in seconds
maxAgePattern = regexp.MustCompile(`\s*max-age\s*=\s*(\d+)\s*`)
// This is a variable on purpose: can be stubbed with a different (fake)
// implementation during tests.
//
// endpoints package code should always call jwtParser()
// instead of directly invoking verifySignedJWT().
jwtParser = verifySignedJWT
// currentUTC returns current time in UTC.
// This is a variable on purpose to be able to stub during testing.
currentUTC = func() time.Time {
return time.Now().UTC()
}
// AuthenticatorFactory creates a new Authenticator.
//
// It is a variable on purpose. You can set it to a stub implementation
// in tests.
AuthenticatorFactory func() Authenticator
)
// An Authenticator can identify the current user.
type Authenticator interface {
// CurrentOAuthClientID returns a clientID associated with the scope.
CurrentOAuthClientID(ctx context.Context, scope string) (string, error)
// CurrentOAuthUser returns a user of this request for the given scope.
// It caches OAuth info at the first call for future invocations.
//
// Returns an error if data for this scope is not available.
CurrentOAuthUser(ctx context.Context, scope string) (*user.User, error)
}
// contextKey is used to store values on a context.
type contextKey int
// Context value keys.
const (
invalidKey contextKey = iota
requestKey
authenticatorKey
)
// HTTPRequest returns the request associated with a context.
func HTTPRequest(c context.Context) *http.Request {
r, _ := c.Value(requestKey).(*http.Request)
return r
}
// authenticator returns the Authenticator associated with a
// context, or nil if there is not one.
func authenticator(c context.Context) Authenticator {
a, _ := c.Value(authenticatorKey).(Authenticator)
return a
}
// Errors for incorrect contexts.
var (
errNoAuthenticator = errors.New("context has no authenticator (use endpoints.NewContext to create a context)")
errNoRequest = errors.New("no request for context (use endpoints.NewContext to create a context)")
)
// NewContext returns a new context for an in-flight API (HTTP) request.
func NewContext(r *http.Request) context.Context {
c := appengine.NewContext(r)
c = context.WithValue(c, requestKey, r)
c = context.WithValue(c, authenticatorKey, AuthenticatorFactory())
return c
}
// parseToken looks for Authorization header and returns a token.
//
// Returns empty string if req does not contain authorization header
// or its value is not prefixed with allowedAuthSchemesUpper.
func parseToken(r *http.Request) string {
// TODO(dhermes): Allow a struct with access_token and bearer_token
// fields here as well.
pieces := strings.Fields(r.Header.Get("Authorization"))
if len(pieces) != 2 {
return ""
}
authHeaderSchemeUpper := strings.ToUpper(pieces[0])
for _, authScheme := range allowedAuthSchemesUpper {
if authHeaderSchemeUpper == authScheme {
return pieces[1]
}
}
return ""
}
type certInfo struct {
Algorithm string `json:"algorithm"`
Exponent string `json:"exponent"`
KeyID string `json:"keyid"`
Modulus string `json:"modulus"`
}
type certsList struct {
KeyValues []*certInfo `json:"keyvalues"`
}
// maxAge parses Cache-Control header value and extracts max-age (in seconds)
func maxAge(s string) int {
match := maxAgePattern.FindStringSubmatch(s)
if len(match) != 2 {
return 0
}
if maxAge, err := strconv.Atoi(match[1]); err == nil {
return maxAge
}
return 0
}
// certExpirationTime computes a cert freshness based on Cache-Control
// and Age headers of h.
//
// Returns 0 if one of the required headers is not present or cert lifetime
// is expired.
func certExpirationTime(h http.Header) time.Duration {
// http://www.w3.org/Protocols/rfc2616/rfc2616-sec4.html#sec4.2 indicates only
// a comma-separated header is valid, so it should be fine to split this on
// commas.
var max int
for _, entry := range strings.Split(h.Get("Cache-Control"), ",") {
max = maxAge(entry)
if max > 0 {
break
}
}
if max <= 0 {
return 0
}
age, err := strconv.Atoi(h.Get("Age"))
if err != nil {
return 0
}
remainingTime := max - age
if remainingTime <= 0 {
return 0
}
return time.Duration(remainingTime) * time.Second
}
// cachedCerts fetches public certificates info from DefaultCertURI and
// caches it for the duration specified in Age header of a response.
func cachedCerts(c context.Context) (*certsList, error) {
namespacedContext, err := appengine.Namespace(c, certNamespace)
if err != nil {
return nil, err
}
var certs *certsList
_, err = memcache.JSON.Get(namespacedContext, DefaultCertURI, &certs)
if err == nil {
return certs, nil
}
// Cache miss or server error.
// If any error other than cache miss, it's proably not a good time
// to use memcache.
var cacheResults = err == memcache.ErrCacheMiss
if !cacheResults {
log.Debugf(c, "%s", err.Error())
}
log.Debugf(c, "Fetching provider certs from: %s", DefaultCertURI)
resp, err := newHTTPClient(c).Get(DefaultCertURI)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, errors.New("Could not reach Cert URI or bad response.")
}
certBytes, err := ioutil.ReadAll(resp.Body)
if err != nil {
return nil, err
}
err = json.Unmarshal(certBytes, &certs)
if err != nil {
return nil, err
}
if cacheResults {
expiration := certExpirationTime(resp.Header)
if expiration > 0 {
item := &memcache.Item{
Key: DefaultCertURI,
Value: certBytes,
Expiration: expiration,
}
err = memcache.Set(namespacedContext, item)
if err != nil {
log.Errorf(c, "Error adding Certs to memcache: %v", err)
}
}
}
return certs, nil
}
type signedJWTHeader struct {
Algorithm string `json:"alg"`
}
type signedJWT struct {
Audience string `json:"aud"`
ClientID string `json:"azp"`
Subject string `json:"sub"`
Email string `json:"email"`
Expires int64 `json:"exp"`
IssuedAt int64 `json:"iat"`
Issuer string `json:"iss"`
}
// addBase64Pad pads s to be a valid base64-encoded string.
func addBase64Pad(s string) string {
switch len(s) % 4 {
case 2:
s += "=="
case 3:
s += "="
}
return s
}
// base64ToBig converts base64-encoded string to a big int.
// Returns error if the encoding is invalid.
func base64ToBig(s string) (*big.Int, error) {
b, err := base64.StdEncoding.DecodeString(addBase64Pad(s))
if err != nil {
return nil, err
}
z := big.NewInt(0)
z.SetBytes(b)
return z, nil
}
// zeroPad prepends 0s to b so that length of the returned slice is size.
func zeroPad(b []byte, size int) []byte {
padded := make([]byte, size-len(b), size)
return append(padded, b...)
}
// contains returns true if value is one of the items of strList.
func contains(strList []string, value string) bool {
for _, choice := range strList {
if choice == value {
return true
}
}
return false
}
// verifySignedJWT decodes and verifies JWT token string.
//
// Verification is based on
// - a certificate exponent and modulus
// - expiration and issue timestamps ("exp" and "iat" fields)
//
// This method expects JWT token string to be in the standard format, e.g. as
// read from Authorization request header: "<header>.<payload>.<signature>",
// where all segments are encoded with URL-base64.
//
// The caller is responsible for performing further token verification.
// (Issuer, Audience, ClientID, etc.)
//
// NOTE: do not call this function directly, use jwtParser() instead.
func verifySignedJWT(c context.Context, jwt string, now int64) (*signedJWT, error) {
segments := strings.Split(jwt, ".")
if len(segments) != 3 {
return nil, fmt.Errorf("Wrong number of segments in token: %s", jwt)
}
// Check that header (first segment) is valid
headerBytes, err := base64.URLEncoding.DecodeString(addBase64Pad(segments[0]))
if err != nil {
return nil, err
}
var header signedJWTHeader
err = json.Unmarshal(headerBytes, &header)
if err != nil {
return nil, err
}
if header.Algorithm != "RS256" {
return nil, fmt.Errorf("Unexpected encryption algorithm: %s", header.Algorithm)
}
// Check that token (second segment) is valid
tokenBytes, err := base64.URLEncoding.DecodeString(addBase64Pad(segments[1]))
if err != nil {
return nil, err
}
var token signedJWT
err = json.Unmarshal(tokenBytes, &token)
if err != nil {
return nil, err
}
// Get current certs
certs, err := cachedCerts(c)
if err != nil {
return nil, err
}
signatureBytes, err := base64.URLEncoding.DecodeString(addBase64Pad(segments[2]))
if err != nil {
return nil, err
}
signature := big.NewInt(0)
signature.SetBytes(signatureBytes)
signed := []byte(fmt.Sprintf("%s.%s", segments[0], segments[1]))
h := sha256.New()
h.Write(signed)
signatureHash := h.Sum(nil)
if len(signatureHash) < 32 {
signatureHash = zeroPad(signatureHash, 32)
}
z := big.NewInt(0)
verified := false
for _, cert := range certs.KeyValues {
exponent, err := base64ToBig(cert.Exponent)
if err != nil {
return nil, err
}
modulus, err := base64ToBig(cert.Modulus)
if err != nil {
return nil, err
}
signatureHashFromCert := z.Exp(signature, exponent, modulus).Bytes()
// Only consider last 32 bytes
if len(signatureHashFromCert) > 32 {
firstIndex := len(signatureHashFromCert) - 32
signatureHashFromCert = signatureHashFromCert[firstIndex:]
} else if len(signatureHashFromCert) < 32 {
signatureHashFromCert = zeroPad(signatureHashFromCert, 32)
}
verified = bytes.Equal(signatureHash, signatureHashFromCert)
if verified {
break
}
}
if !verified {
return nil, fmt.Errorf("Invalid token signature: %s", jwt)
}
// Check time
if token.IssuedAt == 0 {
return nil, fmt.Errorf("Invalid iat value in token: %s", tokenBytes)
}
earliest := token.IssuedAt - clockSkewSecs
if now < earliest {
return nil, fmt.Errorf("Token used too early, %d < %d: %s", now, earliest, tokenBytes)
}
if token.Expires == 0 {
return nil, fmt.Errorf("Invalid exp value in token: %s", tokenBytes)
} else if token.Expires >= now+maxTokenLifetimeSecs {
return nil, fmt.Errorf("exp value is too far in the future: %s", tokenBytes)
}
latest := token.Expires + clockSkewSecs
if now > latest {
return nil, fmt.Errorf("Token used too late, %d > %d: %s", now, latest, tokenBytes)
}
return &token, nil
}
// verifyParsedToken performs further verification of a parsed JWT token and
// checks for the validity of Issuer, Audience, ClientID and Email fields.
//
// Returns true if token passes verification and can be accepted as indicated
// by audiences and clientIDs args.
func verifyParsedToken(c context.Context, token signedJWT, audiences []string, clientIDs []string) bool {
// Verify the issuer.
if token.Issuer != "accounts.google.com" {
log.Warningf(c, "Issuer was not valid: %s", token.Issuer)
return false
}
// Check audiences.
if token.Audience == "" {
log.Warningf(c, "Invalid aud value in token")
return false
}
if token.ClientID == "" {
log.Warningf(c, "Invalid azp value in token")
return false
}
// This is only needed if Audience and ClientID differ, which (currently) only
// happens on Android. In the case they are equal, we only need the ClientID to
// be in the listed of accepted Client IDs.
if token.ClientID != token.Audience && !contains(audiences, token.Audience) {
log.Warningf(c, "Audience not allowed: %s", token.Audience)
return false
}
// Check allowed client IDs.
if len(clientIDs) == 0 {
log.Warningf(c, "No allowed client IDs specified. ID token cannot be verified.")
return false
} else if !contains(clientIDs, token.ClientID) {
log.Warningf(c, "Client ID is not allowed: %s", token.ClientID)
return false
}
if token.Email == "" {
log.Warningf(c, "Invalid email value in token")
return false
}
return true
}
// currentIDTokenUser returns "appengine/user".User object if provided JWT token
// was successfully decoded and passed all verifications.
func currentIDTokenUser(c context.Context, jwt string, audiences []string, clientIDs []string, now int64) (*user.User, error) {
parsedToken, err := jwtParser(c, jwt, now)
if err != nil {
return nil, err
}
if verifyParsedToken(c, *parsedToken, audiences, clientIDs) {
return &user.User{
ID: parsedToken.Subject,
Email: parsedToken.Email,
ClientID: parsedToken.ClientID,
}, nil
}
return nil, errors.New("No ID token user found.")
}
// CurrentBearerTokenScope compares given scopes and clientIDs with those in c.
//
// Both scopes and clientIDs args must have at least one element.
//
// Returns a single scope (one of provided scopes) if the two conditions are met:
// - it is found in Context c
// - client ID on that scope matches one of clientIDs in the args
func CurrentBearerTokenScope(c context.Context, scopes []string, clientIDs []string) (string, error) {
auth := authenticator(c)
if auth == nil {
return "", errNoAuthenticator
}
for _, scope := range scopes {
currentClientID, err := auth.CurrentOAuthClientID(c, scope)
if err != nil {
continue
}
for _, id := range clientIDs {
if id == currentClientID {
return scope, nil
}
}
// If none of the client IDs matches, return nil
log.Debugf(c, "Couldn't find current client ID %q in %v", currentClientID, clientIDs)
return "", errors.New("Mismatched Client ID")
}
// No client ID found for any of the scopes
return "", errors.New("No valid scope")
}
// CurrentBearerTokenUser returns a user associated with the request which is
// expected to have a Bearer token.
//
// Both scopes and clientIDs must have at least one element.
//
// Returns an error if the client did not make a valid request, or none of
// clientIDs are allowed to make requests, or user did not authorize any of
// the scopes.
func CurrentBearerTokenUser(c context.Context, scopes []string, clientIDs []string) (*user.User, error) {
auth := authenticator(c)
if auth == nil {
return nil, errNoAuthenticator
}
scope, err := CurrentBearerTokenScope(c, scopes, clientIDs)
if err != nil {
return nil, err
}
return auth.CurrentOAuthUser(c, scope)
}
// CurrentUser checks for both JWT and Bearer tokens.
//
// It first tries to decode and verify JWT token (if conditions are met)
// and falls back to Bearer token.
//
// The returned user will have only ID, Email and ClientID fields set.
// User.ID is a Google Account ID, which is different from GAE user ID.
// For more info on User.ID see 'sub' claim description on
// https://developers.google.com/identity/protocols/OpenIDConnect#obtainuserinfo
func CurrentUser(c context.Context, scopes []string, audiences []string, clientIDs []string) (*user.User, error) {
// The user hasn't provided any information to allow us to parse either
// an ID token or a Bearer token.
if len(scopes) == 0 && len(audiences) == 0 && len(clientIDs) == 0 {
return nil, errors.New("no client ID or scope info provided.")
}
r := HTTPRequest(c)
if r == nil {
return nil, errNoRequest
}
token := parseToken(r)
if token == "" {
return nil, errors.New("No token in the current context.")
}
// If the only scope is the email scope, check an ID token. Alternatively,
// we dould check if token starts with "ya29." or "1/" to decide that it
// is a Bearer token. This is what is done in Java.
if len(scopes) == 1 && scopes[0] == EmailScope && len(clientIDs) > 0 {
log.Debugf(c, "Checking for ID token.")
now := currentUTC().Unix()
u, err := currentIDTokenUser(c, token, audiences, clientIDs, now)
// Only return in case of success, else pass along and try
// parsing Bearer token.
if err == nil {
return u, err
}
}
log.Debugf(c, "Checking for Bearer token.")
return CurrentBearerTokenUser(c, scopes, clientIDs)
}
func init() {
if appengine.IsDevAppServer() {
AuthenticatorFactory = tokeninfoAuthenticatorFactory
} else {
AuthenticatorFactory = cachingAuthenticatorFactory
}
}