diff --git a/fmutils.go b/fmutils.go index 81e08b3..2f5d101 100644 --- a/fmutils.go +++ b/fmutils.go @@ -183,14 +183,22 @@ func (mask NestedMask) overwrite(srcRft, destRft protoreflect.Message) { } else if srcFD.IsMap() && srcFD.Kind() == protoreflect.MessageKind { srcMap := srcRft.Get(srcFD).Map() destMap := destRft.Get(srcFD).Map() + if !destMap.IsValid() { + destRft.Set(srcFD, protoreflect.ValueOf(srcMap)) + destMap = destRft.Get(srcFD).Map() + } srcMap.Range(func(mk protoreflect.MapKey, mv protoreflect.Value) bool { if mi, ok := submask[mk.String()]; ok { if i, ok := mv.Interface().(protoreflect.Message); ok && len(mi) > 0 { - destMap.Set(mk, mv) - mi.overwrite(i, mv.Message()) + newVal := protoreflect.ValueOf(i.New()) + destMap.Set(mk, newVal) + mi.overwrite(mv.Message(), newVal.Message()) } else { + destMap.Set(mk, mv) } + } else { + destMap.Clear(mk) } return true }) diff --git a/fmutils_test.go b/fmutils_test.go index e655687..3953d5f 100644 --- a/fmutils_test.go +++ b/fmutils_test.go @@ -885,7 +885,7 @@ func TestOverwrite(t *testing.T) { }, { name: "overwrite map with message values", - paths: []string{"attributes.src1.tags", "attributes.src2.tags"}, + paths: []string{"attributes.src1.tags.key1", "attributes.src2"}, src: &testproto.Profile{ User: nil, Attributes: map[string]*testproto.Attribute{ @@ -903,7 +903,7 @@ func TestOverwrite(t *testing.T) { }, Attributes: map[string]*testproto.Attribute{ "dest1": { - Tags: map[string]string{"key4": "value5"}, + Tags: map[string]string{"key4": "value4"}, }, }, }, @@ -913,13 +913,13 @@ func TestOverwrite(t *testing.T) { }, Attributes: map[string]*testproto.Attribute{ "src1": { - Tags: map[string]string{"key1": "value1", "key2": "value2"}, + Tags: map[string]string{"key1": "value1"}, }, "src2": { Tags: map[string]string{"key3": "value3"}, }, "dest1": { - Tags: map[string]string{"key4": "value5"}, + Tags: map[string]string{"key4": "value4"}, }, }, }, @@ -1037,6 +1037,63 @@ func TestOverwrite(t *testing.T) { }, }, }, + { + name: "overwrite repeated message fields to empty list", + paths: []string{"gallery.path"}, + src: &testproto.Profile{ + User: &testproto.User{ + UserId: 567, + Name: "different-name", + }, + Photo: &testproto.Photo{ + Path: "photo-path", + }, + LoginTimestamps: []int64{1, 2, 3}, + Attributes: map[string]*testproto.Attribute{ + "src": {}, + }, + Gallery: []*testproto.Photo{ + { + PhotoId: 123, + Path: "test-path-1", + Dimensions: &testproto.Dimensions{ + Width: 345, + Height: 456, + }, + }, + { + PhotoId: 234, + Path: "test-path-2", + Dimensions: &testproto.Dimensions{ + Width: 3456, + Height: 4567, + }, + }, + { + PhotoId: 345, + Path: "test-path-3", + Dimensions: &testproto.Dimensions{ + Width: 34567, + Height: 45678, + }, + }, + }, + }, + dest: &testproto.Profile{}, + want: &testproto.Profile{ + Gallery: []*testproto.Photo{ + { + Path: "test-path-1", + }, + { + Path: "test-path-2", + }, + { + Path: "test-path-3", + }, + }, + }, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) {