blob: d707080e5dc1994926e818ec0a8692a3b2828fa8 [file] [log] [blame]
// Copyright 2016 The LUCI Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package auth
import (
"bytes"
"crypto/sha1"
"encoding/base64"
"encoding/gob"
"fmt"
"time"
"golang.org/x/net/context"
"github.com/luci/luci-go/common/clock"
"github.com/luci/luci-go/common/data/caching/lru"
"github.com/luci/luci-go/common/data/rand/mathrand"
"github.com/luci/luci-go/common/logging"
"github.com/luci/luci-go/common/sync/mutexpool"
)
// Cache implements a strongly consistent cache.
type Cache interface {
// Get returns a cached item or (nil, nil) if it's not in the cache.
//
// Any returned error is transient error.
Get(c context.Context, key string) ([]byte, error)
// Set unconditionally overwrites an item in the cache.
//
// If 'exp' is zero, the item will have no expiration time.
//
// Any returned error is transient error.
Set(c context.Context, key string, value []byte, exp time.Duration) error
// WithLocalMutex calls 'f' under a process-local mutex assigned to the 'key'.
//
// It is used to ensure only one goroutine is updating the item matching the
// given key.
WithLocalMutex(c context.Context, key string, f func())
}
// tokenCache knows how to keep tokens in a cache represented by Cache.
//
// It implements probabilistic early expiration to workaround cache stampede
// problem for hot items, see Fetch.
type tokenCache struct {
// Kind defines the token kind. Will be used as part of the cache key.
Kind string
// Version defines format of the data. Will be used as part of the cache key.
//
// If you change a type behind interface{} in Token field, you MUST bump the
// version. It will also "invalidate" all existing cached entries (they will
// just become inaccessible and eventually will be evicted from the cache).
Version int
// ExpRandPercent defines when expiration randomizations kicks in.
//
// For example, value of '10' means when token is within 10% of its
// expiration it may randomly (on each access) be considered already expired.
ExpRandPercent int
// MinAcceptedLifetime defines minimal lifetime of a token.
//
// If a token lives less than given time, Store(...) will refuse to save it.
MinAcceptedLifetime time.Duration
}
// cachedToken is stored in a tokenCache.
type cachedToken struct {
Key string // cache key, must be unique (no other restrictions)
Token interface{} // the actual token, must be Gob-serializable
Created time.Time // when it was created, required, UTC
Expiry time.Time // its expiration time, required, UTC
}
// Store unconditionally puts the token in the cache.
func (tc *tokenCache) Store(c context.Context, cache Cache, tok *cachedToken) error {
switch {
case tok.Created.IsZero() || tok.Created.Location() != time.UTC:
panic(fmt.Errorf("tok.Created is not a valid UTC time - %v", tok.Created))
case tok.Expiry.IsZero() || tok.Expiry.Location() != time.UTC:
panic(fmt.Errorf("tok.Expiry is not a valid UTC time - %v", tok.Expiry))
}
ttl := tok.Expiry.Sub(tok.Created)
if ttl < tc.MinAcceptedLifetime {
return fmt.Errorf("refusing to store a token that expires in %s", ttl)
}
blob, err := tc.marshal(tok)
if err != nil {
return err
}
return cache.Set(c, tc.itemKey(tok.Key), blob, ttl)
}
// Fetch grabs cached token if it hasn't expired yet.
//
// Returns (nil, nil) if no such token, it has expired already or its lifetime
// is shorter than 'minTTL'.
//
// If the cached token is close to expiration cutoff (defined by minTTL), this
// function will randomly return a cache miss. That way if multiple concurrent
// processes all constantly use the same token, only the most unlucky one will
// refresh it.
func (tc *tokenCache) Fetch(c context.Context, cache Cache, key string, minTTL time.Duration) (*cachedToken, error) {
blob, err := cache.Get(c, tc.itemKey(key))
switch {
case err != nil:
return nil, err
case blob == nil:
return nil, nil
}
tok, err := tc.unmarshal(blob)
if err != nil {
return nil, err
}
if tok.Key != key {
return nil, fmt.Errorf("SHA1 collision in the token cache: %q vs %q", key, tok.Key)
}
// Shift time by minTTL to simplify the math below.
now := clock.Now(c).Add(minTTL).UTC()
// Don't use expiration randomization if the item is far from expiration.
exp := tok.Expiry.Sub(tok.Created) * time.Duration(tc.ExpRandPercent) / 100
switch {
case now.After(tok.Expiry):
return nil, nil // already passed requested minTTL
case now.Add(exp).Before(tok.Expiry):
return tok, nil // far from requested minTTL
}
// The expiration is close enough. Do the randomization.
// TODO(vadimsh): The distribution was picked randomly. Maybe use exponential
// one, as some literature suggests?
rnd := time.Duration(mathrand.Int63n(c, int64(exp)))
if now.Add(rnd).After(tok.Expiry) {
return nil, nil
}
return tok, nil
}
// WithLocalMutex calls Cache.WithLocalMutex with correct cache key.
func (tc *tokenCache) WithLocalMutex(c context.Context, cache Cache, key string, f func()) {
cache.WithLocalMutex(c, tc.itemKey(key), f)
}
// itemKey derives the short key to use in the underlying cache.
func (tc *tokenCache) itemKey(key string) string {
digest := sha1.Sum([]byte(key))
asStr := base64.RawURLEncoding.EncodeToString(digest[:])
return fmt.Sprintf("%s/%d/%s", tc.Kind, tc.Version, asStr)
}
// marshal converts cachedToken struct to byte blob.
func (tc *tokenCache) marshal(t *cachedToken) ([]byte, error) {
buf := bytes.Buffer{}
enc := gob.NewEncoder(&buf)
if err := enc.Encode(t); err != nil {
return nil, err
}
return buf.Bytes(), nil
}
// unmarshal is reverse of marshal.
func (tc *tokenCache) unmarshal(blob []byte) (*cachedToken, error) {
dec := gob.NewDecoder(bytes.NewReader(blob))
out := &cachedToken{}
if err := dec.Decode(out); err != nil {
return nil, err
}
return out, nil
}
////////////////////////////////////////////////////////////////////////////////
type fetchOrMintTokenOp struct {
Config *Config
Cache *tokenCache
CacheKey string
MinTTL time.Duration
Mint func(context.Context) (tok *cachedToken, err error, label string)
}
// fetchOrMintToken implements high level logic of using token cache.
//
// It's basis or MintAccessTokenForServiceAccount and MintDelegationToken
// implementations.
func fetchOrMintToken(ctx context.Context, op *fetchOrMintTokenOp) (*cachedToken, error, string) {
// Attempt to grab a cached token before entering the critical section.
switch cached, err := op.Cache.Fetch(ctx, op.Config.Cache, op.CacheKey, op.MinTTL); {
case err != nil:
logging.WithError(err).Errorf(ctx, "Failed to fetch the token from cache")
return nil, err, "ERROR_CACHE"
case cached != nil:
return cached, nil, "SUCCESS_CACHE_HIT"
}
logging.Debugf(ctx, "The requested token is not in the cache")
var res struct {
tok *cachedToken
err error
label string
}
op.Cache.WithLocalMutex(ctx, op.Config.Cache, op.CacheKey, func() {
// Recheck the cache now that we have the lock, maybe someone updated the
// cache while we were waiting.
switch cached, err := op.Cache.Fetch(ctx, op.Config.Cache, op.CacheKey, op.MinTTL); {
case err != nil:
logging.WithError(err).Errorf(ctx, "Failed to fetch the token from cache (under the lock)")
res.err = err
res.label = "ERROR_CACHE"
return
case cached != nil:
logging.Debugf(ctx, "Someone put the token in the cache already")
res.tok = cached
res.label = "SUCCESS_CACHE_HIT"
return
}
// Minting a new token involves RPCs to remote services that should be fast.
// Abort the attempt if it gets stuck for longer than 10 sec, it's unlikely
// it'll succeed. Note that we setup the new context only on slow code path
// (on cache miss), since it involves some overhead we don't want to pay on
// the fast path. We assume memcache RPCs don't get stuck for a long time
// (unlike URL Fetch calls to GAE).
logging.Debugf(ctx, "Minting the new token")
ctx, cancel := clock.WithTimeout(ctx, op.Config.adjustedTimeout(10*time.Second))
defer func() {
if res.err != nil {
logging.WithError(res.err).Errorf(ctx, "Failed to mint new token")
}
cancel()
}()
if res.tok, res.err, res.label = op.Mint(ctx); res.err != nil {
res.tok = nil
if res.label == "" {
res.label = "ERROR_UNSPECIFIED"
}
return
}
res.tok.Key = op.CacheKey
// Cache the token. Ignore errors, it's not big deal, we have the token.
if err := op.Cache.Store(ctx, op.Config.Cache, res.tok); err != nil {
logging.WithError(err).Warningf(ctx, "Failed to store the token in the cache")
}
res.label = "SUCCESS_CACHE_MISS"
})
return res.tok, res.err, res.label
}
////////////////////////////////////////////////////////////////////////////////
type memoryCache struct {
cache *lru.Cache
locks mutexpool.P
}
// MemoryCache creates a new in-memory Cache instance that is built on
// top of an LRU cache of the specified size.
func MemoryCache(size int) Cache {
return &memoryCache{
cache: lru.New(size),
}
}
func (mc *memoryCache) Get(c context.Context, key string) ([]byte, error) {
var item *memoryCacheItem
now := clock.Now(c)
_ = mc.cache.Mutate(key, func(cur interface{}) interface{} {
if cur == nil {
return nil
}
item = cur.(*memoryCacheItem)
if now.After(item.exp) {
// Cache item is too old, so expire it.
item = nil
}
return item
})
if item == nil {
// Cache miss (or expired).
return nil, nil
}
return item.value, nil
}
func (mc *memoryCache) Set(c context.Context, key string, value []byte, exp time.Duration) error {
mc.cache.Put(key, &memoryCacheItem{
value: value,
exp: clock.Now(c).Add(exp),
})
return nil
}
func (mc *memoryCache) WithLocalMutex(c context.Context, key string, f func()) {
mc.locks.WithMutex(key, f)
}
type memoryCacheItem struct {
value []byte
exp time.Time
}