diff --git a/feature_iter_skip.go b/feature_iter_skip.go index ad00fdd..79f1672 100644 --- a/feature_iter_skip.go +++ b/feature_iter_skip.go @@ -30,6 +30,17 @@ func (iter *Iterator) ReadBool() (ret bool) { } +func (iter *Iterator) SkipAndReturnBytes() []byte { + if iter.reader != nil { + panic("reader input does not support this api") + } + before := iter.head + iter.Skip() + after := iter.head + return iter.buf[before:after] +} + + // Skip skips a json object and positions to relatively the next json object func (iter *Iterator) Skip() { c := iter.nextToken() diff --git a/feature_reflect.go b/feature_reflect.go index 4329163..0ca5405 100644 --- a/feature_reflect.go +++ b/feature_reflect.go @@ -75,6 +75,7 @@ var fieldEncoders map[string]Encoder var extensions []ExtensionFunc var anyType reflect.Type var marshalerType reflect.Type +var unmarshalerType reflect.Type func init() { typeDecoders = map[string]Decoder{} @@ -86,6 +87,7 @@ func init() { atomic.StorePointer(&ENCODERS, unsafe.Pointer(&map[string]Encoder{})) anyType = reflect.TypeOf((*Any)(nil)).Elem() marshalerType = reflect.TypeOf((*json.Marshaler)(nil)).Elem() + unmarshalerType = reflect.TypeOf((*json.Unmarshaler)(nil)).Elem() } func addDecoderToCache(cacheKey reflect.Type, decoder Decoder) { @@ -293,9 +295,6 @@ func (p prefix) addToEncoder(encoder Encoder, err error) (Encoder, error) { } func decoderOfType(typ reflect.Type) (Decoder, error) { - if typ.ConvertibleTo(anyType) { - return &anyCodec{}, nil - } typeName := typ.String() typeDecoder := typeDecoders[typeName] if typeDecoder != nil { @@ -315,6 +314,13 @@ func decoderOfType(typ reflect.Type) (Decoder, error) { } func createDecoderOfType(typ reflect.Type) (Decoder, error) { + if typ.ConvertibleTo(anyType) { + return &anyCodec{}, nil + } + if typ.ConvertibleTo(unmarshalerType) { + templateInterface := reflect.New(typ).Elem().Interface() + return &optionalDecoder{typ, &unmarshalerDecoder{extractInterface(templateInterface)}}, nil + } switch typ.Kind() { case reflect.String: return &stringCodec{}, nil diff --git a/feature_reflect_native.go b/feature_reflect_native.go index ecca90a..8b4dcd3 100644 --- a/feature_reflect_native.go +++ b/feature_reflect_native.go @@ -362,4 +362,20 @@ func (encoder *marshalerEncoder) isEmpty(ptr unsafe.Pointer) bool { } else { return len(bytes) > 0 } +} + +type unmarshalerDecoder struct { + templateInterface emptyInterface +} + +func (decoder *unmarshalerDecoder) decode(ptr unsafe.Pointer, iter *Iterator) { + templateInterface := decoder.templateInterface + templateInterface.word = ptr + realInterface := (*interface{})(unsafe.Pointer(&templateInterface)) + unmarshaler := (*realInterface).(json.Unmarshaler) + bytes := iter.SkipAndReturnBytes() + err := unmarshaler.UnmarshalJSON(bytes) + if err != nil { + iter.reportError("unmarshaler", err.Error()) + } } \ No newline at end of file diff --git a/jsoniter_customize_test.go b/jsoniter_customize_test.go index 054df9c..6bd7246 100644 --- a/jsoniter_customize_test.go +++ b/jsoniter_customize_test.go @@ -149,4 +149,28 @@ func Test_marshaler(t *testing.T) { str, err := MarshalToString(obj) should.Nil(err) should.Equal(`{"Field":"hello"}`, str) -} \ No newline at end of file +} + +type ObjectImplementedUnmarshaler int + +func (obj *ObjectImplementedUnmarshaler) UnmarshalJSON([]byte) error { + *obj = 100 + return nil +} + +func Test_unmarshaler(t *testing.T) { + type TestObject struct { + Field *ObjectImplementedUnmarshaler + Field2 string + } + should := require.New(t) + obj := TestObject{} + val := ObjectImplementedUnmarshaler(0) + obj.Field = &val + err := json.Unmarshal([]byte(`{"Field":"hello"}`), &obj) + should.Nil(err) + should.Equal(100, int(*obj.Field)) + err = Unmarshal([]byte(`{"Field":"hello"}`), &obj) + should.Nil(err) + should.Equal(100, int(*obj.Field)) +}