blob: c2b9ad0b8d164ddef0e170a0cbb30d52c8b8d7a7 [file] [edit]
// Copyright 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package cel
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"os"
"reflect"
"runtime"
"strings"
"sync"
"testing"
"time"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protodesc"
"google.golang.org/protobuf/reflect/protoreflect"
"github.com/google/cel-go/checker"
celast "github.com/google/cel-go/common/ast"
"github.com/google/cel-go/common/env"
"github.com/google/cel-go/common/operators"
"github.com/google/cel-go/common/overloads"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/common/types/traits"
"github.com/google/cel-go/interpreter"
"github.com/google/cel-go/parser"
"github.com/google/cel-go/test"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
descpb "google.golang.org/protobuf/types/descriptorpb"
dynamicpb "google.golang.org/protobuf/types/dynamicpb"
durationpb "google.golang.org/protobuf/types/known/durationpb"
timestamppb "google.golang.org/protobuf/types/known/timestamppb"
wrapperspb "google.golang.org/protobuf/types/known/wrapperspb"
proto2pb "github.com/google/cel-go/test/proto2pb"
proto3pb "github.com/google/cel-go/test/proto3pb"
)
func Test_ExampleWithBuiltins(t *testing.T) {
// Variables used within this expression environment.
env, err := NewEnv(
Variable("i", StringType),
Variable("you", StringType),
)
if err != nil {
t.Fatalf("environment creation error: %s\n", err)
}
// Compile the expression.
ast, iss := env.Compile(`"Hello " + you + "! I'm " + i + "."`)
if iss.Err() != nil {
t.Fatal(iss.Err())
}
// Create the program, and evaluate it against some input.
prg, err := env.Program(ast)
if err != nil {
t.Fatalf("program creation error: %s\n", err)
}
// If the Eval() call were provided with cel.EvalOptions(OptTrackState) the details response
// (2nd return) would be non-nil.
out, _, err := prg.Eval(
map[string]any{
"i": "CEL",
"you": "world",
},
)
if err != nil {
t.Fatalf("runtime error: %s\n", err)
}
// Hello world! I'm CEL.
if out.Equal(types.String("Hello world! I'm CEL.")) != types.True {
t.Errorf(`got '%v', wanted "Hello world! I'm CEL."`, out.Value())
}
}
func TestEval(t *testing.T) {
env, err := NewEnv(
Variable("input", ListType(IntType)),
CostEstimatorOptions(
checker.OverloadCostEstimate(overloads.TimestampToYear, estimateTimestampToYear),
),
)
if err != nil {
t.Fatalf("NewEnv() failed: %v", err)
}
tests := []struct {
expr string
in any
}{
{
expr: `input.size() != 0`,
in: map[string]any{"input": []int{1, 2, 3}},
},
}
for i, tst := range tests {
tc := tst
t.Run(fmt.Sprintf("[%d]", i), func(t *testing.T) {
ast, iss := env.Compile(tc.expr)
if iss.Err() != nil {
t.Fatalf("env.Compile(%s) failed: %v", tc.expr, iss.Err())
}
ctx := context.Background()
prgOpts := []ProgramOption{
CostTracking(testRuntimeCostEstimator{}),
CostTrackerOptions(
interpreter.OverloadCostTracker(overloads.TimestampToYear, trackTimestampToYear),
),
EvalOptions(OptOptimize, OptTrackCost),
InterruptCheckFrequency(100),
}
prg, err := env.Program(ast, prgOpts...)
if err != nil {
t.Fatalf("env.Program() failed: %v", err)
}
for k := 0; k < 100; k++ {
t.Run(fmt.Sprintf("[%d]", k), func(t *testing.T) {
t.Parallel()
prg.Eval(tc.in)
evalCtx, cancel := context.WithTimeout(ctx, time.Minute)
defer cancel()
_, _, err := prg.ContextEval(evalCtx, tc.in)
if err != nil {
t.Fatalf("prg.ContextEval() failed: %v", err)
}
})
}
})
}
}
func TestAbbrevsCompiled(t *testing.T) {
// Test whether abbreviations successfully resolve at type-check time (compile time).
env := testEnv(t,
Abbrevs("qualified.identifier.name"),
Variable("qualified.identifier.name.first", StringType),
)
prg := compile(t, env, `"hello "+ name.first`) // abbreviation resolved here.
out, _, err := prg.Eval(
map[string]any{
"qualified.identifier.name.first": "Jim",
},
)
if err != nil {
t.Fatal(err)
}
if out.Value() != "hello Jim" {
t.Errorf("got %v, wanted 'hello Jim'", out)
}
}
func TestAbbrevsParsed(t *testing.T) {
// Test whether abbreviations are resolved properly at evaluation time.
env := testEnv(t,
Abbrevs("qualified.identifier.name"),
)
ast, iss := env.Parse(`"hello " + name.first`)
if iss.Err() != nil {
t.Fatal(iss.Err())
}
prg, err := env.Program(ast) // abbreviation resolved here.
if err != nil {
t.Fatal(err)
}
out, _, err := prg.Eval(
map[string]any{
"qualified.identifier.name": map[string]string{
"first": "Jim",
},
},
)
if err != nil {
t.Fatal(err)
}
if out.Value() != "hello Jim" {
t.Errorf("got %v, wanted 'hello Jim'", out)
}
}
func TestAbbrevsDisambiguation(t *testing.T) {
env := testEnv(t,
Abbrevs("external.Expr"),
Container("google.api.expr.v1alpha1"),
Types(&exprpb.Expr{}),
Variable("test", BoolType),
Variable("external.Expr", StringType),
)
// This expression will return either a string or a protobuf Expr value depending on the value
// of the 'test' argument. The fully qualified type name is used indicate that the protobuf
// typed 'Expr' should be used rather than the abbreviatation for 'external.Expr'.
out, err := interpret(t, env, `test ? dyn(Expr) : google.api.expr.v1alpha1.Expr{id: 1}`,
map[string]any{
"test": true,
"external.Expr": "string expr",
},
)
if err != nil {
t.Fatal(err)
}
if out.Value() != "string expr" {
t.Errorf("got %v, wanted 'string expr'", out)
}
out, err = interpret(t, env, `test ? dyn(Expr) : google.api.expr.v1alpha1.Expr{id: 1}`,
map[string]any{
"test": false,
"external.Expr": "wrong expr",
},
)
if err != nil {
t.Fatal(err)
}
want := &exprpb.Expr{Id: 1}
got, err := out.ConvertToNative(reflect.TypeOf(want))
if err != nil {
t.Fatal(err)
}
if !proto.Equal(got.(*exprpb.Expr), want) {
t.Errorf("got %v, wanted '%v'", out, want)
}
}
func TestConvertToNativeJSONStructure(t *testing.T) {
env, err := NewEnv()
if err != nil {
t.Fatalf("NewEnv() failed: %v", err)
}
ast, issues := env.Compile(`{
"parts": [{"kind": "text"}]
}`)
if issues != nil && issues.Err() != nil {
t.Fatal(issues.Err())
}
prg, err := env.Program(ast)
if err != nil {
t.Fatal(err)
}
result, _, err := prg.Eval(map[string]any{})
if err != nil {
t.Fatal(err)
}
native, err := result.ConvertToNative(types.JSONValueType)
if err != nil {
t.Fatal(err)
}
jsonBytes, err := json.Marshal(native)
if err != nil {
t.Fatalf("json.Marshal failed: %v", err)
}
want := `{"parts":[{"kind":"text"}]}`
if string(jsonBytes) != want {
t.Errorf("json.Marshal() failed, got : %s, wanted ", jsonBytes)
}
}
func TestCustomEnvError(t *testing.T) {
env, err := NewCustomEnv(StdLib(), StdLib())
if err != nil {
t.Fatalf("NewCustomEnv() failed: %v", err)
}
_, iss := compileOrError(t, env, "a.b.c == true")
if iss == nil {
t.Error("got successful compile, expected error for duplicate function declarations.")
}
}
func TestCustomEnv(t *testing.T) {
env, err := NewCustomEnv(Variable("a.b.c", BoolType))
if err != nil {
t.Fatalf("NewCustomEnv(a.b.c:bool) failed: %v", err)
}
t.Run("err", func(t *testing.T) {
_, iss := compileOrError(t, env, "a.b.c == true")
if iss == nil {
t.Error("got successful compile, expected error for missing operator '_==_'")
}
})
t.Run("ok", func(t *testing.T) {
out, err := interpret(t, env, "a.b.c", map[string]any{"a.b.c": true})
if err != nil {
t.Fatal(err)
}
if out != types.True {
t.Errorf("got '%v', wanted 'true'", out.Value())
}
})
}
func TestCrossTypeNumericComparisons(t *testing.T) {
tests := []struct {
name string
expr string
iss string
opt EnvOption
out ref.Val
}{
// Statically typed expressions need to opt in to cross-type numeric comparisons
{
name: "double_less_than_int_err",
expr: `1.0 < 2`,
opt: CrossTypeNumericComparisons(false),
iss: `
ERROR: <input>:1:5: found no matching overload for '_<_' applied to '(double, int)'
| 1.0 < 2
| ....^`,
},
{
name: "double_less_than_int_success",
expr: `1.0 < 2`,
opt: CrossTypeNumericComparisons(true),
out: types.True,
},
// Dynamic data already benefits from cross-type numeric comparisons
{
name: "dyn_less_than_int_success",
expr: `dyn(1.0) < 2`,
opt: CrossTypeNumericComparisons(false),
out: types.True,
},
{
name: "dyn_less_than_int_success",
expr: `dyn(1.0) < 2`,
opt: CrossTypeNumericComparisons(true),
out: types.True,
},
}
for _, tst := range tests {
tc := tst
t.Run(tc.name, func(t *testing.T) {
e := testEnv(t, tc.opt)
ast, iss := e.Compile(tc.expr)
if tc.iss != "" {
if iss.Err() == nil {
t.Fatalf("e.Compile(%v) returned ast, expected error: %v", tc.expr, tc.iss)
}
if !test.Compare(iss.Err().Error(), tc.iss) {
t.Fatalf("e.Compile(%v) returned %v, expected error: %v", tc.expr, iss.Err(), tc.iss)
}
return
}
if iss.Err() != nil {
t.Fatalf("e.Compile(%v) failed: %v", tc.expr, iss.Err())
}
prg, err := e.Program(ast)
if err != nil {
t.Fatalf("e.Program() failed: %v", err)
}
out, _, err := prg.Eval(NoVars())
if err != nil {
t.Fatalf("prg.Eval() errored: %v", err)
}
if out != tc.out {
t.Errorf("program eval got %v, wanted %v", out, tc.out)
}
})
}
}
func TestExtendStdlibFunction(t *testing.T) {
env := testEnv(t,
Function(overloads.Contains,
MemberOverload("bytes_contains_bytes", []*Type{BytesType, BytesType}, BoolType,
BinaryBinding(func(bstr, bsub ref.Val) ref.Val {
return types.Bool(bytes.Contains([]byte(bstr.(types.Bytes)), []byte(bsub.(types.Bytes))))
}))),
)
prg := compile(t, env, `b'string'.contains(b'tri') && 'string'.contains('tri')`)
out, _, err := prg.Eval(NoVars())
if err != nil {
t.Fatalf("contains check errored: %v", err)
}
if out != types.True {
t.Errorf("contains check got %v, wanted true", out)
}
}
func TestSubsetStdLib(t *testing.T) {
env, err := NewCustomEnv(
StdLib(StdLibSubset(
&env.LibrarySubset{
IncludeMacros: []string{"has"},
IncludeFunctions: []*env.Function{
{Name: operators.Equals},
{Name: operators.NotEquals},
{Name: operators.LogicalAnd},
{Name: operators.LogicalOr},
{Name: operators.LogicalNot},
{Name: overloads.Size, Overloads: []*env.Overload{{ID: "list_size"}}},
},
},
)))
if err != nil {
t.Fatalf("StdLib() subsetting failed: %v", err)
}
tests := []struct {
name string
expr string
compiles bool
want ref.Val
}{
{
name: "has macro",
expr: "!has({}.a)",
compiles: true,
want: types.True,
},
{
name: "not equals",
expr: "has({}.a) != true",
compiles: true,
want: types.True,
},
{
name: "logical operators",
expr: "has({}.a) != true && has({'b': 1}.b) == true",
compiles: true,
want: types.True,
},
{
name: "list size - allowed",
expr: "[1, 2, 3].size()",
compiles: true,
want: types.Int(3),
},
{
name: "excluded macro",
expr: "[1, 2, 3].exists(i, i != 0)",
compiles: false,
},
{
name: "string size - not allowed",
expr: "'hello'.size()",
compiles: false,
},
}
for _, tst := range tests {
tc := tst
t.Run(tc.name, func(t *testing.T) {
ast, iss := env.Compile(tc.expr)
if tc.compiles && iss.Err() != nil {
t.Fatalf("env.Compile(%q) failed: %v", tc.expr, iss.Err())
}
if !tc.compiles && iss.Err() != nil {
return
}
prg, err := env.Program(ast)
if err != nil {
t.Fatalf("env.Program() failed: %v", err)
}
out, _, err := prg.Eval(NoVars())
if err != nil {
t.Fatalf("prg.Eval() failed: %s", err)
}
if out.Equal(tc.want) != types.True {
t.Errorf("prg.Eval() got %v, wanted %v", out, tc.want)
}
})
}
}
func TestSubsetStdLibError(t *testing.T) {
_, err := NewCustomEnv(
StdLib(StdLibSubset(
env.NewLibrarySubset().AddIncludedMacros("has").AddExcludedMacros("exists")),
))
if err == nil || !strings.Contains(err.Error(), "invalid subset") {
t.Errorf("StdLib() subsetting got %v, wanted error 'invalid subset'", err)
}
}
func TestSubsetStdLibMerge(t *testing.T) {
_, err := NewCustomEnv(
Function("size", MemberOverload("string_size", []*Type{StringType}, IntType)),
StdLib(StdLibSubset(
env.NewLibrarySubset().AddIncludedFunctions([]*env.Function{
{Name: overloads.Size, Overloads: []*env.Overload{{ID: "string_size"}}},
}...),
)))
if err != nil {
t.Errorf("StdLib() subsetting failed to merge: %v", err)
}
}
func TestSubsetStdLibMergeError(t *testing.T) {
_, err := NewCustomEnv(
Function("size", MemberOverload("string_size", []*Type{StringType}, UintType)),
StdLib(StdLibSubset(
env.NewLibrarySubset().AddIncludedFunctions([]*env.Function{
{Name: overloads.Size, Overloads: []*env.Overload{{ID: "string_size"}}},
}...),
)))
if err == nil || !strings.Contains(err.Error(), "merge failed") {
t.Errorf("StdLib() subsetting got %v, wanted error 'merge failed'", err)
}
}
func TestCustomTypes(t *testing.T) {
reg := types.NewEmptyRegistry()
env := testEnv(t,
CustomTypeAdapter(reg),
CustomTypeProvider(reg),
Container("google.api.expr.v1alpha1"),
Types(
&exprpb.Expr{},
types.BoolType,
types.IntType,
types.StringType),
Variable("expr", ObjectType("google.api.expr.v1alpha1.Expr")),
)
ast, iss := env.Compile(`
expr == Expr{id: 2,
call_expr: Expr.Call{
function: "_==_",
args: [
Expr{id: 1, ident_expr: Expr.Ident{ name: "a" }},
Expr{id: 3, ident_expr: Expr.Ident{ name: "b" }}]
}}`)
if iss.Err() != nil {
t.Fatalf("Compile() failed: %v", iss.Err())
}
if ast.OutputType() != BoolType {
t.Fatalf("got %v, wanted type bool", ast.OutputType())
}
prg, _ := env.Program(ast)
vars := map[string]any{"expr": &exprpb.Expr{
Id: 2,
ExprKind: &exprpb.Expr_CallExpr{
CallExpr: &exprpb.Expr_Call{
Function: "_==_",
Args: []*exprpb.Expr{
{
Id: 1,
ExprKind: &exprpb.Expr_IdentExpr{
IdentExpr: &exprpb.Expr_Ident{Name: "a"},
},
},
{
Id: 3,
ExprKind: &exprpb.Expr_IdentExpr{
IdentExpr: &exprpb.Expr_Ident{Name: "b"},
},
},
},
},
},
}}
out, _, _ := prg.Eval(vars)
if out != types.True {
t.Errorf("got '%v', wanted 'true'", out.Value())
}
}
func TestTypeIsolation(t *testing.T) {
b, err := os.ReadFile("testdata/team.fds")
if err != nil {
t.Fatal("can't read fds file: ", err)
}
var fds descpb.FileDescriptorSet
if err = proto.Unmarshal(b, &fds); err != nil {
t.Fatal("can't unmarshal descriptor data: ", err)
}
env := testEnv(t,
TypeDescs(&fds),
Variable("myteam", ObjectType("cel.testdata.Team")),
)
src := "myteam.members[0].name == 'Cyclops'"
compile(t, env, src)
// Ensure that isolated types don't leak through.
e2 := testEnv(t, Variable("myteam", ObjectType("cel.testdata.Team")))
_, iss := compileOrError(t, e2, src)
if iss == nil {
t.Errorf("wanted compile failure for unknown message.")
}
}
func TestDynamicProto(t *testing.T) {
b, err := os.ReadFile("testdata/team.fds")
if err != nil {
t.Fatalf("os.ReadFile() failed: %v", err)
}
var fds descpb.FileDescriptorSet
if err = proto.Unmarshal(b, &fds); err != nil {
t.Fatalf("proto.Unmarshal() failed: %v", err)
}
files := (&fds).GetFile()
fileCopy := make([]any, len(files))
for i := 0; i < len(files); i++ {
fileCopy[i] = files[i]
}
pbFiles, err := protodesc.NewFiles(&fds)
if err != nil {
t.Fatalf("protodesc.NewFiles() failed: %v", err)
}
e := testEnv(t,
Container("cel"),
// The following is identical to registering the FileDescriptorSet;
// however, it tests a different code path which aggregates individual
// FileDescriptorProto values together.
TypeDescs(fileCopy...),
// Additionally, demonstrate that double registration of files doesn't
// cause any problems.
TypeDescs(pbFiles),
)
src := `testdata.Team{name: 'X-Men', members: [
testdata.Mutant{name: 'Jean Grey', level: 20},
testdata.Mutant{name: 'Cyclops', level: 7},
testdata.Mutant{name: 'Storm', level: 7},
testdata.Mutant{name: 'Wolverine', level: 11}
]}`
ast, iss := e.Compile(src)
if iss.Err() != nil {
t.Fatalf("env.Compile(%s) failed: %v", src, iss.Err())
}
prg, err := e.Program(ast, EvalOptions(OptOptimize))
if err != nil {
t.Fatalf("env.Program() failed: %v", err)
}
out, _, err := prg.Eval(NoVars())
if err != nil {
t.Fatalf("program.Eval() failed: %v", err)
}
obj, ok := out.(traits.Indexer)
if !ok {
t.Fatalf("unable to convert output to object: %v", out)
}
if obj.Get(types.String("name")).Equal(types.String("X-Men")) == types.False {
t.Fatalf("got field 'name' %v, wanted X-Men", obj.Get(types.String("name")))
}
}
func TestDynamicProtoFileDescriptors(t *testing.T) {
b, err := os.ReadFile("testdata/team.fds")
if err != nil {
t.Fatalf("os.ReadFile() failed: %v", err)
}
var fds descpb.FileDescriptorSet
if err = proto.Unmarshal(b, &fds); err != nil {
t.Fatalf("proto.Unmarshal() failed: %v", err)
}
files := (&fds).GetFile()
fileCopy := make([]any, len(files))
for i := 0; i < len(files); i++ {
fileCopy[i] = files[i]
}
pbFiles, err := protodesc.NewFiles(&fds)
if err != nil {
t.Fatalf("protodesc.NewFiles() failed: %v", err)
}
desc, err := pbFiles.FindDescriptorByName("cel.testdata.Mutant")
if err != nil {
t.Fatalf("pbFiles.FindDescriptorByName() could not find Mutant: %v", err)
}
msgDesc, ok := desc.(protoreflect.MessageDescriptor)
if !ok {
t.Fatalf("desc not convertible to MessageDescriptor: %T", desc)
}
wolverine := dynamicpb.NewMessage(msgDesc)
wolverine.ProtoReflect().Set(msgDesc.Fields().ByName("name"), protoreflect.ValueOfString("Wolverine"))
env := testEnv(t,
// The following is identical to registering the FileDescriptorSet;
// however, it tests a different code path which aggregates individual
// FileDescriptorProto values together.
TypeDescs(fileCopy...),
Variable("mutant", ObjectType("cel.testdata.Mutant")),
)
src := `has(mutant.name) && mutant.name == 'Wolverine'`
ast, iss := env.Compile(src)
if iss.Err() != nil {
t.Fatalf("env.Compile(%s) failed: %v", src, iss.Err())
}
prg, err := env.Program(ast, EvalOptions(OptOptimize))
if err != nil {
t.Fatalf("env.Program() failed: %v", err)
}
out, _, err := prg.Eval(map[string]any{
"mutant": wolverine,
})
if err != nil {
t.Fatalf("program.Eval() failed: %v", err)
}
obj, ok := out.(types.Bool)
if !ok {
t.Fatalf("unable to convert output to object: %v", out)
}
if obj != types.True {
t.Errorf("got %v, wanted true", out)
}
}
func TestGlobalVars(t *testing.T) {
env := testEnv(t,
Variable("attrs", MapType(StringType, DynType)),
Variable("default", DynType),
Function("get",
MemberOverload("get_map", []*Type{MapType(StringType, DynType), StringType, DynType}, DynType,
FunctionBinding(func(args ...ref.Val) ref.Val {
attrs, ok := args[0].(traits.Mapper)
if !ok {
return types.NewErr(
"invalid operand of type '%v' to obj.get(key, def)",
args[0].Type())
}
key := args[1]
defVal := args[2]
if attrs.Contains(key) == types.True {
return attrs.Get(key)
}
return defVal
}),
),
),
)
ast, iss := env.Compile(`attrs.get("first", attrs.get("second", default))`)
if iss.Err() != nil {
t.Fatalf("e.Parse() failed: %v", iss.Err())
}
// Global variables can be configured as a ProgramOption and optionally overridden on Eval.
// Add a previous globals map to confirm the order of shadowing and a final empty global
// map to show that globals are not clobbered.
prg, err := env.Program(ast,
Globals(map[string]any{
"default": "shadow me",
}),
Globals(map[string]any{
"default": "third",
}),
Globals(map[string]any{}),
)
if err != nil {
t.Fatalf("e.Program() failed: %v", err)
}
t.Run("bad_attrs", func(t *testing.T) {
out, _, err := prg.Eval(map[string]any{
"attrs": []string{"one", "two"},
})
if err == nil {
t.Errorf("prg.Eval() of incorrect arg type invoked function, wanted error, got %v", out)
}
})
t.Run("global_default", func(t *testing.T) {
vars := map[string]any{
"attrs": map[string]any{},
}
out, _, err := prg.Eval(vars)
if err != nil {
t.Fatalf("prg.Eval() failed: %v", err)
}
if out.Equal(types.String("third")) != types.True {
t.Errorf("got '%v', expected 'third'.", out.Value())
}
})
t.Run("attrs_alt", func(t *testing.T) {
vars := map[string]any{
"attrs": map[string]any{"second": "yep"}}
out, _, err := prg.Eval(vars)
if err != nil {
t.Fatalf("prg.Eval(vars) failed: %v", err)
}
if out.Equal(types.String("yep")) != types.True {
t.Errorf("got '%v', expected 'yep'.", out.Value())
}
})
t.Run("local_default", func(t *testing.T) {
vars := map[string]any{
"attrs": map[string]any{},
"default": "fourth"}
out, _, _ := prg.Eval(vars)
if out.Equal(types.String("fourth")) != types.True {
t.Errorf("got '%v', expected 'fourth'.", out.Value())
}
})
}
func TestMacroSubset(t *testing.T) {
// Only enable the 'has' macro rather than all parser macros.
env := testEnv(t,
ClearMacros(), Macros(HasMacro),
Variable("name", MapType(StringType, StringType)),
)
out, err := interpret(t, env, `has(name.first)`,
map[string]any{
"name": map[string]string{
"first": "Jim",
},
})
if err != nil {
t.Fatal(err)
}
if out != types.True {
t.Errorf("got %v, wanted true", out)
}
out, err = interpret(t, env, `[1, 2].all(i, i > 0)`, NoVars())
if err == nil {
t.Errorf("got %v, wanted err", out)
}
}
func TestCustomMacro(t *testing.T) {
joinMacro := NewReceiverMacro("join", 1,
func(meh MacroExprHelper, iterRange *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *Error) {
delim := args[0]
iterIdent := meh.Ident("__iter__")
accuIdent := meh.AccuIdent()
accuInit := meh.LiteralString("")
condition := meh.LiteralBool(true)
step := meh.GlobalCall(
// __result__.size() > 0 ? __result__ + delim + __iter__ : __iter__
operators.Conditional,
meh.GlobalCall(operators.Greater, meh.ReceiverCall("size", accuIdent), meh.LiteralInt(0)),
meh.GlobalCall(operators.Add, meh.GlobalCall(operators.Add, accuIdent, delim), iterIdent),
iterIdent)
return meh.Fold(
"__iter__",
iterRange,
accuIdent.GetIdentExpr().GetName(),
accuInit,
condition,
step,
accuIdent), nil
})
env := testEnv(t, Macros(joinMacro))
ast, iss := env.Compile(`['hello', 'cel', 'friend'].join(',')`)
if iss.Err() != nil {
t.Fatal(iss.Err())
}
prg, err := env.Program(ast, EvalOptions(OptExhaustiveEval))
if err != nil {
t.Fatalf("program creation error: %s\n", err)
}
out, _, err := prg.Eval(NoVars())
if err != nil {
t.Fatal(err)
}
if out.Equal(types.String("hello,cel,friend")) != types.True {
t.Errorf("got %v, wanted 'hello,cel,friend'", out)
}
}
func TestMacroInterop(t *testing.T) {
existsOneMacro := NewReceiverMacro("exists_one", 2,
func(meh MacroExprHelper, iterRange *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *Error) {
return ExistsOneMacroExpander(meh, iterRange, args)
})
transformMacro := NewReceiverMacro("transform", 2,
func(meh MacroExprHelper, iterRange *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *Error) {
return MapMacroExpander(meh, iterRange, args)
})
filterMacro := NewReceiverMacro("filter", 2,
func(meh MacroExprHelper, iterRange *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *Error) {
return FilterMacroExpander(meh, iterRange, args)
})
pairMacro := NewGlobalMacro("pair", 2,
func(meh MacroExprHelper, iterRange *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *Error) {
return meh.NewMap(meh.NewMapEntry(args[0], args[1], false)), nil
})
getMacro := NewReceiverMacro("get", 2,
func(meh MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *Error) {
return meh.GlobalCall(
operators.Conditional,
meh.PresenceTest(meh.Copy(target), args[0].GetIdentExpr().GetName()),
meh.Select(meh.Copy(target), args[0].GetIdentExpr().GetName()),
meh.Copy(args[1]),
), nil
})
env := testEnv(t, Macros(existsOneMacro, transformMacro, filterMacro, pairMacro, getMacro))
tests := []struct {
expr string
out ref.Val
}{
{
expr: `['tr', 's', 'fri'].filter(i, i.size() > 1).transform(i, i + 'end').exists_one(i, i == 'friend')`,
out: types.True,
},
{
expr: `pair('a', 'b')`,
out: types.DefaultTypeAdapter.NativeToValue(map[string]string{"a": "b"}),
},
{
expr: `{}.get(a, 'default')`,
out: types.String("default"),
},
{
expr: `{'a': 'b'}.get(a, 'default')`,
out: types.String("b"),
},
}
for _, tst := range tests {
ast, iss := env.Compile(tst.expr)
if iss.Err() != nil {
t.Fatal(iss.Err())
}
prg, err := env.Program(ast, EvalOptions(OptExhaustiveEval))
if err != nil {
t.Fatalf("program creation error: %s\n", err)
}
out, _, err := prg.Eval(NoVars())
if err != nil {
t.Fatal(err)
}
if out.Equal(tst.out) != types.True {
t.Errorf("got %v, wanted %v", out, tst.out)
}
}
}
func TestMacroModern(t *testing.T) {
existsOneMacro := ReceiverMacro("exists_one", 2, parser.MakeExistsOne)
transformMacro := ReceiverMacro("transform", 2, parser.MakeMap)
filterMacro := ReceiverMacro("filter", 2, parser.MakeFilter)
pairMacro := GlobalMacro("pair", 2,
func(mef MacroExprFactory, iterRange celast.Expr, args []celast.Expr) (celast.Expr, *Error) {
return mef.NewMap(mef.NewMapEntry(args[0], args[1], false)), nil
})
getMacro := ReceiverMacro("get", 2,
func(mef MacroExprFactory, target celast.Expr, args []celast.Expr) (celast.Expr, *Error) {
return mef.NewCall(
operators.Conditional,
mef.NewPresenceTest(mef.Copy(target), args[0].AsIdent()),
mef.NewSelect(mef.Copy(target), args[0].AsIdent()),
mef.Copy(args[1]),
), nil
})
env := testEnv(t, Macros(existsOneMacro, transformMacro, filterMacro, pairMacro, getMacro))
tests := []struct {
expr string
out ref.Val
}{
{
expr: `['tr', 's', 'fri'].filter(i, i.size() > 1).transform(i, i + 'end').exists_one(i, i == 'friend')`,
out: types.True,
},
{
expr: `pair('a', 'b')`,
out: types.DefaultTypeAdapter.NativeToValue(map[string]string{"a": "b"}),
},
{
expr: `{}.get(a, 'default')`,
out: types.String("default"),
},
{
expr: `{'a': 'b'}.get(a, 'default')`,
out: types.String("b"),
},
}
for _, tst := range tests {
ast, iss := env.Compile(tst.expr)
if iss.Err() != nil {
t.Fatal(iss.Err())
}
prg, err := env.Program(ast, EvalOptions(OptExhaustiveEval))
if err != nil {
t.Fatalf("program creation error: %s\n", err)
}
out, _, err := prg.Eval(NoVars())
if err != nil {
t.Fatal(err)
}
if out.Equal(tst.out) != types.True {
t.Errorf("got %v, wanted %v", out, tst.out)
}
}
}
func TestCustomExistsMacro(t *testing.T) {
env := testEnv(t,
Variable("attr", MapType(StringType, BoolType)),
Macros(
NewGlobalVarArgMacro("kleeneOr",
func(meh MacroExprHelper, unused *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *Error) {
inputs := meh.NewList(args...)
eqOne, err := ExistsMacroExpander(meh, inputs, []*exprpb.Expr{
meh.Ident("__iter__"),
meh.GlobalCall(operators.Equals, meh.Ident("__iter__"), meh.LiteralInt(1)),
})
if err != nil {
return nil, err
}
eqZero, err := ExistsMacroExpander(meh, meh.Copy(inputs), []*exprpb.Expr{
meh.Ident("__iter__"),
meh.GlobalCall(operators.Equals, meh.Ident("__iter__"), meh.LiteralInt(0)),
})
if err != nil {
return nil, err
}
return meh.GlobalCall(
operators.Conditional,
eqOne,
meh.LiteralInt(1),
meh.GlobalCall(
operators.Conditional,
eqZero,
meh.LiteralInt(0),
meh.LiteralInt(-1),
),
), nil
},
),
NewGlobalMacro("kleeneEq", 2,
func(meh MacroExprHelper, unused *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *Error) {
attr := args[0]
value := args[1]
hasAttr, err := HasMacroExpander(meh, nil, []*exprpb.Expr{meh.Copy(attr)})
if err != nil {
return nil, err
}
return meh.GlobalCall(
operators.Conditional,
meh.GlobalCall(operators.LogicalNot, hasAttr),
meh.LiteralInt(0),
meh.GlobalCall(
operators.Conditional,
meh.GlobalCall(operators.Equals, attr, value),
meh.LiteralInt(1),
meh.LiteralInt(-1),
),
), nil
},
),
),
)
prg := compile(t, env, "kleeneOr(kleeneEq(attr.value, true), kleeneOr(0, 1, 1)) == 1")
out, _, err := prg.Eval(map[string]any{"attr": map[string]bool{"value": false}})
if err != nil {
t.Errorf("prg.Eval() got %v, wanted non-error", err)
}
if out != types.True {
t.Errorf("prg.Eval() got %v, wanted true", out)
}
}
func TestAstIsChecked(t *testing.T) {
env := testEnv(t)
ast, iss := env.Compile("true")
if iss.Err() != nil {
t.Fatalf("e.Compile('true') failed: %v", iss.Err())
}
if !ast.IsChecked() {
t.Error("got ast.IsChecked() 'false', wanted 'true'.")
}
ce, err := AstToCheckedExpr(ast)
if err != nil {
t.Fatalf("AstToCheckedExpr(%v) failed: %v", ast, err)
}
ast2 := CheckedExprToAst(ce)
if !ast2.IsChecked() {
t.Error("got ast2.IsChecked() 'false', wanted 'true'")
}
if !proto.Equal(ast.Expr(), ast2.Expr()) {
t.Errorf("AST exprs did not roundtrip properly: ast1: %v, ast2: %v", ast, ast2)
}
}
func TestExhaustiveEval(t *testing.T) {
env := testEnv(t,
Variable("k", StringType),
Variable("v", BoolType),
)
ast, iss := env.Compile(`{k: true}[k] || v != false`)
if iss.Err() != nil {
t.Fatalf("env.Compile() failed: %v", iss.Err())
}
prg, err := env.Program(ast, EvalOptions(OptExhaustiveEval))
if err != nil {
t.Fatalf("env.Program() failed: %s\n", err)
}
out, details, err := prg.Eval(
map[string]any{
"k": "key",
"v": true})
if err != nil {
t.Fatalf("runtime error: %s\n", err)
}
if out != types.True {
t.Errorf("got '%v', expected 'true'", out.Value())
}
// Test to see whether 'v != false' was resolved to a value.
// With short-circuiting it normally wouldn't be.
s := details.State()
lhsVal, found := s.Value(ast.Expr().GetCallExpr().GetArgs()[0].Id)
if !found {
t.Error("got not found, wanted evaluation of left hand side expression.")
return
}
if lhsVal != types.True {
t.Errorf("got '%v', expected 'true'", lhsVal)
}
rhsVal, found := s.Value(ast.Expr().GetCallExpr().GetArgs()[1].Id)
if !found {
t.Error("got not found, wanted evaluation of right hand side expression.")
return
}
if rhsVal != types.True {
t.Errorf("got '%v', expected 'true'", rhsVal)
}
}
func TestContextEval(t *testing.T) {
env := testEnv(t, Variable("items", ListType(IntType)))
ast, iss := env.Compile("items.map(i, i * 2).filter(i, i >= 50).size()")
if iss.Err() != nil {
t.Fatalf("env.Compile(expr) failed: %v", iss.Err())
}
prg, err := env.Program(ast, EvalOptions(OptOptimize|OptTrackState), InterruptCheckFrequency(100))
if err != nil {
t.Fatalf("env.Program() failed: %v", err)
}
ctx := context.TODO()
items := make([]int64, 2000)
for i := int64(0); i < 2000; i++ {
items[i] = i
}
out, _, err := prg.ContextEval(ctx, map[string]any{"items": items})
if err != nil {
t.Fatalf("prg.ContextEval() failed: %v", err)
}
if out != types.Int(1975) {
t.Errorf("prg.ContextEval() got %v, wanted 1975", out)
}
evalCtx, cancel := context.WithTimeout(ctx, time.Microsecond)
defer cancel()
out, _, err = prg.ContextEval(evalCtx, map[string]any{"items": items})
if err == nil {
t.Errorf("Got result %v, wanted timeout error", out)
}
if err != nil && !errors.Is(err, context.DeadlineExceeded) {
t.Errorf("Got %v, wanted context deadline exceeded", err)
}
if err != nil && !strings.Contains(err.Error(), "operation interrupted") {
t.Errorf("Got %v, wanted error containing 'operation interrupted'", err)
}
}
func TestContextEvalUnknowns(t *testing.T) {
env, err := NewEnv(
Variable("groups", ListType(IntType)),
Variable("id", IntType),
)
if err != nil {
t.Fatalf("NewEnv() failed: %v", err)
}
pvars, err := PartialVars(
map[string]any{
"groups": []int{1, 2, 3},
},
AttributePattern("id"),
)
if err != nil {
t.Fatalf("PartialVars() failed: %v", err)
}
ast, iss := env.Compile(`groups.exists(t, t == id)`)
if iss.Err() != nil {
t.Fatalf("env.Compile() failed: %v", iss.Err())
}
prg, err := env.Program(ast, EvalOptions(OptTrackState, OptPartialEval), InterruptCheckFrequency(100))
if err != nil {
t.Fatalf("env.Program() failed: %v", err)
}
out, _, err := prg.Eval(pvars)
if err != nil {
t.Fatalf("prg.Eval() failed: %v", err)
}
ctxOut, _, err := prg.ContextEval(context.Background(), pvars)
if err != nil {
t.Fatalf("prg.ContextEval() failed: %v", err)
}
if !reflect.DeepEqual(out, ctxOut) {
t.Errorf("got %v, wanted %v", out, ctxOut)
}
}
func BenchmarkContextEval(b *testing.B) {
env := testEnv(b,
Variable("items", ListType(IntType)),
)
ast, iss := env.Compile("items.map(i, i * 2).filter(i, i >= 50).size()")
if iss.Err() != nil {
b.Fatalf("env.Compile(expr) failed: %v", iss.Err())
}
prg, err := env.Program(ast, EvalOptions(OptOptimize), InterruptCheckFrequency(200))
if err != nil {
b.Fatalf("env.Program() failed: %v", err)
}
ctx := context.TODO()
items := make([]int64, 100)
for i := int64(0); i < 100; i++ {
items[i] = i
}
for i := 0; i < b.N; i++ {
out, _, err := prg.ContextEval(ctx, map[string]any{"items": items})
if err != nil {
b.Fatalf("prg.ContextEval() failed: %v", err)
}
if out != types.Int(75) {
b.Errorf("prg.ContextEval() got %v, wanted 75", out)
}
}
}
func TestEvalRecover(t *testing.T) {
e := testEnv(t,
Function("panic",
Overload("global_panic", []*Type{}, BoolType,
FunctionBinding(func(args ...ref.Val) ref.Val {
panic("watch me recover")
}),
),
),
)
// Test standard evaluation.
pAst, iss := e.Parse("panic()")
if iss.Err() != nil {
t.Fatalf("e.Parse('panic()') failed: %v", iss.Err())
}
prgm, err := e.Program(pAst)
if err != nil {
t.Fatalf("e.Program(Ast) failed: %v", err)
}
_, _, err = prgm.Eval(map[string]any{})
if err.Error() != "internal error: watch me recover" {
t.Errorf("got '%v', wanted 'internal error: watch me recover'", err)
}
// Test the factory-based evaluation.
prgm, _ = e.Program(pAst, EvalOptions(OptTrackState))
_, _, err = prgm.Eval(map[string]any{})
if err.Error() != "internal error: watch me recover" {
t.Errorf("got '%v', wanted 'internal error: watch me recover'", err)
}
}
func TestResidualAst(t *testing.T) {
env := testEnv(t,
Variable("x", IntType),
Variable("y", IntType),
)
unkVars := env.UnknownVars()
ast, iss := env.Parse(`x < 10 && (y == 0 || 'hello' != 'goodbye')`)
if iss.Err() != nil {
t.Fatalf("env.Parse() failed: %v", iss.Err())
}
prg, err := env.Program(ast,
EvalOptions(OptTrackState, OptPartialEval),
)
if err != nil {
t.Fatalf("env.Program() failed: %v", err)
}
out, det, err := prg.Eval(unkVars)
if !types.IsUnknown(out) {
t.Fatalf("got %v, expected unknown", out)
}
if err != nil {
t.Fatal(err)
}
residual, err := env.ResidualAst(ast, det)
if err != nil {
t.Fatal(err)
}
expr, err := AstToString(residual)
if err != nil {
t.Fatal(err)
}
if expr != "x < 10" {
t.Errorf("got expr: %s, wanted x < 10", expr)
}
}
func TestResidualAstComplex(t *testing.T) {
env := testEnv(t,
Variable("resource.name", StringType),
Variable("request.time", TimestampType),
Variable("request.auth.claims", MapType(StringType, StringType)),
)
unkVars, _ := PartialVars(
map[string]any{
"resource.name": "bucket/my-bucket/objects/private",
"request.auth.claims": map[string]string{
"email_verified": "true",
},
},
AttributePattern("request.auth.claims").QualString("email"),
)
ast, iss := env.Compile(
`resource.name.startsWith("bucket/my-bucket") &&
bool(request.auth.claims.email_verified) == true &&
request.auth.claims.email == "wiley@acme.co"`)
if iss.Err() != nil {
t.Fatalf("env.Compile() failed: %v", iss.Err())
}
prg, err := env.Program(ast,
EvalOptions(OptTrackState, OptPartialEval),
)
if err != nil {
t.Fatalf("env.Program() failed: %v", err)
}
out, det, err := prg.Eval(unkVars)
if !types.IsUnknown(out) {
t.Fatalf("got %v, expected unknown", out)
}
if err != nil {
t.Fatal(err)
}
residual, err := env.ResidualAst(ast, det)
if err != nil {
t.Fatal(err)
}
expr, err := AstToString(residual)
if err != nil {
t.Fatal(err)
}
if expr != `request.auth.claims.email == "wiley@acme.co"` {
t.Errorf("got expr: %s, wanted request.auth.claims.email == \"wiley@acme.co\"", expr)
}
}
func TestResidualAstMacros(t *testing.T) {
tests := []struct {
env *Env
in map[string]any
unks []*interpreter.AttributePattern
expr string
residual string
}{
{
env: testEnv(t,
Variable("x", ListType(IntType)),
Variable("y", IntType),
EnableMacroCallTracking()),
in: map[string]any{"y": 11},
unks: []*interpreter.AttributePattern{AttributePattern("x")},
expr: `x.exists(i, i < 10) && [11, 12, 13].all(i, i in [y, 12, 13])`,
residual: `x.exists(i, i < 10)`,
},
{
env: testEnv(t,
Variable("bar", MapType(StringType, DynType)),
Variable("foo", MapType(StringType, DynType)),
EnableMacroCallTracking()),
in: map[string]any{"foo": map[string]any{"a": "b"}},
unks: []*interpreter.AttributePattern{
AttributePattern("bar").QualString("baz").Wildcard(),
},
expr: `foo.exists(t, t == bar.baz.x)`,
residual: `{"a": "b"}.exists(t, t == bar.baz.x)`,
},
}
for i, tst := range tests {
tc := tst
t.Run(fmt.Sprintf("[%d]", i), func(t *testing.T) {
env := tc.env
unkVars, err := PartialVars(tc.in, tc.unks...)
if err != nil {
t.Fatalf("PartialVars() failed: %v", err)
}
ast, iss := env.Compile(tc.expr)
if iss.Err() != nil {
t.Fatalf("env.Compile() failed: %v", iss.Err())
}
prg, err := env.Program(ast, EvalOptions(OptTrackState, OptPartialEval))
if err != nil {
t.Fatalf("env.Program() failed: %v", err)
}
out, det, err := prg.Eval(unkVars)
if !types.IsUnknown(out) {
t.Fatalf("got %v, expected unknown", out)
}
if err != nil {
t.Fatalf("prg.Eval() failed: %v", err)
}
residual, err := env.ResidualAst(ast, det)
if err != nil {
t.Fatalf("env.ResidualAst() failed: %v", err)
}
expr, err := AstToString(residual)
if err != nil {
t.Fatalf("AstToString() failed: %v", err)
}
if expr != tc.residual {
t.Errorf("got expr: %s, wanted %s", expr, tc.residual)
}
})
}
}
func TestResidualAstNil(t *testing.T) {
env := testEnv(t)
ast, err := env.ResidualAst(nil, nil)
if err == nil || !strings.Contains(err.Error(), "unsupported expr") {
t.Errorf("env.ResidualAst() got (%v, %v) wanted unsupported expr error", ast, err)
}
}
func BenchmarkEvalOptions(b *testing.B) {
env := testEnv(b,
Variable("ai", IntType),
Variable("ar", MapType(StringType, StringType)),
)
ast, _ := env.Compile("ai == 20 || ar['foo'] == 'bar'")
vars := map[string]any{
"ai": 2,
"ar": map[string]string{
"foo": "bar",
},
}
opts := map[string]EvalOption{
"track-state": OptTrackState,
"exhaustive-eval": OptExhaustiveEval,
"optimize": OptOptimize,
}
for k, opt := range opts {
b.Run(k, func(bb *testing.B) {
prg, err := env.Program(ast, EvalOptions(opt))
if err != nil {
b.Fatalf("env.Program() failed: %v", err)
}
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < bb.N; i++ {
_, _, err := prg.Eval(vars)
if err != nil {
b.Fatal(err)
}
}
})
}
}
func TestEnvExtension(t *testing.T) {
env := testEnv(t,
Container("google.api.expr.v1alpha1"),
Types(&exprpb.Expr{}),
Variable("expr", ObjectType("google.api.expr.v1alpha1.Expr")),
Variable("m", MapType(TypeParamType("K"), TypeParamType("V"))),
OptionalTypes(),
)
e2, err := env.Extend(
CustomTypeAdapter(types.DefaultTypeAdapter),
Types(&proto3pb.TestAllTypes{}),
OptionalTypes(),
OptionalTypes(),
OptionalTypes(),
)
if err != nil {
t.Fatalf("env.Extend() failed: %v", err)
}
if env == e2 {
t.Error("got object equality, wanted separate objects")
}
if env.TypeAdapter() == e2.TypeAdapter() {
t.Error("got the same type adapter, wanted isolated instances.")
}
if env.TypeProvider() == e2.TypeProvider() {
t.Error("got the same type provider, wanted isolated instances.")
}
e3, err := e2.Extend(OptionalTypes())
if err != nil {
t.Fatalf("env.Extend() failed: %v", err)
}
if e2.TypeAdapter() != e3.TypeAdapter() {
t.Error("got different type adapters, wanted immutable adapter reference")
}
if e2.TypeProvider() == e3.TypeProvider() {
t.Error("got the same type provider, wanted isolated instances.")
}
}
func TestEnvExtensionIsolation(t *testing.T) {
baseEnv := testEnv(t,
Container("google.expr"),
Variable("age", IntType),
Variable("gender", StringType),
Variable("country", StringType),
)
env1, err := baseEnv.Extend(
Types(&proto2pb.TestAllTypes{}),
Variable("name", StringType),
)
if err != nil {
t.Fatal(err)
}
env2, err := baseEnv.Extend(
Types(&proto3pb.TestAllTypes{}),
Variable("group", StringType),
)
if err != nil {
t.Fatal(err)
}
_, issues := env2.Compile(`size(group) > 10
&& !has(proto3.test.TestAllTypes{}.single_int32)`)
if issues.Err() != nil {
t.Fatal(issues.Err())
}
_, issues = env2.Compile(`size(name) > 10`)
if issues.Err() == nil {
t.Fatal("env2 contains 'name', but should not")
}
_, issues = env2.Compile(`!has(proto2.test.TestAllTypes{}.single_int32)`)
if issues.Err() == nil {
t.Fatal("env2 contains 'proto2.test.TestAllTypes', but should not")
}
_, issues = env1.Compile(`size(name) > 10
&& !has(proto2.test.TestAllTypes{}.single_int32)`)
if issues.Err() != nil {
t.Fatal(issues.Err())
}
_, issues = env1.Compile("size(group) > 10")
if issues.Err() == nil {
t.Fatal("env1 contains 'group', but should not")
}
_, issues = env1.Compile(`!has(proto3.test.TestAllTypes{}.single_int32)`)
if issues.Err() == nil {
t.Fatal("env1 contains 'proto3.test.TestAllTypes', but should not")
}
}
func TestVariadicLogicalOperators(t *testing.T) {
env := testEnv(t, variadicLogicalOperatorASTs())
ast, iss := env.Compile(
`(false || false || false || false || true) &&
(true && true && true && true && false)`)
if iss.Err() != nil {
t.Fatalf("Compile() failed: %v", iss.Err())
}
prg, err := env.Program(ast)
if err != nil {
t.Fatalf("Program(ast) failed: %v", err)
}
out, _, err := prg.Eval(NoVars())
if err != nil {
t.Errorf("Eval() got error %v, wanted false", err)
}
if out != types.False {
t.Errorf("Eval() got %v, wanted false", out)
}
}
func TestParseError(t *testing.T) {
env := testEnv(t)
_, iss := env.Parse("invalid & logical_and")
if iss.Err() == nil {
t.Fatal("e.Parse('invalid & logical_and') did not error")
}
}
func TestParseWithMacroTracking(t *testing.T) {
env := testEnv(t, EnableMacroCallTracking())
ast, iss := env.Parse("has(a.b) && a.b.exists(c, c < 10)")
if iss.Err() != nil {
t.Fatalf("e.Parse() failed: %v", iss.Err())
}
pe, err := AstToParsedExpr(ast)
if err != nil {
t.Fatalf("AstToParsedExpr(%v) failed: %v", ast, err)
}
macroCalls := pe.GetSourceInfo().GetMacroCalls()
if len(macroCalls) != 2 {
t.Errorf("got %d macro calls, wanted 2", len(macroCalls))
}
callsFound := map[string]bool{"has": false, "exists": false}
for _, expr := range macroCalls {
f := expr.GetCallExpr().GetFunction()
_, found := callsFound[f]
if !found {
t.Errorf("Unexpected macro call: %v", expr)
}
callsFound[f] = true
}
callsWanted := map[string]bool{"has": true, "exists": true}
if !reflect.DeepEqual(callsFound, callsWanted) {
t.Errorf("Tracked calls %v, but wanted %v", callsFound, callsWanted)
}
}
func TestParseAndCheckConcurrently(t *testing.T) {
env := testEnv(t,
Container("google.api.expr.v1alpha1"),
Types(&exprpb.Expr{}),
Variable("expr", ObjectType("google.api.expr.v1alpha1.Expr")),
)
parseAndCheck := func(expr string) {
_, iss := env.Compile(expr)
if iss.Err() != nil {
t.Fatalf("e.Compile('%s') failed: %v", expr, iss.Err())
}
}
const concurrency = 10
wgDone := sync.WaitGroup{}
wgDone.Add(concurrency)
for i := 0; i < concurrency; i++ {
go func(i int) {
defer wgDone.Done()
parseAndCheck(fmt.Sprintf("expr.id + %d", i))
}(i)
}
wgDone.Wait()
}
func TestCustomInterpreterDecorator(t *testing.T) {
var lastInstruction interpreter.Interpretable
optimizeArith := func(i interpreter.Interpretable) (interpreter.Interpretable, error) {
lastInstruction = i
// Only optimize the instruction if it is a call.
call, ok := i.(interpreter.InterpretableCall)
if !ok {
return i, nil
}
// Only optimize the math functions when they have constant arguments.
switch call.Function() {
case operators.Add,
operators.Subtract,
operators.Multiply,
operators.Divide:
// These are all binary operators so they should have to arguments
args := call.Args()
_, lhsIsConst := args[0].(interpreter.InterpretableConst)
_, rhsIsConst := args[1].(interpreter.InterpretableConst)
// When the values are constant then the call can be evaluated with
// an empty activation and the value returns as a constant.
if !lhsIsConst || !rhsIsConst {
return i, nil
}
val := call.Eval(interpreter.EmptyActivation())
if types.IsError(val) {
return nil, val.(*types.Err)
}
return interpreter.NewConstValue(call.ID(), val), nil
default:
return i, nil
}
}
env := testEnv(t, Variable("foo", IntType))
ast, iss := env.Compile(`foo == -1 + 2 * 3 / 3`)
if iss.Err() != nil {
t.Fatalf("env.Compile() failed: %v", iss.Err())
}
_, err := env.Program(ast,
EvalOptions(OptPartialEval),
CustomDecorator(optimizeArith))
if err != nil {
t.Fatalf("env.Program() failed: %v", err)
}
call, ok := lastInstruction.(interpreter.InterpretableCall)
if !ok {
t.Errorf("got %v, expected call", lastInstruction)
}
args := call.Args()
lhs := args[0]
lastAttr, ok := lhs.(interpreter.InterpretableAttribute)
if !ok {
t.Errorf("got %v, wanted attribute", lhs)
}
absAttr := lastAttr.Attr().(interpreter.NamespacedAttribute)
varNames := absAttr.CandidateVariableNames()
if len(varNames) != 1 || varNames[0] != "foo" {
t.Errorf("got variables %v, wanted foo", varNames)
}
rhs := args[1]
lastConst, ok := rhs.(interpreter.InterpretableConst)
if !ok {
t.Errorf("got %v, wanted constant", rhs)
}
// This is the last number produced by the optimization.
if lastConst.Value().Equal(types.IntOne) == types.False {
t.Errorf("got %v as the last observed constant, wanted 1", lastConst)
}
}
// TestEstimateCostAndRuntimeCost sanity checks that the cost systems are usable from the program API.
func TestEstimateCostAndRuntimeCost(t *testing.T) {
intList := ListType(IntType)
zeroCost := checker.CostEstimate{}
cases := []struct {
name string
expr string
decls []EnvOption
hints map[string]uint64
want checker.CostEstimate
in any
}{
{
name: "const",
expr: `"Hello World!"`,
want: zeroCost,
in: map[string]any{},
},
{
name: "identity",
expr: `input`,
decls: []EnvOption{Variable("input", intList)},
want: checker.CostEstimate{Min: 1, Max: 1},
in: map[string]any{"input": []int{1, 2}},
},
{
name: "str concat",
expr: `"abcdefg".contains(str1 + str2)`,
decls: []EnvOption{
Variable("str1", StringType),
Variable("str2", StringType),
},
hints: map[string]uint64{"str1": 10, "str2": 10},
want: checker.CostEstimate{Min: 2, Max: 6},
in: map[string]any{"str1": "val1111111", "str2": "val2222222"},
},
}
for _, tst := range cases {
tc := tst
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
if tc.hints == nil {
tc.hints = map[string]uint64{}
}
envOpts := []EnvOption{
CostEstimatorOptions(
checker.OverloadCostEstimate(overloads.TimestampToYear, estimateTimestampToYear),
),
}
envOpts = append(envOpts, tc.decls...)
env := testEnv(t, envOpts...)
ast, iss := env.Compile(tc.expr)
if iss.Err() != nil {
t.Fatalf("env.Compile(%v) failed: %v", tc.expr, iss.Err())
}
est, err := env.EstimateCost(ast, testCostEstimator{hints: tc.hints})
if err != nil {
t.Fatalf("Env.EstimateCost(ast *Ast, estimator checker.CostEstimator) failed to estimate cost: %s\n", err)
}
if est.Min != tc.want.Min || est.Max != tc.want.Max {
t.Fatalf("Env.EstimateCost(ast *Ast, estimator checker.CostEstimator) failed to return the right cost interval. Got [%v, %v], wanted [%v, %v]",
est.Min, est.Max, tc.want.Min, tc.want.Max)
}
checkedAst, iss := env.Check(ast)
if iss.Err() != nil {
t.Fatalf(`Env.Check(ast *Ast) failed to check expression: %v`, iss.Err())
}
// Evaluate expression.
program, err := env.Program(checkedAst,
CostTracking(testRuntimeCostEstimator{}),
CostTrackerOptions(
interpreter.OverloadCostTracker(overloads.TimestampToYear, trackTimestampToYear),
),
)
if err != nil {
t.Fatalf(`Env.Program(ast *Ast, opts ...ProgramOption) failed to construct program: %v`, err)
}
_, details, err := program.Eval(tc.in)
if err != nil {
t.Fatalf(`Program.Eval(vars any) failed to evaluate expression: %v`, err)
}
actualCost := details.ActualCost()
if actualCost == nil {
t.Errorf(`EvalDetails.ActualCost() got nil for "%s" cost, wanted %d`, tc.expr, actualCost)
}
if est.Min > *actualCost || est.Max < *actualCost {
t.Errorf("EvalDetails.ActualCost() failed to return a runtime cost %d is the range of estimate cost [%d, %d]", *actualCost,
est.Min, est.Max)
}
})
}
}
func TestCostLimit(t *testing.T) {
cases := []struct {
name string
expr string
decls []EnvOption
costLimit uint64
in any
err error
}{
{
name: "greater",
expr: `val1 > val2`,
decls: []EnvOption{
Variable("val1", IntType),
Variable("val2", IntType),
},
in: map[string]any{"val1": 1, "val2": 2},
costLimit: 10,
},
{
name: "greater - error",
expr: `val1 > val2`,
decls: []EnvOption{
Variable("val1", IntType),
Variable("val2", IntType),
},
in: map[string]any{"val1": 1, "val2": 2},
costLimit: 0,
err: errors.New("actual cost limit exceeded"),
},
}
for _, tst := range cases {
tc := tst
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
envOpts := []EnvOption{
CostEstimatorOptions(
checker.OverloadCostEstimate(overloads.TimestampToYear, estimateTimestampToYear),
),
}
envOpts = append(envOpts, tc.decls...)
env := testEnv(t, envOpts...)
ast, iss := env.Compile(tc.expr)
if iss.Err() != nil {
t.Fatalf("env.Compile(%v) failed: %v", tc.expr, iss.Err())
}
est, err := env.EstimateCost(ast, testCostEstimator{hints: map[string]uint64{}})
if err != nil {
t.Fatalf("Env.EstimateCost(ast *Ast, estimator checker.CostEstimator) failed to estimate cost: %s\n", err)
}
checkedAst, iss := env.Check(ast)
if iss.Err() != nil {
t.Fatalf(`Env.Check(ast *Ast) failed to check expression: %v`, iss.Err())
}
// Evaluate expression.
program, err := env.Program(checkedAst,
CostTracking(testRuntimeCostEstimator{}),
CostTrackerOptions(
interpreter.OverloadCostTracker(overloads.TimestampToYear, trackTimestampToYear),
),
CostLimit(tc.costLimit),
)
if err != nil {
t.Fatalf(`Env.Program(ast *Ast, opts ...ProgramOption) failed to construct program: %v`, err)
}
_, details, err := program.Eval(tc.in)
if err != nil && tc.err == nil {
t.Fatalf(`Program.Eval(vars any) failed to evaluate expression: %v`, err)
}
actualCost := details.ActualCost()
if actualCost == nil {
t.Errorf(`EvalDetails.ActualCost() got nil for "%s" cost, wanted %d`, tc.expr, actualCost)
}
if err == nil {
if est.Min > *actualCost || est.Max < *actualCost {
t.Errorf("EvalDetails.ActualCost() failed to return a runtime cost %d is the range of estimate cost [%d, %d]", *actualCost,
est.Min, est.Max)
}
} else {
if !strings.Contains(err.Error(), tc.err.Error()) {
t.Fatalf("program.Eval() got error %v, wanted error containing %v", err, tc.err)
}
}
})
}
}
func TestPartialVars(t *testing.T) {
env := testEnv(t,
Variable("x", StringType),
Variable("y", IntType),
)
ast, iss := env.Compile("x == string(y)")
if iss.Err() != nil {
t.Fatalf("env.Compile() failed: %v", iss.Err())
}
prg, err := env.Program(ast, EvalOptions(OptPartialEval))
if err != nil {
t.Fatalf("env.Program() failed: %v", err)
}
tests := []struct {
in map[string]any
unk []*interpreter.AttributePattern
out ref.Val
partialOut ref.Val
}{
{
in: map[string]any{},
unk: []*interpreter.AttributePattern{
interpreter.NewAttributePattern("x"),
interpreter.NewAttributePattern("y"),
},
out: types.NewUnknown(1, types.NewAttributeTrail("x")),
},
{
in: map[string]any{"x": "10"},
unk: []*interpreter.AttributePattern{
interpreter.NewAttributePattern("y"),
},
out: types.NewUnknown(4, types.NewAttributeTrail("y")),
},
{
in: map[string]any{"y": 10},
unk: []*interpreter.AttributePattern{
interpreter.NewAttributePattern("x"),
},
out: types.NewUnknown(1, types.NewAttributeTrail("x")),
},
{
in: map[string]any{"x": "10", "y": 10},
unk: []*interpreter.AttributePattern{},
out: types.True,
},
{
in: map[string]any{"x": "10", "y": 9},
unk: []*interpreter.AttributePattern{},
out: types.False,
},
{
in: map[string]any{"y": 10},
unk: []*interpreter.AttributePattern{},
out: types.NewErr("no such attribute: x"),
partialOut: types.NewUnknown(1, types.NewAttributeTrail("x")),
},
{
in: map[string]any{"x": "10"},
unk: []*interpreter.AttributePattern{},
out: types.NewErr("no such attribute: y"),
partialOut: types.NewUnknown(4, types.NewAttributeTrail("y")),
},
{
in: map[string]any{},
unk: []*interpreter.AttributePattern{},
out: types.NewErr("no such attribute: x"),
partialOut: types.NewUnknown(1, types.NewAttributeTrail("x")),
},
}
for i, tst := range tests {
tc := tst
t.Run(fmt.Sprintf("[%d]", i), func(t *testing.T) {
// Manually configured unknown patterns
vars, err := PartialVars(tc.in, tc.unk...)
if err != nil {
t.Fatalf("PartialVars() failed: %v", err)
}
out, _, err := prg.Eval(vars)
if err != nil {
if types.IsError(out) {
if !out.(*types.Err).Is(err) {
t.Errorf("Eval() got %v, wanted error %v", err, out)
}
}
} else if types.IsUnknown(out) {
if !reflect.DeepEqual(out, tc.out) {
t.Errorf("Eval() got unknown %v, wanted %v", out, tc.out)
}
} else if out.Equal(tc.out) != types.True {
t.Errorf("Eval() got %v, wanted %v", out, tc.out)
}
// Inferred unknown patterns
vars2, err := env.PartialVars(tc.in)
if err != nil {
t.Fatalf("env.PartialVars() failed: %v", err)
}
out2, _, err := prg.Eval(vars2)
if err != nil {
t.Fatalf("prg.Eval() with inferred unknowns failed: %v", err)
}
want := tc.out
if tc.partialOut != nil {
want = tc.partialOut
}
if types.IsUnknown(out2) {
if !reflect.DeepEqual(out2, want) {
t.Errorf("Eval() got unknown %v, wanted %v", out2, want)
}
} else if out2.Equal(want) != types.True {
t.Errorf("Eval() got %v, wanted %v", out2, want)
}
})
}
}
func TestResidualAstAttributeQualifiers(t *testing.T) {
env := testEnv(t,
Variable("x", MapType(StringType, DynType)),
Variable("y", ListType(IntType)),
Variable("u", IntType),
)
ast, iss := env.Parse(`x.abc == u && x["abc"] == u && x[x.string] == u && y[0] == u && y[x.zero] == u && (true ? x : y).abc == u && (false ? y : x).abc == u`)
if iss.Err() != nil {
t.Fatalf("env.Parse() failed: %v", iss.Err())
}
prg, err := env.Program(ast,
EvalOptions(OptTrackState, OptPartialEval),
)
if err != nil {
t.Fatalf("env.Program() failed: %v", err)
}
vars, _ := PartialVars(map[string]any{
"x": map[string]any{
"zero": 0,
"abc": 123,
"string": "abc",
},
"y": []int{123},
}, AttributePattern("u"))
out, det, err := prg.ContextEval(context.TODO(), vars)
if !types.IsUnknown(out) {
t.Fatalf("got %v, expected unknown", out)
}
if err != nil {
t.Fatal(err)
}
residual, err := env.ResidualAst(ast, det)
if err != nil {
t.Fatal(err)
}
expr, err := AstToString(residual)
if err != nil {
t.Fatal(err)
}
const want = "123 == u && 123 == u && 123 == u && 123 == u && 123 == u && 123 == u && 123 == u"
if expr != want {
t.Errorf("got expr: %s, wanted %s", expr, want)
}
}
func TestPartialVarsEnv(t *testing.T) {
env := testEnv(t,
Variable("x", IntType),
Variable("y", IntType),
)
// Use env to make sure internals are all initialized.
ast, iss := env.Compile("x == y")
if iss.Err() != nil {
t.Fatalf("env.Compile() failed: %v", iss.Err())
}
prg, err := env.Program(ast, EvalOptions(OptPartialEval))
if err != nil {
t.Fatalf("env.Program() failed: %v", err)
}
act, err := env.PartialVars(map[string]any{"x": 1, "y": 1})
if err != nil {
t.Fatalf("env.PartialVars failed: %v", err)
}
val, _, err := prg.Eval(act)
if err != nil {
t.Fatalf("Eval failed: %v", err)
}
if val != types.True {
t.Fatalf("want: %v, got: %v", types.True, val)
}
}
func TestPartialVarsExtendedEnv(t *testing.T) {
env := testEnv(t,
Variable("x", IntType),
Variable("y", IntType),
)
env.Compile("x == y")
// Now test that a sub environment is correctly copied.
env2, err := env.Extend(Variable("z", IntType))
if err != nil {
t.Fatalf("env.Extend failed: %v", err)
}
ast, iss := env2.Compile("x == y && y == z")
if iss.Err() != nil {
t.Fatalf("env.Compile() failed: %v", iss.Err())
}
prg, err := env2.Program(ast, EvalOptions(OptPartialEval))
if err != nil {
t.Fatalf("env.Program() failed: %v", err)
}
act, err := env2.PartialVars(map[string]any{"z": 1, "y": 1})
if err != nil {
t.Fatalf("env.PartialVars failed: %v", err)
}
val, _, err := prg.Eval(act)
if err != nil {
t.Fatalf("Eval failed: %v", err)
}
if !types.IsUnknown(val) {
t.Fatalf("Wanted unknown, got %v", val)
}
if !reflect.DeepEqual(val, types.NewUnknown(1, types.NewAttributeTrail("x"))) {
t.Fatalf("Wanted Unknown(x (1)), got: %v", val)
}
}
func TestResidualAstModified(t *testing.T) {
env := testEnv(t,
Variable("x", MapType(StringType, IntType)),
Variable("y", IntType),
)
ast, iss := env.Parse("x == y")
if iss.Err() != nil {
t.Fatalf("env.Parse() failed: %v", iss.Err())
}
prg, err := env.Program(ast,
EvalOptions(OptTrackState, OptPartialEval),
)
if err != nil {
t.Fatalf("env.Program() failed: %v", err)
}
for _, x := range []int{123, 456} {
vars, _ := PartialVars(map[string]any{
"x": x,
}, AttributePattern("y"))
out, det, err := prg.Eval(vars)
if !types.IsUnknown(out) {
t.Fatalf("got %v, expected unknown", out)
}
if err != nil {
t.Fatal(err)
}
residual, err := env.ResidualAst(ast, det)
if err != nil {
t.Fatal(err)
}
orig, err := AstToString(ast)
if err != nil {
t.Fatal(err)
}
if orig != "x == y" {
t.Errorf("parsed ast: got expr: %s, wanted x == y", orig)
}
expr, err := AstToString(residual)
if err != nil {
t.Fatal(err)
}
want := fmt.Sprintf("%d == y", x)
if expr != want {
t.Errorf("residual ast: got expr: %s, wanted %s", expr, want)
}
}
}
func TestContextProto(t *testing.T) {
descriptor := new(proto3pb.TestAllTypes).ProtoReflect().Descriptor()
option := DeclareContextProto(descriptor)
env := testEnv(t, option)
expression := `
single_int64 == 1
&& single_double == 1.0
&& single_bool == true
&& single_string == ''
&& single_nested_message == google.expr.proto3.test.TestAllTypes.NestedMessage{}
&& standalone_enum == google.expr.proto3.test.TestAllTypes.NestedEnum.FOO
&& single_duration == duration('5s')
&& single_timestamp == timestamp(63154820)
&& single_any == null
&& single_uint32_wrapper == null
&& single_uint64_wrapper == 0u
&& repeated_int32 == [1,2]
&& map_string_string == {'': ''}
&& map_int64_nested_type == {0 : google.expr.proto3.test.NestedTestAllTypes{}}`
ast, iss := env.Compile(expression)
if iss.Err() != nil {
t.Fatalf("env.Compile(%s) failed: %s", expression, iss.Err())
}
prg, err := env.Program(ast)
if err != nil {
t.Fatalf("env.Program() failed: %v", err)
}
in := &proto3pb.TestAllTypes{
SingleInt64: 1,
SingleDouble: 1.0,
SingleBool: true,
NestedType: &proto3pb.TestAllTypes_SingleNestedMessage{
SingleNestedMessage: &proto3pb.TestAllTypes_NestedMessage{},
},
StandaloneEnum: proto3pb.TestAllTypes_FOO,
SingleDuration: &durationpb.Duration{Seconds: 5},
SingleTimestamp: &timestamppb.Timestamp{
Seconds: 63154820,
},
SingleUint64Wrapper: wrapperspb.UInt64(0),
RepeatedInt32: []int32{1, 2},
MapStringString: map[string]string{"": ""},
MapInt64NestedType: map[int64]*proto3pb.NestedTestAllTypes{0: {}},
}
vars, err := ContextProtoVars(in)
if err != nil {
t.Fatalf("ContextProtoVars(%v) failed: %v", in, err)
}
out, _, err := prg.Eval(vars)
if err != nil {
t.Fatalf("prg.Eval() failed: %v", err)
}
if out.Equal(types.True) != types.True {
t.Errorf("prg.Eval() got %v, wanted true", out)
}
}
func TestContextProtoJSONFieldNames(t *testing.T) {
descriptor := new(proto3pb.TestAllTypes).ProtoReflect().Descriptor()
env := testEnv(t, JSONFieldNames(true), DeclareContextProto(descriptor))
expression := `
singleInt64 == 1
&& singleDouble == 1.0
&& singleBool == true
&& singleString == ''
&& singleNestedMessage == google.expr.proto3.test.TestAllTypes.NestedMessage{}
&& standaloneEnum == google.expr.proto3.test.TestAllTypes.NestedEnum.FOO
&& singleDuration == duration('5s')
&& singleTimestamp == timestamp(63154820)
&& singleAny == null
&& singleUint32Wrapper == null
&& singleUint64Wrapper == 0u
&& repeatedInt32 == [1,2]
&& mapStringString == {'': ''}
&& mapInt64NestedType == {0 : google.expr.proto3.test.NestedTestAllTypes{}}`
ast, iss := env.Compile(expression)
if iss.Err() != nil {
t.Fatalf("env.Compile(%s) failed: %s", expression, iss.Err())
}
prg, err := env.Program(ast)
if err != nil {
t.Fatalf("env.Program() failed: %v", err)
}
in := &proto3pb.TestAllTypes{
SingleInt64: 1,
SingleDouble: 1.0,
SingleBool: true,
NestedType: &proto3pb.TestAllTypes_SingleNestedMessage{
SingleNestedMessage: &proto3pb.TestAllTypes_NestedMessage{},
},
StandaloneEnum: proto3pb.TestAllTypes_FOO,
SingleDuration: &durationpb.Duration{Seconds: 5},
SingleTimestamp: &timestamppb.Timestamp{
Seconds: 63154820,
},
SingleUint64Wrapper: wrapperspb.UInt64(0),
RepeatedInt32: []int32{1, 2},
MapStringString: map[string]string{"": ""},
MapInt64NestedType: map[int64]*proto3pb.NestedTestAllTypes{0: {}},
}
vars, err := ContextProtoVars(in, types.JSONFieldNames(true))
if err != nil {
t.Fatalf("ContextProtoVars(%v) failed: %v", in, err)
}
out, _, err := prg.Eval(vars)
if err != nil {
t.Fatalf("prg.Eval() failed: %v", err)
}
if out.Equal(types.True) != types.True {
t.Errorf("prg.Eval() got %v, wanted true", out)
}
}
func TestRegexOptimizer(t *testing.T) {
var stringTests = []struct {
expr string
optimizeRegex bool
progErr string
err string
parseOnly bool
}{
{expr: `"123 abc 456".matches('[0-9]*')`},
{expr: `"123 abc 456".matches('[0-9]' + '*')`},
{expr: `"123 abc 456".matches('[0-9]*')`, optimizeRegex: true},
{expr: `"123 abc 456".matches('[0-9]' + '*')`, optimizeRegex: true},
{
// Verify that a regex compilation error for an optimized regex is
// reported at program creation time.
expr: `"123 abc 456".matches(')[0-9]*')`, optimizeRegex: true,
progErr: "error parsing regexp: unexpected ): `)[0-9]*`",
},
{
expr: `"123 abc 456".matches(')[0-9]*')`,
err: "error parsing regexp: unexpected ): `)[0-9]*`",
},
}
env := testEnv(t)
for i, tst := range stringTests {
tc := tst
t.Run(fmt.Sprintf("[%d]", i), func(tt *testing.T) {
var asts []*Ast
pAst, iss := env.Parse(tc.expr)
if iss.Err() != nil {
tt.Fatal(iss.Err())
}
asts = append(asts, pAst)
if !tc.parseOnly {
cAst, iss := env.Check(pAst)
if iss.Err() != nil {
tt.Fatal(iss.Err())
}
asts = append(asts, cAst)
}
for _, ast := range asts {
var opts []ProgramOption
if tc.optimizeRegex {
opts = append(opts, EvalOptions(OptOptimize))
}
prg, progErr := env.Program(ast, opts...)
if tc.progErr != "" {
if progErr == nil {
tt.Fatalf("wanted error %s for expr: %s", tc.progErr, tc.expr)
}
if tc.progErr != progErr.Error() {
tt.Errorf("got error %v, wanted error %s for expr: %s", progErr, tc.progErr, tc.expr)
}
continue
} else if progErr != nil {
tt.Fatal(progErr)
}
out, _, err := prg.Eval(NoVars())
if tc.err != "" {
if err == nil {
tt.Fatalf("got value %v, wanted error %s for expr: %s",
out.Value(), tc.err, tc.expr)
}
if tc.err != err.Error() {
tt.Errorf("got error %v, wanted error %s for expr: %s", err, tc.err, tc.expr)
}
} else if err != nil {
tt.Fatal(err)
} else if out.Value() != true {
tt.Errorf("got %v, wanted true for expr: %s", out.Value(), tc.expr)
}
}
})
}
}
func TestDefaultUTCTimeZoneDisabled(t *testing.T) {
testEnvs := []struct {
name string
env *Env
}{
{"default", testEnv(t, Variable("x", TimestampType))},
{"enabled", testEnv(t, Variable("x", TimestampType), DefaultUTCTimeZone(true))},
{"disabled", testEnv(t, Variable("x", TimestampType), DefaultUTCTimeZone(false))},
}
exprs := []struct {
name string
value string
envOut map[string]ref.Val
}{
{
name: "default-timezone",
value: `
x.getFullYear() == 1970
&& x.getMonth() == 0
&& x.getDayOfYear() == 0
&& x.getDayOfMonth() == 0
&& x.getDate() == 1
&& x.getDayOfWeek() == 4
&& x.getHours() == 2
&& x.getMinutes() == 5
&& x.getSeconds() == 6
&& x.getMilliseconds() == 1`,
envOut: map[string]ref.Val{
"default": types.True,
"enabled": types.True,
"disabled": types.False,
},
},
{
name: "default-local-year",
value: `x.getFullYear()`,
envOut: map[string]ref.Val{
"default": types.Int(1970),
"enabled": types.Int(1970),
"disabled": types.Int(1969),
},
},
{
name: "default-local-day-of-year",
value: `x.getDayOfYear()`,
envOut: map[string]ref.Val{
"default": types.Int(0),
"enabled": types.Int(0),
"disabled": types.Int(364),
},
},
{
name: "default-local-month",
value: `x.getMonth()`,
envOut: map[string]ref.Val{
"default": types.Int(0),
"enabled": types.Int(0),
"disabled": types.Int(11),
},
},
{
name: "default-local-day-of-month",
value: `
x.getDayOfMonth() == 30
&& x.getDate() == 31`,
envOut: map[string]ref.Val{
"default": types.False,
"enabled": types.False,
"disabled": types.True,
},
},
{
name: "default-local-dates",
value: `x.getDayOfWeek()`,
envOut: map[string]ref.Val{
"default": types.Int(4),
"enabled": types.Int(4),
"disabled": types.Int(3),
},
},
{
name: "default-local-times",
value: `
x.getHours() == 18
&& x.getMinutes() == 5
&& x.getSeconds() == 6
&& x.getMilliseconds() == 1`,
envOut: map[string]ref.Val{
"default": types.False,
"enabled": types.False,
"disabled": types.True,
},
},
{
name: "explicit",
value: `
x.getFullYear('-07:30') == 1969
&& x.getDayOfYear('-07:30') == 364
&& x.getMonth('-07:30') == 11
&& x.getDayOfMonth('-07:30') == 30
&& x.getDate('-07:30') == 31
&& x.getDayOfWeek('-07:30') == 3
&& x.getHours('-07:30') == 18
&& x.getMinutes('-07:30') == 35
&& x.getSeconds('-07:30') == 6
&& x.getMilliseconds('-07:30') == 1
&& x.getFullYear('23:15') == 1970
&& x.getDayOfYear('23:15') == 1
&& x.getMonth('23:15') == 0
&& x.getDayOfMonth('23:15') == 1
&& x.getDate('23:15') == 2
&& x.getDayOfWeek('23:15') == 5
&& x.getHours('23:15') == 1
&& x.getMinutes('23:15') == 20
&& x.getSeconds('23:15') == 6
&& x.getMilliseconds('23:15') == 1`,
envOut: map[string]ref.Val{
"default": types.True,
"enabled": types.True,
"disabled": types.True,
},
},
}
offset, _ := time.ParseDuration("-8h")
vars := map[string]any{
"x": time.Unix(7506, 1000000).In(time.FixedZone("", int(offset.Seconds()))),
}
for _, e := range testEnvs {
te := e
for _, expr := range exprs {
ex := expr
t.Run(fmt.Sprintf("%s/%s", te.name, ex.name), func(t *testing.T) {
env := te.env
expr := ex.value
out, err := interpret(t, env, expr, vars)
if err != nil {
t.Fatal(err)
}
if out.Equal(ex.envOut[te.name]) != types.True {
t.Errorf("interpret got %v, wanted %v", out, ex.envOut[te.name])
}
})
}
}
}
func TestDefaultUTCTimeZoneExtension(t *testing.T) {
env := testEnv(t,
Variable("x", TimestampType),
Variable("y", DurationType),
)
env, err := env.Extend()
if err != nil {
t.Fatalf("env.Extend() failed: %v", err)
}
out, err := interpret(t, env, `
x.getFullYear() == 1970
&& y.getHours() == 2
&& y.getMinutes() == 120
&& y.getSeconds() == 7235
&& y.getMilliseconds() == 7235000`,
map[string]any{
"x": time.Unix(7506, 1000000).Local(),
"y": time.Duration(7235) * time.Second,
},
)
if err != nil {
t.Fatalf("prg.Eval() failed: %v", err)
}
if out != types.True {
t.Errorf("Eval() got %v, wanted true", out.Value())
}
}
func TestDefaultUTCTimeZoneError(t *testing.T) {
env := testEnv(t, Variable("x", TimestampType))
out, err := interpret(t, env, `
x.getFullYear(':xx') == 1969
|| x.getDayOfYear('xx:') == 364
|| x.getMonth('Am/Ph') == 11
|| x.getDayOfMonth('Am/Ph') == 30
|| x.getDate('Am/Ph') == 31
|| x.getDayOfWeek('Am/Ph') == 3
|| x.getHours('Am/Ph') == 19
|| x.getMinutes('Am/Ph') == 5
|| x.getSeconds('Am/Ph') == 6
|| x.getMilliseconds('Am/Ph') == 1
`, map[string]any{
"x": time.Unix(7506, 1000000).Local(),
})
if err == nil {
t.Fatalf("prg.Eval() got %v wanted error", out)
}
}
func TestParserRecursionLimit(t *testing.T) {
testCases := []struct {
expr string
errorSubstr string
out ref.Val
}{
{
expr: `0 + 1 + 2 + 3 + 4 + 5 + 6 + 7 + 8 + 9 + 10 + 11`,
errorSubstr: "max recursion depth exceeded",
},
{
expr: `0 + 1 + 2 + 3 + 4 + 5 + 6 + 7 + 8 + 9 + 10`,
out: types.Int(55),
},
{
expr: `0 + 1 + 2 + 3 + 4 + 5 + 6 + 7 + 8 + 9 + 10 == 45`,
errorSubstr: "max recursion depth exceeded",
},
{
// Operator precedence means that '==' is the root.
expr: `0 + 1 + 2 + 3 + 4 + 5 + 6 + 7 + 8 + 9 == 0 + 1 + 2 + 3 + 4 + 5 + 6 + 7 + 8 + 9`,
out: types.True,
},
}
for _, tc := range testCases {
tc := tc
t.Run(tc.expr, func(t *testing.T) {
env := testEnv(t, ParserRecursionLimit(10))
out, err := interpret(t, env,
tc.expr, map[string]any{})
if tc.errorSubstr != "" {
if err == nil || !strings.Contains(err.Error(), tc.errorSubstr) {
t.Fatalf("prg.Eval() wanted error containing '%s' got %v", tc.errorSubstr, err)
}
}
if tc.out != nil {
if tc.out != out {
t.Errorf("prg.Eval() wanted %v got %v", tc.out, out)
}
}
})
}
}
func TestQuotedFields(t *testing.T) {
testCases := []struct {
expr string
errorSubstr string
out ref.Val
}{
{
expr: "{'key-1': 64}.`key-1`",
out: types.Int(64),
},
{
expr: "{'key-1': 64}.`key-2`",
errorSubstr: "no such key: key-2",
},
{
expr: "has({'key-1': 64}.`key-1`)",
out: types.True,
},
{
expr: "has({'key-1': 64}.`key-2`)",
out: types.False,
},
}
for _, tc := range testCases {
tc := tc
t.Run(tc.expr, func(t *testing.T) {
env := testEnv(t, ParserRecursionLimit(10),
EnableIdentifierEscapeSyntax())
out, err := interpret(t, env,
tc.expr, map[string]any{})
if tc.errorSubstr != "" {
if err == nil || !strings.Contains(err.Error(), tc.errorSubstr) {
t.Fatalf("prg.Eval() wanted error containing '%s' got %v", tc.errorSubstr, err)
}
}
if tc.out != nil {
if tc.out != out {
t.Errorf("prg.Eval() wanted %v got %v", tc.out, out)
}
}
})
}
}
func TestDynamicDispatch(t *testing.T) {
env := testEnv(t,
HomogeneousAggregateLiterals(),
Function("first",
MemberOverload("first_list_int", []*Type{ListType(IntType)}, IntType,
UnaryBinding(func(list ref.Val) ref.Val {
l := list.(traits.Lister)
if l.Size() == types.IntZero {
return types.IntZero
}
return l.Get(types.IntZero)
}),
),
MemberOverload("first_list_double", []*Type{ListType(DoubleType)}, DoubleType,
UnaryBinding(func(list ref.Val) ref.Val {
l := list.(traits.Lister)
if l.Size() == types.IntZero {
return types.Double(0.0)
}
return l.Get(types.IntZero)
}),
),
MemberOverload("first_list_string", []*Type{ListType(StringType)}, StringType,
UnaryBinding(func(list ref.Val) ref.Val {
l := list.(traits.Lister)
if l.Size() == types.IntZero {
return types.String("")
}
return l.Get(types.IntZero)
}),
),
MemberOverload("first_list_list_string", []*Type{ListType(ListType(StringType))}, ListType(StringType),
UnaryBinding(func(list ref.Val) ref.Val {
l := list.(traits.Lister)
if l.Size() == types.IntZero {
return types.DefaultTypeAdapter.NativeToValue([]string{})
}
return l.Get(types.IntZero)
}),
),
),
)
out, err := interpret(t, env, `
dyn([]).first() == 0
&& [1, 2].first() == 1
&& [1.0, 2.0].first() == 1.0
&& ["hello", "world"].first() == "hello"
&& [["hello"], ["world", "!"]].first().first() == "hello"
&& [[], ["empty"]].first().first() == ""
&& dyn([1, 2]).first() == 1
&& dyn([1.0, 2.0]).first() == 1.0
&& dyn(["hello", "world"]).first() == "hello"
&& dyn([["hello"], ["world", "!"]]).first().first() == "hello"
`, map[string]any{},
)
if err != nil {
t.Fatalf("prg.Eval() failed: %v", err)
}
if out != types.True {
t.Fatalf("prg.Eval() got %v wanted true", out)
}
}
func TestOptionalValuesCompile(t *testing.T) {
env := testEnv(t,
OptionalTypes(),
// Test variables.
Variable("m", MapType(StringType, MapType(StringType, StringType))),
Variable("optm", OptionalType(MapType(StringType, MapType(StringType, StringType)))),
Variable("l", ListType(StringType)),
Variable("optl", OptionalType(ListType(StringType))),
Variable("x", OptionalType(IntType)),
Variable("y", IntType),
)
tests := []struct {
expr string
references map[int64]*celast.ReferenceInfo
}{
{
expr: `x.or(optional.of(y)).orValue(42)`,
references: map[int64]*celast.ReferenceInfo{
1: {Name: "x"},
2: {OverloadIDs: []string{"optional_or_optional"}},
4: {OverloadIDs: []string{"optional_of"}},
5: {Name: "y"},
6: {OverloadIDs: []string{"optional_orValue_value"}},
},
},
{
expr: `m.?x.hasValue()`,
references: map[int64]*celast.ReferenceInfo{
1: {Name: "m"},
3: {OverloadIDs: []string{"select_optional_field"}},
4: {OverloadIDs: []string{"optional_hasValue"}},
},
},
{
expr: `has(m.?x.y)`,
references: map[int64]*celast.ReferenceInfo{
2: {Name: "m"},
4: {OverloadIDs: []string{"select_optional_field"}},
},
},
{
// Optional index selection in map.
expr: `m.k[?'dashed-index'].orValue('default value')`,
references: map[int64]*celast.ReferenceInfo{
1: {Name: "m"},
3: {OverloadIDs: []string{"map_optindex_optional_value"}},
5: {OverloadIDs: []string{"optional_orValue_value"}},
},
},
{
// Optional index selection in list.
expr: `l[?y]`,
references: map[int64]*celast.ReferenceInfo{
1: {Name: "l"},
2: {OverloadIDs: []string{"list_optindex_optional_int"}},
3: {Name: "y"},
},
},
{
// Index selection against a value in an optional map.
expr: `optm.c['index'].orValue('default value')`,
references: map[int64]*celast.ReferenceInfo{
1: {Name: "optm"},
3: {OverloadIDs: []string{"optional_map_index_value"}},
5: {OverloadIDs: []string{"optional_orValue_value"}},
},
},
{
// Index selection against a value in an optional map.
expr: `optm.c[?'index']`,
references: map[int64]*celast.ReferenceInfo{
1: {Name: "optm"},
3: {OverloadIDs: []string{"optional_map_optindex_optional_value"}},
},
},
{
// Index selection against a value in an optional list.
expr: `optl[0]`,
references: map[int64]*celast.ReferenceInfo{
1: {Name: "optl"},
2: {OverloadIDs: []string{"optional_list_index_int"}},
},
},
{
// Index selection against a value in an optional list.
expr: `optl[?0]`,
references: map[int64]*celast.ReferenceInfo{
1: {Name: "optl"},
2: {OverloadIDs: []string{"optional_list_optindex_optional_int"}},
},
},
}
for i, tst := range tests {
tc := tst
t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
ast, iss := env.Compile(tc.expr)
if iss.Err() != nil {
t.Fatalf("%v failed: %v", tc.expr, iss.Err())
}
for id, reference := range ast.NativeRep().ReferenceMap() {
other, found := tc.references[id]
if !found {
t.Errorf("Compile(%v) expected reference %d: %v", tc.expr, id, reference)
} else if !reference.Equals(other) {
t.Errorf("Compile(%v) got reference %d: %v, wanted %v", tc.expr, id, reference, other)
}
}
})
}
}
func TestOptionalValuesEval(t *testing.T) {
env := testEnv(t,
OptionalTypes(),
// Container and test message types.
Container("google.expr.proto2.test"),
Types(&proto2pb.TestAllTypes{}),
// Test variables.
Variable("m", MapType(StringType, MapType(StringType, StringType))),
Variable("l", ListType(StringType)),
Variable("optm", OptionalType(MapType(StringType, MapType(StringType, StringType)))),
Variable("optl", OptionalType(ListType(StringType))),
Variable("x", OptionalType(IntType)),
Variable("y", OptionalType(IntType)),
Variable("z", IntType),
)
adapter := env.TypeAdapter()
tests := []struct {
expr string
in map[string]any
out any
}{
{
expr: `has({'foo': optional.none()}.foo)`,
out: types.True,
},
{
expr: `has({'foo': optional.none()}.foo.value)`,
out: types.False,
},
{
expr: `has({?'foo': optional.none()}.foo)`,
out: types.False,
},
{
expr: `has({?'foo': optional.none()}.foo.value)`,
out: "no such key: foo",
},
{
expr: `{}.?invalid`,
out: types.OptionalNone,
},
{
expr: `{'null_field': dyn(null)}.?null_field`,
out: types.OptionalOf(types.NullValue),
},
{
expr: `{'null_field': dyn(null)}.?null_field.?nested`,
out: types.OptionalNone,
},
{
expr: `{'zero_field': dyn(0)}.?zero_field.?invalid`,
out: types.OptionalNone,
},
{
expr: `{0: dyn(0)}[?0].?invalid`,
out: types.OptionalNone,
},
{
expr: `{true: dyn(0)}[?false].?invalid`,
out: types.OptionalNone,
},
{
expr: `{true: dyn(0)}[?true].?invalid`,
out: types.OptionalNone,
},
{
expr: `x.or(y).orValue(z)`,
in: map[string]any{
"x": types.OptionalNone,
"y": types.OptionalNone,
"z": 42,
},
out: 42,
},
{
expr: `x.optMap(y, y + 1)`,
in: map[string]any{
"x": types.OptionalNone,
},
out: types.OptionalNone,
},
{
expr: `m.?key.optFlatMap(k, k.?subkey)`,
in: map[string]any{
"m": map[string]any{},
},
out: types.OptionalNone,
},
{
expr: `m.?key.optFlatMap(k, k.?subkey)`,
in: map[string]any{
"m": map[string]any{
"key": map[string]string{},
},
},
out: types.OptionalNone,
},
{
expr: `m.?key.optFlatMap(k, k.?subkey)`,
in: map[string]any{
"m": map[string]any{
"key": map[string]string{
"subkey": "subvalue",
},
},
},
out: types.OptionalOf(types.String("subvalue")),
},
{
expr: `m.?key.optFlatMap(k, k.?subkey)`,
in: map[string]any{
"m": map[string]any{
"key": map[string]string{
"subkey": "",
},
},
},
out: types.OptionalOf(types.String("")),
},
{
expr: `m.?key.optFlatMap(k, optional.ofNonZeroValue(k.subkey))`,
in: map[string]any{
"m": map[string]any{
"key": map[string]string{
"subkey": "",
},
},
},
out: types.OptionalNone,
},
{
expr: `x.optMap(y, y + 1)`,
in: map[string]any{
"x": types.OptionalOf(types.Int(42)),
},
out: types.OptionalOf(types.Int(43)),
},
{
expr: `optional.ofNonZeroValue(z).or(optional.of(10)).value() == 42`,
in: map[string]any{
"z": 42,
},
out: true,
},
{
// Equivalent to m.?x.hasValue()
expr: `(has(m.x) ? optional.of(m.x) : optional.none()).hasValue()`,
in: map[string]any{
"m": map[string]map[string]string{},
},
out: false,
},
{
expr: `m.?x.hasValue()`,
in: map[string]any{
"m": map[string]any{},
},
out: false,
},
{
expr: `has(m.?x.y)`,
in: map[string]any{
"m": map[string]any{},
},
out: false,
},
{
expr: `has(m.?x.y)`,
in: map[string]any{
"m": map[string]any{
"x": map[string]string{
"y": "z",
},
},
},
out: true,
},
{
expr: `type(optional.none()) == optional_type`,
out: true,
},
{
// return the value of m.c['dashed-index'], no magic in the optional.of() call.
expr: `optional.ofNonZeroValue('').or(optional.of(m.c['dashed-index'])).orValue('default value')`,
in: map[string]any{
"m": map[string]any{
"c": map[string]string{
"dashed-index": "goodbye",
},
},
},
out: "goodbye",
},
{
// Optional index selection in map where the index is found.
expr: `m.c[?'dashed-index'].orValue('default value')`,
in: map[string]any{
"m": map[string]any{
"c": map[string]string{
"dashed-index": "goodbye",
},
},
},
out: "goodbye",
},
{
// Optional index selection in map where the index is absent.
expr: `m.c[?'missing-index'].orValue('default value')`,
in: map[string]any{
"m": map[string]any{
"c": map[string]string{},
},
},
out: "default value",
},
{
// Traditional index selection against an optional value in map where the index is found.
expr: `optm.c.index.orValue('default value')`,
in: map[string]any{
"optm": types.OptionalOf(
adapter.NativeToValue(
map[string]any{
"c": map[string]string{
"index": "goodbye",
},
},
),
),
},
out: "goodbye",
},
{
// Traditional index selection against an optional value in map where the index is absent.
expr: `optm.c.missing.or(optl[0]).orValue('default value')`,
in: map[string]any{
"optm": types.OptionalOf(
adapter.NativeToValue(
map[string]any{
"c": map[string]string{},
},
),
),
"optl": types.OptionalNone,
},
out: "default value",
},
{
// Traditional index selection against an optional value in map where the index is absent.
expr: `optm.c.missing.or(optl[0]).orValue('default value')`,
in: map[string]any{
"optm": types.OptionalOf(
adapter.NativeToValue(
map[string]any{
"c": map[string]string{},
},
),
),
"optl": types.OptionalOf(
adapter.NativeToValue([]string{"list-value"}),
),
},
out: "list-value",
},
{
// Traditional index selection against an optional value in map where the index is found.
expr: `optm.c['index'].orValue('default value')`,
in: map[string]any{
"optm": types.OptionalOf(
adapter.NativeToValue(
map[string]any{
"c": map[string]string{
"index": "goodbye",
},
},
),
),
},
out: "goodbye",
},
{
// Traditional index selection against an optional value in map where the index is absent.
expr: `optm.c['missing'].orValue('default value')`,
in: map[string]any{
"optm": types.OptionalOf(
adapter.NativeToValue(
map[string]any{
"c": map[string]string{},
},
),
),
},
out: "default value",
},
{
// Presence test using optional value where the field is absent.
expr: `has(optm.c) && !has(optm.c.missing)`,
in: map[string]any{
"optm": types.OptionalOf(
adapter.NativeToValue(
map[string]any{
"c": map[string]string{
"entry": "hello world",
},
},
),
),
},
out: true,
},
{
// ensure an error is propagated to the result.
expr: `optional.ofNonZeroValue(m.a.z).orValue(m.c['dashed-index'])`,
in: map[string]any{
"m": map[string]any{
"c": map[string]string{
"dashed-index": "goodbye",
},
},
},
out: "no such key: a",
},
{
expr: `m.?c.missing.or(m.?c['dashed-index']).orValue('').size()`,
in: map[string]any{
"m": map[string]any{
"c": map[string]string{
"dashed-index": "goodbye",
},
},
},
out: 7,
},
{
expr: `{?'nested_map': optional.ofNonZeroValue({?'map': m.?c})}`,
in: map[string]any{
"m": map[string]any{
"c": map[string]string{
"dashed-index": "goodbye",
},
},
},
out: map[string]any{
"nested_map": map[string]any{
"map": map[string]string{
"dashed-index": "goodbye",
},
},
},
},
{
expr: `{?'nested_map': optional.ofNonZeroValue({?'map': m.?c}), 'singleton': true}`,
in: map[string]any{
"m": map[string]any{},
},
out: map[string]any{
"singleton": true,
},
},
{
expr: `[?m.?c, ?x, ?y]`,
in: map[string]any{
"m": map[string]any{},
"x": types.OptionalOf(types.Int(42)),
"y": types.OptionalNone,
},
out: []any{42},
},
{
expr: `[?optional.ofNonZeroValue(m.?c.orValue({}))]`,
in: map[string]any{
"m": map[string]any{
"c": []string{},
},
},
out: []any{},
},
{
expr: `optional.ofNonZeroValue({?'nested_map': optional.ofNonZeroValue({?'map': m.?c})})`,
in: map[string]any{
"m": map[string]any{},
},
out: types.OptionalNone,
},
{
expr: `TestAllTypes{?single_double_wrapper: optional.ofNonZeroValue(0.0)}`,
out: &proto2pb.TestAllTypes{},
},
{
expr: `optional.ofNonZeroValue(TestAllTypes{?single_double_wrapper: optional.ofNonZeroValue(0.0)})`,
out: types.OptionalNone,
},
{
expr: `TestAllTypes{
?map_string_string: m[?'nested']
}`,
in: map[string]any{
"m": map[string]any{
"nested": map[string]any{},
},
},
out: &proto2pb.TestAllTypes{},
},
{
expr: `TestAllTypes{
?map_string_string: optional.ofNonZeroValue(m[?'nested'].orValue({}))
}`,
in: map[string]any{
"m": map[string]any{
"nested": map[string]any{},
},
},
out: &proto2pb.TestAllTypes{},
},
{
expr: `TestAllTypes{
?map_string_string: m[?'nested']
}`,
in: map[string]any{
"m": map[string]any{
"nested": map[string]any{
"hello": "world",
},
},
},
out: &proto2pb.TestAllTypes{
MapStringString: map[string]string{"hello": "world"},
},
},
{
expr: `TestAllTypes{
repeated_string: ['greetings', ?m.nested.?hello],
?repeated_int32: optional.ofNonZeroValue([?x, ?y]),
}`,
in: map[string]any{
"m": map[string]any{
"nested": map[string]any{
"hello": "world",
},
},
"x": types.OptionalNone,
"y": types.OptionalNone,
},
out: &proto2pb.TestAllTypes{
RepeatedString: []string{"greetings", "world"},
},
},
{expr: `[].first()`, out: types.OptionalNone},
{expr: `['a','b','c'].first()`, out: types.OptionalOf(types.String("a"))},
{expr: `[].last()`, out: types.OptionalNone},
{expr: `[1, 2, 3].last()`, out: types.OptionalOf(types.Int(3))},
{expr: `optional.unwrap([])`, out: []any{}},
{expr: `optional.unwrap([optional.none(), optional.none()])`, out: []any{}},
{expr: `optional.unwrap([optional.of(42), optional.none(), optional.of("a")])`, out: []any{types.Int(42), types.String("a")}},
{expr: `optional.unwrap([optional.of(42), optional.of("a")])`, out: []any{types.Int(42), types.String("a")}},
{expr: `[].unwrapOpt()`, out: []any{}},
{expr: `[optional.none(), optional.none()].unwrapOpt()`, out: []any{}},
{expr: `[optional.of(42), optional.none(), optional.of("a")].unwrapOpt()`, out: []any{types.Int(42), types.String("a")}},
{expr: `[optional.of(42), optional.of("a")].unwrapOpt()`, out: []any{types.Int(42), types.String("a")}},
{expr: `optional.of(optional.of(1)) != dyn(optional.of(1))`, out: types.True},
{expr: `(true ? optional.of(optional.of(1)) : dyn(optional.of(2))) != dyn(optional.of(1))`, out: types.True},
}
for i, tst := range tests {
tc := tst
t.Run(fmt.Sprintf("%s/%d", tc.expr, i), func(t *testing.T) {
ast, iss := env.Compile(tc.expr)
if iss.Err() != nil {
t.Fatalf("%v failed: %v", tc.expr, iss.Err())
}
prg, err := env.Program(ast)
if err != nil {
t.Fatalf("env.Program() failed: %v", err)
}
out, _, err := prg.Eval(tc.in)
if err != nil && err.Error() != tc.out {
t.Errorf("prg.Eval() got %v, wanted %v", err, tc.out)
}
want := adapter.NativeToValue(tc.out)
if err == nil && out.Equal(want) != types.True {
t.Errorf("prg.Eval() got %v, wanted %v", out, want)
}
})
}
}
func unknownsEquivalent(t *testing.T, a *types.Unknown, b *types.Unknown) bool {
t.Helper()
return a.Contains(b) && b.Contains(a)
}
func TestOptionalValuesEvalUnknowns(t *testing.T) {
env := testEnv(t,
OptionalTypes(),
// Container and test message types.
Container("google.expr.proto2.test"),
Types(&proto2pb.TestAllTypes{}),
// Test variables.
Variable("x", OptionalType(IntType)),
Variable("y", OptionalType(IntType)),
Variable("z", IntType),
)
tests := []struct {
expr string
in map[string]any
out ref.Val
}{
{
expr: `x.or(y).orValue(z)`,
in: map[string]any{
"y": types.OptionalNone,
"z": 42,
},
out: types.NewUnknown(1, types.NewAttributeTrail("x")),
},
{
expr: `x.or(y).orValue(z)`,
in: map[string]any{
"x": types.OptionalNone,
"y": types.OptionalNone,
},
out: types.NewUnknown(5, types.NewAttributeTrail("z")),
},
{
expr: `x.or(y).orValue(z)`,
in: map[string]any{
"x": types.OptionalOf(types.IntOne),
"y": types.OptionalNone,
},
out: types.IntOne,
},
{
expr: `x.or(y).orValue(z)`,
in: map[string]any{
"x": types.OptionalNone,
"y": types.OptionalOf(types.IntOne),
},
out: types.IntOne,
},
}
for i, tst := range tests {
tc := tst
t.Run(fmt.Sprintf("%s/%d", tc.expr, i), func(t *testing.T) {
ast, iss := env.Compile(tc.expr)
if iss.Err() != nil {
t.Fatalf("%v failed: %v", tc.expr, iss.Err())
}
prg, err := env.Program(ast, EvalOptions(OptPartialEval))
if err != nil {
t.Fatalf("env.Program() failed: %v", err)
}
act, err := env.PartialVars(tc.in)
if err != nil {
t.Fatalf("env.PartialVars() returned error %v", err)
}
out, _, err := prg.Eval(act)
if err != nil {
t.Fatalf("prg.Eval() got error %v, wanted nil", err)
}
// unknowns don't define equality so special case.
if wantUnk, ok := tc.out.(*types.Unknown); ok {
if unk, ok := out.(*types.Unknown); !ok || !unknownsEquivalent(t, unk, wantUnk) {
t.Errorf("prg.Eval() got %v, wanted %v", unk, tc.out)
}
return
}
if eq, ok := out.Equal(tc.out).Value().(bool); !ok || !eq {
t.Errorf("prg.Eval() got %v, wanted %v", out, tc.out)
}
})
}
}
func TestOptionalValuesEvalErrorCases(t *testing.T) {
env := testEnv(t,
OptionalTypes(),
// Container and test message types.
Container("google.expr.proto2.test"),
Types(&proto2pb.TestAllTypes{}),
// Test variables.
Variable("x", OptionalType(IntType)),
Variable("y", OptionalType(IntType)),
Variable("z", IntType),
)
tests := []struct {
expr string
in map[string]any
wantErr string
}{
{
expr: `dyn(1).or(optional.of(2))`,
wantErr: "no such overload",
},
{
expr: `dyn(1).orValue(2)`,
wantErr: "no such overload",
},
{
expr: `optional.of(1/0).or(optional.of(2))`,
wantErr: "division by zero",
},
{
expr: `optional.of(1/0).orValue(2)`,
wantErr: "division by zero",
},
}
for i, tst := range tests {
tc := tst
t.Run(fmt.Sprintf("%s/%d", tc.expr, i), func(t *testing.T) {
ast, iss := env.Compile(tc.expr)
if iss.Err() != nil {
t.Fatalf("%v failed: %v", tc.expr, iss.Err())
}
prg, err := env.Program(ast)
if err != nil {
t.Errorf("env.Program() failed: %v", err)
}
_, _, err = prg.Eval(tc.in)
if err == nil {
t.Fatalf("prg.Eval got nil error, wanted %v", tc.wantErr)
}
if !strings.Contains(err.Error(), tc.wantErr) {
t.Errorf("prg.Eval() got %v, wanted %v", err, tc.wantErr)
}
})
}
}
func TestEnableErrorOnBadPresenceTest(t *testing.T) {
env := testEnv(t,
OptionalTypes(),
EnableErrorOnBadPresenceTest(true),
)
adapter := env.TypeAdapter()
tests := []struct {
expr string
in map[string]any
out any
}{
{
expr: `{}.?invalid`,
out: types.OptionalNone,
},
{
expr: `{'null_field': dyn(null)}.?null_field`,
out: types.OptionalOf(types.NullValue),
},
{
expr: `{'null_field': dyn(null)}.?null_field.?nested`,
out: "no such key: nested",
},
{
expr: `{'zero_field': dyn(0)}.?zero_field.?invalid`,
out: "no such key: invalid",
},
{
expr: `{0: dyn(0)}[?0].?invalid`,
out: "no such key: invalid",
},
{
expr: `{true: dyn(0)}[?false].?invalid`,
out: types.OptionalNone,
},
{
expr: `{true: dyn(0)}[?true].?invalid`,
out: "no such key: invalid",
},
}
for i, tst := range tests {
tc := tst
t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
ast, iss := env.Compile(tc.expr)
if iss.Err() != nil {
t.Fatalf("%v failed: %v", tc.expr, iss.Err())
}
prg, err := env.Program(ast)
if err != nil {
t.Errorf("env.Program() failed: %v", err)
}
out, _, err := prg.Eval(tc.in)
if err != nil && err.Error() != tc.out {
t.Errorf("prg.Eval() got %v, wanted %v", err, tc.out)
}
want := adapter.NativeToValue(tc.out)
if err == nil && out.Equal(want) != types.True {
t.Errorf("prg.Eval() got %v, wanted %v", out, want)
}
})
}
}
func TestOptionalMacroError(t *testing.T) {
env := testEnv(t,
OptionalTypes(),
// Test variables.
Variable("x", OptionalType(IntType)),
)
_, iss := env.Compile("x.optMap(y.z, y.z + 1)")
if iss.Err() == nil || !strings.Contains(iss.Err().Error(), "variable name must be a simple identifier") {
t.Errorf("optMap() got an unexpected result: %v", iss.Err())
}
_, iss = env.Compile("x.optFlatMap(y.z, y.z + 1)")
if iss.Err() == nil || !strings.Contains(iss.Err().Error(), "variable name must be a simple identifier") {
t.Errorf("optFlatMap() got an unexpected result: %v", iss.Err())
}
env = testEnv(t,
OptionalTypes(OptionalTypesVersion(0)),
// Test variables.
Variable("x", OptionalType(IntType)),
)
_, iss = env.Compile("x.optFlatMap(y, y.z + 1)")
if iss.Err() == nil || !strings.Contains(iss.Err().Error(), "undeclared reference to 'optFlatMap'") {
t.Errorf("optFlatMap() got an unexpected result: %v", iss.Err())
}
}
func TestParserExpressionSizeLimit(t *testing.T) {
env := testEnv(t, ParserExpressionSizeLimit(10))
_, iss := env.Parse("'greeting'")
if iss.Err() != nil {
t.Errorf("Parse('greeting') failed: %v", iss.Err())
}
_, iss = env.Parse("'greetings'")
if !strings.Contains(iss.Err().Error(), "size exceeds limit") {
t.Errorf("Parse('greetings') got unexpected error: %v", iss.Err())
}
}
func BenchmarkOptionalValues(b *testing.B) {
env := testEnv(b,
OptionalTypes(),
Variable("x", OptionalType(IntType)),
Variable("y", OptionalType(IntType)),
Variable("z", IntType),
)
ast, iss := env.Compile("x.or(y).orValue(z)")
if iss.Err() != nil {
b.Fatalf("env.Compile(x.or(y).orValue(z)) failed: %v", iss.Err())
}
prg, err := env.Program(ast, EvalOptions(OptOptimize))
if err != nil {
b.Errorf("env.Program() failed: %v", err)
}
input := map[string]any{
"x": types.OptionalNone,
"y": types.OptionalNone,
"z": 42,
}
for i := 0; i < b.N; i++ {
prg.Eval(input)
}
}
func BenchmarkDynamicDispatch(b *testing.B) {
env := testEnv(b,
HomogeneousAggregateLiterals(),
Function("first",
MemberOverload("first_list_int", []*Type{ListType(IntType)}, IntType,
UnaryBinding(func(list ref.Val) ref.Val {
l := list.(traits.Lister)
if l.Size() == types.IntZero {
return types.IntZero
}
return l.Get(types.IntZero)
}),
),
MemberOverload("first_list_double", []*Type{ListType(DoubleType)}, DoubleType,
UnaryBinding(func(list ref.Val) ref.Val {
l := list.(traits.Lister)
if l.Size() == types.IntZero {
return types.Double(0.0)
}
return l.Get(types.IntZero)
}),
),
MemberOverload("first_list_string", []*Type{ListType(StringType)}, StringType,
UnaryBinding(func(list ref.Val) ref.Val {
l := list.(traits.Lister)
if l.Size() == types.IntZero {
return types.String("")
}
return l.Get(types.IntZero)
}),
),
MemberOverload("first_list_list_string", []*Type{ListType(ListType(StringType))}, ListType(StringType),
UnaryBinding(func(list ref.Val) ref.Val {
l := list.(traits.Lister)
if l.Size() == types.IntZero {
return types.DefaultTypeAdapter.NativeToValue([]string{})
}
return l.Get(types.IntZero)
}),
),
),
)
prg := compile(b, env, `
[].first() == 0
&& [1, 2].first() == 1
&& [1.0, 2.0].first() == 1.0
&& ["hello", "world"].first() == "hello"
&& [["hello"], ["world", "!"]].first().first() == "hello"`)
prgDyn := compile(b, env, `
dyn([]).first() == 0
&& dyn([1, 2]).first() == 1
&& dyn([1.0, 2.0]).first() == 1.0
&& dyn(["hello", "world"]).first() == "hello"
&& dyn([["hello"], ["world", "!"]]).first().first() == "hello"`)
b.ResetTimer()
b.Run("DirectDispatch", func(b *testing.B) {
for i := 0; i < b.N; i++ {
prg.Eval(NoVars())
}
})
b.ResetTimer()
b.Run("DynamicDispatch", func(b *testing.B) {
for i := 0; i < b.N; i++ {
prgDyn.Eval(NoVars())
}
})
}
func TestAstProgramNilValue(t *testing.T) {
var ast *Ast = nil
env := testEnv(t)
prg, err := env.Program(ast)
if err == nil || !strings.Contains(err.Error(), "unsupported expr") {
t.Errorf("env.Program() got (%v,%v) wanted unsupported expr error", prg, err)
}
}
func TestJSONFieldNames(t *testing.T) {
tests := []struct {
name string
expr string
jsonFieldNames bool
}{
{
name: "proto simple field",
expr: `msg.single_int32 == 1`,
},
{
name: "proto map field",
expr: `msg.map_string_string['key'] == 'value'`,
},
{
name: "json simple field",
expr: `msg.singleInt32 == 1`,
jsonFieldNames: true,
},
{
name: "json repeated field",
expr: `msg.mapStringString['key'] == 'value'`,
jsonFieldNames: true,
},
{
name: "message with json field",
expr: `TestAllTypes{singleInt32: 1} != msg`,
jsonFieldNames: true,
},
{
name: "message with json field and proto fallback",
expr: `dyn(TestAllTypes{singleInt32: 2}).single_int32 == 2`,
jsonFieldNames: true,
},
{
name: "json with proto fallback",
expr: `dyn(msg).single_int32 == dyn(msg).singleInt32`,
jsonFieldNames: true,
},
{
name: "proto with extensions",
expr: `google.expr.proto2.test.ExampleType{fooBar: 'value'}.fooBar == 'value'`,
jsonFieldNames: true,
},
{
name: "json opt fields",
expr: "jsonOptMsg.int32_snake_case_json_name == 1 && " +
"jsonOptMsg.int64CamelCaseJsonName == 2 && " +
"jsonOptMsg.uint32DefaultJsonName == 3u && " +
"jsonOptMsg.`uint64-custom-json-name` == 4u && " +
"jsonOptMsg.single_string == 'shadows' && " +
"jsonOptMsg.singleString == 'shadowed'",
jsonFieldNames: true,
},
{
name: "json opt fields fallback",
expr: "dyn(jsonOptMsg).int32_snake_case_json_name == 1 && " +
"dyn(jsonOptMsg).`uint64-custom-json-name` == 4u && " +
"dyn(jsonOptMsg).single_string == 'shadows' && " +
"dyn(jsonOptMsg).string_json_name_shadows == 'shadows' && " +
"dyn(jsonOptMsg).singleString == 'shadowed'",
jsonFieldNames: true,
},
}
msg := &proto3pb.TestAllTypes{
SingleInt32: 1,
MapStringString: map[string]string{
"key": "value",
},
}
jsonOptMsg := &proto3pb.TestJsonNames{
Int32SnakeCaseJsonName: 1,
Int64CamelCaseJsonName: 2,
Uint32DefaultJsonName: 3,
Uint64CustomJsonName: 4,
StringJsonNameShadows: "shadows",
SingleString: "shadowed",
}
for _, tst := range tests {
tc := tst
t.Run(tc.name, func(t *testing.T) {
env, err := NewEnv(
EnableIdentifierEscapeSyntax(),
JSONFieldNames(tc.jsonFieldNames),
Types(msg, &proto2pb.ExternalMessageType{}, jsonOptMsg),
Container(string(msg.ProtoReflect().Descriptor().ParentFile().Package())),
Variable("msg", ObjectType(string(msg.ProtoReflect().Descriptor().FullName()))),
Variable("jsonOptMsg", ObjectType(string(jsonOptMsg.ProtoReflect().Descriptor().FullName()))),
)
if err != nil {
t.Fatalf("NewEnv() failed: %v", err)
}
ast, iss := env.Compile(tc.expr)
if iss.Err() != nil {
t.Fatalf("env.Compile() failed: %v", iss.Err())
}
prg, err := env.Program(ast)
if err != nil {
t.Fatalf("env.Program() failed: %v", err)
}
out, _, err := prg.Eval(map[string]any{"msg": msg, "jsonOptMsg": jsonOptMsg})
if err != nil {
t.Fatalf("prg.Eval() failed: %v", err)
}
if out != types.True {
t.Errorf("prg.Eval() got %v, wanted 'true'", out)
}
if tc.jsonFieldNames {
noJSONEnv, err := env.Extend(JSONFieldNames(false))
if err != nil {
t.Fatalf("env.Extend() failed: %v", err)
}
_, err = noJSONEnv.Program(ast)
if err == nil {
t.Fatal("env with json disabled allowed program with json extension to be planned")
}
} else {
jsonEnv, err := env.Extend(JSONFieldNames(true))
if err != nil {
t.Fatalf("env.Extend() failed: %v", err)
}
prg, err = jsonEnv.Program(ast)
if err != nil {
t.Fatalf("env.Program() failed: %v", err)
}
out, _, err := prg.Eval(map[string]any{"msg": msg})
if err != nil {
t.Fatalf("prg.Eval() failed: %v", err)
}
if out != types.True {
t.Errorf("prg.Eval() got %v, wanted 'true'", out)
}
}
})
}
}
func TestJSONFieldNamesInvalidProvider(t *testing.T) {
type wrapperRegistry struct {
*types.Registry
}
reg, err := types.NewProtoRegistry(types.JSONFieldNames(true))
if err != nil {
t.Fatalf("types.NewProtoRegistry() failed: %v", err)
}
wrapped := wrapperRegistry{Registry: reg}
_, err = NewEnv(CustomTypeProvider(wrapped), CustomTypeAdapter(reg), JSONFieldNames(true))
if err == nil {
t.Error("NewEnv() created a CEL environment successfully despite incompatible configs")
}
}
// TODO: ideally testCostEstimator and testRuntimeCostEstimator would be shared in a test fixtures package
type testCostEstimator struct {
hints map[string]uint64
}
func (tc testCostEstimator) EstimateSize(element checker.AstNode) *checker.SizeEstimate {
if l, ok := tc.hints[strings.Join(element.Path(), ".")]; ok {
return &checker.SizeEstimate{Min: 0, Max: l}
}
return nil
}
func (tc testCostEstimator) EstimateCallCost(function, overloadID string, target *checker.AstNode, args []checker.AstNode) *checker.CallEstimate {
return nil
}
func estimateTimestampToYear(estimator checker.CostEstimator, target *checker.AstNode, args []checker.AstNode) *checker.CallEstimate {
return &checker.CallEstimate{CostEstimate: checker.CostEstimate{Min: 7, Max: 7}}
}
type testRuntimeCostEstimator struct{}
var timeToYearCost uint64 = 7
func (e testRuntimeCostEstimator) CallCost(function, overloadID string, args []ref.Val, result ref.Val) *uint64 {
argsSize := make([]uint64, len(args))
for i, arg := range args {
reflectV := reflect.ValueOf(arg.Value())
switch reflectV.Kind() {
// Note that the CEL bytes type is implemented with Go byte slices, therefore also supported by the following
// code.
case reflect.String, reflect.Array, reflect.Slice, reflect.Map:
argsSize[i] = uint64(reflectV.Len())
default:
argsSize[i] = 1
}
}
return nil
}
func trackTimestampToYear(args []ref.Val, result ref.Val) *uint64 {
return &timeToYearCost
}
func testEnv(t testing.TB, opts ...EnvOption) *Env {
t.Helper()
e, err := NewEnv(opts...)
if err != nil {
t.Fatalf("NewEnv() failed: %v", err)
}
return e
}
func compile(t testing.TB, env *Env, expr string) Program {
t.Helper()
prg, err := compileOrError(t, env, expr)
if err != nil {
t.Fatal(err)
}
return prg
}
func compileOrError(t testing.TB, env *Env, expr string) (Program, error) {
t.Helper()
ast, iss := env.Compile(expr)
if iss.Err() != nil {
return nil, fmt.Errorf("env.Compile(%s) failed: %v", expr, iss.Err())
}
prg, err := env.Program(ast, EvalOptions(OptOptimize))
if err != nil {
return nil, fmt.Errorf("env.Program() failed: %v", err)
}
return prg, nil
}
func interpret(t testing.TB, env *Env, expr string, vars any) (ref.Val, error) {
t.Helper()
prg, err := compileOrError(t, env, expr)
if err != nil {
return nil, err
}
out, _, err := prg.Eval(vars)
if err != nil {
return nil, fmt.Errorf("prg.Eval(%v) failed: %v", vars, err)
}
return out, nil
}
func TestExpressionSizeLimitEarlyEnforcement(t *testing.T) {
env, err := NewEnv(ParserExpressionSizeLimit(1000))
if err != nil {
t.Fatalf("NewEnv() failed: %v", err)
}
tests := []struct {
name string
mode string
}{
{name: "compile_rejects_oversized", mode: "compile"},
{name: "parse_rejects_oversized", mode: "parse"},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
payload := strings.Repeat("a", 10_000_000)
var m1, m2 runtime.MemStats
runtime.GC()
runtime.ReadMemStats(&m1)
switch tc.mode {
case "compile":
_, iss := env.Compile(payload)
if iss == nil || iss.Err() == nil {
t.Fatal("expected size limit error, got nil")
}
if !strings.Contains(iss.Err().Error(), "expression code point size exceeds limit") {
t.Fatalf("unexpected error: %v", iss.Err())
}
case "parse":
_, iss := env.Parse(payload)
if iss == nil || iss.Err() == nil {
t.Fatal("expected size limit error, got nil")
}
if !strings.Contains(iss.Err().Error(), "expression code point size exceeds limit") {
t.Fatalf("unexpected error: %v", iss.Err())
}
}
runtime.ReadMemStats(&m2)
allocDelta := (m2.TotalAlloc - m1.TotalAlloc) / (1024 * 1024)
if allocDelta > 5 {
t.Errorf("excessive memory allocation: %dMiB during %s (expected <5MiB with early enforcement)",
allocDelta, tc.mode)
}
t.Logf("[%s] memory delta: %dMiB", tc.mode, allocDelta)
})
}
}