| // Copyright 2010 The Go Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style |
| // license that can be found in the LICENSE file. |
| |
| package proto |
| |
| /* |
| * Support for message sets. |
| */ |
| |
| import ( |
| "errors" |
| |
| "github.com/golang/protobuf/v2/reflect/protoreflect" |
| ) |
| |
| // errNoMessageTypeID occurs when a protocol buffer does not have a message type ID. |
| // A message type ID is required for storing a protocol buffer in a message set. |
| var errNoMessageTypeID = errors.New("proto does not have a message type ID") |
| |
| // The first two types (_MessageSet_Item and messageSet) |
| // model what the protocol compiler produces for the following protocol message: |
| // message MessageSet { |
| // repeated group Item = 1 { |
| // required int32 type_id = 2; |
| // required string message = 3; |
| // }; |
| // } |
| // That is the MessageSet wire format. We can't use a proto to generate these |
| // because that would introduce a circular dependency between it and this package. |
| |
| type _MessageSet_Item struct { |
| TypeId *int32 `protobuf:"varint,2,req,name=type_id"` |
| Message []byte `protobuf:"bytes,3,req,name=message"` |
| } |
| |
| type messageSet struct { |
| Item []*_MessageSet_Item `protobuf:"group,1,rep"` |
| XXX_unrecognized []byte |
| // TODO: caching? |
| } |
| |
| // Make sure messageSet is a Message. |
| var _ Message = (*messageSet)(nil) |
| |
| // messageTypeIder is an interface satisfied by a protocol buffer type |
| // that may be stored in a MessageSet. |
| type messageTypeIder interface { |
| MessageTypeId() int32 |
| } |
| |
| func (ms *messageSet) find(pb Message) *_MessageSet_Item { |
| mti, ok := pb.(messageTypeIder) |
| if !ok { |
| return nil |
| } |
| id := mti.MessageTypeId() |
| for _, item := range ms.Item { |
| if *item.TypeId == id { |
| return item |
| } |
| } |
| return nil |
| } |
| |
| func (ms *messageSet) Has(pb Message) bool { |
| return ms.find(pb) != nil |
| } |
| |
| func (ms *messageSet) Unmarshal(pb Message) error { |
| if item := ms.find(pb); item != nil { |
| return Unmarshal(item.Message, pb) |
| } |
| if _, ok := pb.(messageTypeIder); !ok { |
| return errNoMessageTypeID |
| } |
| return nil // TODO: return error instead? |
| } |
| |
| func (ms *messageSet) Marshal(pb Message) error { |
| msg, err := Marshal(pb) |
| if err != nil { |
| return err |
| } |
| if item := ms.find(pb); item != nil { |
| // reuse existing item |
| item.Message = msg |
| return nil |
| } |
| |
| mti, ok := pb.(messageTypeIder) |
| if !ok { |
| return errNoMessageTypeID |
| } |
| |
| mtid := mti.MessageTypeId() |
| ms.Item = append(ms.Item, &_MessageSet_Item{ |
| TypeId: &mtid, |
| Message: msg, |
| }) |
| return nil |
| } |
| |
| func (ms *messageSet) Reset() { *ms = messageSet{} } |
| func (ms *messageSet) String() string { return CompactTextString(ms) } |
| func (*messageSet) ProtoMessage() {} |
| |
| // Support for the message_set_wire_format message option. |
| |
| func skipVarint(buf []byte) []byte { |
| i := 0 |
| for ; buf[i]&0x80 != 0; i++ { |
| } |
| return buf[i+1:] |
| } |
| |
| // unmarshalMessageSet decodes the extension map encoded in buf in the message set wire format. |
| // It is called by Unmarshal methods on protocol buffer messages with the message_set_wire_format option. |
| func unmarshalMessageSet(buf []byte, exts interface{}) error { |
| m := extensionFieldsOf(exts) |
| |
| ms := new(messageSet) |
| if err := Unmarshal(buf, ms); err != nil { |
| return err |
| } |
| for _, item := range ms.Item { |
| id := protoreflect.FieldNumber(*item.TypeId) |
| msg := item.Message |
| |
| // Restore wire type and field number varint, plus length varint. |
| // Be careful to preserve duplicate items. |
| b := EncodeVarint(uint64(id)<<3 | WireBytes) |
| if m.Has(id) { |
| ext := m.Get(id) |
| |
| // Existing data; rip off the tag and length varint |
| // so we join the new data correctly. |
| // We can assume that ext.Raw is set because we are unmarshaling. |
| o := ext.Raw[len(b):] // skip wire type and field number |
| _, n := DecodeVarint(o) // calculate length of length varint |
| o = o[n:] // skip length varint |
| msg = append(o, msg...) // join old data and new data |
| } |
| b = append(b, EncodeVarint(uint64(len(msg)))...) |
| b = append(b, msg...) |
| |
| m.Set(id, Extension{Raw: b}) |
| } |
| return nil |
| } |