You've already forked json-iterator
							
							
				mirror of
				https://github.com/json-iterator/go.git
				synced 2025-10-31 00:07:40 +02:00 
			
		
		
		
	#28 extension should support specifying encoder
This commit is contained in:
		| @@ -38,7 +38,7 @@ func writeToStream(val interface{}, stream *Stream, encoder Encoder) { | |||||||
|  |  | ||||||
| type DecoderFunc func(ptr unsafe.Pointer, iter *Iterator) | type DecoderFunc func(ptr unsafe.Pointer, iter *Iterator) | ||||||
| type EncoderFunc func(ptr unsafe.Pointer, stream *Stream) | type EncoderFunc func(ptr unsafe.Pointer, stream *Stream) | ||||||
| type ExtensionFunc func(typ reflect.Type, field *reflect.StructField) ([]string, DecoderFunc) | type ExtensionFunc func(typ reflect.Type, field *reflect.StructField) ([]string, EncoderFunc, DecoderFunc) | ||||||
|  |  | ||||||
| type funcDecoder struct { | type funcDecoder struct { | ||||||
| 	fun DecoderFunc | 	fun DecoderFunc | ||||||
|   | |||||||
| @@ -13,12 +13,16 @@ func encoderOfStruct(typ reflect.Type) (Encoder, error) { | |||||||
| 	structEncoder_ := &structEncoder{} | 	structEncoder_ := &structEncoder{} | ||||||
| 	for i := 0; i < typ.NumField(); i++ { | 	for i := 0; i < typ.NumField(); i++ { | ||||||
| 		field := typ.Field(i) | 		field := typ.Field(i) | ||||||
|  | 		fieldEncoderKey := fmt.Sprintf("%s/%s", typ.String(), field.Name) | ||||||
| 		var extensionProvidedFieldNames []string | 		var extensionProvidedFieldNames []string | ||||||
| 		for _, extension := range extensions { | 		for _, extension := range extensions { | ||||||
| 			alternativeFieldNames, _ := extension(typ, &field) | 			alternativeFieldNames, fun, _ := extension(typ, &field) | ||||||
| 			if alternativeFieldNames != nil { | 			if alternativeFieldNames != nil { | ||||||
| 				extensionProvidedFieldNames = alternativeFieldNames | 				extensionProvidedFieldNames = alternativeFieldNames | ||||||
| 			} | 			} | ||||||
|  | 			if fun != nil { | ||||||
|  | 				fieldEncoders[fieldEncoderKey] = &funcEncoder{fun} | ||||||
|  | 			} | ||||||
| 		} | 		} | ||||||
| 		tagParts := strings.Split(field.Tag.Get("json"), ",") | 		tagParts := strings.Split(field.Tag.Get("json"), ",") | ||||||
| 		// if fieldNames set by extension, use theirs, otherwise try tags | 		// if fieldNames set by extension, use theirs, otherwise try tags | ||||||
| @@ -29,9 +33,9 @@ func encoderOfStruct(typ reflect.Type) (Encoder, error) { | |||||||
| 				omitempty = true | 				omitempty = true | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
| 		var encoder Encoder | 		encoder := fieldEncoders[fieldEncoderKey] | ||||||
| 		var err error | 		var err error | ||||||
| 		if len(fieldNames) > 0 { | 		if encoder == nil && len(fieldNames) > 0 { | ||||||
| 			encoder, err = encoderOfType(field.Type) | 			encoder, err = encoderOfType(field.Type) | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 				return prefix(fmt.Sprintf("{%s}", field.Name)).addToEncoder(encoder, err) | 				return prefix(fmt.Sprintf("{%s}", field.Name)).addToEncoder(encoder, err) | ||||||
| @@ -59,7 +63,7 @@ func decoderOfStruct(typ reflect.Type) (Decoder, error) { | |||||||
| 		fieldDecoderKey := fmt.Sprintf("%s/%s", typ.String(), field.Name) | 		fieldDecoderKey := fmt.Sprintf("%s/%s", typ.String(), field.Name) | ||||||
| 		var extensionProviedFieldNames []string | 		var extensionProviedFieldNames []string | ||||||
| 		for _, extension := range extensions { | 		for _, extension := range extensions { | ||||||
| 			alternativeFieldNames, fun := extension(typ, &field) | 			alternativeFieldNames, _, fun := extension(typ, &field) | ||||||
| 			if alternativeFieldNames != nil { | 			if alternativeFieldNames != nil { | ||||||
| 				extensionProviedFieldNames = alternativeFieldNames | 				extensionProviedFieldNames = alternativeFieldNames | ||||||
| 			} | 			} | ||||||
| @@ -112,8 +116,8 @@ func calcFieldNames(originalFieldName string, tagProvidedFieldName string, exten | |||||||
| } | } | ||||||
|  |  | ||||||
| func EnableUnexportedStructFieldsSupport() { | func EnableUnexportedStructFieldsSupport() { | ||||||
| 	RegisterExtension(func(type_ reflect.Type, field *reflect.StructField) ([]string, DecoderFunc) { | 	RegisterExtension(func(type_ reflect.Type, field *reflect.StructField) ([]string, EncoderFunc, DecoderFunc) { | ||||||
| 		return []string{field.Name}, nil | 		return []string{field.Name}, nil, nil | ||||||
| 	}) | 	}) | ||||||
| } | } | ||||||
|  |  | ||||||
|   | |||||||
| @@ -86,22 +86,28 @@ type TestObject1 struct { | |||||||
| } | } | ||||||
|  |  | ||||||
| func Test_customize_field_by_extension(t *testing.T) { | func Test_customize_field_by_extension(t *testing.T) { | ||||||
| 	RegisterExtension(func(type_ reflect.Type, field *reflect.StructField) ([]string, DecoderFunc) { | 	should := require.New(t) | ||||||
|  | 	RegisterExtension(func(type_ reflect.Type, field *reflect.StructField) ([]string, EncoderFunc, DecoderFunc) { | ||||||
| 		if type_.String() == "jsoniter.TestObject1" && field.Name == "field1" { | 		if type_.String() == "jsoniter.TestObject1" && field.Name == "field1" { | ||||||
| 			return []string{"field-1"}, func(ptr unsafe.Pointer, iter *Iterator) { | 			encode := func(ptr unsafe.Pointer, stream *Stream) { | ||||||
|  | 				str := *((*string)(ptr)) | ||||||
|  | 				val, _ := strconv.Atoi(str) | ||||||
|  | 				stream.WriteInt(val) | ||||||
|  | 			} | ||||||
|  | 			decode := func(ptr unsafe.Pointer, iter *Iterator) { | ||||||
| 				*((*string)(ptr)) = strconv.Itoa(iter.ReadInt()) | 				*((*string)(ptr)) = strconv.Itoa(iter.ReadInt()) | ||||||
| 			} | 			} | ||||||
|  | 			return []string{"field-1"}, encode, decode | ||||||
| 		} | 		} | ||||||
| 		return nil, nil | 		return nil, nil, nil | ||||||
| 	}) | 	}) | ||||||
| 	obj := TestObject1{} | 	obj := TestObject1{} | ||||||
| 	err := Unmarshal([]byte(`{"field-1": 100}`), &obj) | 	err := UnmarshalFromString(`{"field-1": 100}`, &obj) | ||||||
| 	if err != nil { | 	should.Nil(err) | ||||||
| 		t.Fatal(err) | 	should.Equal("100", obj.field1) | ||||||
| 	} | 	str, err := MarshalToString(obj) | ||||||
| 	if obj.field1 != "100" { | 	should.Nil(err) | ||||||
| 		t.Fatal(obj.field1) | 	should.Equal(`{"field-1":100}`, str) | ||||||
| 	} |  | ||||||
| } | } | ||||||
|  |  | ||||||
| func Test_unexported_fields(t *testing.T) { | func Test_unexported_fields(t *testing.T) { | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user