You've already forked json-iterator
							
							
				mirror of
				https://github.com/json-iterator/go.git
				synced 2025-10-31 00:07:40 +02:00 
			
		
		
		
	support TextMarshaler as map key
This commit is contained in:
		| @@ -6,6 +6,7 @@ import ( | ||||
| 	"sync/atomic" | ||||
| 	"unsafe" | ||||
| 	"encoding/json" | ||||
| 	"encoding" | ||||
| ) | ||||
|  | ||||
| /* | ||||
| @@ -77,6 +78,7 @@ var jsonRawMessageType reflect.Type | ||||
| var anyType reflect.Type | ||||
| var marshalerType reflect.Type | ||||
| var unmarshalerType reflect.Type | ||||
| var textUnmarshalerType reflect.Type | ||||
|  | ||||
| func init() { | ||||
| 	typeDecoders = map[string]Decoder{} | ||||
| @@ -91,6 +93,7 @@ func init() { | ||||
| 	anyType = reflect.TypeOf((*Any)(nil)).Elem() | ||||
| 	marshalerType = reflect.TypeOf((*json.Marshaler)(nil)).Elem() | ||||
| 	unmarshalerType = reflect.TypeOf((*json.Unmarshaler)(nil)).Elem() | ||||
| 	textUnmarshalerType = reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem() | ||||
| } | ||||
|  | ||||
| func addDecoderToCache(cacheKey reflect.Type, decoder Decoder) { | ||||
|   | ||||
| @@ -25,45 +25,49 @@ func (decoder *mapDecoder) decode(ptr unsafe.Pointer, iter *Iterator) { | ||||
| 	if realVal.IsNil() { | ||||
| 		realVal.Set(reflect.MakeMap(realVal.Type())) | ||||
| 	} | ||||
| 	iter.ReadObjectCB(func(iter *Iterator, keyStr string) bool{ | ||||
| 	iter.ReadObjectCB(func(iter *Iterator, keyStr string) bool { | ||||
| 		elem := reflect.New(decoder.elemType) | ||||
| 		decoder.elemDecoder.decode(unsafe.Pointer(elem.Pointer()), iter) | ||||
| 		// to put into map, we have to use reflection | ||||
| 		realVal.SetMapIndex(decodeMapKey(iter, keyStr, decoder.keyType), elem.Elem()) | ||||
| 		keyType := decoder.keyType | ||||
| 		switch { | ||||
| 		case keyType.Kind() == reflect.String: | ||||
| 			realVal.SetMapIndex(reflect.ValueOf(keyStr), elem.Elem()) | ||||
| 			return true | ||||
| 		case keyType.Implements(textUnmarshalerType): | ||||
| 			textUnmarshaler := reflect.New(keyType.Elem()).Interface().(encoding.TextUnmarshaler) | ||||
| 			err := textUnmarshaler.UnmarshalText([]byte(keyStr)) | ||||
| 			if err != nil { | ||||
| 				iter.reportError("read map key as TextUnmarshaler", err.Error()) | ||||
| 				return false | ||||
| 			} | ||||
| 			realVal.SetMapIndex(reflect.ValueOf(textUnmarshaler), elem.Elem()) | ||||
| 			return true | ||||
| 		default: | ||||
| 			switch keyType.Kind() { | ||||
| 			case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: | ||||
| 				n, err := strconv.ParseInt(keyStr, 10, 64) | ||||
| 				if err != nil || reflect.Zero(keyType).OverflowInt(n) { | ||||
| 					iter.reportError("read map key as int64", "read int64 failed") | ||||
| 					return false | ||||
| 				} | ||||
| 				realVal.SetMapIndex(reflect.ValueOf(n).Convert(keyType), elem.Elem()) | ||||
| 				return true | ||||
| 			case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: | ||||
| 				n, err := strconv.ParseUint(keyStr, 10, 64) | ||||
| 				if err != nil || reflect.Zero(keyType).OverflowUint(n) { | ||||
| 					iter.reportError("read map key as uint64", "read uint64 failed") | ||||
| 					return false | ||||
| 				} | ||||
| 				realVal.SetMapIndex(reflect.ValueOf(n).Convert(keyType), elem.Elem()) | ||||
| 				return true | ||||
| 			} | ||||
| 		} | ||||
| 		iter.reportError("read map key", "unexpected map key type "+keyType.String()) | ||||
| 		return true | ||||
| 	}) | ||||
| } | ||||
|  | ||||
| func decodeMapKey(iter *Iterator, keyStr string, keyType reflect.Type) reflect.Value { | ||||
| 	switch { | ||||
| 	case keyType.Kind() == reflect.String: | ||||
| 		return reflect.ValueOf(keyStr) | ||||
| 	//case reflect.PtrTo(kt).Implements(textUnmarshalerType): | ||||
| 	//	kv = reflect.New(v.Type().Key()) | ||||
| 	//	d.literalStore(item, kv, true) | ||||
| 	//	kv = kv.Elem() | ||||
| 	default: | ||||
| 		switch keyType.Kind() { | ||||
| 		case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: | ||||
| 			n, err := strconv.ParseInt(keyStr, 10, 64) | ||||
| 			if err != nil || reflect.Zero(keyType).OverflowInt(n) { | ||||
| 				iter.reportError("read map key as int64", "read int64 failed") | ||||
| 				return reflect.ValueOf("") | ||||
| 			} | ||||
| 			return reflect.ValueOf(n).Convert(keyType) | ||||
| 		case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: | ||||
| 			n, err := strconv.ParseUint(keyStr, 10, 64) | ||||
| 			if err != nil || reflect.Zero(keyType).OverflowUint(n) { | ||||
| 				iter.reportError("read map key as uint64", "read uint64 failed") | ||||
| 				return reflect.ValueOf("") | ||||
| 			} | ||||
| 			return reflect.ValueOf(n).Convert(keyType) | ||||
| 		} | ||||
| 	} | ||||
| 	iter.reportError("read map key", "json: Unexpected key type") | ||||
| 	return reflect.ValueOf("") | ||||
| } | ||||
|  | ||||
| type mapEncoder struct { | ||||
| 	mapType      reflect.Type | ||||
| 	elemType     reflect.Type | ||||
| @@ -131,4 +135,4 @@ func (encoder *mapEncoder) isEmpty(ptr unsafe.Pointer) bool { | ||||
| 	realInterface := (*interface{})(unsafe.Pointer(&mapInterface)) | ||||
| 	realVal := reflect.ValueOf(*realInterface) | ||||
| 	return realVal.Len() == 0 | ||||
| } | ||||
| } | ||||
|   | ||||
| @@ -3,6 +3,7 @@ package jsoniter | ||||
| import ( | ||||
| 	"testing" | ||||
| 	"github.com/json-iterator/go/require" | ||||
| 	"math/big" | ||||
| ) | ||||
|  | ||||
| func Test_read_map(t *testing.T) { | ||||
| @@ -81,4 +82,22 @@ func Test_decode_int_key_map(t *testing.T) { | ||||
| 	var val map[int]string | ||||
| 	should.Nil(UnmarshalFromString(`{"1":"2"}`, &val)) | ||||
| 	should.Equal(map[int]string{1: "2"}, val) | ||||
| } | ||||
|  | ||||
| func Test_encode_TextMarshaler_key_map(t *testing.T) { | ||||
| 	should := require.New(t) | ||||
| 	f, _, _  := big.ParseFloat("1", 10, 64, big.ToZero) | ||||
| 	val := map[*big.Float]string{f: "2"} | ||||
| 	str, err := MarshalToString(val) | ||||
| 	should.Nil(err) | ||||
| 	should.Equal(`{"1":"2"}`, str) | ||||
| } | ||||
|  | ||||
| func Test_decode_TextMarshaler_key_map(t *testing.T) { | ||||
| 	should := require.New(t) | ||||
| 	var val map[*big.Float]string | ||||
| 	should.Nil(UnmarshalFromString(`{"1":"2"}`, &val)) | ||||
| 	str, err := MarshalToString(val) | ||||
| 	should.Nil(err) | ||||
| 	should.Equal(`{"1":"2"}`, str) | ||||
| } | ||||
		Reference in New Issue
	
	Block a user