mirror of
https://github.com/json-iterator/go.git
synced 2025-04-20 11:28:49 +02:00
support TextMarshaler as map key
This commit is contained in:
parent
577ddede74
commit
d8e64aa825
@ -358,6 +358,7 @@ func createEncoderOfType(cfg *frozenConfig, prefix string, typ reflect.Type) Val
|
||||
checkIsEmpty := createCheckIsEmpty(cfg, typ)
|
||||
var encoder ValEncoder = &directTextMarshalerEncoder{
|
||||
checkIsEmpty: checkIsEmpty,
|
||||
stringEncoder: cfg.EncoderOf(reflect.TypeOf("")),
|
||||
}
|
||||
return encoder
|
||||
}
|
||||
@ -365,14 +366,16 @@ func createEncoderOfType(cfg *frozenConfig, prefix string, typ reflect.Type) Val
|
||||
checkIsEmpty := createCheckIsEmpty(cfg, typ)
|
||||
var encoder ValEncoder = &textMarshalerEncoder{
|
||||
valType: reflect2.Type2(typ),
|
||||
stringEncoder: cfg.EncoderOf(reflect.TypeOf("")),
|
||||
checkIsEmpty: checkIsEmpty,
|
||||
}
|
||||
return encoder
|
||||
}
|
||||
if ptrType.Implements(textMarshalerType) {
|
||||
if typ.Kind() == reflect.Map && ptrType.Implements(textMarshalerType) {
|
||||
checkIsEmpty := createCheckIsEmpty(cfg, ptrType)
|
||||
var encoder ValEncoder = &textMarshalerEncoder{
|
||||
valType: reflect2.Type2(ptrType),
|
||||
stringEncoder: cfg.EncoderOf(reflect.TypeOf("")),
|
||||
checkIsEmpty: checkIsEmpty,
|
||||
}
|
||||
return &referenceEncoder{encoder}
|
||||
|
@ -2,11 +2,12 @@ package jsoniter
|
||||
|
||||
import (
|
||||
"encoding"
|
||||
"encoding/json"
|
||||
"reflect"
|
||||
"sort"
|
||||
"strconv"
|
||||
"unsafe"
|
||||
"github.com/v2pro/plz/reflect2"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
func decoderOfMap(cfg *frozenConfig, prefix string, typ reflect.Type) ValDecoder {
|
||||
@ -16,13 +17,48 @@ func decoderOfMap(cfg *frozenConfig, prefix string, typ reflect.Type) ValDecoder
|
||||
}
|
||||
|
||||
func encoderOfMap(cfg *frozenConfig, prefix string, typ reflect.Type) ValEncoder {
|
||||
elemType := typ.Elem()
|
||||
encoder := &emptyInterfaceCodec{}
|
||||
mapInterface := reflect.New(typ).Elem().Interface()
|
||||
if cfg.sortMapKeys {
|
||||
return &sortKeysMapEncoder{typ, elemType, encoder, *((*emptyInterface)(unsafe.Pointer(&mapInterface)))}
|
||||
return &sortKeysMapEncoder{
|
||||
mapType: reflect2.Type2(typ).(*reflect2.UnsafeMapType),
|
||||
keyEncoder: encoderOfMapKey(cfg, prefix+" [mapKey]", typ.Key()),
|
||||
elemEncoder: encoderOfType(cfg, prefix+" [mapElem]", typ.Elem()),
|
||||
}
|
||||
}
|
||||
return &mapEncoder{
|
||||
mapType: reflect2.Type2(typ).(*reflect2.UnsafeMapType),
|
||||
keyEncoder: encoderOfMapKey(cfg, prefix+" [mapKey]", typ.Key()),
|
||||
elemEncoder: encoderOfType(cfg, prefix+" [mapElem]", typ.Elem()),
|
||||
}
|
||||
}
|
||||
|
||||
func encoderOfMapKey(cfg *frozenConfig, prefix string, typ reflect.Type) ValEncoder {
|
||||
switch typ.Kind() {
|
||||
case reflect.String:
|
||||
return encoderOfType(cfg, prefix, reflect2.DefaultTypeOfKind(reflect.String).Type1())
|
||||
case reflect.Bool,
|
||||
reflect.Uint8, reflect.Int8,
|
||||
reflect.Uint16, reflect.Int16,
|
||||
reflect.Uint32, reflect.Int32,
|
||||
reflect.Uint64, reflect.Int64,
|
||||
reflect.Uint, reflect.Int,
|
||||
reflect.Float32, reflect.Float64,
|
||||
reflect.Uintptr:
|
||||
typ = reflect2.DefaultTypeOfKind(typ.Kind()).Type1()
|
||||
return &numericMapKeyEncoder{encoderOfType(cfg, prefix, typ)}
|
||||
default:
|
||||
if typ == textMarshalerType {
|
||||
return &directTextMarshalerEncoder{
|
||||
stringEncoder: cfg.EncoderOf(reflect.TypeOf("")),
|
||||
}
|
||||
}
|
||||
if typ.Implements(textMarshalerType) {
|
||||
return &textMarshalerEncoder{
|
||||
valType: reflect2.Type2(typ),
|
||||
stringEncoder: cfg.EncoderOf(reflect.TypeOf("")),
|
||||
}
|
||||
}
|
||||
return &lazyErrorEncoder{err: fmt.Errorf("unsupported map key type: %v", typ)}
|
||||
}
|
||||
return &mapEncoder{typ, elemType, encoder, *((*emptyInterface)(unsafe.Pointer(&mapInterface)))}
|
||||
}
|
||||
|
||||
type mapDecoder struct {
|
||||
@ -99,159 +135,108 @@ func (decoder *mapDecoder) Decode(ptr unsafe.Pointer, iter *Iterator) {
|
||||
})
|
||||
}
|
||||
|
||||
type numericMapKeyEncoder struct {
|
||||
encoder ValEncoder
|
||||
}
|
||||
|
||||
func (encoder *numericMapKeyEncoder) Encode(ptr unsafe.Pointer, stream *Stream) {
|
||||
stream.writeByte('"')
|
||||
encoder.encoder.Encode(ptr, stream)
|
||||
stream.writeByte('"')
|
||||
}
|
||||
|
||||
func (encoder *numericMapKeyEncoder) IsEmpty(ptr unsafe.Pointer) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
type mapEncoder struct {
|
||||
mapType reflect.Type
|
||||
elemType reflect.Type
|
||||
mapType *reflect2.UnsafeMapType
|
||||
keyEncoder ValEncoder
|
||||
elemEncoder ValEncoder
|
||||
mapInterface emptyInterface
|
||||
}
|
||||
|
||||
func (encoder *mapEncoder) Encode(ptr unsafe.Pointer, stream *Stream) {
|
||||
mapInterface := encoder.mapInterface
|
||||
mapInterface.word = ptr
|
||||
realInterface := (*interface{})(unsafe.Pointer(&mapInterface))
|
||||
realVal := reflect.ValueOf(*realInterface)
|
||||
stream.WriteObjectStart()
|
||||
for i, key := range realVal.MapKeys() {
|
||||
iter := encoder.mapType.UnsafeIterate(ptr)
|
||||
for i := 0; iter.HasNext(); i++ {
|
||||
if i != 0 {
|
||||
stream.WriteMore()
|
||||
}
|
||||
encodeMapKey(key, stream)
|
||||
key, elem := iter.UnsafeNext()
|
||||
encoder.keyEncoder.Encode(key, stream)
|
||||
if stream.indention > 0 {
|
||||
stream.writeTwoBytes(byte(':'), byte(' '))
|
||||
} else {
|
||||
stream.writeByte(':')
|
||||
}
|
||||
val := realVal.MapIndex(key).Interface()
|
||||
encoder.elemEncoder.Encode(unsafe.Pointer(&val), stream)
|
||||
encoder.elemEncoder.Encode(elem, stream)
|
||||
}
|
||||
stream.WriteObjectEnd()
|
||||
}
|
||||
|
||||
func encodeMapKey(key reflect.Value, stream *Stream) {
|
||||
if key.Kind() == reflect.String {
|
||||
stream.WriteString(key.String())
|
||||
return
|
||||
}
|
||||
if tm, ok := key.Interface().(encoding.TextMarshaler); ok {
|
||||
buf, err := tm.MarshalText()
|
||||
if err != nil {
|
||||
stream.Error = err
|
||||
return
|
||||
}
|
||||
stream.writeByte('"')
|
||||
stream.Write(buf)
|
||||
stream.writeByte('"')
|
||||
return
|
||||
}
|
||||
switch key.Kind() {
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
stream.writeByte('"')
|
||||
stream.WriteInt64(key.Int())
|
||||
stream.writeByte('"')
|
||||
return
|
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
|
||||
stream.writeByte('"')
|
||||
stream.WriteUint64(key.Uint())
|
||||
stream.writeByte('"')
|
||||
return
|
||||
}
|
||||
stream.Error = &json.UnsupportedTypeError{Type: key.Type()}
|
||||
}
|
||||
|
||||
func (encoder *mapEncoder) IsEmpty(ptr unsafe.Pointer) bool {
|
||||
mapInterface := encoder.mapInterface
|
||||
mapInterface.word = ptr
|
||||
realInterface := (*interface{})(unsafe.Pointer(&mapInterface))
|
||||
realVal := reflect.ValueOf(*realInterface)
|
||||
return realVal.Len() == 0
|
||||
iter := encoder.mapType.UnsafeIterate(ptr)
|
||||
return !iter.HasNext()
|
||||
}
|
||||
|
||||
type sortKeysMapEncoder struct {
|
||||
mapType reflect.Type
|
||||
elemType reflect.Type
|
||||
mapType *reflect2.UnsafeMapType
|
||||
keyEncoder ValEncoder
|
||||
elemEncoder ValEncoder
|
||||
mapInterface emptyInterface
|
||||
}
|
||||
|
||||
func (encoder *sortKeysMapEncoder) Encode(ptr unsafe.Pointer, stream *Stream) {
|
||||
ptr = *(*unsafe.Pointer)(ptr)
|
||||
if ptr == nil {
|
||||
if *(*unsafe.Pointer)(ptr) == nil {
|
||||
stream.WriteNil()
|
||||
return
|
||||
}
|
||||
mapInterface := encoder.mapInterface
|
||||
mapInterface.word = ptr
|
||||
realInterface := (*interface{})(unsafe.Pointer(&mapInterface))
|
||||
realVal := reflect.ValueOf(*realInterface)
|
||||
|
||||
// Extract and sort the keys.
|
||||
keys := realVal.MapKeys()
|
||||
sv := stringValues(make([]reflectWithString, len(keys)))
|
||||
for i, v := range keys {
|
||||
sv[i].v = v
|
||||
if err := sv[i].resolve(); err != nil {
|
||||
stream.Error = err
|
||||
return
|
||||
}
|
||||
}
|
||||
sort.Sort(sv)
|
||||
|
||||
stream.WriteObjectStart()
|
||||
for i, key := range sv {
|
||||
mapIter := encoder.mapType.UnsafeIterate(ptr)
|
||||
subStream := stream.cfg.BorrowStream(nil)
|
||||
subIter := stream.cfg.BorrowIterator(nil)
|
||||
keyValues := encodedKeyValues{}
|
||||
for mapIter.HasNext() {
|
||||
subStream.buf = make([]byte, 0, 64)
|
||||
key, elem := mapIter.UnsafeNext()
|
||||
encoder.keyEncoder.Encode(key, subStream)
|
||||
encodedKey := subStream.Buffer()
|
||||
subIter.ResetBytes(encodedKey)
|
||||
decodedKey := subIter.ReadString()
|
||||
if stream.indention > 0 {
|
||||
subStream.writeTwoBytes(byte(':'), byte(' '))
|
||||
} else {
|
||||
subStream.writeByte(':')
|
||||
}
|
||||
encoder.elemEncoder.Encode(elem, subStream)
|
||||
keyValues = append(keyValues, encodedKV{
|
||||
key: decodedKey,
|
||||
keyValue: subStream.Buffer(),
|
||||
})
|
||||
}
|
||||
sort.Sort(keyValues)
|
||||
for i, keyValue := range keyValues {
|
||||
if i != 0 {
|
||||
stream.WriteMore()
|
||||
}
|
||||
stream.WriteVal(key.s) // might need html escape, so can not WriteString directly
|
||||
if stream.indention > 0 {
|
||||
stream.writeTwoBytes(byte(':'), byte(' '))
|
||||
} else {
|
||||
stream.writeByte(':')
|
||||
}
|
||||
val := realVal.MapIndex(key.v).Interface()
|
||||
encoder.elemEncoder.Encode(unsafe.Pointer(&val), stream)
|
||||
stream.Write(keyValue.keyValue)
|
||||
}
|
||||
stream.WriteObjectEnd()
|
||||
stream.cfg.ReturnStream(subStream)
|
||||
stream.cfg.ReturnIterator(subIter)
|
||||
}
|
||||
|
||||
// stringValues is a slice of reflect.Value holding *reflect.StringValue.
|
||||
// It implements the methods to sort by string.
|
||||
type stringValues []reflectWithString
|
||||
|
||||
type reflectWithString struct {
|
||||
v reflect.Value
|
||||
s string
|
||||
}
|
||||
|
||||
func (w *reflectWithString) resolve() error {
|
||||
if w.v.Kind() == reflect.String {
|
||||
w.s = w.v.String()
|
||||
return nil
|
||||
}
|
||||
if tm, ok := w.v.Interface().(encoding.TextMarshaler); ok {
|
||||
buf, err := tm.MarshalText()
|
||||
w.s = string(buf)
|
||||
return err
|
||||
}
|
||||
switch w.v.Kind() {
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
w.s = strconv.FormatInt(w.v.Int(), 10)
|
||||
return nil
|
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
|
||||
w.s = strconv.FormatUint(w.v.Uint(), 10)
|
||||
return nil
|
||||
}
|
||||
return &json.UnsupportedTypeError{Type: w.v.Type()}
|
||||
}
|
||||
|
||||
func (sv stringValues) Len() int { return len(sv) }
|
||||
func (sv stringValues) Swap(i, j int) { sv[i], sv[j] = sv[j], sv[i] }
|
||||
func (sv stringValues) Less(i, j int) bool { return sv[i].s < sv[j].s }
|
||||
|
||||
func (encoder *sortKeysMapEncoder) IsEmpty(ptr unsafe.Pointer) bool {
|
||||
mapInterface := encoder.mapInterface
|
||||
mapInterface.word = ptr
|
||||
realInterface := (*interface{})(unsafe.Pointer(&mapInterface))
|
||||
realVal := reflect.ValueOf(*realInterface)
|
||||
return realVal.Len() == 0
|
||||
iter := encoder.mapType.UnsafeIterate(ptr)
|
||||
return !iter.HasNext()
|
||||
}
|
||||
|
||||
type encodedKeyValues []encodedKV
|
||||
|
||||
type encodedKV struct {
|
||||
key string
|
||||
keyValue []byte
|
||||
}
|
||||
|
||||
func (sv encodedKeyValues) Len() int { return len(sv) }
|
||||
func (sv encodedKeyValues) Swap(i, j int) { sv[i], sv[j] = sv[j], sv[i] }
|
||||
func (sv encodedKeyValues) Less(i, j int) bool { return sv[i].key < sv[j].key }
|
||||
|
@ -55,6 +55,7 @@ func (encoder *directMarshalerEncoder) IsEmpty(ptr unsafe.Pointer) bool {
|
||||
|
||||
type textMarshalerEncoder struct {
|
||||
valType reflect2.Type
|
||||
stringEncoder ValEncoder
|
||||
checkIsEmpty checkIsEmpty
|
||||
}
|
||||
|
||||
@ -69,7 +70,8 @@ func (encoder *textMarshalerEncoder) Encode(ptr unsafe.Pointer, stream *Stream)
|
||||
if err != nil {
|
||||
stream.Error = err
|
||||
} else {
|
||||
stream.WriteString(string(bytes))
|
||||
str := string(bytes)
|
||||
encoder.stringEncoder.Encode(unsafe.Pointer(&str), stream)
|
||||
}
|
||||
}
|
||||
|
||||
@ -78,6 +80,7 @@ func (encoder *textMarshalerEncoder) IsEmpty(ptr unsafe.Pointer) bool {
|
||||
}
|
||||
|
||||
type directTextMarshalerEncoder struct {
|
||||
stringEncoder ValEncoder
|
||||
checkIsEmpty checkIsEmpty
|
||||
}
|
||||
|
||||
@ -91,7 +94,8 @@ func (encoder *directTextMarshalerEncoder) Encode(ptr unsafe.Pointer, stream *St
|
||||
if err != nil {
|
||||
stream.Error = err
|
||||
} else {
|
||||
stream.WriteString(string(bytes))
|
||||
str := string(bytes)
|
||||
encoder.stringEncoder.Encode(unsafe.Pointer(&str), stream)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -74,7 +74,7 @@ func init() {
|
||||
(*[]jsonMarshaler)(nil),
|
||||
(*[]jsonMarshalerMap)(nil),
|
||||
(*[]textMarshaler)(nil),
|
||||
(*[]textMarshalerMap)(nil),
|
||||
selectedSymmetricCase{(*[]textMarshalerMap)(nil)},
|
||||
)
|
||||
}
|
||||
|
||||
|
@ -8,11 +8,23 @@ import (
|
||||
func init() {
|
||||
jsonMarshaler := json.Marshaler(fakeJsonMarshaler{})
|
||||
textMarshaler := encoding.TextMarshaler(fakeTextMarshaler{})
|
||||
textMarshaler2 := encoding.TextMarshaler(&fakeTextMarshaler2{})
|
||||
marshalCases = append(marshalCases,
|
||||
fakeJsonMarshaler{},
|
||||
&jsonMarshaler,
|
||||
fakeTextMarshaler{},
|
||||
&textMarshaler,
|
||||
fakeTextMarshaler2{},
|
||||
&textMarshaler2,
|
||||
map[fakeTextMarshaler]int{
|
||||
fakeTextMarshaler{}: 100,
|
||||
},
|
||||
map[*fakeTextMarshaler]int{
|
||||
&fakeTextMarshaler{}: 100,
|
||||
},
|
||||
map[encoding.TextMarshaler]int{
|
||||
textMarshaler: 100,
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
@ -40,3 +52,15 @@ func (q fakeTextMarshaler) MarshalText() ([]byte, error) {
|
||||
func (q *fakeTextMarshaler) UnmarshalText(value []byte) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type fakeTextMarshaler2 struct {
|
||||
Field2 int
|
||||
}
|
||||
|
||||
func (q *fakeTextMarshaler2) MarshalText() ([]byte, error) {
|
||||
return []byte(`"abc"`), nil
|
||||
}
|
||||
|
||||
func (q *fakeTextMarshaler2) UnmarshalText(value []byte) error {
|
||||
return nil
|
||||
}
|
||||
|
@ -54,9 +54,9 @@ func Test_marshal(t *testing.T) {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
should := require.New(t)
|
||||
output1, err1 := json.Marshal(testCase)
|
||||
should.NoError(err1)
|
||||
should.NoError(err1, "json")
|
||||
output2, err2 := jsoniter.ConfigCompatibleWithStandardLibrary.Marshal(testCase)
|
||||
should.NoError(err2)
|
||||
should.NoError(err2, "jsoniter")
|
||||
should.Equal(string(output1), string(output2))
|
||||
})
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user