blob: ff3026348dcfcda2ba2ff89d212be7686e08032a [file] [log] [blame]
// Copyright 2018 The Goma Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
package auth
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io/ioutil"
"net/http"
"net/url"
"strconv"
"strings"
"time"
"golang.org/x/oauth2"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"go.chromium.org/goma/server/rpc"
)
// expiryDelta is how earlier a token is considered expired than
// its actual expiration time. same as
// https://github.com/golang/oauth2/blob/master/token.go
const expiryDelta = 10 * time.Second
func expiryTime(t time.Time) time.Time {
return t.Round(0).Add(-expiryDelta)
}
// TokenInfo represents access token's info.
type TokenInfo struct {
// Email is email address associated with the access token.
Email string
// Audience is OAuth2 client_id of the access token.
Audience string
// ExpiresAt is expirary timestamp of the access token.
ExpiresAt time.Time
// Err represents error of access token.
Err error
}
// parseToken parses authorization header and extracts oauth2 token.
func parseToken(auth string) (*oauth2.Token, error) {
if !strings.HasPrefix(auth, "Bearer ") {
return nil, fmt.Errorf("not bearer authorization header: %q", auth)
}
return &oauth2.Token{
AccessToken: strings.TrimSpace(strings.TrimPrefix(auth, "Bearer ")),
TokenType: "Bearer",
}, nil
}
func tokenKey(token *oauth2.Token) string {
return fmt.Sprintf("%s %s", token.TokenType, token.AccessToken)
}
const tokeninfoURL = "https://oauth2.googleapis.com/tokeninfo?access_token="
func trimAccessTokenInErr(err error, accessToken string) error {
if err == nil {
return nil
}
if strings.Contains(err.Error(), accessToken) {
return errors.New(strings.Replace(err.Error(), accessToken, accessToken[:len(accessToken)/3], -1) + "...")
}
return err
}
// Fetch fetches access token info.
func fetch(ctx context.Context, token *oauth2.Token) (*TokenInfo, error) {
req, err := http.NewRequest("GET", tokeninfoURL+url.QueryEscape(token.AccessToken), nil)
if err != nil {
return nil, trimAccessTokenInErr(err, token.AccessToken)
}
var tokenInfo *TokenInfo
err = rpc.Retry{}.Do(ctx, func() error {
// always use 1 min timeout regardless of incoming context.
// retry is controlled by incoming context.
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Minute)
defer cancel()
req = req.WithContext(ctx)
resp, err := http.DefaultClient.Do(req)
if err != nil {
return status.Errorf(codes.Unavailable, "fetch error: %v", trimAccessTokenInErr(err, token.AccessToken))
}
defer resp.Body.Close()
data, err := ioutil.ReadAll(resp.Body)
if err != nil {
return status.Errorf(codes.Internal, "read tokeninfo response: code=%d %v", resp.StatusCode, err)
}
tokenInfo, err = parseResp(data)
if err != nil {
return status.Errorf(codes.Internal, "parse tokeninfo response: code=%d %v", resp.StatusCode, err)
}
return nil
})
return tokenInfo, err
}
func parseResp(data []byte) (*TokenInfo, error) {
var js struct {
Email string `json:"email"`
Audience string `json:"aud"`
ExpiresAt string `json:"exp"`
Error string `json:"error"`
ErrorDescription string `json:"error_description"`
}
d := json.NewDecoder(bytes.NewReader(data))
err := d.Decode(&js)
if err != nil {
return nil, err
}
if js.Email == "" && js.ExpiresAt == "" && js.Error == "" && js.ErrorDescription == "" {
return nil, status.Errorf(codes.Internal, "unexpected token info: %v", js)
}
ti := &TokenInfo{
Email: js.Email,
Audience: js.Audience,
}
if js.Error != "" || js.ErrorDescription != "" {
ti.Err = status.Errorf(codes.PermissionDenied, "%s: %q", js.Error, js.ErrorDescription)
} else {
// parse exp iff no error is set.
// when error exists, maybe no exp.
ts, err := strconv.ParseInt(js.ExpiresAt, 10, 64)
if err != nil {
ti.Err = status.Errorf(codes.Internal, "parse exp: %v", err)
} else {
ti.ExpiresAt = time.Unix(ts, 0)
}
}
return ti, nil
}
func runAt(t time.Time, f func()) {
time.Sleep(t.Sub(time.Now()))
f()
}