diff --git a/feature_reflect_native.go b/feature_reflect_native.go index 55e4d71..332b9e0 100644 --- a/feature_reflect_native.go +++ b/feature_reflect_native.go @@ -32,11 +32,9 @@ type intCodec struct { } func (codec *intCodec) Decode(ptr unsafe.Pointer, iter *Iterator) { - if iter.ReadNil() { - *((*int)(ptr)) = 0 - return + if !iter.ReadNil() { + *((*int)(ptr)) = iter.ReadInt() } - *((*int)(ptr)) = iter.ReadInt() } func (codec *intCodec) Encode(ptr unsafe.Pointer, stream *Stream) { @@ -55,11 +53,9 @@ type uintptrCodec struct { } func (codec *uintptrCodec) Decode(ptr unsafe.Pointer, iter *Iterator) { - if iter.ReadNil() { - *((*uintptr)(ptr)) = 0 - return + if !iter.ReadNil() { + *((*uintptr)(ptr)) = uintptr(iter.ReadUint64()) } - *((*uintptr)(ptr)) = uintptr(iter.ReadUint64()) } func (codec *uintptrCodec) Encode(ptr unsafe.Pointer, stream *Stream) { @@ -78,11 +74,9 @@ type int8Codec struct { } func (codec *int8Codec) Decode(ptr unsafe.Pointer, iter *Iterator) { - if iter.ReadNil() { - *((*uint8)(ptr)) = 0 - return + if !iter.ReadNil() { + *((*int8)(ptr)) = iter.ReadInt8() } - *((*int8)(ptr)) = iter.ReadInt8() } func (codec *int8Codec) Encode(ptr unsafe.Pointer, stream *Stream) { @@ -101,11 +95,9 @@ type int16Codec struct { } func (codec *int16Codec) Decode(ptr unsafe.Pointer, iter *Iterator) { - if iter.ReadNil() { - *((*int16)(ptr)) = 0 - return + if !iter.ReadNil() { + *((*int16)(ptr)) = iter.ReadInt16() } - *((*int16)(ptr)) = iter.ReadInt16() } func (codec *int16Codec) Encode(ptr unsafe.Pointer, stream *Stream) { @@ -124,11 +116,9 @@ type int32Codec struct { } func (codec *int32Codec) Decode(ptr unsafe.Pointer, iter *Iterator) { - if iter.ReadNil() { - *((*int32)(ptr)) = 0 - return + if !iter.ReadNil() { + *((*int32)(ptr)) = iter.ReadInt32() } - *((*int32)(ptr)) = iter.ReadInt32() } func (codec *int32Codec) Encode(ptr unsafe.Pointer, stream *Stream) { @@ -147,11 +137,9 @@ type int64Codec struct { } func (codec *int64Codec) Decode(ptr unsafe.Pointer, iter *Iterator) { - if iter.ReadNil() { - *((*int64)(ptr)) = 0 - return + if !iter.ReadNil() { + *((*int64)(ptr)) = iter.ReadInt64() } - *((*int64)(ptr)) = iter.ReadInt64() } func (codec *int64Codec) Encode(ptr unsafe.Pointer, stream *Stream) { @@ -170,11 +158,10 @@ type uintCodec struct { } func (codec *uintCodec) Decode(ptr unsafe.Pointer, iter *Iterator) { - if iter.ReadNil() { - *((*uint)(ptr)) = 0 + if !iter.ReadNil() { + *((*uint)(ptr)) = iter.ReadUint() return } - *((*uint)(ptr)) = iter.ReadUint() } func (codec *uintCodec) Encode(ptr unsafe.Pointer, stream *Stream) { @@ -193,11 +180,9 @@ type uint8Codec struct { } func (codec *uint8Codec) Decode(ptr unsafe.Pointer, iter *Iterator) { - if iter.ReadNil() { - *((*uint8)(ptr)) = 0 - return + if !iter.ReadNil() { + *((*uint8)(ptr)) = iter.ReadUint8() } - *((*uint8)(ptr)) = iter.ReadUint8() } func (codec *uint8Codec) Encode(ptr unsafe.Pointer, stream *Stream) { @@ -216,11 +201,9 @@ type uint16Codec struct { } func (codec *uint16Codec) Decode(ptr unsafe.Pointer, iter *Iterator) { - if iter.ReadNil() { - *((*uint16)(ptr)) = 0 - return + if !iter.ReadNil() { + *((*uint16)(ptr)) = iter.ReadUint16() } - *((*uint16)(ptr)) = iter.ReadUint16() } func (codec *uint16Codec) Encode(ptr unsafe.Pointer, stream *Stream) { @@ -239,11 +222,9 @@ type uint32Codec struct { } func (codec *uint32Codec) Decode(ptr unsafe.Pointer, iter *Iterator) { - if iter.ReadNil() { - *((*uint32)(ptr)) = 0 - return + if !iter.ReadNil() { + *((*uint32)(ptr)) = iter.ReadUint32() } - *((*uint32)(ptr)) = iter.ReadUint32() } func (codec *uint32Codec) Encode(ptr unsafe.Pointer, stream *Stream) { @@ -262,11 +243,9 @@ type uint64Codec struct { } func (codec *uint64Codec) Decode(ptr unsafe.Pointer, iter *Iterator) { - if iter.ReadNil() { - *((*uint64)(ptr)) = 0 - return + if !iter.ReadNil() { + *((*uint64)(ptr)) = iter.ReadUint64() } - *((*uint64)(ptr)) = iter.ReadUint64() } func (codec *uint64Codec) Encode(ptr unsafe.Pointer, stream *Stream) { @@ -285,11 +264,9 @@ type float32Codec struct { } func (codec *float32Codec) Decode(ptr unsafe.Pointer, iter *Iterator) { - if iter.ReadNil() { - *((*float32)(ptr)) = 0 - return + if !iter.ReadNil() { + *((*float32)(ptr)) = iter.ReadFloat32() } - *((*float32)(ptr)) = iter.ReadFloat32() } func (codec *float32Codec) Encode(ptr unsafe.Pointer, stream *Stream) { @@ -308,11 +285,9 @@ type float64Codec struct { } func (codec *float64Codec) Decode(ptr unsafe.Pointer, iter *Iterator) { - if iter.ReadNil() { - *((*float64)(ptr)) = 0 - return + if !iter.ReadNil() { + *((*float64)(ptr)) = iter.ReadFloat64() } - *((*float64)(ptr)) = iter.ReadFloat64() } func (codec *float64Codec) Encode(ptr unsafe.Pointer, stream *Stream) { @@ -352,13 +327,39 @@ type emptyInterfaceCodec struct { } func (codec *emptyInterfaceCodec) Decode(ptr unsafe.Pointer, iter *Iterator) { - if iter.ReadNil() { - *((*interface{})(ptr)) = nil + existing := *((*interface{})(ptr)) + + // Checking for both typed and untyped nil pointers. + if existing != nil && + reflect.TypeOf(existing).Kind() == reflect.Ptr && + !reflect.ValueOf(existing).IsNil() { + + var ptrToExisting interface{} + for { + elem := reflect.ValueOf(existing).Elem() + if elem.Kind() != reflect.Ptr || elem.IsNil() { + break + } + ptrToExisting = existing + existing = elem.Interface() + } + + if iter.ReadNil() { + if ptrToExisting != nil { + nilPtr := reflect.Zero(reflect.TypeOf(ptrToExisting).Elem()) + reflect.ValueOf(ptrToExisting).Elem().Set(nilPtr) + } else { + *((*interface{})(ptr)) = nil + } + } else { + iter.ReadVal(existing) + } + return } - existing := *((*interface{})(ptr)) - if existing != nil && reflect.TypeOf(existing).Kind() == reflect.Ptr { - iter.ReadVal(existing) + + if iter.ReadNil() { + *((*interface{})(ptr)) = nil } else { *((*interface{})(ptr)) = iter.Read() } diff --git a/jsoniter_interface_test.go b/jsoniter_interface_test.go index a464f46..5a05288 100644 --- a/jsoniter_interface_test.go +++ b/jsoniter_interface_test.go @@ -437,3 +437,124 @@ func Test_marshal_nil_nonempty_interface(t *testing.T) { should.NoError(err) should.Equal(nil, obj.Field) } + +func Test_overwrite_interface_ptr_value_with_nil(t *testing.T) { + type Wrapper struct { + Payload interface{} `json:"payload,omitempty"` + } + type Payload struct { + Value int `json:"val,omitempty"` + } + + should := require.New(t) + + payload := &Payload{} + wrapper := &Wrapper{ + Payload: &payload, + } + + err := json.Unmarshal([]byte(`{"payload": {"val": 42}}`), &wrapper) + should.Equal(nil, err) + should.Equal(&payload, wrapper.Payload) + should.Equal(42, (*(wrapper.Payload.(**Payload))).Value) + + err = json.Unmarshal([]byte(`{"payload": null}`), &wrapper) + should.Equal(nil, err) + should.Equal(&payload, wrapper.Payload) + should.Equal((*Payload)(nil), payload) + + payload = &Payload{} + wrapper = &Wrapper{ + Payload: &payload, + } + + err = Unmarshal([]byte(`{"payload": {"val": 42}}`), &wrapper) + should.Equal(nil, err) + should.Equal(&payload, wrapper.Payload) + should.Equal(42, (*(wrapper.Payload.(**Payload))).Value) + + err = Unmarshal([]byte(`{"payload": null}`), &wrapper) + should.Equal(nil, err) + should.Equal(&payload, wrapper.Payload) + should.Equal((*Payload)(nil), payload) +} + +func Test_overwrite_interface_value_with_nil(t *testing.T) { + type Wrapper struct { + Payload interface{} `json:"payload,omitempty"` + } + type Payload struct { + Value int `json:"val,omitempty"` + } + + should := require.New(t) + + payload := &Payload{} + wrapper := &Wrapper{ + Payload: payload, + } + + err := json.Unmarshal([]byte(`{"payload": {"val": 42}}`), &wrapper) + should.Equal(nil, err) + should.Equal(42, (*(wrapper.Payload.(*Payload))).Value) + + err = json.Unmarshal([]byte(`{"payload": null}`), &wrapper) + should.Equal(nil, err) + should.Equal(nil, wrapper.Payload) + should.Equal(42, payload.Value) + + payload = &Payload{} + wrapper = &Wrapper{ + Payload: payload, + } + + err = Unmarshal([]byte(`{"payload": {"val": 42}}`), &wrapper) + should.Equal(nil, err) + should.Equal(42, (*(wrapper.Payload.(*Payload))).Value) + + err = Unmarshal([]byte(`{"payload": null}`), &wrapper) + should.Equal(nil, err) + should.Equal(nil, wrapper.Payload) + should.Equal(42, payload.Value) +} + +func Test_unmarshal_into_nil(t *testing.T) { + type Payload struct { + Value int `json:"val,omitempty"` + } + type Wrapper struct { + Payload interface{} `json:"payload,omitempty"` + } + + should := require.New(t) + + var payload *Payload + wrapper := &Wrapper{ + Payload: payload, + } + + err := json.Unmarshal([]byte(`{"payload": {"val": 42}}`), &wrapper) + should.Nil(err) + should.NotNil(wrapper.Payload) + should.Nil(payload) + + err = json.Unmarshal([]byte(`{"payload": null}`), &wrapper) + should.Nil(err) + should.Nil(wrapper.Payload) + should.Nil(payload) + + payload = nil + wrapper = &Wrapper{ + Payload: payload, + } + + err = Unmarshal([]byte(`{"payload": {"val": 42}}`), &wrapper) + should.Nil(err) + should.NotNil(wrapper.Payload) + should.Nil(payload) + + err = Unmarshal([]byte(`{"payload": null}`), &wrapper) + should.Nil(err) + should.Nil(wrapper.Payload) + should.Nil(payload) +} diff --git a/jsoniter_null_test.go b/jsoniter_null_test.go index df71bf2..8c89147 100644 --- a/jsoniter_null_test.go +++ b/jsoniter_null_test.go @@ -3,9 +3,10 @@ package jsoniter import ( "bytes" "encoding/json" - "github.com/stretchr/testify/require" "io" "testing" + + "github.com/stretchr/testify/require" ) func Test_read_null(t *testing.T) { @@ -135,3 +136,33 @@ func Test_encode_nil_array(t *testing.T) { should.Nil(err) should.Equal("null", string(output)) } + +func Test_decode_nil_num(t *testing.T) { + type TestData struct { + Field int `json:"field"` + } + should := require.New(t) + + data1 := []byte(`{"field": 42}`) + data2 := []byte(`{"field": null}`) + + // Checking stdlib behavior as well + obj2 := TestData{} + err := json.Unmarshal(data1, &obj2) + should.Equal(nil, err) + should.Equal(42, obj2.Field) + + err = json.Unmarshal(data2, &obj2) + should.Equal(nil, err) + should.Equal(42, obj2.Field) + + obj := TestData{} + + err = Unmarshal(data1, &obj) + should.Equal(nil, err) + should.Equal(42, obj.Field) + + err = Unmarshal(data2, &obj) + should.Equal(nil, err) + should.Equal(42, obj.Field) +}