blob: eb95f678e5750e6c77a90e7f1ad899bb23f1228a [file] [log] [blame]
// Copyright 2015 The LUCI Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package authdbimpl
import (
"bytes"
"context"
"crypto/sha256"
"encoding/hex"
"fmt"
"strings"
"time"
ds "go.chromium.org/luci/gae/service/datastore"
"go.chromium.org/luci/common/clock"
"go.chromium.org/luci/common/errors"
"go.chromium.org/luci/common/logging"
"go.chromium.org/luci/common/retry/transient"
"go.chromium.org/luci/server/auth/service"
"go.chromium.org/luci/server/auth/service/protocol"
)
// maxShardSize is a limit on a blob size to store in a single entity.
const maxShardSize = 1020 * 1024 // 1020 KiB
// SnapshotInfo identifies some concrete AuthDB snapshot.
//
// Singleton entity. Serves as a pointer to a blob with corresponding AuthDB
// proto message (stored in separate Snapshot entity).
type SnapshotInfo struct {
AuthServiceURL string `gae:",noindex"`
Rev int64 `gae:",noindex"`
_kind string `gae:"$kind,gaeauth.SnapshotInfo"`
_id int64 `gae:"$id,1"`
}
// GetSnapshotID returns datastore ID of the corresponding Snapshot entity.
func (si *SnapshotInfo) GetSnapshotID() string {
if strings.IndexByte(si.AuthServiceURL, ',') != -1 {
panic(fmt.Errorf("forbidden symbol ',' in URL %q", si.AuthServiceURL))
}
return fmt.Sprintf("v1,%s,%d", si.AuthServiceURL, si.Rev)
}
// Snapshot is serialized deflated AuthDB blob with some minimal metadata.
//
// Root entity. Immutable. Key has the form "v1,<AuthServiceURL>,<Revision>",
// it's generated by SnapshotInfo.GetSnapshotID(). It is globally unique
// version identifier, since it includes URL of an auth service. AuthServiceURL
// should be not very long (~< 250 chars) for this too work.
//
// Currently does not get garbage collected.
type Snapshot struct {
ID string `gae:"$id"`
// AuthDBDeflated is zlib-compressed serialized AuthDB protobuf message.
//
// If it is too big, it is stored in a bunch of SnapshotShard entities
// referenced by ShardIDs field below.
//
// Note: if the old version of this code tries to load a new Snapshot entity
// with ShardIDs field populated, it would abort with an error because old
// code doesn't know about ShardIDs field (it is not in the old Snapshot
// entity struct). This is desirable: the new sharded data structure is not
// (and can't be made) compatible with old code, so it is good that it breaks
// as soon as possible.
AuthDBDeflated []byte `gae:",noindex"`
// ShardIDs is a list of IDs of SnapshotShard entities to fetch.
ShardIDs []string `gae:",noindex"`
CreatedAt time.Time // when it was created on Auth service
FetchedAt time.Time // when it was fetched and put into the datastore
_kind string `gae:"$kind,gaeauth.Snapshot"`
}
// SnapshotShard holds a shard of a deflated AuthDB.
type SnapshotShard struct {
// ID is "<Snapshot ID>:<shard hash>".
ID string `gae:"$id"`
// Shard is the actual data.
Shard []byte `gae:",noindex"`
_kind string `gae:"$kind,gaeauth.SnapshotShard"`
}
// GetLatestSnapshotInfo fetches SnapshotInfo singleton entity.
//
// If no such entity is stored, returns (nil, nil).
func GetLatestSnapshotInfo(ctx context.Context) (*SnapshotInfo, error) {
report := durationReporter(ctx, latestSnapshotInfoDuration)
logging.Debugf(ctx, "Fetching AuthDB snapshot info from the datastore")
ctx = ds.WithoutTransaction(defaultNS(ctx))
info := SnapshotInfo{}
switch err := ds.Get(ctx, &info); {
case err == ds.ErrNoSuchEntity:
report("SUCCESS")
return nil, nil
case err != nil:
report("ERROR_TRANSIENT")
return nil, transient.Tag.Apply(err)
default:
report("SUCCESS")
return &info, nil
}
}
// deleteSnapshotInfo removes SnapshotInfo entity from the datastore.
//
// Used to detach the service from auth_service.
func deleteSnapshotInfo(ctx context.Context) error {
ctx = ds.WithoutTransaction(ctx)
return ds.Delete(ctx, ds.KeyForObj(ctx, &SnapshotInfo{}))
}
// GetAuthDBSnapshot fetches, inflates and deserializes AuthDB snapshot.
func GetAuthDBSnapshot(ctx context.Context, id string) (*protocol.AuthDB, error) {
report := durationReporter(ctx, getSnapshotDuration)
logging.Debugf(ctx, "Fetching AuthDB snapshot from the datastore")
defer logging.Debugf(ctx, "AuthDB snapshot fetched")
blob, code, err := fetchDeflated(ctx, id)
if err != nil {
report(code)
return nil, err
}
db, err := service.InflateAuthDB(blob)
if err != nil {
report("ERROR_INFLATION")
return nil, err
}
report("SUCCESS")
return db, nil
}
// fetchDeflated fetches a deflated AuthDB from datastore, perhaps reassembling
// it from shards.
//
// See also storeDeflated.
func fetchDeflated(ctx context.Context, id string) (blob []byte, code string, err error) {
ctx = ds.WithoutTransaction(defaultNS(ctx))
snap := Snapshot{ID: id}
switch err = ds.Get(ctx, &snap); {
case err == ds.ErrNoSuchEntity:
return nil, "ERROR_NO_SNAPSHOT", err // not transient
case err != nil:
return nil, "ERROR_TRANSIENT", transient.Tag.Apply(err)
}
if len(snap.ShardIDs) != 0 {
logging.Infof(ctx, "Reconstructing from %d shards", len(snap.ShardIDs))
switch snap.AuthDBDeflated, err = unshardAuthDB(ctx, snap.ShardIDs); {
case transient.Tag.In(err):
return nil, "ERROR_SHARDS_TRANSIENT", err
case err != nil:
// We apply the transient tag here to return Internal code
// instead of Unauthenticated code. The Unauthenticated code
// is misleading when we encountered an error in unshardAuthDB.
// https://source.chromium.org/chromium/infra/infra/+/main:go/src/go.chromium.org/luci/server/auth/auth.go;l=272
return nil, "ERROR_SHARDS_MISSING", transient.Tag.Apply(err)
}
}
return snap.AuthDBDeflated, "SUCCESS", nil
}
// ConfigureAuthService makes initial fetch of AuthDB snapshot from the auth
// service and sets up PubSub subscription.
//
// `baseURL` is root URL of currently running service, will be used to derive
// PubSub push endpoint URL.
//
// If `authServiceURL` is blank, disables the fetching.
func ConfigureAuthService(ctx context.Context, baseURL, authServiceURL string) error {
logging.Infof(ctx, "Reconfiguring AuthDB to be fetched from %q", authServiceURL)
ctx = defaultNS(ctx)
// If switching auth services, need to grab URL of a currently configured
// auth service to unsubscribe from its PubSub stream.
prevAuthServiceURL := ""
switch existing, err := GetLatestSnapshotInfo(ctx); {
case err != nil:
return err
case existing != nil:
prevAuthServiceURL = existing.AuthServiceURL
}
// Stopping synchronization completely?
if authServiceURL == "" {
if prevAuthServiceURL != "" {
if err := killPubSub(ctx, prevAuthServiceURL); err != nil {
return err
}
}
return deleteSnapshotInfo(ctx)
}
// Fetch latest AuthDB snapshot and store it in the datastore, thus verifying
// authServiceURL works end-to-end.
srv := getAuthService(ctx, authServiceURL)
latestRev, err := srv.GetLatestSnapshotRevision(ctx)
if err != nil {
return err
}
info := &SnapshotInfo{
AuthServiceURL: authServiceURL,
Rev: latestRev,
}
if err := fetchSnapshot(ctx, info); err != nil {
logging.Errorf(ctx, "Failed to fetch latest snapshot from %s - %s", authServiceURL, err)
return err
}
// Configure PubSub subscription to receive future updates.
if err := setupPubSub(ctx, baseURL, authServiceURL); err != nil {
logging.Errorf(ctx, "Failed to configure pubsub subscription - %s", err)
return err
}
// All is configured. Switch SnapshotInfo entity to point to new snapshot.
// It makes syncAuthDB fetch changes from `authServiceURL`, thus promoting
// `authServiceURL` to the status of main auth service.
if err := ds.Put(ds.WithoutTransaction(ctx), info); err != nil {
return transient.Tag.Apply(err)
}
// Stop getting notifications from previously used auth service.
if prevAuthServiceURL != "" && prevAuthServiceURL != authServiceURL {
return killPubSub(ctx, prevAuthServiceURL)
}
return nil
}
// fetchSnapshot fetches AuthDB snapshot specified by `info` and puts it into
// the datastore.
//
// Idempotent. Doesn't touch SnapshotInfo entity itself, and thus always safe
// to call.
func fetchSnapshot(ctx context.Context, info *SnapshotInfo) error {
srv := getAuthService(ctx, info.AuthServiceURL)
snap, err := srv.GetSnapshot(ctx, info.Rev)
if err != nil {
return err
}
blob, err := service.DeflateAuthDB(snap.AuthDB)
if err != nil {
return err
}
if err := storeDeflated(ctx, info.GetSnapshotID(), blob, snap.Created, maxShardSize); err != nil {
return err
}
logging.Infof(ctx, "Lag: %s", clock.Now(ctx).Sub(snap.Created))
return nil
}
// storeDeflated stores a deflated AuthDB into datastore, perhaps splitting it
// into shards.
//
// See also fetchDeflated.
func storeDeflated(ctx context.Context, id string, blob []byte, created time.Time, maxShardSize int) error {
ctx = ds.WithoutTransaction(defaultNS(ctx))
snapshot := Snapshot{
ID: id,
CreatedAt: created.UTC(),
FetchedAt: clock.Now(ctx).UTC(),
}
// If we are able to store AuthDB inline in the Snapshot, do it. That way
// older versions of this code can still successfully read it. If it doesn't
// fit, there's nothing we can do other than to store it separately in shards.
// The old code will see unrecognized ShardIDs field and will fail.
if len(blob) < maxShardSize {
snapshot.AuthDBDeflated = blob
} else {
var err error
if snapshot.ShardIDs, err = shardAuthDB(ctx, id, blob, maxShardSize); err != nil {
return err
}
logging.Infof(ctx, "Split into %d shards", len(snapshot.ShardIDs))
}
return transient.Tag.Apply(ds.Put(ctx, &snapshot))
}
// syncAuthDB fetches latest AuthDB snapshot from the configured auth service,
// puts it into the datastore and updates SnapshotInfo entity to point to it.
//
// Expects authenticating transport to be in the context. Called when receiving
// PubSub notifications.
//
// Returns SnapshotInfo of the most recent snapshot.
func syncAuthDB(ctx context.Context) (*SnapshotInfo, error) {
report := durationReporter(ctx, syncAuthDBDuration)
// `info` is what we have in the datastore now.
info, err := GetLatestSnapshotInfo(ctx)
if err != nil {
report("ERROR_GET_LATEST_INFO")
return nil, err
}
if info == nil {
report("ERROR_NOT_CONFIGURED")
return nil, errors.New("auth_service URL is not configured")
}
// Grab revision number of the latest snapshot on the server.
srv := getAuthService(ctx, info.AuthServiceURL)
latestRev, err := srv.GetLatestSnapshotRevision(ctx)
if err != nil {
report("ERROR_GET_LATEST_REVISION")
return nil, err
}
// Nothing new?
if info.Rev == latestRev {
logging.Infof(ctx, "AuthDB is up-to-date at revision %d", latestRev)
report("SUCCESS_UP_TO_DATE")
return info, nil
}
// Auth service traveled back in time?
if info.Rev > latestRev {
logging.Errorf(
ctx, "Latest AuthDB revision on server is %d, we have %d. It should not happen",
latestRev, info.Rev)
report("SUCCESS_NEWER_ALREADY")
return info, nil
}
// Fetch the actual snapshot from the server and put it into the datastore.
info.Rev = latestRev
if err = fetchSnapshot(ctx, info); err != nil {
logging.Errorf(ctx, "Failed to fetch snapshot %d from %q - %s", info.Rev, info.AuthServiceURL, err)
report("ERROR_FETCHING")
return nil, err
}
// Move pointer to the latest snapshot only if it is more recent than what is
// already in the datastore.
var latest *SnapshotInfo
err = ds.RunInTransaction(ds.WithoutTransaction(ctx), func(ctx context.Context) error {
latest = &SnapshotInfo{}
switch err := ds.Get(ctx, latest); {
case err == ds.ErrNoSuchEntity:
logging.Warningf(ctx, "No longer need to fetch AuthDB, not configured anymore")
return nil
case err != nil:
return err
case latest.AuthServiceURL != info.AuthServiceURL:
logging.Warningf(
ctx, "No longer need to fetch AuthDB from %q, %q is primary now",
info.AuthServiceURL, latest.AuthServiceURL)
return nil
case latest.Rev >= info.Rev:
logging.Warningf(ctx, "Already have rev %d", info.Rev)
return nil
}
latest = info
return ds.Put(ctx, info)
}, nil)
if err != nil {
report("ERROR_COMMITTING")
return nil, transient.Tag.Apply(err)
}
report("SUCCESS_UPDATED")
return latest, nil
}
// shardAuthDB splits an AuthDB blob into multiple SnapshotShard entities.
func shardAuthDB(ctx context.Context, id string, blob []byte, maxSize int) ([]string, error) {
var ids []string
var shard []byte
for len(blob) != 0 {
shardSize := maxSize
if shardSize > len(blob) {
shardSize = len(blob)
}
shard, blob = blob[:shardSize], blob[shardSize:]
digest := sha256.Sum256(shard)
shardID := fmt.Sprintf("%s:%s", id, hex.EncodeToString(digest[:]))
ids = append(ids, shardID)
// Store shards sequentially to avoid allocating RAM to store full `blob` in
// RPC buffers. There's no requirement for this code to be performant, it
// executes in a background job.
err := ds.Put(ctx, &SnapshotShard{ID: shardID, Shard: shard})
if err != nil {
return nil, transient.Tag.Apply(err)
}
}
return ids, nil
}
// unshardAuthDB fetches SnapshotShard entities and reassembles the AuthDB blob.
func unshardAuthDB(ctx context.Context, shardIDs []string) ([]byte, error) {
shards := make([]SnapshotShard, len(shardIDs))
for idx, id := range shardIDs {
shards[idx].ID = id
}
if err := ds.Get(ctx, shards); err != nil {
if merr, ok := err.(errors.MultiError); ok {
for _, inner := range merr {
if inner == ds.ErrNoSuchEntity {
return nil, err // fatal
}
}
return nil, transient.Tag.Apply(err)
} else {
// Overall RPC error.
return nil, transient.Tag.Apply(err)
}
}
slices := make([][]byte, len(shards))
for idx, shard := range shards {
slices[idx] = shard.Shard
}
return bytes.Join(slices, nil), nil
}