| // Copyright 2020 The LUCI Authors. |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| |
| // Package mask provides utility functions for google protobuf field mask |
| // |
| // Supports advanced field mask semantics: |
| // - Refer to fields and map keys using . literals: |
| // - Supported map key types: string, integer, bool. (double, float, enum, |
| // and bytes keys are not supported by protobuf or this implementation) |
| // - Fields: "publisher.name" means field "name" of field "publisher" |
| // - String map keys: "metadata.year" means string key 'year' of map field |
| // metadata |
| // - Integer map keys (e.g. int32): 'year_ratings.0' means integer key 0 of |
| // a map field year_ratings |
| // - Bool map keys: 'access_text.true' means boolean key true of a map field |
| // access_text |
| // - String map keys that cannot be represented as an unquoted string literal, |
| // must be quoted using backticks: metadata.`year.published`, metadata.`17`, |
| // metadata.``. Backtick can be escaped with ``: a.`b``c` means map key "b`c" |
| // of map field a. |
| // - Refer to all map keys using a * literal: "topics.*.archived" means field |
| // "archived" of all map values of map field "topic". |
| // - Refer to all elements of a repeated field using a * literal: authors.*.name |
| // - Refer to all fields of a message using * literal: publisher.*. |
| // - Prohibit addressing a single element in repeated fields: authors.0.name |
| // |
| // FieldMask.paths string grammar: |
| // path = segment {'.' segment} |
| // segment = literal | star | quoted_string; |
| // literal = string | integer | boolean |
| // string = (letter | '_') {letter | '_' | digit} |
| // integer = ['-'] digit {digit}; |
| // boolean = 'true' | 'false'; |
| // quoted_string = '`' { utf8-no-backtick | '``' } '`' |
| // star = '*' |
| package mask |
| |
| import ( |
| "fmt" |
| "sort" |
| "strings" |
| |
| "github.com/golang/protobuf/proto" |
| "google.golang.org/genproto/protobuf/field_mask" |
| protoV2 "google.golang.org/protobuf/proto" |
| "google.golang.org/protobuf/reflect/protoreflect" |
| |
| "go.chromium.org/luci/common/data/stringset" |
| ) |
| |
| // Mask is a tree representation of a field Mask. Serves as a tree node too. |
| // Each node represents a segment of a path, e.g. "bar" in "foo.bar.qux". |
| // A Field Mask with paths ["a","b.c"] is parsed as |
| // <root> |
| // / \ |
| // "a" "b" |
| // / |
| // "c" |
| // |
| // Zero value is not valid. Use IsEmpty() to check if the mask is zero. |
| type Mask struct { |
| // descriptor is the proto descriptor of the message of the field this node |
| // represents. If the field kind is not a message, then descriptor is nil and |
| // the node must be a leaf unless isRepeated is true which denotes a repeated |
| // scalar field. |
| descriptor protoreflect.MessageDescriptor |
| // isRepeated indicates whether the segment represents a repeated field or |
| // not. Children of this node are the field elements. |
| isRepeated bool |
| // children maps segments to its node. e.g. children of the root in the |
| // example above has keys "a" and "b", and values are Mask objects and the |
| // Mask object "b" maps to will have a single child "c". All types of segment |
| // (i.e. int, bool, string, star) will be converted to string. |
| children map[string]*Mask |
| } |
| |
| // FromFieldMask parses a field mask to a mask. |
| // |
| // Trailing stars will be removed, e.g. parses ['a.*'] as ['a']. |
| // Redundant paths will be removed, e.g. parses ['a', 'a.b'] as ['a']. |
| // |
| // If isFieldNameJSON is set to true, json name will be used instead of |
| // canonical name defined in proto during parsing (e.g. "fooBar" instead of |
| // "foo_bar"). However, the child field name in return mask will always be |
| // in canonical form. |
| // |
| // If isUpdateMask is set to true, a repeated field is allowed only as the last |
| // field in a path string. |
| func FromFieldMask(fieldMask *field_mask.FieldMask, targetMsg proto.Message, isFieldNameJSON bool, isUpdateMask bool) (*Mask, error) { |
| descriptor := proto.MessageReflect(targetMsg).Descriptor() |
| parsedPaths := make([]path, len(fieldMask.GetPaths())) |
| for i, p := range fieldMask.GetPaths() { |
| parsedPath, err := parsePath(p, descriptor, isFieldNameJSON) |
| if err != nil { |
| return nil, err |
| } |
| parsedPaths[i] = parsedPath |
| } |
| return fromParsedPaths(parsedPaths, descriptor, isUpdateMask) |
| } |
| |
| // MustFromReadMask is a shortcut FromFieldMask with isFieldNameJSON and |
| // isUpdateMask as false, that accepts field mask a variadic paths and |
| // that panics if the mask is invalid. |
| // It is useful when the mask is hardcoded. |
| func MustFromReadMask(targetMsg proto.Message, paths ...string) *Mask { |
| ret, err := FromFieldMask(&field_mask.FieldMask{Paths: paths}, targetMsg, false, false) |
| if err != nil { |
| panic(err) |
| } |
| return ret |
| } |
| |
| // All returns a field mask that selects all fields. |
| func All(targetMsg proto.Message) *Mask { |
| return MustFromReadMask(targetMsg, "*") |
| } |
| |
| // fromParsedPaths constructs a mask tree from a slice of parsed paths. |
| func fromParsedPaths(parsedPaths []path, desc protoreflect.MessageDescriptor, isUpdateMask bool) (*Mask, error) { |
| root := &Mask{ |
| descriptor: desc, |
| children: make(map[string]*Mask), |
| } |
| for _, p := range normalizePaths(parsedPaths) { |
| curNode := root |
| curNodeName := "" |
| for _, seg := range p { |
| if curNode.isRepeated && isUpdateMask { |
| return nil, fmt.Errorf("update mask allows a repeated field only at the last position; field: %s is not last", curNodeName) |
| } |
| if _, ok := curNode.children[seg]; !ok { |
| child := &Mask{ |
| children: make(map[string]*Mask), |
| } |
| switch curDesc := curNode.descriptor; { |
| case curDesc.IsMapEntry(): |
| child.descriptor = curDesc.Fields().ByName(protoreflect.Name("value")).Message() |
| case curNode.isRepeated: |
| child.descriptor = curDesc |
| default: |
| field := curDesc.Fields().ByName(protoreflect.Name(seg)) |
| child.descriptor = field.Message() |
| child.isRepeated = field.Cardinality() == protoreflect.Repeated |
| } |
| curNode.children[seg] = child |
| } |
| curNode = curNode.children[seg] |
| curNodeName = seg |
| } |
| } |
| return root, nil |
| } |
| |
| // normalizePaths normalizes parsed paths. Returns a new slice of paths. |
| // |
| // Removes trailing stars for all paths, e.g. converts ["a", "*"] to ["a"]. |
| // Removes paths that have a segment prefix already present in paths, |
| // e.g. removes ["a", "b"] from [["a", "b"], ["a",]]. |
| // |
| // The result slice is stable and ordered by the number of segments of each |
| // path. If two paths have same number of segments, break the tie by comparing |
| // the segments at each index lexicographically. |
| func normalizePaths(paths []path) []path { |
| paths = removeTrailingStars(paths) |
| sort.SliceStable(paths, func(i, j int) bool { |
| lenI, lenJ := len(paths[i]), len(paths[j]) |
| if lenI == lenJ { |
| for index, segI := range paths[i] { |
| if segI == paths[j][index] { |
| continue |
| } |
| return segI < paths[j][index] |
| } |
| return true |
| } |
| return lenI < lenJ |
| }) |
| |
| present := stringset.New(len(paths)) |
| delimiter := string(pathDelimiter) |
| ret := make([]path, 0, len(paths)) |
| PATH_LOOP: |
| for _, p := range paths { |
| for i := range p { |
| if present.Has(strings.Join(p[:i+1], delimiter)) { |
| continue PATH_LOOP |
| } |
| } |
| ret = append(ret, p) |
| present.Add(strings.Join(p, delimiter)) |
| } |
| return ret |
| } |
| |
| func removeTrailingStars(paths []path) []path { |
| ret := make([]path, 0, len(paths)) |
| for _, p := range paths { |
| if n := len(p); n > 0 && p[n-1] == "*" { |
| p = p[:n-1] |
| } |
| ret = append(ret, p) |
| } |
| return ret |
| } |
| |
| // Trim clears protobuf message fields that are not in the mask. |
| // |
| // If mask is empty, this is a noop. It returns error when the supplied |
| // message is nil or has a different message descriptor from that of mask. |
| // It uses Includes to decide what to trim, see its doc. |
| func (m *Mask) Trim(msg proto.Message) error { |
| if m.IsEmpty() { |
| return nil |
| } |
| reflectMsg := proto.MessageReflect(msg) |
| if err := checkMsgHaveDesc(reflectMsg, m.descriptor); err != nil { |
| return err |
| } |
| m.trimImpl(reflectMsg) |
| return nil |
| } |
| |
| func (m *Mask) trimImpl(reflectMsg protoreflect.Message) { |
| reflectMsg.Range(func(fieldDesc protoreflect.FieldDescriptor, fieldVal protoreflect.Value) bool { |
| fieldName := string(fieldDesc.Name()) |
| switch incl, _ := m.includesImpl(path{fieldName}); incl { |
| case Exclude: |
| reflectMsg.Clear(fieldDesc) |
| case IncludePartially: |
| // child for this field must exist because the path is included partially |
| switch child := m.children[fieldName]; { |
| case fieldDesc.IsMap(): |
| child.trimMap(fieldVal.Map(), fieldDesc.MapValue().Kind()) |
| case fieldDesc.Kind() != protoreflect.MessageKind: |
| // The field is scalar but the mask does not specify to include |
| // it entirely. Skip it because scalars do not have subfields. |
| // Note that FromFieldMask would fail on such a mask because a |
| // scalar field cannot be followed by other fields. |
| reflectMsg.Clear(fieldDesc) |
| case fieldDesc.IsList(): |
| // star child is the only possible child for list field |
| if starChild, ok := child.children["*"]; ok { |
| for i, list := 0, fieldVal.List(); i < list.Len(); i++ { |
| starChild.trimImpl(list.Get(i).Message()) |
| } |
| } |
| default: |
| child.trimImpl(fieldVal.Message()) |
| } |
| } |
| return true |
| }) |
| } |
| |
| func (m *Mask) trimMap(protoMap protoreflect.Map, valueKind protoreflect.Kind) { |
| protoMap.Range(func(k protoreflect.MapKey, v protoreflect.Value) bool { |
| keyString := k.String() |
| switch incl, _ := m.includesImpl(path{keyString}); { |
| case incl == Exclude: |
| protoMap.Clear(k) |
| case incl == IncludePartially && valueKind != protoreflect.MessageKind: |
| // same reason as comment above that value is scalar |
| protoMap.Clear(k) |
| case incl == IncludePartially: |
| // Mask might not have a child of keyName but it can still partially |
| // include the key because of star child. So, check both key child |
| // and star child. |
| for _, seg := range []string{keyString, "*"} { |
| if child, ok := m.children[seg]; ok { |
| child.trimImpl(v.Message()) |
| } |
| } |
| } |
| return true |
| }) |
| } |
| |
| // Inclusiveness tells if a field value at the given path is included. |
| type Inclusiveness int8 |
| |
| const ( |
| // Exclude indicates the field value is excluded. |
| Exclude Inclusiveness = iota |
| // IncludePartially indicates some subfields of the field value are included. |
| IncludePartially |
| // IncludeEntirely indicates the entire field value is included. |
| IncludeEntirely |
| ) |
| |
| // Includes tells the Inclusiveness of a field value at the given path. |
| // |
| // The path must have canonical field names, i.e. not JSON names. |
| // Returns error if path parsing fails. |
| func (m *Mask) Includes(path string) (Inclusiveness, error) { |
| parsedPath, err := parsePath(path, m.descriptor, false) |
| if err != nil { |
| return Exclude, err |
| } |
| incl, _ := m.includesImpl(parsedPath) |
| return incl, nil |
| } |
| |
| // MustIncludes tells the Inclusiveness of a field value at the given path. |
| // |
| // This is essentially the same as Includes, but panics if the given path is invalid. |
| func (m *Mask) MustIncludes(path string) Inclusiveness { |
| incl, err := m.Includes(path) |
| if err != nil { |
| panic(fmt.Sprintf("MustIncludes(%q): %s", path, err)) |
| } |
| return incl |
| } |
| |
| // includesImpl implements Includes(). It returns the computed inclusiveness |
| // and the leaf mask that includes the path if IncludeEntirely, or the |
| // intermediate mask that the last segment of path represents if |
| // IncludePartially or an empty mask if Exclude. |
| func (m *Mask) includesImpl(p path) (Inclusiveness, *Mask) { |
| if len(m.children) == 0 { |
| return IncludeEntirely, m |
| } |
| if len(p) == 0 { |
| // This node is intermediate and we've exhausted the path. Some of the |
| // value's subfields are included, so includes this value partially. |
| return IncludePartially, m |
| } |
| |
| var incl Inclusiveness |
| var inclMask *Mask |
| // star child should also be examined. |
| // e.g. children are {"a": {"b": {}}, "*": {"c": {}}} |
| // If seg is 'x', we should check the star child. |
| for _, seg := range []string{p[0], "*"} { |
| if child, ok := m.children[seg]; ok { |
| if cIncl, cInclMask := child.includesImpl(p[1:]); cIncl > incl { |
| incl, inclMask = cIncl, cInclMask |
| } |
| } |
| } |
| return incl, inclMask |
| } |
| |
| // Merge merges masked fields from src to dest. |
| // |
| // If mask is empty, this is a noop. It returns error when one of src or dest |
| // message is nil or has different message descriptor from that of mask. |
| // Empty field will be merged as long as they are present in the mask. Repeated |
| // fields or map fields will be overwritten entirely. Partial updates are not |
| // supported for such field. |
| func (m *Mask) Merge(src, dest proto.Message) error { |
| if m.IsEmpty() { |
| return nil |
| } |
| srcReflectMsg := proto.MessageReflect(src) |
| if err := checkMsgHaveDesc(srcReflectMsg, m.descriptor); err != nil { |
| return fmt.Errorf("src message: %s", err.Error()) |
| } |
| destReflectMsg := proto.MessageReflect(dest) |
| if err := checkMsgHaveDesc(destReflectMsg, m.descriptor); err != nil { |
| return fmt.Errorf("dest message: %s", err.Error()) |
| } |
| m.mergeImpl(srcReflectMsg, destReflectMsg) |
| return nil |
| } |
| |
| func (m *Mask) mergeImpl(src, dest protoreflect.Message) { |
| for seg, submask := range m.children { |
| // star field is not supported for update mask so this won't be nil |
| fieldDesc := m.descriptor.Fields().ByName(protoreflect.Name(seg)) |
| switch srcVal, kind := src.Get(fieldDesc), fieldDesc.Kind(); { |
| case fieldDesc.IsList(): |
| newField := dest.NewField(fieldDesc) |
| srcList, destList := srcVal.List(), newField.List() |
| for i := 0; i < srcList.Len(); i++ { |
| destList.Append(cloneValue(srcList.Get(i), kind)) |
| } |
| dest.Set(fieldDesc, newField) |
| |
| case fieldDesc.IsMap(): |
| newField := dest.NewField(fieldDesc) |
| destMap := newField.Map() |
| mapValKind := fieldDesc.MapValue().Kind() |
| srcVal.Map().Range(func(k protoreflect.MapKey, v protoreflect.Value) bool { |
| destMap.Set(k, cloneValue(v, mapValKind)) |
| return true |
| }) |
| dest.Set(fieldDesc, newField) |
| |
| case fieldDesc.Kind() == protoreflect.MessageKind: |
| switch srcNil := !srcVal.Message().IsValid(); { |
| case srcNil && !dest.Get(fieldDesc).Message().IsValid(): |
| // dest is also nil message. No need to proceed merging. |
| case srcNil && len(submask.children) == 0: |
| dest.Clear(fieldDesc) |
| case len(submask.children) == 0: |
| dest.Set(fieldDesc, cloneValue(srcVal, fieldDesc.Kind())) |
| default: |
| // only singular message field can be merged partially |
| submask.mergeImpl(srcVal.Message(), dest.Mutable(fieldDesc).Message()) |
| } |
| |
| default: |
| dest.Set(fieldDesc, srcVal) // scalar value |
| } |
| } |
| } |
| |
| // cloneValue returns a cloned value for message kind or the same instance of |
| // input value for all the other kinds (i.e. scalar). List and map value are not |
| // expected as they have been explicitly handled in mergeImpl. |
| func cloneValue(v protoreflect.Value, kind protoreflect.Kind) protoreflect.Value { |
| if kind == protoreflect.MessageKind { |
| clonedMsg := protoV2.Clone(v.Message().Interface()).ProtoReflect() |
| return protoreflect.ValueOf(clonedMsg) |
| } |
| return v |
| } |
| |
| // Submask returns a sub-mask given a path from the received mask to it. |
| // |
| // For example, for a mask ["a.b.c"], m.submask("a.b") will return a mask ["c"]. |
| // |
| // If the received mask includes the path entirely, returns a Mask that includes |
| // everything. For example, for mask ["a"], m.submask("a.b") returns a mask |
| // without children. |
| // |
| // Returns error if path parsing fails or path is excluded from the received |
| // mask. |
| func (m *Mask) Submask(path string) (*Mask, error) { |
| ctx := &parseCtx{ |
| curDescriptor: m.descriptor, |
| isList: m.isRepeated && !(m.descriptor != nil && m.descriptor.IsMapEntry()), |
| } |
| |
| parsedPath, err := parsePathWithContext(path, ctx, false) |
| if err != nil { |
| return nil, err |
| } |
| switch incl, inclMask := m.includesImpl(parsedPath); incl { |
| case IncludeEntirely: |
| return &Mask{ |
| descriptor: ctx.curDescriptor, |
| isRepeated: ctx.isList || (ctx.curDescriptor != nil && ctx.curDescriptor.IsMapEntry()), |
| }, nil |
| case Exclude: |
| return nil, fmt.Errorf("the given path %q is excluded from mask", path) |
| case IncludePartially: |
| return inclMask, nil |
| default: |
| return nil, fmt.Errorf("unknown Inclusiveness: %d", incl) |
| } |
| } |
| |
| // MustSubmask returns a sub-mask given a path from the received mask to it. |
| // |
| // This is essentially the same as Submask, but panics if the given path is invalid or |
| // exlcuded from the received mask. |
| func (m *Mask) MustSubmask(path string) *Mask { |
| sm, err := m.Submask(path) |
| if err != nil { |
| panic(fmt.Sprintf("MustSubmask(%q): %s", path, err)) |
| } |
| return sm |
| } |
| |
| // IsEmpty reports whether a mask is of empty value. Such mask implies keeping |
| // everything when calling Trim, merging nothing when calling Merge and always |
| // returning IncludeEntirely when calling Includes |
| func (m *Mask) IsEmpty() bool { |
| return m == nil || (m.descriptor == nil && !m.isRepeated && len(m.children) == 0) |
| } |
| |
| // checkMsgHaveDesc validates that the descriptor of given proto reflect message |
| // matches the expected message descriptor. It returns error when the given |
| // message is nil or descriptor of which doesn't match the expectation. |
| func checkMsgHaveDesc(msg protoreflect.Message, expectedDesc protoreflect.MessageDescriptor) error { |
| if msg == nil { |
| return fmt.Errorf("nil message") |
| } |
| if msgDesc := msg.Descriptor(); msgDesc != expectedDesc { |
| return fmt.Errorf("expected message have descriptor: %s; got descriptor: %s", expectedDesc.FullName(), msgDesc.FullName()) |
| } |
| return nil |
| } |