blob: b4f913c28956428860cc0ac674882f72199fcd55 [file] [log] [blame]
// Copyright 2024 The LUCI Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package model
// This file contains the functionality used to validate
// the Go implementation of Auth Service (v2) by comparing its generated
// entities to those generated by the Python implemention (v1).
//
// TODO: Remove dry run and comparison code once we have fully rolled
// out Auth Service v2 (b/321019030).
import (
"bytes"
"context"
"fmt"
"os"
"strings"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"golang.org/x/exp/slices"
"google.golang.org/protobuf/proto"
"go.chromium.org/luci/common/errors"
"go.chromium.org/luci/common/logging"
"go.chromium.org/luci/gae/service/datastore"
"go.chromium.org/luci/server/auth/service/protocol"
"go.chromium.org/luci/auth_service/impl/util/zlib"
)
// Names of enviroment variables which control component functionality
// while transitioning from Auth Service v1 (Python) to
// Auth Service v2 (Go).
//
// Each environment variable should be either "true" or "false" (case
// insensitive).
const (
DryRunAPIChangesEnvVar = "DRY_RUN_API_CHANGES"
DryRunCronConfigEnvVar = "DRY_RUN_CRON_CONFIG"
DryRunCronRealmsEnvVar = "DRY_RUN_CRON_REALMS"
DryRunCronStaleAuthEnvVar = "DRY_RUN_CRON_STALE_AUTH"
DryRunTQChangelogEnvVar = "DRY_RUN_TQ_CHANGELOG"
DryRunTQReplicationEnvVar = "DRY_RUN_TQ_REPLICATION"
EnableGroupImportsEnvVar = "ENABLE_GROUP_IMPORTS"
)
// ParseDryRunEnvVar parses the dry run flag from the given environment
// variable, defaulting to true.
func ParseDryRunEnvVar(envVar string) bool {
dryRun := true
if strings.ToLower(os.Getenv(envVar)) == "false" {
dryRun = false
}
return dryRun
}
// ParseEnableEnvVar parses the enable flag from the given environment
// variable, defaulting to false.
func ParseEnableEnvVar(envVar string) bool {
return strings.ToLower(os.Getenv(envVar)) == "true"
}
// CompareV2Entities compares the entities generated by Auth Service v1
// (Python) to those generated by Auth Service v2 (Go) for the latest
// AuthDB revision. The comparison result is logged.
//
// Returns an annotated error if one occurred.
func CompareV2Entities(ctx context.Context) error {
v1Latest, err := GetAuthDBSnapshotLatest(ctx, false)
if err != nil {
return errors.Annotate(err, "error getting latest revision").Err()
}
latestRev := v1Latest.AuthDBRev
if err := compareSnapshots(ctx, latestRev); err != nil {
return errors.Annotate(err, "error comparing snapshots").Err()
}
if err := compareChangelogs(ctx, latestRev); err != nil {
return errors.Annotate(err, "error comparing changelogs").Err()
}
return nil
}
func compareChangelogs(ctx context.Context, authDBRev int64) error {
// Check if both Auth Servicev1 and v2 have processed the changes for
// the given revision.
v1LogRev, err := getAuthDBLogRev(ctx, authDBRev, false)
if err != nil {
return errors.Annotate(err, "error checking v1 changelog was processed").Err()
}
v2LogRev, err := getAuthDBLogRev(ctx, authDBRev, true)
if err != nil {
return errors.Annotate(err, "error checking v2 changelog was processed").Err()
}
if v1LogRev == nil || v2LogRev == nil {
logging.Infof(ctx, "changelogs are not yet processed for Rev %d", authDBRev)
return nil
}
// Get the AuthDBChanges created by Auth Service v1 and v2 for the
// given revision.
v1Changes, err := getChangesForRevision(ctx, authDBRev, false)
if err != nil {
return errors.Annotate(err, "error getting all v1 AuthDBChanges").Err()
}
v2Changes, err := getChangesForRevision(ctx, authDBRev, true)
if err != nil {
return errors.Annotate(err, "error getting all v2 AuthDBChanges").Err()
}
// Compare the changes.
diffs := diffChangelogs(v1Changes, v2Changes)
if len(diffs) > 0 {
logging.Errorf(ctx, "AuthDBChange entities for Rev %d do not match", authDBRev)
// Log the differences for debugging.
for _, diff := range diffs {
logging.Debugf(ctx, diff)
}
} else {
logging.Infof(ctx, "AuthDBChange entities for Rev %d are equivalent", authDBRev)
}
return nil
}
func getChangesForRevision(ctx context.Context, authDBRev int64, dryRun bool) ([]*AuthDBChange, error) {
query := datastore.NewQuery(entityKind("AuthDBChange", dryRun)).Ancestor(constructLogRevisionKey(ctx, authDBRev, dryRun)).Order("-__key__")
var changes []*AuthDBChange
if err := datastore.GetAll(ctx, query, &changes); err != nil {
return nil, err
}
return changes, nil
}
// diffChangelogs returns the "functional" differences between the given
// slices of AuthDBChanges.
//
// Fields ignored include:
// * Kind - there will be a V2 prefix; and
// * Parent - there will be V2 prefixes.
func diffChangelogs(changelogA, changelogB []*AuthDBChange) []string {
ignoredFields := cmpopts.IgnoreFields(AuthDBChange{},
"Kind", "Parent")
diffs := []string{}
changeCountA := len(changelogA)
changeCountB := len(changelogB)
for i := 0; i < changeCountA && i < changeCountB; i++ {
diff := cmp.Diff(changelogA[i], changelogB[i], ignoredFields)
if diff != "" {
diffs = append(diffs, diff)
}
}
// Record the missing changes.
if changeCountA != changeCountB {
diffs = append(diffs, fmt.Sprintf("Total changes count: %d vs %d",
changeCountA, changeCountB))
var longerChangelog []*AuthDBChange
var start, max int
if changeCountA > changeCountB {
longerChangelog = changelogA
start = changeCountB
max = changeCountA
} else {
longerChangelog = changelogB
start = changeCountA
max = changeCountB
}
for i := start; i < max; i++ {
diffs = append(diffs, fmt.Sprintf("missing AuthDBChange: %+v", longerChangelog[i]))
}
}
return diffs
}
func compareSnapshots(ctx context.Context, authDBRev int64) error {
// Get the AuthDBSnapshots created by Auth Service v1 and v2 for the
// given revision.
v1Snapshot, err := GetAuthDBSnapshot(ctx, authDBRev, false, false)
if err != nil {
if errors.Is(err, datastore.ErrNoSuchEntity) {
logging.Infof(ctx, "AuthDBSnapshot for Rev %d not yet created", authDBRev)
return nil
}
return errors.Annotate(err, "failed to get v1 AuthDBSnapshot").Err()
}
v2Snapshot, err := GetAuthDBSnapshot(ctx, authDBRev, false, true)
if err != nil {
if errors.Is(err, datastore.ErrNoSuchEntity) {
logging.Infof(ctx, "V2AuthDBSnapshot for Rev %d not yet created", authDBRev)
return nil
}
return errors.Annotate(err, "failed to get v2 AuthDBSnapshot").Err()
}
// Compare the snapshots.
diffs, err := diffSnapshots(v1Snapshot, v2Snapshot)
if err != nil {
return errors.Annotate(err, "error comparing snapshots").Err()
}
if len(diffs) > 0 {
logging.Errorf(ctx, "AuthDBSnapshots for Rev %d do not match", authDBRev)
// Log the differences for debugging.
for _, diff := range diffs {
logging.Debugf(ctx, diff)
}
} else {
logging.Infof(ctx, "AuthDBSnapshots for Rev %d are equivalent", authDBRev)
}
return nil
}
// processSnapshot is a helper function to get the
// ReplicationPushRequest from the given AuthDBSnapshot.
func processSnapshot(authDBSnapshot *AuthDBSnapshot) (*protocol.ReplicationPushRequest, error) {
authDBBlob, err := zlib.Decompress(authDBSnapshot.AuthDBDeflated)
if err != nil {
return nil, errors.Annotate(err, "error decompressing AuthDBDeflated").Err()
}
req := &protocol.ReplicationPushRequest{}
if err := proto.Unmarshal(authDBBlob, req); err != nil {
return nil, errors.Annotate(err, "error unmarshalling AuthDB blob").Err()
}
return req, nil
}
// diffSnapshots returns the "functional" differences between the two
// AuthDBSnapshots.
//
// AuthDBSnapshot fields ignored in this comparison include:
// - Kind;
// - ShardIDs;
// - CreatedTS;
// - AuthDBSha256 (expected differences due to timestamps and
// AuthCodeVersion);
// - the AuthCodeVersion within the compressed serialized
// ReplicationPushRequest blob (this is expected to be "1.x.x" for
// the Python version and "2.x.x" for Go).
func diffSnapshots(v1Snapshot, v2Snapshot *AuthDBSnapshot) ([]string, error) {
diffs := []string{}
diffTemplate := "field '%s': '%v' vs '%v'"
if v1Snapshot.ID != v2Snapshot.ID {
diffs = append(diffs, fmt.Sprintf(diffTemplate, "ID",
v1Snapshot.ID, v2Snapshot.ID))
}
// Get the ReplicationPushRequests from the snapshots for comparison.
v1Req, err := processSnapshot(v1Snapshot)
if err != nil {
return diffs, errors.Annotate(err, "error processing v1 snapshot").Err()
}
v2Req, err := processSnapshot(v2Snapshot)
if err != nil {
return diffs, errors.Annotate(err, "error processing v2 snapshot").Err()
}
if v1Req.Revision.AuthDbRev != v2Req.Revision.AuthDbRev {
diffs = append(diffs, fmt.Sprintf(diffTemplate, "Revision.AuthDBRev",
v1Req.Revision.AuthDbRev, v2Req.Revision.AuthDbRev))
}
if v1Req.Revision.PrimaryId != v2Req.Revision.PrimaryId {
diffs = append(diffs, fmt.Sprintf(diffTemplate, "Revision.PrimaryId",
v1Req.Revision.PrimaryId, v2Req.Revision.PrimaryId))
}
// Compare AuthDB protos.
authDBDiffs := diffAuthDBs(v1Req.AuthDb, v2Req.AuthDb)
for _, authDBDiff := range authDBDiffs {
diffs = append(diffs, fmt.Sprintf("AuthDb.%s", authDBDiff))
}
return diffs, nil
}
// diffAuthDBs returns the differences between the given AuthDB protos.
func diffAuthDBs(a, b *protocol.AuthDB) []string {
diffs := []string{}
// Record the fields that are different.
if a.OauthClientId != b.OauthClientId {
diffs = append(diffs, "OauthClientId")
}
if a.OauthClientSecret != b.OauthClientSecret {
diffs = append(diffs, "OauthClientSecret")
}
if a.TokenServerUrl != b.TokenServerUrl {
diffs = append(diffs, "TokenServerUrl")
}
if !bytes.Equal(a.SecurityConfig, b.SecurityConfig) {
diffs = append(diffs, "SecurityConfig")
}
if !proto.Equal(a.Realms, b.Realms) {
diffs = append(diffs, "Realms")
}
// Record the slice fields that are different.
clientIDsEqual := slices.EqualFunc(a.OauthAdditionalClientIds, b.OauthAdditionalClientIds,
func(clientA, clientB string) bool {
return clientA == clientB
})
if !clientIDsEqual {
diffs = append(diffs, "OauthAdditionalClientIds")
}
ipAllowlistsEqual := slices.EqualFunc(a.IpWhitelists, b.IpWhitelists,
func(listA, listB *protocol.AuthIPWhitelist) bool {
return proto.Equal(listA, listB)
})
if !ipAllowlistsEqual {
diffs = append(diffs, "IpAllowlists")
}
ipAllowlistAssignmentsEqual := slices.EqualFunc(a.IpWhitelistAssignments, b.IpWhitelistAssignments,
func(assignA, assignB *protocol.AuthIPWhitelistAssignment) bool {
return proto.Equal(assignA, assignB)
})
if !ipAllowlistAssignmentsEqual {
diffs = append(diffs, "IpAllowlistsAssignments")
}
// Record group differences if present.
groupCountA := len(a.Groups)
groupCountB := len(b.Groups)
// Ignore AuthGroup unexported fields.
ignoredFields := cmpopts.IgnoreFields(protocol.AuthGroup{},
"state", "sizeCache", "unknownFields")
for i := 0; i < groupCountA && i < groupCountB; i++ {
if !proto.Equal(a.Groups[i], b.Groups[i]) {
diff := cmp.Diff(a.Groups[i], b.Groups[i], ignoredFields)
diffs = append(diffs, fmt.Sprintf("Groups - index %d: %s", i, diff))
}
}
// Record the names of missing groups.
if groupCountA != groupCountB {
diffs = append(diffs, fmt.Sprintf("Groups - total count: %d vs %d", groupCountA, groupCountB))
var largerGroups []*protocol.AuthGroup
var start, max int
if groupCountA > groupCountB {
largerGroups = a.Groups
start = groupCountB
max = groupCountA
} else {
largerGroups = b.Groups
start = groupCountA
max = groupCountB
}
for i := start; i < max; i++ {
diffs = append(diffs, fmt.Sprintf("Groups - missing '%s'", largerGroups[i].Name))
}
}
return diffs
}