blob: 18db3898ddd93472f5c9894b683dbaf8aba13fc2 [file] [log] [blame]
// Copyright 2023 The LUCI Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package proto
import (
"encoding/binary"
"errors"
"fmt"
"hash"
"math"
"sort"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/types/known/anypb"
)
const anyName = "google.protobuf.Any"
// StableHash does a deterministic and ordered walk over the proto message `m`,
// feeding each field into `h`.
//
// This is useful to produce stable, deterministic hashes of proto messages
// where comparison of messages generated from different sources or runtimes is
// important.
//
// Because of this, `m` may not contain any unknown fields, since there is no
// way to canonicalize them without the message definition. If the message has
// any unknown fields, this function returns an error.
// NOTE:
// - The hash value can ONLY be used for backward compatible protobuf change to
// determine if the protobuf message is different. If two protobuf specs are
// incompatible, their value MUST NOT be compared with each other and not
// promised to be different even their messages are clearly different.
// - google.protobuf.Any is supported via Any.UnmarshalNew: if the message is
// not registered, this returns an error.
// - 0-valued scalar fields are not distinguished from absent fields.
func StableHash(h hash.Hash, m proto.Message) error {
return hashMessage(h, m.ProtoReflect())
}
func hashValue(h hash.Hash, v protoreflect.Value) error {
switch v := v.Interface().(type) {
case int32:
return hashNumber(h, uint64(v))
case int64:
return hashNumber(h, uint64(v))
case uint32:
return hashNumber(h, uint64(v))
case uint64:
return hashNumber(h, v)
case float32:
return hashNumber(h, uint64(math.Float32bits(v)))
case float64:
return hashNumber(h, math.Float64bits(v))
case string:
return hashBytes(h, []byte(v))
case []byte:
return hashBytes(h, v)
case protoreflect.EnumNumber:
return hashNumber(h, uint64(v))
case protoreflect.Message:
return hashMessage(h, v)
case protoreflect.List:
return hashList(h, v)
case protoreflect.Map:
return hashMap(h, v)
case bool:
var b uint64
if v {
b = 1
}
return hashNumber(h, b)
default:
return fmt.Errorf("unknown type: %T", v)
}
}
func hashMessage(h hash.Hash, m protoreflect.Message) error {
if m.Descriptor().FullName() == anyName {
a, err := m.Interface().(*anypb.Any).UnmarshalNew()
if err != nil {
return err
}
return hashMessage(h, a.ProtoReflect())
}
if m.GetUnknown() != nil {
return fmt.Errorf("unknown fields cannot be hashed")
}
// Collect a sorted list of populated message fields.
var fds []protoreflect.FieldDescriptor
m.Range(func(fd protoreflect.FieldDescriptor, _ protoreflect.Value) bool {
fds = append(fds, fd)
return true
})
sort.Slice(fds, func(i, j int) bool { return fds[i].Number() < fds[j].Number() })
// Iterate over message fields.
for _, fd := range fds {
if err := hashNumber(h, uint64(fd.Number())); err != nil {
return err
}
if err := hashValue(h, m.Get(fd)); err != nil {
return err
}
}
return nil
}
func hashList(h hash.Hash, lv protoreflect.List) error {
if err := hashNumber(h, uint64(lv.Len())); err != nil {
return err
}
for i := 0; i < lv.Len(); i++ {
if err := hashValue(h, lv.Get(i)); err != nil {
return err
}
}
return nil
}
func hashMap(h hash.Hash, mv protoreflect.Map) error {
if err := hashNumber(h, uint64(mv.Len())); err != nil {
return err
}
if mv.Len() == 0 {
return nil
}
// Collect a sorted list of populated map entries.
var ks []protoreflect.MapKey
mv.Range(func(k protoreflect.MapKey, _ protoreflect.Value) bool {
ks = append(ks, k)
return true
})
var sortFn func(i, j int) bool
switch ks[0].Interface().(type) {
case bool:
sortFn = func(i, j int) bool { return !ks[i].Bool() && ks[j].Bool() }
case int32, int64:
sortFn = func(i, j int) bool { return ks[i].Int() < ks[j].Int() }
case uint32, uint64:
sortFn = func(i, j int) bool { return ks[i].Uint() < ks[j].Uint() }
case string:
sortFn = func(i, j int) bool { return ks[i].String() < ks[j].String() }
default:
return errors.New("invalid map key type")
}
sort.Slice(ks, sortFn)
// Iterate over map entries.
for _, k := range ks {
if err := hashValue(h, k.Value()); err != nil {
return err
}
if err := hashValue(h, mv.Get(k)); err != nil {
return err
}
}
return nil
}
func hashNumber(h hash.Hash, v uint64) error {
var b [8]byte
binary.LittleEndian.PutUint64(b[:], v)
if _, err := h.Write(b[:]); err != nil {
return err
}
return nil
}
func hashBytes(h hash.Hash, v []byte) error {
if err := hashNumber(h, uint64(len(v))); err != nil {
return err
}
if _, err := h.Write([]byte(v)); err != nil {
return err
}
return nil
}