blob: 1396a74acdfc26fcbb34e7f5084a3b97e4bb6885 [file] [log] [blame]
// Copyright 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package interpreter
import (
"encoding/json"
"errors"
"reflect"
"testing"
"github.com/golang/protobuf/proto"
"github.com/google/cel-go/checker"
"github.com/google/cel-go/checker/decls"
"github.com/google/cel-go/common"
"github.com/google/cel-go/common/packages"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/common/types/traits"
"github.com/google/cel-go/interpreter/functions"
"github.com/google/cel-go/parser"
structpb "github.com/golang/protobuf/ptypes/struct"
tpb "github.com/golang/protobuf/ptypes/timestamp"
wrapperspb "github.com/golang/protobuf/ptypes/wrappers"
proto2pb "github.com/google/cel-go/test/proto2pb"
proto3pb "github.com/google/cel-go/test/proto3pb"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
)
type testCase struct {
name string
expr string
pkg string
env []*exprpb.Decl
types []proto.Message
funcs []*functions.Overload
attrs AttributeFactory
unchecked bool
in map[string]interface{}
out interface{}
err string
}
var (
testData = []testCase{
{
name: "and_false_1st",
expr: `false && true`,
out: types.False,
},
{
name: "and_false_2nd",
expr: `true && false`,
out: types.False,
},
{
name: "and_error_1st_false",
expr: `1/0 != 0 && false`,
out: types.False,
},
{
name: "and_error_2nd_false",
expr: `false && 1/0 != 0`,
out: types.False,
},
{
name: "and_error_1st_error",
expr: `1/0 != 0 && true`,
err: "divide by zero",
},
{
name: "and_error_2nd_error",
expr: `true && 1/0 != 0`,
err: "divide by zero",
},
{
name: "call_no_args",
expr: `zero()`,
unchecked: true,
funcs: []*functions.Overload{
{
Operator: "zero",
Function: func(args ...ref.Val) ref.Val {
return types.IntZero
},
},
},
out: types.IntZero,
},
{
name: "call_one_arg",
expr: `neg(1)`,
unchecked: true,
funcs: []*functions.Overload{
{
Operator: "neg",
OperandTrait: traits.NegatorType,
Unary: func(arg ref.Val) ref.Val {
return arg.(traits.Negater).Negate()
},
},
},
out: types.IntNegOne,
},
{
name: "call_two_arg",
expr: `b'abc'.concat(b'def')`,
unchecked: true,
funcs: []*functions.Overload{
{
Operator: "concat",
OperandTrait: traits.AdderType,
Binary: func(lhs, rhs ref.Val) ref.Val {
return lhs.(traits.Adder).Add(rhs)
},
},
},
out: []byte{'a', 'b', 'c', 'd', 'e', 'f'},
},
{
name: "call_varargs",
expr: `addall(a, b, c, d) == 10`,
unchecked: true,
funcs: []*functions.Overload{
{
Operator: "addall",
OperandTrait: traits.AdderType,
Function: func(args ...ref.Val) ref.Val {
val := types.Int(0)
for _, arg := range args {
val += arg.(types.Int)
}
return val
},
},
},
in: map[string]interface{}{
"a": 1, "b": 2, "c": 3, "d": 4,
},
},
{
name: "complex",
expr: `
!(headers.ip in ["10.0.1.4", "10.0.1.5"]) &&
((headers.path.startsWith("v1") && headers.token in ["v1", "v2", "admin"]) ||
(headers.path.startsWith("v2") && headers.token in ["v2", "admin"]) ||
(headers.path.startsWith("/admin") && headers.token == "admin" && headers.ip in ["10.0.1.2", "10.0.1.2", "10.0.1.2"]))
`,
env: []*exprpb.Decl{
decls.NewIdent("headers", decls.NewMapType(decls.String, decls.String), nil),
},
in: map[string]interface{}{
"headers": map[string]interface{}{
"ip": "10.0.1.2",
"path": "/admin/edit",
"token": "admin",
},
},
},
{
name: "complex_qual_vars",
expr: `
!(headers.ip in ["10.0.1.4", "10.0.1.5"]) &&
((headers.path.startsWith("v1") && headers.token in ["v1", "v2", "admin"]) ||
(headers.path.startsWith("v2") && headers.token in ["v2", "admin"]) ||
(headers.path.startsWith("/admin") && headers.token == "admin" && headers.ip in ["10.0.1.2", "10.0.1.2", "10.0.1.2"]))
`,
env: []*exprpb.Decl{
decls.NewIdent("headers.ip", decls.String, nil),
decls.NewIdent("headers.path", decls.String, nil),
decls.NewIdent("headers.token", decls.String, nil),
},
in: map[string]interface{}{
"headers.ip": "10.0.1.2",
"headers.path": "/admin/edit",
"headers.token": "admin",
},
},
{
name: "cond",
expr: `a ? b < 1.2 : c == ['hello']`,
env: []*exprpb.Decl{
decls.NewIdent("a", decls.Bool, nil),
decls.NewIdent("b", decls.Double, nil),
decls.NewIdent("c", decls.NewListType(decls.String), nil),
},
in: map[string]interface{}{
"a": true,
"b": 2.0,
"c": []string{"hello"},
},
out: types.False,
},
{
name: "in_list",
expr: `6 in [2, 12, 6]`,
},
{
name: "in_map",
expr: `'other-key' in {'key': null, 'other-key': 42}`,
},
{
name: "index",
expr: `m['key'][1] == 42u && m['null'] == null && m[string(0)] == 10`,
env: []*exprpb.Decl{
decls.NewIdent("m", decls.NewMapType(decls.String, decls.Dyn), nil),
},
in: map[string]interface{}{
"m": map[string]interface{}{
"key": []uint{21, 42},
"null": nil,
"0": 10,
},
},
},
{
name: "index_relative",
expr: `([[[1]], [[2]], [[3]]][0][0] + [2, 3, {'four': {'five': 'six'}}])[3].four.five == 'six'`,
},
{
name: "literal_bool_false",
expr: `false`,
out: types.False,
},
{
name: "literal_bool_true",
expr: `true`,
},
{
name: "literal_null",
expr: `null`,
out: types.NullValue,
},
{
name: "literal_list",
expr: `[1, 2, 3]`,
out: []int64{1, 2, 3},
},
{
name: "literal_map",
expr: `{'hi': 21, 'world': 42u}`,
out: map[string]interface{}{
"hi": 21,
"world": uint(42),
},
},
{
name: "literal_equiv_string_bytes",
expr: `string(bytes("\303\277")) == '''\303\277'''`,
},
{
name: "literal_not_equiv_string_bytes",
expr: `string(b"\303\277") != '''\303\277'''`,
},
{
name: "literal_equiv_bytes_string",
expr: `string(b"\303\277") == 'ÿ'`,
},
{
name: "literal_bytes_string",
expr: `string(b'aaa"bbb')`,
out: `aaa"bbb`,
},
{
name: "literal_bytes_string2",
expr: `string(b"""Kim\t""")`,
out: `Kim `,
},
{
name: "literal_pb3_msg",
pkg: "google.api.expr",
types: []proto.Message{&exprpb.Expr{}},
expr: `v1alpha1.Expr{
id: 1,
const_expr: v1alpha1.Constant{
string_value: "oneof_test"
}
}`,
out: &exprpb.Expr{Id: 1,
ExprKind: &exprpb.Expr_ConstExpr{
ConstExpr: &exprpb.Constant{
ConstantKind: &exprpb.Constant_StringValue{
StringValue: "oneof_test"}}}},
},
{
name: "literal_pb_enum",
pkg: "google.expr.proto3.test",
types: []proto.Message{&proto3pb.TestAllTypes{}},
expr: `TestAllTypes{
repeated_nested_enum: [
0,
TestAllTypes.NestedEnum.BAZ,
TestAllTypes.NestedEnum.BAR],
repeated_int32: [
TestAllTypes.NestedEnum.FOO,
TestAllTypes.NestedEnum.BAZ]}`,
out: &proto3pb.TestAllTypes{
RepeatedNestedEnum: []proto3pb.TestAllTypes_NestedEnum{
proto3pb.TestAllTypes_FOO,
proto3pb.TestAllTypes_BAZ,
proto3pb.TestAllTypes_BAR,
},
RepeatedInt32: []int32{0, 2},
},
},
{
name: "string_to_timestamp",
expr: `timestamp('1986-04-26T01:23:40Z')`,
out: &tpb.Timestamp{Seconds: 514862620},
},
{
name: "macro_all_non_strict",
expr: `![0, 2, 4].all(x, 4/x != 2 && 4/(4-x) != 2)`,
},
{
name: "macro_all_non_strict_var",
expr: `code == "111" && ["a", "b"].all(x, x in tags)
|| code == "222" && ["a", "b"].all(x, x in tags)`,
env: []*exprpb.Decl{
decls.NewIdent("code", decls.String, nil),
decls.NewIdent("tags", decls.NewListType(decls.String), nil),
},
in: map[string]interface{}{
"code": "222",
"tags": []string{"a", "b"},
},
},
{
name: "macro_exists_lit",
expr: `[1, 2, 3, 4, 5u, 1.0].exists(e, type(e) == uint)`,
},
{
name: "macro_exists_nonstrict",
expr: `[0, 2, 4].exists(x, 4/x == 2 && 4/(4-x) == 2)`,
},
{
name: "macro_exists_var",
expr: `elems.exists(e, type(e) == uint)`,
env: []*exprpb.Decl{
decls.NewIdent("elems", decls.NewListType(decls.Dyn), nil),
},
in: map[string]interface{}{
"elems": []interface{}{0, 1, 2, 3, 4, uint(5), 6},
},
},
{
name: "macro_exists_one",
expr: `[1, 2, 3].exists_one(x, (x % 2) == 0)`,
},
{
name: "macro_filter",
expr: `[1, 2, 3].filter(x, x > 2) == [3]`,
},
{
name: "macro_has_map_key",
expr: `has({'a':1}.a) && !has({}.a)`,
},
{
name: "macro_has_pb2_field",
types: []proto.Message{&proto2pb.TestAllTypes{}},
env: []*exprpb.Decl{
decls.NewIdent("pb2", decls.NewObjectType("google.expr.proto2.test.TestAllTypes"), nil),
},
in: map[string]interface{}{
"pb2": &proto2pb.TestAllTypes{
RepeatedBool: []bool{false},
MapInt64NestedType: map[int64]*proto2pb.NestedTestAllTypes{
1: &proto2pb.NestedTestAllTypes{},
},
MapStringString: map[string]string{},
},
},
expr: `!has(pb2.single_int64)
&& has(pb2.repeated_bool)
&& !has(pb2.repeated_int32)
&& has(pb2.map_int64_nested_type)
&& !has(pb2.map_string_string)`,
},
{
name: "macro_has_pb3_field",
types: []proto.Message{&exprpb.ParsedExpr{}},
pkg: "google.api.expr.v1alpha1",
expr: `has(v1alpha1.ParsedExpr{expr:Expr{id: 1}}.expr)
&& !has(ParsedExpr{expr:Expr{}}.expr.id)
&& has(SourceInfo{positions:{1:1}}.positions)
&& !has(SourceInfo{positions:{}}.positions)
&& !has(expr.v1alpha1.ParsedExpr{expr:Expr{id: 1}}.source_info)`,
},
{
name: "macro_map",
expr: `[1, 2, 3].map(x, x * 2) == [2, 4, 6]`,
},
{
name: "matches",
expr: `input.matches('k.*')
&& !'foo'.matches('k.*')
&& !'bar'.matches('k.*')
&& 'kilimanjaro'.matches('.*ro')`,
env: []*exprpb.Decl{
decls.NewIdent("input", decls.String, nil),
},
in: map[string]interface{}{
"input": "kathmandu",
},
},
{
name: "or_true_1st",
expr: `ai == 20 || ar["foo"] == "bar"`,
env: []*exprpb.Decl{
decls.NewIdent("ai", decls.Int, nil),
decls.NewIdent("ar", decls.NewMapType(decls.String, decls.String), nil),
},
in: map[string]interface{}{
"ai": 20,
"ar": map[string]string{
"foo": "bar",
},
},
},
{
name: "or_true_2nd",
expr: `ai == 20 || ar["foo"] == "bar"`,
env: []*exprpb.Decl{
decls.NewIdent("ai", decls.Int, nil),
decls.NewIdent("ar", decls.NewMapType(decls.String, decls.String), nil),
},
in: map[string]interface{}{
"ai": 2,
"ar": map[string]string{
"foo": "bar",
},
},
},
{
name: "or_false",
expr: `ai == 20 || ar["foo"] == "bar"`,
env: []*exprpb.Decl{
decls.NewIdent("ai", decls.Int, nil),
decls.NewIdent("ar", decls.NewMapType(decls.String, decls.String), nil),
},
in: map[string]interface{}{
"ai": 2,
"ar": map[string]string{
"foo": "baz",
},
},
out: types.False,
},
{
name: "or_error_1st_error",
expr: `1/0 != 0 || false`,
err: "divide by zero",
},
{
name: "or_error_2nd_error",
expr: `false || 1/0 != 0`,
err: "divide by zero",
},
{
name: "or_error_1st_true",
expr: `1/0 != 0 || true`,
out: types.True,
},
{
name: "or_error_2nd_true",
expr: `true || 1/0 != 0`,
out: types.True,
},
{
name: "pkg_qualified_id",
expr: `b.c.d != 10`,
pkg: "a.b",
env: []*exprpb.Decl{
decls.NewIdent("a.b.c.d", decls.Int, nil),
},
in: map[string]interface{}{
"a.b.c.d": 9,
},
},
{
name: "pkg_qualified_id_unchecked",
expr: `c.d != 10`,
unchecked: true,
pkg: "a.b",
in: map[string]interface{}{
"a.c.d": 9,
},
},
{
name: "pkg_qualified_index_unchecked",
expr: `b.c['d'] == 10`,
unchecked: true,
pkg: "a.b",
in: map[string]interface{}{
"a.b.c": map[string]int{
"d": 10,
},
},
},
{
name: "select_key",
expr: `m.strMap['val'] == 'string'
&& m.floatMap['val'] == 1.5
&& m.doubleMap['val'] == -2.0
&& m.intMap['val'] == -3
&& m.int32Map['val'] == 4
&& m.int64Map['val'] == -5
&& m.uintMap['val'] == 6u
&& m.uint32Map['val'] == 7u
&& m.uint64Map['val'] == 8u
&& m.boolMap['val'] == true
&& m.boolMap['val'] != false`,
env: []*exprpb.Decl{
decls.NewIdent("m", decls.NewMapType(decls.String, decls.Dyn), nil),
},
in: map[string]interface{}{
"m": map[string]interface{}{
"strMap": map[string]string{"val": "string"},
"floatMap": map[string]float32{"val": 1.5},
"doubleMap": map[string]float64{"val": -2.0},
"intMap": map[string]int{"val": -3},
"int32Map": map[string]int32{"val": 4},
"int64Map": map[string]int64{"val": -5},
"uintMap": map[string]uint{"val": 6},
"uint32Map": map[string]uint32{"val": 7},
"uint64Map": map[string]uint64{"val": 8},
"boolMap": map[string]bool{"val": true},
},
},
},
{
name: "select_bool_key",
expr: `m.boolStr[true] == 'string'
&& m.boolFloat32[true] == 1.5
&& m.boolFloat64[false] == -2.1
&& m.boolInt[false] == -3
&& m.boolInt32[false] == 0
&& m.boolInt64[true] == 4
&& m.boolUint[true] == 5u
&& m.boolUint32[true] == 6u
&& m.boolUint64[false] == 7u
&& m.boolBool[true]
&& m.boolIface[false] == true`,
env: []*exprpb.Decl{
decls.NewIdent("m", decls.NewMapType(decls.String, decls.Dyn), nil),
},
in: map[string]interface{}{
"m": map[string]interface{}{
"boolStr": map[bool]string{true: "string"},
"boolFloat32": map[bool]float32{true: 1.5},
"boolFloat64": map[bool]float64{false: -2.1},
"boolInt": map[bool]int{false: -3},
"boolInt32": map[bool]int32{false: 0},
"boolInt64": map[bool]int64{true: 4},
"boolUint": map[bool]uint{true: 5},
"boolUint32": map[bool]uint32{true: 6},
"boolUint64": map[bool]uint64{false: 7},
"boolBool": map[bool]bool{true: true},
"boolIface": map[bool]interface{}{false: true},
},
},
},
{
name: "select_uint_key",
expr: `m.uintIface[1u] == 'string'
&& m.uint32Iface[2u] == 1.5
&& m.uint64Iface[3u] == -2.1
&& m.uint64String[4u] == 'three'`,
env: []*exprpb.Decl{
decls.NewIdent("m", decls.NewMapType(decls.String, decls.Dyn), nil),
},
in: map[string]interface{}{
"m": map[string]interface{}{
"uintIface": map[uint]interface{}{1: "string"},
"uint32Iface": map[uint32]interface{}{2: 1.5},
"uint64Iface": map[uint64]interface{}{3: -2.1},
"uint64String": map[uint64]string{4: "three"},
},
},
},
{
name: "select_index",
expr: `m.strList[0] == 'string'
&& m.floatList[0] == 1.5
&& m.doubleList[0] == -2.0
&& m.intList[0] == -3
&& m.int32List[0] == 4
&& m.int64List[0] == -5
&& m.uintList[0] == 6u
&& m.uint32List[0] == 7u
&& m.uint64List[0] == 8u
&& m.boolList[0] == true
&& m.boolList[1] != true
&& m.ifaceList[0] == {}`,
env: []*exprpb.Decl{
decls.NewIdent("m", decls.NewMapType(decls.String, decls.Dyn), nil),
},
in: map[string]interface{}{
"m": map[string]interface{}{
"strList": []string{"string"},
"floatList": []float32{1.5},
"doubleList": []float64{-2.0},
"intList": []int{-3},
"int32List": []int32{4},
"int64List": []int64{-5},
"uintList": []uint{6},
"uint32List": []uint32{7},
"uint64List": []uint64{8},
"boolList": []bool{true, false},
"ifaceList": []interface{}{map[string]string{}},
},
},
},
{
name: "select_field",
expr: `a.b.c
&& pb3.repeated_nested_enum[0] == test.TestAllTypes.NestedEnum.BAR
&& json.list[0] == 'world'`,
pkg: "google.expr.proto3",
types: []proto.Message{&proto3pb.TestAllTypes{}},
env: []*exprpb.Decl{
decls.NewIdent("a.b", decls.NewMapType(decls.String, decls.Bool), nil),
decls.NewIdent("pb3", decls.NewObjectType("google.expr.proto3.test.TestAllTypes"), nil),
decls.NewIdent("json", decls.NewMapType(decls.String, decls.Dyn), nil),
},
in: map[string]interface{}{
"a.b": map[string]bool{
"c": true,
},
"pb3": &proto3pb.TestAllTypes{
RepeatedNestedEnum: []proto3pb.TestAllTypes_NestedEnum{proto3pb.TestAllTypes_BAR},
},
"json": &structpb.Value{
Kind: &structpb.Value_StructValue{
StructValue: &structpb.Struct{
Fields: map[string]*structpb.Value{
"list": {Kind: &structpb.Value_ListValue{
ListValue: &structpb.ListValue{
Values: []*structpb.Value{
{Kind: &structpb.Value_StringValue{
StringValue: "world",
}},
},
},
}},
},
},
},
},
},
},
// pb2 primitive fields may have default values set.
{
name: "select_pb2_primitive_fields",
expr: `!has(a.single_int32)
&& a.single_int32 == -32
&& a.single_int64 == -64
&& a.single_uint32 == 32u
&& a.single_uint64 == 64u
&& a.single_float == 3.0
&& a.single_double == 6.4
&& a.single_bool
&& "empty" == a.single_string`,
types: []proto.Message{&proto2pb.TestAllTypes{}},
in: map[string]interface{}{
"a": &proto2pb.TestAllTypes{},
},
env: []*exprpb.Decl{
decls.NewIdent("a", decls.NewObjectType("google.expr.proto2.test.TestAllTypes"), nil),
},
},
// Wrapper type nil or value test.
{
name: "select_pb3_wrapper_fields",
expr: `!has(a.single_int32_wrapper) && a.single_int32_wrapper == null
&& has(a.single_int64_wrapper) && a.single_int64_wrapper == 0
&& has(a.single_string_wrapper) && a.single_string_wrapper == "hello"
&& a.single_int64_wrapper == google.protobuf.Int32Value{value: 0}`,
types: []proto.Message{&proto3pb.TestAllTypes{}},
env: []*exprpb.Decl{
decls.NewIdent("a", decls.NewObjectType("google.expr.proto3.test.TestAllTypes"), nil),
},
in: map[string]interface{}{
"a": &proto3pb.TestAllTypes{
SingleInt64Wrapper: &wrapperspb.Int64Value{},
SingleStringWrapper: &wrapperspb.StringValue{Value: "hello"},
},
},
},
{
name: "select_pb3_compare",
expr: `a.single_uint64 > 3u`,
pkg: "google.expr.proto3.test",
types: []proto.Message{&proto3pb.TestAllTypes{}},
env: []*exprpb.Decl{
decls.NewIdent("a", decls.NewObjectType("google.expr.proto3.test.TestAllTypes"), nil),
},
in: map[string]interface{}{
"a": &proto3pb.TestAllTypes{
SingleUint64: 10,
},
},
out: types.True,
},
{
name: "select_custom_pb3_compare",
expr: `a.bb > 100`,
pkg: "google.expr.proto3.test",
types: []proto.Message{&proto3pb.TestAllTypes_NestedMessage{}},
env: []*exprpb.Decl{
decls.NewIdent("a",
decls.NewObjectType("google.expr.proto3.test.TestAllTypes.NestedMessage"), nil),
},
attrs: &custAttrFactory{
AttributeFactory: NewAttributeFactory(
packages.NewPackage("google.expr.proto3.test"),
types.NewRegistry(),
types.NewRegistry(),
),
},
in: map[string]interface{}{
"a": &proto3pb.TestAllTypes_NestedMessage{
Bb: 101,
},
},
out: types.True,
},
{
name: "select_relative",
expr: `json('{"hi":"world"}').hi == 'world'`,
env: []*exprpb.Decl{
decls.NewFunction("json",
decls.NewOverload("string_to_json",
[]*exprpb.Type{decls.String}, decls.Dyn)),
},
funcs: []*functions.Overload{
{
Operator: "json",
Unary: func(val ref.Val) ref.Val {
str, ok := val.(types.String)
if !ok {
return types.ValOrErr(val, "no such overload")
}
m := make(map[string]interface{})
err := json.Unmarshal([]byte(str), &m)
if err != nil {
return types.NewErr("invalid json: %v", err)
}
return types.DefaultTypeAdapter.NativeToValue(m)
},
},
},
},
{
name: "select_subsumed_field",
expr: `a.b.c`,
env: []*exprpb.Decl{
decls.NewIdent("a.b.c", decls.Int, nil),
decls.NewIdent("a.b", decls.NewMapType(decls.String, decls.String), nil),
},
in: map[string]interface{}{
"a.b.c": 10,
"a.b": map[string]string{
"c": "ten",
},
},
out: types.Int(10),
},
{
name: "select_empty_repeated_nested",
expr: `TestAllTypes{}.repeated_nested_message.size() == 0`,
types: []proto.Message{&proto3pb.TestAllTypes{}},
pkg: "google.expr.proto3.test",
out: types.True,
},
}
)
func BenchmarkInterpreter(b *testing.B) {
for _, tst := range testData {
prg, vars, err := program(&tst, Optimize())
if err != nil {
b.Fatal(err)
}
// Benchmark the eval.
b.Run(tst.name, func(bb *testing.B) {
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < bb.N; i++ {
prg.Eval(vars)
}
})
}
}
func BenchmarkInterpreter_Parallel(b *testing.B) {
for _, tst := range testData {
prg, vars, err := program(&tst, Optimize())
if err != nil {
b.Fatal(err)
}
b.ResetTimer()
b.Run(tst.name,
func(bb *testing.B) {
bb.RunParallel(func(pb *testing.PB) {
for pb.Next() {
prg.Eval(vars)
}
})
})
}
}
func TestInterpreter(t *testing.T) {
for _, tst := range testData {
tc := tst
prg, vars, err := program(&tc)
if err != nil {
t.Fatalf("%s: %v", tc.name, err)
}
t.Run(tc.name, func(tt *testing.T) {
tt.Parallel()
var want ref.Val = types.True
if tc.out != nil {
want = tc.out.(ref.Val)
}
got := prg.Eval(vars)
_, expectUnk := want.(types.Unknown)
if expectUnk {
if !reflect.DeepEqual(got, want) {
tt.Fatalf("Got %v, wanted %v", got, want)
}
} else if tc.err != "" {
if !types.IsError(got) || got.(*types.Err).String() != tc.err {
tt.Fatalf("Got %v (%T), wanted error: %s", got, got, tc.err)
}
} else if got.Equal(want) != types.True {
tt.Fatalf("Got %v, wanted %v", got, want)
}
state := NewEvalState()
opts := map[string]InterpretableDecorator{
"optimize": Optimize(),
"exhaustive": ExhaustiveEval(state),
"track": TrackState(state),
}
for mode, opt := range opts {
prg, vars, err = program(&tc, opt)
if err != nil {
tt.Fatal(err)
}
tt.Run(mode, func(ttt *testing.T) {
got := prg.Eval(vars)
_, expectUnk := want.(types.Unknown)
if expectUnk {
if !reflect.DeepEqual(got, want) {
ttt.Errorf("Got %v, wanted %v", got, want)
}
} else if tc.err != "" {
if !types.IsError(got) || got.(*types.Err).String() != tc.err {
ttt.Errorf("Got %v (%T), wanted error: %s", got, got, tc.err)
}
} else if got.Equal(want) != types.True {
ttt.Errorf("Got %v, wanted %v", got, want)
}
state.Reset()
})
}
})
}
}
func TestInterpreter_LogicalAndMissingType(t *testing.T) {
src := common.NewTextSource(`a && TestProto{c: true}.c`)
parsed, errors := parser.Parse(src)
if len(errors.GetErrors()) != 0 {
t.Errorf(errors.ToDisplayString())
}
reg := types.NewRegistry()
pkg := packages.DefaultPackage
attrs := NewAttributeFactory(pkg, reg, reg)
intr := NewStandardInterpreter(pkg, reg, reg, attrs)
i, err := intr.NewUncheckedInterpretable(parsed.GetExpr())
if err == nil {
t.Errorf("Got '%v', wanted error", i)
}
}
func TestInterpreter_ExhaustiveConditionalExpr(t *testing.T) {
src := common.NewTextSource(`a ? b < 1.0 : c == ['hello']`)
parsed, errors := parser.Parse(src)
if len(errors.GetErrors()) != 0 {
t.Errorf(errors.ToDisplayString())
}
state := NewEvalState()
pkg := packages.DefaultPackage
reg := types.NewRegistry(&exprpb.ParsedExpr{})
attrs := NewAttributeFactory(pkg, reg, reg)
intr := NewStandardInterpreter(pkg, reg, reg, attrs)
interpretable, _ := intr.NewUncheckedInterpretable(
parsed.GetExpr(),
ExhaustiveEval(state))
vars, _ := NewActivation(map[string]interface{}{
"a": types.True,
"b": types.Double(0.999),
"c": types.NewStringList(reg, []string{"hello"})})
result := interpretable.Eval(vars)
// Operator "_==_" is at Expr 7, should be evaluated in exhaustive mode
// even though "a" is true
ev, _ := state.Value(7)
// "==" should be evaluated in exhaustive mode though unnecessary
if ev != types.True {
t.Errorf("Else expression expected to be true, got: %v", ev)
}
if result != types.True {
t.Errorf("Expected true, got: %v", result)
}
}
func TestInterpreter_ExhaustiveLogicalOrEquals(t *testing.T) {
// a || b == "b"
// Operator "==" is at Expr 4, should be evaluated though "a" is true
src := common.NewTextSource(`a || b == "b"`)
parsed, errors := parser.Parse(src)
if len(errors.GetErrors()) != 0 {
t.Errorf(errors.ToDisplayString())
}
state := NewEvalState()
reg := types.NewRegistry(&exprpb.Expr{})
pkg := packages.NewPackage("test")
attrs := NewAttributeFactory(pkg, reg, reg)
interp := NewStandardInterpreter(pkg, reg, reg, attrs)
i, _ := interp.NewUncheckedInterpretable(
parsed.GetExpr(),
ExhaustiveEval(state))
vars, _ := NewActivation(map[string]interface{}{
"a": true,
"b": "b",
})
result := i.Eval(vars)
rhv, _ := state.Value(3)
// "==" should be evaluated in exhaustive mode though unnecessary
if rhv != types.True {
t.Errorf("Right hand side expression expected to be true, got: %v", rhv)
}
if result != types.True {
t.Errorf("Expected true, got: %v", result)
}
}
func TestInterpreter_SetProto2PrimitiveFields(t *testing.T) {
// Test the use of proto2 primitives within object construction.
src := common.NewTextSource(
`input == TestAllTypes{
single_int32: 1,
single_int64: 2,
single_uint32: 3u,
single_uint64: 4u,
single_float: -3.3,
single_double: -2.2,
single_string: "hello world",
single_bool: true
}`)
parsed, errors := parser.Parse(src)
if len(errors.GetErrors()) != 0 {
t.Errorf(errors.ToDisplayString())
}
pkg := packages.NewPackage("google.expr.proto2.test")
reg := types.NewRegistry(&proto2pb.TestAllTypes{})
env := checker.NewStandardEnv(pkg, reg)
env.Add(decls.NewIdent("input", decls.NewObjectType("google.expr.proto2.test.TestAllTypes"), nil))
checked, errors := checker.Check(parsed, src, env)
if len(errors.GetErrors()) != 0 {
t.Errorf(errors.ToDisplayString())
}
attrs := NewAttributeFactory(pkg, reg, reg)
i := NewStandardInterpreter(pkg, reg, reg, attrs)
eval, _ := i.NewInterpretable(checked)
one := int32(1)
two := int64(2)
three := uint32(3)
four := uint64(4)
five := float32(-3.3)
six := float64(-2.2)
str := "hello world"
truth := true
input := &proto2pb.TestAllTypes{
SingleInt32: &one,
SingleInt64: &two,
SingleUint32: &three,
SingleUint64: &four,
SingleFloat: &five,
SingleDouble: &six,
SingleString: &str,
SingleBool: &truth,
}
vars, _ := NewActivation(map[string]interface{}{
"input": reg.NativeToValue(input),
})
result := eval.Eval(vars)
got, ok := result.(ref.Val).Value().(bool)
if !ok {
t.Fatalf("Got '%v', wanted 'true'.", result)
}
expected := true
if !reflect.DeepEqual(got, expected) {
t.Errorf("Could not build object properly. Got '%v', wanted '%v'",
result.(ref.Val).Value(),
expected)
}
}
func TestInterpreter_MissingIdentInSelect(t *testing.T) {
src := common.NewTextSource(`a.b.c`)
parsed, errors := parser.Parse(src)
if len(errors.GetErrors()) != 0 {
t.Fatalf(errors.ToDisplayString())
}
pkg := packages.NewPackage("test")
reg := types.NewRegistry()
env := checker.NewStandardEnv(pkg, reg)
env.Add(decls.NewIdent("a.b", decls.Dyn, nil))
checked, errors := checker.Check(parsed, src, env)
if len(errors.GetErrors()) != 0 {
t.Fatalf(errors.ToDisplayString())
}
attrs := NewPartialAttributeFactory(pkg, reg, reg)
interp := NewStandardInterpreter(pkg, reg, reg, attrs)
i, _ := interp.NewInterpretable(checked)
vars, _ := NewPartialActivation(
map[string]interface{}{
"a.b": map[string]interface{}{
"d": "hello",
},
},
NewAttributePattern("a.b").QualString("c"))
result := i.Eval(vars)
if !types.IsUnknown(result) {
t.Errorf("Got %v, wanted unknown", result)
}
result = i.Eval(EmptyActivation())
if !types.IsError(result) {
t.Errorf("Got %v, wanted error", result)
}
}
func program(tst *testCase, opts ...InterpretableDecorator) (Interpretable, Activation, error) {
// Configure the package.
pkg := packages.DefaultPackage
if tst.pkg != "" {
pkg = packages.NewPackage(tst.pkg)
}
reg := types.NewRegistry()
if tst.types != nil {
reg = types.NewRegistry(tst.types...)
}
attrs := NewAttributeFactory(pkg, reg, reg)
if tst.attrs != nil {
attrs = tst.attrs
}
// Configure the environment.
env := checker.NewStandardEnv(pkg, reg)
if tst.env != nil {
env.Add(tst.env...)
}
// Configure the program input.
vars := EmptyActivation()
if tst.in != nil {
vars, _ = NewActivation(tst.in)
}
// Adapt the test output, if needed.
if tst.out != nil {
tst.out = reg.NativeToValue(tst.out)
}
disp := NewDispatcher()
disp.Add(functions.StandardOverloads()...)
if tst.funcs != nil {
disp.Add(tst.funcs...)
}
interp := NewInterpreter(disp, pkg, reg, reg, attrs)
// Parse the expression.
s := common.NewTextSource(tst.expr)
parsed, errs := parser.Parse(s)
if len(errs.GetErrors()) != 0 {
return nil, nil, errors.New(errs.ToDisplayString())
}
if tst.unchecked {
// Build the program plan.
prg, err := interp.NewUncheckedInterpretable(
parsed.GetExpr(), opts...)
if err != nil {
return nil, nil, err
}
return prg, vars, nil
}
// Check the expression.
checked, errs := checker.Check(parsed, s, env)
if len(errs.GetErrors()) != 0 {
return nil, nil, errors.New(errs.ToDisplayString())
}
// Build the program plan.
prg, err := interp.NewInterpretable(checked, opts...)
if err != nil {
return nil, nil, err
}
return prg, vars, nil
}