From 9f6e5962a9c771ad9b94195093d472dc3d932053 Mon Sep 17 00:00:00 2001 From: Oleg Shaldybin Date: Fri, 15 Sep 2017 14:05:38 -0700 Subject: [PATCH] Improve stdlib compatibility 1. Null values for primitive types no longer clear the original value in the destination object. 2. Dereference multiple levels of pointers in the destination interface{} type before unmarshaling into it. This is needed to match stdlib behavior when working with nested interface{} fields. If the destination object is a pointer to interface{} then the incoming nil value should nil out the destination object but keep the reference to that nil value on its parent object. However if the destination object is an interface{} value it should set the reference to nil but keep the original object intact. 3. Correctly handle typed nil decode destinations. --- feature_reflect_native.go | 113 +++++++++++++++++----------------- jsoniter_interface_test.go | 121 +++++++++++++++++++++++++++++++++++++ jsoniter_null_test.go | 33 +++++++++- 3 files changed, 210 insertions(+), 57 deletions(-) diff --git a/feature_reflect_native.go b/feature_reflect_native.go index 55e4d71..332b9e0 100644 --- a/feature_reflect_native.go +++ b/feature_reflect_native.go @@ -32,11 +32,9 @@ type intCodec struct { } func (codec *intCodec) Decode(ptr unsafe.Pointer, iter *Iterator) { - if iter.ReadNil() { - *((*int)(ptr)) = 0 - return + if !iter.ReadNil() { + *((*int)(ptr)) = iter.ReadInt() } - *((*int)(ptr)) = iter.ReadInt() } func (codec *intCodec) Encode(ptr unsafe.Pointer, stream *Stream) { @@ -55,11 +53,9 @@ type uintptrCodec struct { } func (codec *uintptrCodec) Decode(ptr unsafe.Pointer, iter *Iterator) { - if iter.ReadNil() { - *((*uintptr)(ptr)) = 0 - return + if !iter.ReadNil() { + *((*uintptr)(ptr)) = uintptr(iter.ReadUint64()) } - *((*uintptr)(ptr)) = uintptr(iter.ReadUint64()) } func (codec *uintptrCodec) Encode(ptr unsafe.Pointer, stream *Stream) { @@ -78,11 +74,9 @@ type int8Codec struct { } func (codec *int8Codec) Decode(ptr unsafe.Pointer, iter *Iterator) { - if iter.ReadNil() { - *((*uint8)(ptr)) = 0 - return + if !iter.ReadNil() { + *((*int8)(ptr)) = iter.ReadInt8() } - *((*int8)(ptr)) = iter.ReadInt8() } func (codec *int8Codec) Encode(ptr unsafe.Pointer, stream *Stream) { @@ -101,11 +95,9 @@ type int16Codec struct { } func (codec *int16Codec) Decode(ptr unsafe.Pointer, iter *Iterator) { - if iter.ReadNil() { - *((*int16)(ptr)) = 0 - return + if !iter.ReadNil() { + *((*int16)(ptr)) = iter.ReadInt16() } - *((*int16)(ptr)) = iter.ReadInt16() } func (codec *int16Codec) Encode(ptr unsafe.Pointer, stream *Stream) { @@ -124,11 +116,9 @@ type int32Codec struct { } func (codec *int32Codec) Decode(ptr unsafe.Pointer, iter *Iterator) { - if iter.ReadNil() { - *((*int32)(ptr)) = 0 - return + if !iter.ReadNil() { + *((*int32)(ptr)) = iter.ReadInt32() } - *((*int32)(ptr)) = iter.ReadInt32() } func (codec *int32Codec) Encode(ptr unsafe.Pointer, stream *Stream) { @@ -147,11 +137,9 @@ type int64Codec struct { } func (codec *int64Codec) Decode(ptr unsafe.Pointer, iter *Iterator) { - if iter.ReadNil() { - *((*int64)(ptr)) = 0 - return + if !iter.ReadNil() { + *((*int64)(ptr)) = iter.ReadInt64() } - *((*int64)(ptr)) = iter.ReadInt64() } func (codec *int64Codec) Encode(ptr unsafe.Pointer, stream *Stream) { @@ -170,11 +158,10 @@ type uintCodec struct { } func (codec *uintCodec) Decode(ptr unsafe.Pointer, iter *Iterator) { - if iter.ReadNil() { - *((*uint)(ptr)) = 0 + if !iter.ReadNil() { + *((*uint)(ptr)) = iter.ReadUint() return } - *((*uint)(ptr)) = iter.ReadUint() } func (codec *uintCodec) Encode(ptr unsafe.Pointer, stream *Stream) { @@ -193,11 +180,9 @@ type uint8Codec struct { } func (codec *uint8Codec) Decode(ptr unsafe.Pointer, iter *Iterator) { - if iter.ReadNil() { - *((*uint8)(ptr)) = 0 - return + if !iter.ReadNil() { + *((*uint8)(ptr)) = iter.ReadUint8() } - *((*uint8)(ptr)) = iter.ReadUint8() } func (codec *uint8Codec) Encode(ptr unsafe.Pointer, stream *Stream) { @@ -216,11 +201,9 @@ type uint16Codec struct { } func (codec *uint16Codec) Decode(ptr unsafe.Pointer, iter *Iterator) { - if iter.ReadNil() { - *((*uint16)(ptr)) = 0 - return + if !iter.ReadNil() { + *((*uint16)(ptr)) = iter.ReadUint16() } - *((*uint16)(ptr)) = iter.ReadUint16() } func (codec *uint16Codec) Encode(ptr unsafe.Pointer, stream *Stream) { @@ -239,11 +222,9 @@ type uint32Codec struct { } func (codec *uint32Codec) Decode(ptr unsafe.Pointer, iter *Iterator) { - if iter.ReadNil() { - *((*uint32)(ptr)) = 0 - return + if !iter.ReadNil() { + *((*uint32)(ptr)) = iter.ReadUint32() } - *((*uint32)(ptr)) = iter.ReadUint32() } func (codec *uint32Codec) Encode(ptr unsafe.Pointer, stream *Stream) { @@ -262,11 +243,9 @@ type uint64Codec struct { } func (codec *uint64Codec) Decode(ptr unsafe.Pointer, iter *Iterator) { - if iter.ReadNil() { - *((*uint64)(ptr)) = 0 - return + if !iter.ReadNil() { + *((*uint64)(ptr)) = iter.ReadUint64() } - *((*uint64)(ptr)) = iter.ReadUint64() } func (codec *uint64Codec) Encode(ptr unsafe.Pointer, stream *Stream) { @@ -285,11 +264,9 @@ type float32Codec struct { } func (codec *float32Codec) Decode(ptr unsafe.Pointer, iter *Iterator) { - if iter.ReadNil() { - *((*float32)(ptr)) = 0 - return + if !iter.ReadNil() { + *((*float32)(ptr)) = iter.ReadFloat32() } - *((*float32)(ptr)) = iter.ReadFloat32() } func (codec *float32Codec) Encode(ptr unsafe.Pointer, stream *Stream) { @@ -308,11 +285,9 @@ type float64Codec struct { } func (codec *float64Codec) Decode(ptr unsafe.Pointer, iter *Iterator) { - if iter.ReadNil() { - *((*float64)(ptr)) = 0 - return + if !iter.ReadNil() { + *((*float64)(ptr)) = iter.ReadFloat64() } - *((*float64)(ptr)) = iter.ReadFloat64() } func (codec *float64Codec) Encode(ptr unsafe.Pointer, stream *Stream) { @@ -352,13 +327,39 @@ type emptyInterfaceCodec struct { } func (codec *emptyInterfaceCodec) Decode(ptr unsafe.Pointer, iter *Iterator) { - if iter.ReadNil() { - *((*interface{})(ptr)) = nil + existing := *((*interface{})(ptr)) + + // Checking for both typed and untyped nil pointers. + if existing != nil && + reflect.TypeOf(existing).Kind() == reflect.Ptr && + !reflect.ValueOf(existing).IsNil() { + + var ptrToExisting interface{} + for { + elem := reflect.ValueOf(existing).Elem() + if elem.Kind() != reflect.Ptr || elem.IsNil() { + break + } + ptrToExisting = existing + existing = elem.Interface() + } + + if iter.ReadNil() { + if ptrToExisting != nil { + nilPtr := reflect.Zero(reflect.TypeOf(ptrToExisting).Elem()) + reflect.ValueOf(ptrToExisting).Elem().Set(nilPtr) + } else { + *((*interface{})(ptr)) = nil + } + } else { + iter.ReadVal(existing) + } + return } - existing := *((*interface{})(ptr)) - if existing != nil && reflect.TypeOf(existing).Kind() == reflect.Ptr { - iter.ReadVal(existing) + + if iter.ReadNil() { + *((*interface{})(ptr)) = nil } else { *((*interface{})(ptr)) = iter.Read() } diff --git a/jsoniter_interface_test.go b/jsoniter_interface_test.go index a464f46..5a05288 100644 --- a/jsoniter_interface_test.go +++ b/jsoniter_interface_test.go @@ -437,3 +437,124 @@ func Test_marshal_nil_nonempty_interface(t *testing.T) { should.NoError(err) should.Equal(nil, obj.Field) } + +func Test_overwrite_interface_ptr_value_with_nil(t *testing.T) { + type Wrapper struct { + Payload interface{} `json:"payload,omitempty"` + } + type Payload struct { + Value int `json:"val,omitempty"` + } + + should := require.New(t) + + payload := &Payload{} + wrapper := &Wrapper{ + Payload: &payload, + } + + err := json.Unmarshal([]byte(`{"payload": {"val": 42}}`), &wrapper) + should.Equal(nil, err) + should.Equal(&payload, wrapper.Payload) + should.Equal(42, (*(wrapper.Payload.(**Payload))).Value) + + err = json.Unmarshal([]byte(`{"payload": null}`), &wrapper) + should.Equal(nil, err) + should.Equal(&payload, wrapper.Payload) + should.Equal((*Payload)(nil), payload) + + payload = &Payload{} + wrapper = &Wrapper{ + Payload: &payload, + } + + err = Unmarshal([]byte(`{"payload": {"val": 42}}`), &wrapper) + should.Equal(nil, err) + should.Equal(&payload, wrapper.Payload) + should.Equal(42, (*(wrapper.Payload.(**Payload))).Value) + + err = Unmarshal([]byte(`{"payload": null}`), &wrapper) + should.Equal(nil, err) + should.Equal(&payload, wrapper.Payload) + should.Equal((*Payload)(nil), payload) +} + +func Test_overwrite_interface_value_with_nil(t *testing.T) { + type Wrapper struct { + Payload interface{} `json:"payload,omitempty"` + } + type Payload struct { + Value int `json:"val,omitempty"` + } + + should := require.New(t) + + payload := &Payload{} + wrapper := &Wrapper{ + Payload: payload, + } + + err := json.Unmarshal([]byte(`{"payload": {"val": 42}}`), &wrapper) + should.Equal(nil, err) + should.Equal(42, (*(wrapper.Payload.(*Payload))).Value) + + err = json.Unmarshal([]byte(`{"payload": null}`), &wrapper) + should.Equal(nil, err) + should.Equal(nil, wrapper.Payload) + should.Equal(42, payload.Value) + + payload = &Payload{} + wrapper = &Wrapper{ + Payload: payload, + } + + err = Unmarshal([]byte(`{"payload": {"val": 42}}`), &wrapper) + should.Equal(nil, err) + should.Equal(42, (*(wrapper.Payload.(*Payload))).Value) + + err = Unmarshal([]byte(`{"payload": null}`), &wrapper) + should.Equal(nil, err) + should.Equal(nil, wrapper.Payload) + should.Equal(42, payload.Value) +} + +func Test_unmarshal_into_nil(t *testing.T) { + type Payload struct { + Value int `json:"val,omitempty"` + } + type Wrapper struct { + Payload interface{} `json:"payload,omitempty"` + } + + should := require.New(t) + + var payload *Payload + wrapper := &Wrapper{ + Payload: payload, + } + + err := json.Unmarshal([]byte(`{"payload": {"val": 42}}`), &wrapper) + should.Nil(err) + should.NotNil(wrapper.Payload) + should.Nil(payload) + + err = json.Unmarshal([]byte(`{"payload": null}`), &wrapper) + should.Nil(err) + should.Nil(wrapper.Payload) + should.Nil(payload) + + payload = nil + wrapper = &Wrapper{ + Payload: payload, + } + + err = Unmarshal([]byte(`{"payload": {"val": 42}}`), &wrapper) + should.Nil(err) + should.NotNil(wrapper.Payload) + should.Nil(payload) + + err = Unmarshal([]byte(`{"payload": null}`), &wrapper) + should.Nil(err) + should.Nil(wrapper.Payload) + should.Nil(payload) +} diff --git a/jsoniter_null_test.go b/jsoniter_null_test.go index df71bf2..8c89147 100644 --- a/jsoniter_null_test.go +++ b/jsoniter_null_test.go @@ -3,9 +3,10 @@ package jsoniter import ( "bytes" "encoding/json" - "github.com/stretchr/testify/require" "io" "testing" + + "github.com/stretchr/testify/require" ) func Test_read_null(t *testing.T) { @@ -135,3 +136,33 @@ func Test_encode_nil_array(t *testing.T) { should.Nil(err) should.Equal("null", string(output)) } + +func Test_decode_nil_num(t *testing.T) { + type TestData struct { + Field int `json:"field"` + } + should := require.New(t) + + data1 := []byte(`{"field": 42}`) + data2 := []byte(`{"field": null}`) + + // Checking stdlib behavior as well + obj2 := TestData{} + err := json.Unmarshal(data1, &obj2) + should.Equal(nil, err) + should.Equal(42, obj2.Field) + + err = json.Unmarshal(data2, &obj2) + should.Equal(nil, err) + should.Equal(42, obj2.Field) + + obj := TestData{} + + err = Unmarshal(data1, &obj) + should.Equal(nil, err) + should.Equal(42, obj.Field) + + err = Unmarshal(data2, &obj) + should.Equal(nil, err) + should.Equal(42, obj.Field) +}