From ae57d167e8bace7dceb3d03296e534a5d1db0959 Mon Sep 17 00:00:00 2001 From: Oleg Shaldybin Date: Thu, 14 Sep 2017 23:04:54 -0700 Subject: [PATCH] Fix custom marshaler for enum types When MarshalJSON was defined on a pointer receiver custom enum type marshaling/unmarshaling was panicing since the underlying primitive type was treated as a pointer. Since method set for pointer receivers includes value receiver methods we don't really need optionalEncoder and can just use marshalEncoder directly. --- feature_reflect.go | 1 - feature_reflect_native.go | 1 + jsoniter_enum_marshaler_test.go | 50 +++++++++++++++++++++++++++++++++ 3 files changed, 51 insertions(+), 1 deletion(-) create mode 100644 jsoniter_enum_marshaler_test.go diff --git a/feature_reflect.go b/feature_reflect.go index fc14a72..4483e34 100644 --- a/feature_reflect.go +++ b/feature_reflect.go @@ -476,7 +476,6 @@ func createEncoderOfType(cfg *frozenConfig, typ reflect.Type) (ValEncoder, error templateInterface: extractInterface(templateInterface), checkIsEmpty: checkIsEmpty, } - encoder = &optionalEncoder{encoder} return encoder, nil } if typ.Implements(textMarshalerType) { diff --git a/feature_reflect_native.go b/feature_reflect_native.go index 78912c0..cf939bf 100644 --- a/feature_reflect_native.go +++ b/feature_reflect_native.go @@ -661,6 +661,7 @@ func (encoder *marshalerEncoder) Encode(ptr unsafe.Pointer, stream *Stream) { templateInterface.word = ptr realInterface := (*interface{})(unsafe.Pointer(&templateInterface)) marshaler := (*realInterface).(json.Marshaler) + bytes, err := marshaler.MarshalJSON() if err != nil { stream.Error = err diff --git a/jsoniter_enum_marshaler_test.go b/jsoniter_enum_marshaler_test.go new file mode 100644 index 0000000..4f9be8b --- /dev/null +++ b/jsoniter_enum_marshaler_test.go @@ -0,0 +1,50 @@ +package jsoniter + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/require" +) + +type MyEnum int64 + +const ( + MyEnumA MyEnum = iota + MyEnumB +) + +func (m *MyEnum) MarshalJSON() ([]byte, error) { + return []byte(fmt.Sprintf(`"foo-%d"`, int(*m))), nil +} + +func (m *MyEnum) UnmarshalJSON(jb []byte) error { + switch string(jb) { + case `"foo-1"`: + *m = MyEnumB + default: + *m = MyEnumA + } + return nil +} + +func Test_custom_marshaler_on_enum(t *testing.T) { + type Wrapper struct { + Payload interface{} + } + type Wrapper2 struct { + Payload MyEnum + } + should := require.New(t) + + w := Wrapper{Payload: MyEnumB} + + jb, err := Marshal(w) + should.Equal(nil, err) + should.Equal(`{"Payload":"foo-1"}`, string(jb)) + + var w2 Wrapper2 + err = Unmarshal(jb, &w2) + should.Equal(nil, err) + should.Equal(MyEnumB, w2.Payload) +}