From 8f6a840c63594b300f7a980300558400e59d3ae1 Mon Sep 17 00:00:00 2001 From: Tao Wen Date: Tue, 20 Jun 2017 13:33:40 +0800 Subject: [PATCH] fix anonymous struct --- feature_reflect_extension.go | 120 +++++++++++++++++++++-------------- feature_reflect_object.go | 41 ++++++------ jsoniter_customize_test.go | 2 +- jsoniter_object_test.go | 73 +++++++++++++++++++++ 4 files changed, 169 insertions(+), 67 deletions(-) diff --git a/feature_reflect_extension.go b/feature_reflect_extension.go index 3f21b91..8e572f9 100644 --- a/feature_reflect_extension.go +++ b/feature_reflect_extension.go @@ -16,14 +16,22 @@ var extensions = []Extension{} type StructDescriptor struct { Type reflect.Type - Fields map[string]*Binding + Fields []*Binding +} + +func (structDescriptor *StructDescriptor) GetField(fieldName string) *Binding { + for _, binding := range structDescriptor.Fields { + if binding.Field.Name == fieldName { + return binding + } + } + return nil } type Binding struct { Field *reflect.StructField FromNames []string ToNames []string - ShouldOmitEmpty bool Encoder ValEncoder Decoder ValDecoder } @@ -131,47 +139,75 @@ func getTypeEncoderFromExtension(typ reflect.Type) ValEncoder { } func describeStruct(cfg *frozenConfig, typ reflect.Type) (*StructDescriptor, error) { - bindings := map[string]*Binding{} - for _, field := range listStructFields(typ) { - tagParts := strings.Split(field.Tag.Get("json"), ",") - fieldNames := calcFieldNames(field.Name, tagParts[0]) - fieldCacheKey := fmt.Sprintf("%s/%s", typ.String(), field.Name) - decoder := fieldDecoders[fieldCacheKey] - if decoder == nil && len(fieldNames) > 0 { - var err error - decoder, err = decoderOfType(cfg, field.Type) - if err != nil { - return nil, err + bindings := []*Binding{} + for i := 0; i < typ.NumField(); i++ { + field := typ.Field(i) + if field.Anonymous { + if field.Type.Kind() == reflect.Struct { + structDescriptor, err := describeStruct(cfg, field.Type) + if err != nil { + return nil, err + } + for _, binding := range structDescriptor.Fields { + bindings = append(bindings, binding) + } + } else if field.Type.Kind() == reflect.Ptr && field.Type.Elem().Kind() == reflect.Struct { + structDescriptor, err := describeStruct(cfg, field.Type.Elem()) + if err != nil { + return nil, err + } + for _, binding := range structDescriptor.Fields { + binding.Encoder = &optionalEncoder{binding.Encoder} + binding.Encoder = &structFieldEncoder{&field, binding.Encoder, false} + binding.Decoder = &optionalDecoder{field.Type, binding.Decoder} + binding.Decoder = &structFieldDecoder{&field, binding.Decoder} + bindings = append(bindings, binding) + } } - } - encoder := fieldEncoders[fieldCacheKey] - if encoder == nil && len(fieldNames) > 0 { - var err error - encoder, err = encoderOfType(cfg, field.Type) - if err != nil { - return nil, err + } else { + tagParts := strings.Split(field.Tag.Get("json"), ",") + fieldNames := calcFieldNames(field.Name, tagParts[0]) + fieldCacheKey := fmt.Sprintf("%s/%s", typ.String(), field.Name) + decoder := fieldDecoders[fieldCacheKey] + if decoder == nil && len(fieldNames) > 0 { + var err error + decoder, err = decoderOfType(cfg, field.Type) + if err != nil { + return nil, err + } } - // map is stored as pointer in the struct - if field.Type.Kind() == reflect.Map { - encoder = &optionalEncoder{encoder} + encoder := fieldEncoders[fieldCacheKey] + if encoder == nil && len(fieldNames) > 0 { + var err error + encoder, err = encoderOfType(cfg, field.Type) + if err != nil { + return nil, err + } + // map is stored as pointer in the struct + if field.Type.Kind() == reflect.Map { + encoder = &optionalEncoder{encoder} + } } - } - binding := &Binding{ - Field: field, - FromNames: fieldNames, - ToNames: fieldNames, - Decoder: decoder, - Encoder: encoder, - } - for _, tagPart := range tagParts[1:] { - if tagPart == "omitempty" { - binding.ShouldOmitEmpty = true - } else if tagPart == "string" { - binding.Decoder = &stringModeDecoder{binding.Decoder} - binding.Encoder = &stringModeEncoder{binding.Encoder} + binding := &Binding{ + Field: &field, + FromNames: fieldNames, + ToNames: fieldNames, + Decoder: decoder, + Encoder: encoder, } + shouldOmitEmpty := false + for _, tagPart := range tagParts[1:] { + if tagPart == "omitempty" { + shouldOmitEmpty = true + } else if tagPart == "string" { + binding.Decoder = &stringModeDecoder{binding.Decoder} + binding.Encoder = &stringModeEncoder{binding.Encoder} + } + } + binding.Decoder = &structFieldDecoder{&field, binding.Decoder} + binding.Encoder = &structFieldEncoder{&field, binding.Encoder, shouldOmitEmpty} + bindings = append(bindings, binding) } - bindings[field.Name] = binding } structDescriptor := &StructDescriptor{ Type: typ, @@ -185,14 +221,6 @@ func describeStruct(cfg *frozenConfig, typ reflect.Type) (*StructDescriptor, err func listStructFields(typ reflect.Type) []*reflect.StructField { fields := []*reflect.StructField{} - for i := 0; i < typ.NumField(); i++ { - field := typ.Field(i) - if field.Anonymous { - fields = append(fields, listStructFields(field.Type)...) - } else { - fields = append(fields, &field) - } - } return fields } diff --git a/feature_reflect_object.go b/feature_reflect_object.go index 9ca9aa4..e355308 100644 --- a/feature_reflect_object.go +++ b/feature_reflect_object.go @@ -8,24 +8,20 @@ import ( ) func encoderOfStruct(cfg *frozenConfig, typ reflect.Type) (ValEncoder, error) { - structEncoder_ := &structEncoder{} fields := map[string]*structFieldEncoder{} structDescriptor, err := describeStruct(cfg, typ) if err != nil { return nil, err } for _, binding := range structDescriptor.Fields { - for _, fieldName := range binding.ToNames { - fields[fieldName] = &structFieldEncoder{binding.Field, fieldName, binding.Encoder, binding.ShouldOmitEmpty} + for _, toName := range binding.ToNames { + fields[toName] = binding.Encoder.(*structFieldEncoder) } } if len(fields) == 0 { return &emptyStructEncoder{}, nil } - for _, field := range fields { - structEncoder_.fields = append(structEncoder_.fields, field) - } - return structEncoder_, nil + return &structEncoder{fields}, nil } func decoderOfStruct(cfg *frozenConfig, typ reflect.Type) (ValDecoder, error) { @@ -35,8 +31,8 @@ func decoderOfStruct(cfg *frozenConfig, typ reflect.Type) (ValDecoder, error) { return nil, err } for _, binding := range structDescriptor.Fields { - for _, fieldName := range binding.FromNames { - fields[fieldName] = &structFieldDecoder{binding.Field, binding.Decoder} + for _, fromName := range binding.FromNames { + fields[fromName] = binding.Decoder.(*structFieldDecoder) } } return createStructDecoder(typ, fields) @@ -959,14 +955,12 @@ func (decoder *structFieldDecoder) decode(ptr unsafe.Pointer, iter *Iterator) { type structFieldEncoder struct { field *reflect.StructField - fieldName string fieldEncoder ValEncoder omitempty bool } func (encoder *structFieldEncoder) encode(ptr unsafe.Pointer, stream *Stream) { fieldPtr := uintptr(ptr) + encoder.field.Offset - stream.WriteObjectField(encoder.fieldName) encoder.fieldEncoder.encode(unsafe.Pointer(fieldPtr), stream) if stream.Error != nil && stream.Error != io.EOF { stream.Error = fmt.Errorf("%s: %s", encoder.field.Name, stream.Error.Error()) @@ -983,19 +977,20 @@ func (encoder *structFieldEncoder) isEmpty(ptr unsafe.Pointer) bool { } type structEncoder struct { - fields []*structFieldEncoder + fields map[string]*structFieldEncoder } func (encoder *structEncoder) encode(ptr unsafe.Pointer, stream *Stream) { stream.WriteObjectStart() isNotFirst := false - for _, field := range encoder.fields { + for fieldName, field := range encoder.fields { if field.omitempty && field.isEmpty(ptr) { continue } if isNotFirst { stream.WriteMore() } + stream.WriteObjectField(fieldName) field.encode(ptr, stream) isNotFirst = true } @@ -1006,17 +1001,23 @@ func (encoder *structEncoder) encodeInterface(val interface{}, stream *Stream) { var encoderToUse ValEncoder encoderToUse = encoder if len(encoder.fields) == 1 { - firstEncoder := encoder.fields[0].fieldEncoder + var firstField *structFieldEncoder + var firstFieldName string + for fieldName, field := range encoder.fields { + firstFieldName = fieldName + firstField = field + } + firstEncoder := firstField.fieldEncoder firstEncoderName := reflect.TypeOf(firstEncoder).String() // interface{} has inline optimization for this case if firstEncoderName == "*jsoniter.optionalEncoder" { encoderToUse = &structEncoder{ - fields: []*structFieldEncoder{{ - field: encoder.fields[0].field, - fieldName: encoder.fields[0].fieldName, - fieldEncoder: firstEncoder.(*optionalEncoder).valueEncoder, - omitempty: encoder.fields[0].omitempty, - }}, + fields: map[string]*structFieldEncoder{ + firstFieldName: { + field: firstField.field, + fieldEncoder: firstEncoder.(*optionalEncoder).valueEncoder, + omitempty: firstField.omitempty, + }}, } } } diff --git a/jsoniter_customize_test.go b/jsoniter_customize_test.go index c6b47f4..35d0a6e 100644 --- a/jsoniter_customize_test.go +++ b/jsoniter_customize_test.go @@ -93,7 +93,7 @@ func (extension *testExtension) UpdateStructDescriptor(structDescriptor *StructD if structDescriptor.Type.String() != "jsoniter.TestObject1" { return } - binding := structDescriptor.Fields["field1"] + binding := structDescriptor.GetField("field1") binding.Encoder = &funcEncoder{fun: func(ptr unsafe.Pointer, stream *Stream) { str := *((*string)(ptr)) val, _ := strconv.Atoi(str) diff --git a/jsoniter_object_test.go b/jsoniter_object_test.go index 07436f3..26ce20a 100644 --- a/jsoniter_object_test.go +++ b/jsoniter_object_test.go @@ -323,6 +323,79 @@ func Test_decode_anonymous_struct(t *testing.T) { should.Equal("value", outer.Key) } +func Test_multiple_level_anonymous_struct(t *testing.T) { + type Level1 struct { + Field1 string + } + type Level2 struct { + Level1 + Field2 string + } + type Level3 struct { + Level2 + Field3 string + } + should := require.New(t) + output, err := MarshalToString(Level3{Level2{Level1{"1"}, "2"}, "3"}) + should.Nil(err) + should.Contains(output, `"Field1":"1"`) + should.Contains(output, `"Field2":"2"`) + should.Contains(output, `"Field3":"3"`) +} + +func Test_multiple_level_anonymous_struct_with_ptr(t *testing.T) { + type Level1 struct { + Field1 string + Field2 string + Field4 string + } + type Level2 struct { + *Level1 + Field2 string + Field3 string + } + type Level3 struct { + *Level2 + Field3 string + } + should := require.New(t) + output, err := MarshalToString(Level3{&Level2{&Level1{"1", "", "4"}, "2", ""}, "3"}) + should.Nil(err) + should.Contains(output, `"Field1":"1"`) + should.Contains(output, `"Field2":"2"`) + should.Contains(output, `"Field3":"3"`) + should.Contains(output, `"Field4":"4"`) +} + + + +func Test_shadow_struct_field(t *testing.T) { + should := require.New(t) + type omit *struct{} + type CacheItem struct { + Key string `json:"key"` + MaxAge int `json:"cacheAge"` + } + output, err := MarshalToString(struct { + *CacheItem + + // Omit bad keys + OmitMaxAge omit `json:"cacheAge,omitempty"` + + // Add nice keys + MaxAge int `json:"max_age"` + }{ + CacheItem: &CacheItem{ + Key: "value", + MaxAge: 100, + }, + MaxAge: 20, + }) + should.Nil(err) + should.Contains(output, `"key":"value"`) + should.Contains(output, `"max_age":20`) +} + func Test_decode_nested(t *testing.T) { type StructOfString struct { Field1 string