From 8961be9c211ef8fb82db75119e05ebe43e3f993a Mon Sep 17 00:00:00 2001 From: Sai To Yeung Date: Thu, 7 May 2020 22:21:59 -0400 Subject: [PATCH] Map keys of custom types should serialize using MarshalText when available (#461) * Map keys of custom types should serialize/deserialize using MarshalText/UnmarshalText when available - this brings marshaling/unmarshaling behavior in line with encoding/json - in general, any types that implement the interfaces from the encoding package (TextUnmarshaler, TextMarshaler, etc.) should use the provided method when available --- reflect_map.go | 76 ++++++++++++++++++++++------------------- value_tests/map_test.go | 15 ++++++++ 2 files changed, 55 insertions(+), 36 deletions(-) diff --git a/reflect_map.go b/reflect_map.go index 9e2b623..13fb67e 100644 --- a/reflect_map.go +++ b/reflect_map.go @@ -49,6 +49,33 @@ func decoderOfMapKey(ctx *ctx, typ reflect2.Type) ValDecoder { return decoder } } + + ptrType := reflect2.PtrTo(typ) + if ptrType.Implements(unmarshalerType) { + return &referenceDecoder{ + &unmarshalerDecoder{ + valType: ptrType, + }, + } + } + if typ.Implements(unmarshalerType) { + return &unmarshalerDecoder{ + valType: typ, + } + } + if ptrType.Implements(textUnmarshalerType) { + return &referenceDecoder{ + &textUnmarshalerDecoder{ + valType: ptrType, + }, + } + } + if typ.Implements(textUnmarshalerType) { + return &textUnmarshalerDecoder{ + valType: typ, + } + } + switch typ.Kind() { case reflect.String: return decoderOfType(ctx, reflect2.DefaultTypeOfKind(reflect.String)) @@ -63,31 +90,6 @@ func decoderOfMapKey(ctx *ctx, typ reflect2.Type) ValDecoder { typ = reflect2.DefaultTypeOfKind(typ.Kind()) return &numericMapKeyDecoder{decoderOfType(ctx, typ)} default: - ptrType := reflect2.PtrTo(typ) - if ptrType.Implements(unmarshalerType) { - return &referenceDecoder{ - &unmarshalerDecoder{ - valType: ptrType, - }, - } - } - if typ.Implements(unmarshalerType) { - return &unmarshalerDecoder{ - valType: typ, - } - } - if ptrType.Implements(textUnmarshalerType) { - return &referenceDecoder{ - &textUnmarshalerDecoder{ - valType: ptrType, - }, - } - } - if typ.Implements(textUnmarshalerType) { - return &textUnmarshalerDecoder{ - valType: typ, - } - } return &lazyErrorDecoder{err: fmt.Errorf("unsupported map key type: %v", typ)} } } @@ -103,6 +105,19 @@ func encoderOfMapKey(ctx *ctx, typ reflect2.Type) ValEncoder { return encoder } } + + if typ == textMarshalerType { + return &directTextMarshalerEncoder{ + stringEncoder: ctx.EncoderOf(reflect2.TypeOf("")), + } + } + if typ.Implements(textMarshalerType) { + return &textMarshalerEncoder{ + valType: typ, + stringEncoder: ctx.EncoderOf(reflect2.TypeOf("")), + } + } + switch typ.Kind() { case reflect.String: return encoderOfType(ctx, reflect2.DefaultTypeOfKind(reflect.String)) @@ -117,17 +132,6 @@ func encoderOfMapKey(ctx *ctx, typ reflect2.Type) ValEncoder { typ = reflect2.DefaultTypeOfKind(typ.Kind()) return &numericMapKeyEncoder{encoderOfType(ctx, typ)} default: - if typ == textMarshalerType { - return &directTextMarshalerEncoder{ - stringEncoder: ctx.EncoderOf(reflect2.TypeOf("")), - } - } - if typ.Implements(textMarshalerType) { - return &textMarshalerEncoder{ - valType: typ, - stringEncoder: ctx.EncoderOf(reflect2.TypeOf("")), - } - } if typ.Kind() == reflect.Interface { return &dynamicMapKeyEncoder{ctx, typ} } diff --git a/value_tests/map_test.go b/value_tests/map_test.go index f8ffa5a..02a1895 100644 --- a/value_tests/map_test.go +++ b/value_tests/map_test.go @@ -31,6 +31,7 @@ func init() { map[string]*json.RawMessage{"hello": pRawMessage(json.RawMessage("[]"))}, map[Date]bool{{}: true}, map[Date2]bool{{}: true}, + map[customKey]string{customKey(1): "bar"}, ) unmarshalCases = append(unmarshalCases, unmarshalCase{ ptr: (*map[string]string)(nil), @@ -55,6 +56,9 @@ func init() { "2018-12-13": true, "2018-12-14": true }`, + }, unmarshalCase{ + ptr: (*map[customKey]string)(nil), + input: `{"foo": "bar"}`, }) } @@ -115,3 +119,14 @@ func (d Date2) UnmarshalJSON(b []byte) error { func (d Date2) MarshalJSON() ([]byte, error) { return []byte(d.Time.Format("2006-01-02")), nil } + +type customKey int32 + +func (c customKey) MarshalText() ([]byte, error) { + return []byte("foo"), nil +} + +func (c *customKey) UnmarshalText(value []byte) error { + *c = 1 + return nil +}