| // Copyright 2009 The Go Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style |
| // license that can be found in the LICENSE file. |
| |
| // Package reflect is a fork of go's standard library reflection package, which |
| // allows for deep equal with equality functions defined. |
| package reflect |
| |
| import ( |
| "fmt" |
| "reflect" |
| "strings" |
| ) |
| |
| // Equalities is a map from type to a function comparing two values of |
| // that type. |
| type Equalities map[reflect.Type]reflect.Value |
| |
| // EqualitiesOrDie adds the given funcs and panics on any error. |
| func EqualitiesOrDie(funcs ...interface{}) Equalities { |
| e := Equalities{} |
| if err := e.AddFuncs(funcs...); err != nil { |
| panic(err) |
| } |
| return e |
| } |
| |
| // AddFuncs is a shortcut for multiple calls to AddFunc. |
| func (e Equalities) AddFuncs(funcs ...interface{}) error { |
| for _, f := range funcs { |
| if err := e.AddFunc(f); err != nil { |
| return err |
| } |
| } |
| return nil |
| } |
| |
| // AddFunc uses func as an equality function: it must take |
| // two parameters of the same type, and return a boolean. |
| func (e Equalities) AddFunc(eqFunc interface{}) error { |
| fv := reflect.ValueOf(eqFunc) |
| ft := fv.Type() |
| if ft.Kind() != reflect.Func { |
| return fmt.Errorf("expected func, got: %v", ft) |
| } |
| if ft.NumIn() != 2 { |
| return fmt.Errorf("expected two 'in' params, got: %v", ft) |
| } |
| if ft.NumOut() != 1 { |
| return fmt.Errorf("expected one 'out' param, got: %v", ft) |
| } |
| if ft.In(0) != ft.In(1) { |
| return fmt.Errorf("expected arg 1 and 2 to have same type, but got %v", ft) |
| } |
| var forReturnType bool |
| boolType := reflect.TypeOf(forReturnType) |
| if ft.Out(0) != boolType { |
| return fmt.Errorf("expected bool return, got: %v", ft) |
| } |
| e[ft.In(0)] = fv |
| return nil |
| } |
| |
| // Below here is forked from go's reflect/deepequal.go |
| |
| // During deepValueEqual, must keep track of checks that are |
| // in progress. The comparison algorithm assumes that all |
| // checks in progress are true when it reencounters them. |
| // Visited comparisons are stored in a map indexed by visit. |
| type visit struct { |
| a1 uintptr |
| a2 uintptr |
| typ reflect.Type |
| } |
| |
| // unexportedTypePanic is thrown when you use this DeepEqual on something that has an |
| // unexported type. It indicates a programmer error, so should not occur at runtime, |
| // which is why it's not public and thus impossible to catch. |
| type unexportedTypePanic []reflect.Type |
| |
| func (u unexportedTypePanic) Error() string { return u.String() } |
| func (u unexportedTypePanic) String() string { |
| strs := make([]string, len(u)) |
| for i, t := range u { |
| strs[i] = fmt.Sprintf("%v", t) |
| } |
| return "an unexported field was encountered, nested like this: " + strings.Join(strs, " -> ") |
| } |
| |
| func makeUsefulPanic(v reflect.Value) { |
| if x := recover(); x != nil { |
| if u, ok := x.(unexportedTypePanic); ok { |
| u = append(unexportedTypePanic{v.Type()}, u...) |
| x = u |
| } |
| panic(x) |
| } |
| } |
| |
| // deepValueEqual tests for deep equality using reflected types. The map argument tracks |
| // comparisons that have already been seen, which allows short circuiting on |
| // recursive types. |
| func (e Equalities) deepValueEqual(v1, v2 reflect.Value, visited map[visit]bool, depth int) bool { |
| defer makeUsefulPanic(v1) |
| |
| if !v1.IsValid() || !v2.IsValid() { |
| return v1.IsValid() == v2.IsValid() |
| } |
| if v1.Type() != v2.Type() { |
| return false |
| } |
| if fv, ok := e[v1.Type()]; ok { |
| return fv.Call([]reflect.Value{v1, v2})[0].Bool() |
| } |
| if v1.CanAddr() { |
| if fv, ok := e[v1.Addr().Type()]; ok { |
| return fv.Call([]reflect.Value{v1.Addr(), v2.Addr()})[0].Bool() |
| } |
| } |
| |
| hard := func(k reflect.Kind) bool { |
| switch k { |
| case reflect.Array, reflect.Map, reflect.Slice, reflect.Struct: |
| return true |
| } |
| return false |
| } |
| |
| if v1.CanAddr() && v2.CanAddr() && hard(v1.Kind()) { |
| addr1 := v1.UnsafeAddr() |
| addr2 := v2.UnsafeAddr() |
| if addr1 > addr2 { |
| // Canonicalize order to reduce number of entries in visited. |
| addr1, addr2 = addr2, addr1 |
| } |
| |
| // Short circuit if references are identical ... |
| if addr1 == addr2 { |
| return true |
| } |
| |
| // ... or already seen |
| typ := v1.Type() |
| v := visit{addr1, addr2, typ} |
| if visited[v] { |
| return true |
| } |
| |
| // Remember for later. |
| visited[v] = true |
| } |
| |
| switch v1.Kind() { |
| case reflect.Array: |
| // We don't need to check length here because length is part of |
| // an array's type, which has already been filtered for. |
| for i := 0; i < v1.Len(); i++ { |
| if !e.deepValueEqual(v1.Index(i), v2.Index(i), visited, depth+1) { |
| return false |
| } |
| } |
| return true |
| case reflect.Slice: |
| if (v1.IsNil() || v1.Len() == 0) != (v2.IsNil() || v2.Len() == 0) { |
| return false |
| } |
| if v1.IsNil() || v1.Len() == 0 { |
| return true |
| } |
| if v1.Len() != v2.Len() { |
| return false |
| } |
| if v1.Pointer() == v2.Pointer() { |
| return true |
| } |
| for i := 0; i < v1.Len(); i++ { |
| if !e.deepValueEqual(v1.Index(i), v2.Index(i), visited, depth+1) { |
| return false |
| } |
| } |
| return true |
| case reflect.Interface: |
| if v1.IsNil() || v2.IsNil() { |
| return v1.IsNil() == v2.IsNil() |
| } |
| return e.deepValueEqual(v1.Elem(), v2.Elem(), visited, depth+1) |
| case reflect.Ptr: |
| return e.deepValueEqual(v1.Elem(), v2.Elem(), visited, depth+1) |
| case reflect.Struct: |
| for i, n := 0, v1.NumField(); i < n; i++ { |
| if !e.deepValueEqual(v1.Field(i), v2.Field(i), visited, depth+1) { |
| return false |
| } |
| } |
| return true |
| case reflect.Map: |
| if (v1.IsNil() || v1.Len() == 0) != (v2.IsNil() || v2.Len() == 0) { |
| return false |
| } |
| if v1.IsNil() || v1.Len() == 0 { |
| return true |
| } |
| if v1.Len() != v2.Len() { |
| return false |
| } |
| if v1.Pointer() == v2.Pointer() { |
| return true |
| } |
| for _, k := range v1.MapKeys() { |
| if !e.deepValueEqual(v1.MapIndex(k), v2.MapIndex(k), visited, depth+1) { |
| return false |
| } |
| } |
| return true |
| case reflect.Func: |
| if v1.IsNil() && v2.IsNil() { |
| return true |
| } |
| // Can't do better than this: |
| return false |
| default: |
| // Normal equality suffices |
| if !v1.CanInterface() || !v2.CanInterface() { |
| panic(unexportedTypePanic{}) |
| } |
| return v1.Interface() == v2.Interface() |
| } |
| } |
| |
| // DeepEqual is like reflect.DeepEqual, but focused on semantic equality |
| // instead of memory equality. |
| // |
| // It will use e's equality functions if it finds types that match. |
| // |
| // An empty slice *is* equal to a nil slice for our purposes; same for maps. |
| // |
| // Unexported field members cannot be compared and will cause an informative panic; you must add an Equality |
| // function for these types. |
| func (e Equalities) DeepEqual(a1, a2 interface{}) bool { |
| if a1 == nil || a2 == nil { |
| return a1 == a2 |
| } |
| v1 := reflect.ValueOf(a1) |
| v2 := reflect.ValueOf(a2) |
| if v1.Type() != v2.Type() { |
| return false |
| } |
| return e.deepValueEqual(v1, v2, make(map[visit]bool), 0) |
| } |
| |
| func (e Equalities) deepValueDerive(v1, v2 reflect.Value, visited map[visit]bool, depth int) bool { |
| defer makeUsefulPanic(v1) |
| |
| if !v1.IsValid() || !v2.IsValid() { |
| return v1.IsValid() == v2.IsValid() |
| } |
| if v1.Type() != v2.Type() { |
| return false |
| } |
| if fv, ok := e[v1.Type()]; ok { |
| return fv.Call([]reflect.Value{v1, v2})[0].Bool() |
| } |
| if v1.CanAddr() { |
| if fv, ok := e[v1.Addr().Type()]; ok { |
| return fv.Call([]reflect.Value{v1.Addr(), v2.Addr()})[0].Bool() |
| } |
| } |
| |
| hard := func(k reflect.Kind) bool { |
| switch k { |
| case reflect.Array, reflect.Map, reflect.Slice, reflect.Struct: |
| return true |
| } |
| return false |
| } |
| |
| if v1.CanAddr() && v2.CanAddr() && hard(v1.Kind()) { |
| addr1 := v1.UnsafeAddr() |
| addr2 := v2.UnsafeAddr() |
| if addr1 > addr2 { |
| // Canonicalize order to reduce number of entries in visited. |
| addr1, addr2 = addr2, addr1 |
| } |
| |
| // Short circuit if references are identical ... |
| if addr1 == addr2 { |
| return true |
| } |
| |
| // ... or already seen |
| typ := v1.Type() |
| v := visit{addr1, addr2, typ} |
| if visited[v] { |
| return true |
| } |
| |
| // Remember for later. |
| visited[v] = true |
| } |
| |
| switch v1.Kind() { |
| case reflect.Array: |
| // We don't need to check length here because length is part of |
| // an array's type, which has already been filtered for. |
| for i := 0; i < v1.Len(); i++ { |
| if !e.deepValueDerive(v1.Index(i), v2.Index(i), visited, depth+1) { |
| return false |
| } |
| } |
| return true |
| case reflect.Slice: |
| if v1.IsNil() || v1.Len() == 0 { |
| return true |
| } |
| if v1.Len() > v2.Len() { |
| return false |
| } |
| if v1.Pointer() == v2.Pointer() { |
| return true |
| } |
| for i := 0; i < v1.Len(); i++ { |
| if !e.deepValueDerive(v1.Index(i), v2.Index(i), visited, depth+1) { |
| return false |
| } |
| } |
| return true |
| case reflect.String: |
| if v1.Len() == 0 { |
| return true |
| } |
| if v1.Len() > v2.Len() { |
| return false |
| } |
| return v1.String() == v2.String() |
| case reflect.Interface: |
| if v1.IsNil() { |
| return true |
| } |
| return e.deepValueDerive(v1.Elem(), v2.Elem(), visited, depth+1) |
| case reflect.Ptr: |
| if v1.IsNil() { |
| return true |
| } |
| return e.deepValueDerive(v1.Elem(), v2.Elem(), visited, depth+1) |
| case reflect.Struct: |
| for i, n := 0, v1.NumField(); i < n; i++ { |
| if !e.deepValueDerive(v1.Field(i), v2.Field(i), visited, depth+1) { |
| return false |
| } |
| } |
| return true |
| case reflect.Map: |
| if v1.IsNil() || v1.Len() == 0 { |
| return true |
| } |
| if v1.Len() > v2.Len() { |
| return false |
| } |
| if v1.Pointer() == v2.Pointer() { |
| return true |
| } |
| for _, k := range v1.MapKeys() { |
| if !e.deepValueDerive(v1.MapIndex(k), v2.MapIndex(k), visited, depth+1) { |
| return false |
| } |
| } |
| return true |
| case reflect.Func: |
| if v1.IsNil() && v2.IsNil() { |
| return true |
| } |
| // Can't do better than this: |
| return false |
| default: |
| // Normal equality suffices |
| if !v1.CanInterface() || !v2.CanInterface() { |
| panic(unexportedTypePanic{}) |
| } |
| return v1.Interface() == v2.Interface() |
| } |
| } |
| |
| // DeepDerivative is similar to DeepEqual except that unset fields in a1 are |
| // ignored (not compared). This allows us to focus on the fields that matter to |
| // the semantic comparison. |
| // |
| // The unset fields include a nil pointer and an empty string. |
| func (e Equalities) DeepDerivative(a1, a2 interface{}) bool { |
| if a1 == nil { |
| return true |
| } |
| v1 := reflect.ValueOf(a1) |
| v2 := reflect.ValueOf(a2) |
| if v1.Type() != v2.Type() { |
| return false |
| } |
| return e.deepValueDerive(v1, v2, make(map[visit]bool), 0) |
| } |