diff --git a/feature_reflect.go b/feature_reflect.go index 4b0ab92..e1a7789 100644 --- a/feature_reflect.go +++ b/feature_reflect.go @@ -38,7 +38,7 @@ func writeToStream(val interface{}, stream *Stream, encoder Encoder) { type DecoderFunc func(ptr unsafe.Pointer, iter *Iterator) type EncoderFunc func(ptr unsafe.Pointer, stream *Stream) -type ExtensionFunc func(typ reflect.Type, field *reflect.StructField) ([]string, DecoderFunc) +type ExtensionFunc func(typ reflect.Type, field *reflect.StructField) ([]string, EncoderFunc, DecoderFunc) type funcDecoder struct { fun DecoderFunc diff --git a/feature_reflect_object.go b/feature_reflect_object.go index 722567a..cd7da3c 100644 --- a/feature_reflect_object.go +++ b/feature_reflect_object.go @@ -13,12 +13,16 @@ func encoderOfStruct(typ reflect.Type) (Encoder, error) { structEncoder_ := &structEncoder{} for i := 0; i < typ.NumField(); i++ { field := typ.Field(i) + fieldEncoderKey := fmt.Sprintf("%s/%s", typ.String(), field.Name) var extensionProvidedFieldNames []string for _, extension := range extensions { - alternativeFieldNames, _ := extension(typ, &field) + alternativeFieldNames, fun, _ := extension(typ, &field) if alternativeFieldNames != nil { extensionProvidedFieldNames = alternativeFieldNames } + if fun != nil { + fieldEncoders[fieldEncoderKey] = &funcEncoder{fun} + } } tagParts := strings.Split(field.Tag.Get("json"), ",") // if fieldNames set by extension, use theirs, otherwise try tags @@ -29,9 +33,9 @@ func encoderOfStruct(typ reflect.Type) (Encoder, error) { omitempty = true } } - var encoder Encoder + encoder := fieldEncoders[fieldEncoderKey] var err error - if len(fieldNames) > 0 { + if encoder == nil && len(fieldNames) > 0 { encoder, err = encoderOfType(field.Type) if err != nil { return prefix(fmt.Sprintf("{%s}", field.Name)).addToEncoder(encoder, err) @@ -59,7 +63,7 @@ func decoderOfStruct(typ reflect.Type) (Decoder, error) { fieldDecoderKey := fmt.Sprintf("%s/%s", typ.String(), field.Name) var extensionProviedFieldNames []string for _, extension := range extensions { - alternativeFieldNames, fun := extension(typ, &field) + alternativeFieldNames, _, fun := extension(typ, &field) if alternativeFieldNames != nil { extensionProviedFieldNames = alternativeFieldNames } @@ -112,8 +116,8 @@ func calcFieldNames(originalFieldName string, tagProvidedFieldName string, exten } func EnableUnexportedStructFieldsSupport() { - RegisterExtension(func(type_ reflect.Type, field *reflect.StructField) ([]string, DecoderFunc) { - return []string{field.Name}, nil + RegisterExtension(func(type_ reflect.Type, field *reflect.StructField) ([]string, EncoderFunc, DecoderFunc) { + return []string{field.Name}, nil, nil }) } diff --git a/jsoniter_customize_test.go b/jsoniter_customize_test.go index 9b502a9..8bb9984 100644 --- a/jsoniter_customize_test.go +++ b/jsoniter_customize_test.go @@ -86,22 +86,28 @@ type TestObject1 struct { } func Test_customize_field_by_extension(t *testing.T) { - RegisterExtension(func(type_ reflect.Type, field *reflect.StructField) ([]string, DecoderFunc) { + should := require.New(t) + RegisterExtension(func(type_ reflect.Type, field *reflect.StructField) ([]string, EncoderFunc, DecoderFunc) { if type_.String() == "jsoniter.TestObject1" && field.Name == "field1" { - return []string{"field-1"}, func(ptr unsafe.Pointer, iter *Iterator) { + encode := func(ptr unsafe.Pointer, stream *Stream) { + str := *((*string)(ptr)) + val, _ := strconv.Atoi(str) + stream.WriteInt(val) + } + decode := func(ptr unsafe.Pointer, iter *Iterator) { *((*string)(ptr)) = strconv.Itoa(iter.ReadInt()) } + return []string{"field-1"}, encode, decode } - return nil, nil + return nil, nil, nil }) obj := TestObject1{} - err := Unmarshal([]byte(`{"field-1": 100}`), &obj) - if err != nil { - t.Fatal(err) - } - if obj.field1 != "100" { - t.Fatal(obj.field1) - } + err := UnmarshalFromString(`{"field-1": 100}`, &obj) + should.Nil(err) + should.Equal("100", obj.field1) + str, err := MarshalToString(obj) + should.Nil(err) + should.Equal(`{"field-1":100}`, str) } func Test_unexported_fields(t *testing.T) {