From 552afb3625c4254d2690897887ac3ba5c3a1d467 Mon Sep 17 00:00:00 2001 From: Tao Wen Date: Mon, 9 Jan 2017 19:19:48 +0800 Subject: [PATCH] struct encoder --- feature_adapter.go | 1 + feature_reflect.go | 65 ++++++++++++++++------ feature_reflect_native.go | 96 ++++++++++++++++++++++++--------- feature_reflect_object.go | 84 ++++++++++++++++++++++++++++- jsoniter_bool_test.go | 11 ++++ jsoniter_float_test.go | 18 +++++++ jsoniter_int_test.go | 88 +++++++++++++++++++++++++++--- jsoniter_reflect_struct_test.go | 23 +++++++- 8 files changed, 336 insertions(+), 50 deletions(-) diff --git a/feature_adapter.go b/feature_adapter.go index b04e74d..0fd10dc 100644 --- a/feature_adapter.go +++ b/feature_adapter.go @@ -31,6 +31,7 @@ func Marshal(v interface{}) ([]byte, error) { buf := &bytes.Buffer{} stream := NewStream(buf, 4096) stream.WriteVal(v) + stream.Flush() if stream.Error != nil { return nil, stream.Error } diff --git a/feature_reflect.go b/feature_reflect.go index a8fb923..56d7c90 100644 --- a/feature_reflect.go +++ b/feature_reflect.go @@ -304,13 +304,20 @@ func (stream *Stream) WriteVal(val interface{}) { type prefix string -func (p prefix) addTo(decoder Decoder, err error) (Decoder, error) { +func (p prefix) addToDecoder(decoder Decoder, err error) (Decoder, error) { if err != nil { return nil, fmt.Errorf("%s: %s", p, err.Error()) } return decoder, err } +func (p prefix) addToEncoder(encoder Encoder, err error) (Encoder, error) { + if err != nil { + return nil, fmt.Errorf("%s: %s", p, err.Error()) + } + return encoder, err +} + func decoderOfType(typ reflect.Type) (Decoder, error) { typeName := typ.String() if typeName == "jsoniter.Any" { @@ -326,39 +333,39 @@ func decoderOfType(typ reflect.Type) (Decoder, error) { case reflect.Int: return &intCodec{}, nil case reflect.Int8: - return &int8Decoder{}, nil + return &int8Codec{}, nil case reflect.Int16: - return &int16Decoder{}, nil + return &int16Codec{}, nil case reflect.Int32: - return &int32Decoder{}, nil + return &int32Codec{}, nil case reflect.Int64: - return &int64Decoder{}, nil + return &int64Codec{}, nil case reflect.Uint: - return &uintDecoder{}, nil + return &uintCodec{}, nil case reflect.Uint8: - return &uint8Decoder{}, nil + return &uint8Codec{}, nil case reflect.Uint16: - return &uint16Decoder{}, nil + return &uint16Codec{}, nil case reflect.Uint32: - return &uint32Decoder{}, nil + return &uint32Codec{}, nil case reflect.Uint64: - return &uint64Decoder{}, nil + return &uint64Codec{}, nil case reflect.Float32: - return &float32Decoder{}, nil + return &float32Codec{}, nil case reflect.Float64: - return &float64Decoder{}, nil + return &float64Codec{}, nil case reflect.Bool: - return &boolDecoder{}, nil + return &boolCodec{}, nil case reflect.Interface: return &interfaceDecoder{}, nil case reflect.Struct: return decoderOfStruct(typ) case reflect.Slice: - return prefix("[slice]").addTo(decoderOfSlice(typ)) + return prefix("[slice]").addToDecoder(decoderOfSlice(typ)) case reflect.Map: - return prefix("[map]").addTo(decoderOfMap(typ)) + return prefix("[map]").addToDecoder(decoderOfMap(typ)) case reflect.Ptr: - return prefix("[optional]").addTo(decoderOfOptional(typ.Elem())) + return prefix("[optional]").addToDecoder(decoderOfOptional(typ.Elem())) default: return nil, fmt.Errorf("unsupported type: %v", typ) } @@ -372,6 +379,32 @@ func encoderOfType(typ reflect.Type) (Encoder, error) { return &stringCodec{}, nil case reflect.Int: return &intCodec{}, nil + case reflect.Int8: + return &int8Codec{}, nil + case reflect.Int16: + return &int16Codec{}, nil + case reflect.Int32: + return &int32Codec{}, nil + case reflect.Int64: + return &int64Codec{}, nil + case reflect.Uint: + return &uintCodec{}, nil + case reflect.Uint8: + return &uint8Codec{}, nil + case reflect.Uint16: + return &uint16Codec{}, nil + case reflect.Uint32: + return &uint32Codec{}, nil + case reflect.Uint64: + return &uint64Codec{}, nil + case reflect.Float32: + return &float32Codec{}, nil + case reflect.Float64: + return &float64Codec{}, nil + case reflect.Bool: + return &boolCodec{}, nil + case reflect.Struct: + return encoderOfStruct(typ) default: return nil, fmt.Errorf("unsupported type: %v", typ) } diff --git a/feature_reflect_native.go b/feature_reflect_native.go index f5ef546..af0d241 100644 --- a/feature_reflect_native.go +++ b/feature_reflect_native.go @@ -24,90 +24,138 @@ func (codec *intCodec) encode(ptr unsafe.Pointer, stream *Stream) { stream.WriteInt(*((*int)(ptr))) } -type int8Decoder struct { +type int8Codec struct { } -func (decoder *int8Decoder) decode(ptr unsafe.Pointer, iter *Iterator) { +func (codec *int8Codec) decode(ptr unsafe.Pointer, iter *Iterator) { *((*int8)(ptr)) = iter.ReadInt8() } -type int16Decoder struct { +func (codec *int8Codec) encode(ptr unsafe.Pointer, stream *Stream) { + stream.WriteInt8(*((*int8)(ptr))) } -func (decoder *int16Decoder) decode(ptr unsafe.Pointer, iter *Iterator) { +type int16Codec struct { +} + +func (codec *int16Codec) decode(ptr unsafe.Pointer, iter *Iterator) { *((*int16)(ptr)) = iter.ReadInt16() } -type int32Decoder struct { +func (codec *int16Codec) encode(ptr unsafe.Pointer, stream *Stream) { + stream.WriteInt16(*((*int16)(ptr))) } -func (decoder *int32Decoder) decode(ptr unsafe.Pointer, iter *Iterator) { +type int32Codec struct { +} + +func (codec *int32Codec) decode(ptr unsafe.Pointer, iter *Iterator) { *((*int32)(ptr)) = iter.ReadInt32() } -type int64Decoder struct { +func (codec *int32Codec) encode(ptr unsafe.Pointer, stream *Stream) { + stream.WriteInt32(*((*int32)(ptr))) } -func (decoder *int64Decoder) decode(ptr unsafe.Pointer, iter *Iterator) { +type int64Codec struct { +} + +func (codec *int64Codec) decode(ptr unsafe.Pointer, iter *Iterator) { *((*int64)(ptr)) = iter.ReadInt64() } -type uintDecoder struct { +func (codec *int64Codec) encode(ptr unsafe.Pointer, stream *Stream) { + stream.WriteInt64(*((*int64)(ptr))) } -func (decoder *uintDecoder) decode(ptr unsafe.Pointer, iter *Iterator) { +type uintCodec struct { +} + +func (codec *uintCodec) decode(ptr unsafe.Pointer, iter *Iterator) { *((*uint)(ptr)) = iter.ReadUint() } -type uint8Decoder struct { +func (codec *uintCodec) encode(ptr unsafe.Pointer, stream *Stream) { + stream.WriteUint(*((*uint)(ptr))) } -func (decoder *uint8Decoder) decode(ptr unsafe.Pointer, iter *Iterator) { +type uint8Codec struct { +} + +func (codec *uint8Codec) decode(ptr unsafe.Pointer, iter *Iterator) { *((*uint8)(ptr)) = iter.ReadUint8() } -type uint16Decoder struct { +func (codec *uint8Codec) encode(ptr unsafe.Pointer, stream *Stream) { + stream.WriteUint8(*((*uint8)(ptr))) } -func (decoder *uint16Decoder) decode(ptr unsafe.Pointer, iter *Iterator) { +type uint16Codec struct { +} + +func (decoder *uint16Codec) decode(ptr unsafe.Pointer, iter *Iterator) { *((*uint16)(ptr)) = iter.ReadUint16() } -type uint32Decoder struct { +func (decoder *uint16Codec) encode(ptr unsafe.Pointer, stream *Stream) { + stream.WriteUint16(*((*uint16)(ptr))) } -func (decoder *uint32Decoder) decode(ptr unsafe.Pointer, iter *Iterator) { +type uint32Codec struct { +} + +func (codec *uint32Codec) decode(ptr unsafe.Pointer, iter *Iterator) { *((*uint32)(ptr)) = iter.ReadUint32() } -type uint64Decoder struct { +func (codec *uint32Codec) encode(ptr unsafe.Pointer, stream *Stream) { + stream.WriteUint32(*((*uint32)(ptr))) } -func (decoder *uint64Decoder) decode(ptr unsafe.Pointer, iter *Iterator) { +type uint64Codec struct { +} + +func (codec *uint64Codec) decode(ptr unsafe.Pointer, iter *Iterator) { *((*uint64)(ptr)) = iter.ReadUint64() } -type float32Decoder struct { +func (codec *uint64Codec) encode(ptr unsafe.Pointer, stream *Stream) { + stream.WriteUint64(*((*uint64)(ptr))) } -func (decoder *float32Decoder) decode(ptr unsafe.Pointer, iter *Iterator) { +type float32Codec struct { +} + +func (codec *float32Codec) decode(ptr unsafe.Pointer, iter *Iterator) { *((*float32)(ptr)) = iter.ReadFloat32() } -type float64Decoder struct { +func (codec *float32Codec) encode(ptr unsafe.Pointer, stream *Stream) { + stream.WriteFloat32(*((*float32)(ptr))) } -func (decoder *float64Decoder) decode(ptr unsafe.Pointer, iter *Iterator) { +type float64Codec struct { +} + +func (codec *float64Codec) decode(ptr unsafe.Pointer, iter *Iterator) { *((*float64)(ptr)) = iter.ReadFloat64() } -type boolDecoder struct { +func (codec *float64Codec) encode(ptr unsafe.Pointer, stream *Stream) { + stream.WriteFloat64(*((*float64)(ptr))) } -func (decoder *boolDecoder) decode(ptr unsafe.Pointer, iter *Iterator) { +type boolCodec struct { +} + +func (codec *boolCodec) decode(ptr unsafe.Pointer, iter *Iterator) { *((*bool)(ptr)) = iter.ReadBool() } +func (codec *boolCodec) encode(ptr unsafe.Pointer, stream *Stream) { + stream.WriteBool(*((*bool)(ptr))) +} + type interfaceDecoder struct { } diff --git a/feature_reflect_object.go b/feature_reflect_object.go index 8f9637f..c66d853 100644 --- a/feature_reflect_object.go +++ b/feature_reflect_object.go @@ -8,6 +8,49 @@ import ( "strings" ) + +func encoderOfStruct(typ reflect.Type) (Encoder, error) { + structEncoder_ := &structEncoder{} + for i := 0; i < typ.NumField(); i++ { + field := typ.Field(i) + var fieldNames []string + for _, extension := range extensions { + alternativeFieldNames, _ := extension(typ, &field) + if alternativeFieldNames != nil { + fieldNames = alternativeFieldNames + } + } + tagParts := strings.Split(field.Tag.Get("json"), ",") + // if fieldNames set by extension, use theirs, otherwise try tags + if fieldNames == nil { + /// tagParts[0] always present, even if no tags + switch tagParts[0] { + case "": + fieldNames = []string{field.Name} + case "-": + fieldNames = []string{} + default: + fieldNames = []string{tagParts[0]} + } + } + encoder, err := encoderOfType(field.Type) + if err != nil { + return prefix(fmt.Sprintf("{%s}", field.Name)).addToEncoder(encoder, err) + } + for _, fieldName := range fieldNames { + if structEncoder_.firstField == nil { + structEncoder_.firstField = &structFieldEncoder{&field, fieldName, encoder} + } else { + structEncoder_.fields = append(structEncoder_.fields, &structFieldEncoder{&field, fieldName, encoder}) + } + } + } + if structEncoder_.firstField == nil { + return &emptyStructEncoder{}, nil + } + return structEncoder_, nil +} + func decoderOfStruct(typ reflect.Type) (Decoder, error) { fields := map[string]*structFieldDecoder{} for i := 0; i < typ.NumField(); i++ { @@ -41,7 +84,7 @@ func decoderOfStruct(typ reflect.Type) (Decoder, error) { var err error decoder, err = decoderOfType(field.Type) if err != nil { - return prefix(fmt.Sprintf("{%s}", field.Name)).addTo(decoder, err) + return prefix(fmt.Sprintf("{%s}", field.Name)).addToDecoder(decoder, err) } } if len(tagParts) > 1 && tagParts[1] == "string" { @@ -336,3 +379,42 @@ func (decoder *structFieldDecoder) decode(ptr unsafe.Pointer, iter *Iterator) { iter.Error = fmt.Errorf("%s: %s", decoder.field.Name, iter.Error.Error()) } } + +type structFieldEncoder struct { + field *reflect.StructField + fieldName string + fieldEncoder Encoder +} + +func (encoder *structFieldEncoder) encode(ptr unsafe.Pointer, stream *Stream) { + fieldPtr := uintptr(ptr) + encoder.field.Offset + stream.WriteObjectField(encoder.fieldName) + encoder.fieldEncoder.encode(unsafe.Pointer(fieldPtr), stream) + if stream.Error != nil && stream.Error != io.EOF { + stream.Error = fmt.Errorf("%s: %s", encoder.field.Name, stream.Error.Error()) + } +} + + +type structEncoder struct { + firstField *structFieldEncoder + fields []*structFieldEncoder +} + +func (encoder *structEncoder) encode(ptr unsafe.Pointer, stream *Stream) { + stream.WriteObjectStart() + encoder.firstField.encode(ptr, stream) + for _, field := range encoder.fields { + stream.WriteMore() + field.encode(ptr, stream) + } + stream.WriteObjectEnd() +} + +type emptyStructEncoder struct { +} + +func (encoder *emptyStructEncoder) encode(ptr unsafe.Pointer, stream *Stream) { + stream.WriteObjectStart() + stream.WriteObjectEnd() +} \ No newline at end of file diff --git a/jsoniter_bool_test.go b/jsoniter_bool_test.go index 1beee3b..6093e1c 100644 --- a/jsoniter_bool_test.go +++ b/jsoniter_bool_test.go @@ -30,4 +30,15 @@ func Test_write_true_false(t *testing.T) { stream.Flush() should.Nil(stream.Error) should.Equal("truefalse", buf.String()) +} + + +func Test_write_val_bool(t *testing.T) { + should := require.New(t) + buf := &bytes.Buffer{} + stream := NewStream(buf, 4096) + stream.WriteVal(true) + stream.Flush() + should.Nil(stream.Error) + should.Equal("true", buf.String()) } \ No newline at end of file diff --git a/jsoniter_float_test.go b/jsoniter_float_test.go index defa267..5c8862d 100644 --- a/jsoniter_float_test.go +++ b/jsoniter_float_test.go @@ -47,6 +47,15 @@ func Test_write_float32(t *testing.T) { should.Nil(stream.Error) should.Equal(strconv.FormatFloat(float64(val), 'f', -1, 32), buf.String()) }) + t.Run(fmt.Sprintf("%v", val), func(t *testing.T) { + should := require.New(t) + buf := &bytes.Buffer{} + stream := NewStream(buf, 4096) + stream.WriteVal(val) + stream.Flush() + should.Nil(stream.Error) + should.Equal(strconv.FormatFloat(float64(val), 'f', -1, 32), buf.String()) + }) } should := require.New(t) buf := &bytes.Buffer{} @@ -71,6 +80,15 @@ func Test_write_float64(t *testing.T) { should.Nil(stream.Error) should.Equal(strconv.FormatFloat(val, 'f', -1, 64), buf.String()) }) + t.Run(fmt.Sprintf("%v", val), func(t *testing.T) { + should := require.New(t) + buf := &bytes.Buffer{} + stream := NewStream(buf, 4096) + stream.WriteVal(val) + stream.Flush() + should.Nil(stream.Error) + should.Equal(strconv.FormatFloat(val, 'f', -1, 64), buf.String()) + }) } should := require.New(t) buf := &bytes.Buffer{} diff --git a/jsoniter_int_test.go b/jsoniter_int_test.go index ed178dd..63e0a32 100644 --- a/jsoniter_int_test.go +++ b/jsoniter_int_test.go @@ -84,11 +84,20 @@ func Test_write_uint8(t *testing.T) { should.Nil(stream.Error) should.Equal(strconv.FormatUint(uint64(val), 10), buf.String()) }) + t.Run(fmt.Sprintf("%v", val), func(t *testing.T) { + should := require.New(t) + buf := &bytes.Buffer{} + stream := NewStream(buf, 4096) + stream.WriteVal(val) + stream.Flush() + should.Nil(stream.Error) + should.Equal(strconv.FormatUint(uint64(val), 10), buf.String()) + }) } should := require.New(t) buf := &bytes.Buffer{} stream := NewStream(buf, 3) - stream.WriteString("a") + stream.WriteRaw("a") stream.WriteUint8(100) // should clear buffer stream.Flush() should.Nil(stream.Error) @@ -107,11 +116,20 @@ func Test_write_int8(t *testing.T) { should.Nil(stream.Error) should.Equal(strconv.FormatInt(int64(val), 10), buf.String()) }) + t.Run(fmt.Sprintf("%v", val), func(t *testing.T) { + should := require.New(t) + buf := &bytes.Buffer{} + stream := NewStream(buf, 4096) + stream.WriteVal(val) + stream.Flush() + should.Nil(stream.Error) + should.Equal(strconv.FormatInt(int64(val), 10), buf.String()) + }) } should := require.New(t) buf := &bytes.Buffer{} stream := NewStream(buf, 4) - stream.WriteString("a") + stream.WriteRaw("a") stream.WriteInt8(-100) // should clear buffer stream.Flush() should.Nil(stream.Error) @@ -130,11 +148,20 @@ func Test_write_uint16(t *testing.T) { should.Nil(stream.Error) should.Equal(strconv.FormatUint(uint64(val), 10), buf.String()) }) + t.Run(fmt.Sprintf("%v", val), func(t *testing.T) { + should := require.New(t) + buf := &bytes.Buffer{} + stream := NewStream(buf, 4096) + stream.WriteVal(val) + stream.Flush() + should.Nil(stream.Error) + should.Equal(strconv.FormatUint(uint64(val), 10), buf.String()) + }) } should := require.New(t) buf := &bytes.Buffer{} stream := NewStream(buf, 5) - stream.WriteString("a") + stream.WriteRaw("a") stream.WriteUint16(10000) // should clear buffer stream.Flush() should.Nil(stream.Error) @@ -153,11 +180,20 @@ func Test_write_int16(t *testing.T) { should.Nil(stream.Error) should.Equal(strconv.FormatInt(int64(val), 10), buf.String()) }) + t.Run(fmt.Sprintf("%v", val), func(t *testing.T) { + should := require.New(t) + buf := &bytes.Buffer{} + stream := NewStream(buf, 4096) + stream.WriteVal(val) + stream.Flush() + should.Nil(stream.Error) + should.Equal(strconv.FormatInt(int64(val), 10), buf.String()) + }) } should := require.New(t) buf := &bytes.Buffer{} stream := NewStream(buf, 6) - stream.WriteString("a") + stream.WriteRaw("a") stream.WriteInt16(-10000) // should clear buffer stream.Flush() should.Nil(stream.Error) @@ -176,11 +212,20 @@ func Test_write_uint32(t *testing.T) { should.Nil(stream.Error) should.Equal(strconv.FormatUint(uint64(val), 10), buf.String()) }) + t.Run(fmt.Sprintf("%v", val), func(t *testing.T) { + should := require.New(t) + buf := &bytes.Buffer{} + stream := NewStream(buf, 4096) + stream.WriteVal(val) + stream.Flush() + should.Nil(stream.Error) + should.Equal(strconv.FormatUint(uint64(val), 10), buf.String()) + }) } should := require.New(t) buf := &bytes.Buffer{} stream := NewStream(buf, 10) - stream.WriteString("a") + stream.WriteRaw("a") stream.WriteUint32(0xffffffff) // should clear buffer stream.Flush() should.Nil(stream.Error) @@ -199,11 +244,20 @@ func Test_write_int32(t *testing.T) { should.Nil(stream.Error) should.Equal(strconv.FormatInt(int64(val), 10), buf.String()) }) + t.Run(fmt.Sprintf("%v", val), func(t *testing.T) { + should := require.New(t) + buf := &bytes.Buffer{} + stream := NewStream(buf, 4096) + stream.WriteVal(val) + stream.Flush() + should.Nil(stream.Error) + should.Equal(strconv.FormatInt(int64(val), 10), buf.String()) + }) } should := require.New(t) buf := &bytes.Buffer{} stream := NewStream(buf, 11) - stream.WriteString("a") + stream.WriteRaw("a") stream.WriteInt32(-0x7fffffff) // should clear buffer stream.Flush() should.Nil(stream.Error) @@ -224,11 +278,20 @@ func Test_write_uint64(t *testing.T) { should.Nil(stream.Error) should.Equal(strconv.FormatUint(uint64(val), 10), buf.String()) }) + t.Run(fmt.Sprintf("%v", val), func(t *testing.T) { + should := require.New(t) + buf := &bytes.Buffer{} + stream := NewStream(buf, 4096) + stream.WriteVal(val) + stream.Flush() + should.Nil(stream.Error) + should.Equal(strconv.FormatUint(uint64(val), 10), buf.String()) + }) } should := require.New(t) buf := &bytes.Buffer{} stream := NewStream(buf, 10) - stream.WriteString("a") + stream.WriteRaw("a") stream.WriteUint64(0xffffffff) // should clear buffer stream.Flush() should.Nil(stream.Error) @@ -249,11 +312,20 @@ func Test_write_int64(t *testing.T) { should.Nil(stream.Error) should.Equal(strconv.FormatInt(val, 10), buf.String()) }) + t.Run(fmt.Sprintf("%v", val), func(t *testing.T) { + should := require.New(t) + buf := &bytes.Buffer{} + stream := NewStream(buf, 4096) + stream.WriteVal(val) + stream.Flush() + should.Nil(stream.Error) + should.Equal(strconv.FormatInt(val, 10), buf.String()) + }) } should := require.New(t) buf := &bytes.Buffer{} stream := NewStream(buf, 10) - stream.WriteString("a") + stream.WriteRaw("a") stream.WriteInt64(0xffffffff) // should clear buffer stream.Flush() should.Nil(stream.Error) diff --git a/jsoniter_reflect_struct_test.go b/jsoniter_reflect_struct_test.go index 3130b70..f70d728 100644 --- a/jsoniter_reflect_struct_test.go +++ b/jsoniter_reflect_struct_test.go @@ -109,4 +109,25 @@ func Test_decode_struct_field_with_tag(t *testing.T) { should.Equal("hello", obj.Field1) should.Equal("world", obj.Field2) should.Equal(100, obj.Field3) -} \ No newline at end of file +} + +func Test_write_val_zero_field_struct(t *testing.T) { + should := require.New(t) + type TestObject struct { + } + obj := TestObject{} + str, err := MarshalToString(obj) + should.Nil(err) + should.Equal(`{}`, str) +} + +func Test_write_val_one_field_struct(t *testing.T) { + should := require.New(t) + type TestObject struct { + Field1 string `json:"field-1"` + } + obj := TestObject{"hello"} + str, err := MarshalToString(obj) + should.Nil(err) + should.Equal(`{"field-1":"hello"}`, str) +}