1
0
mirror of https://github.com/json-iterator/go.git synced 2025-04-20 11:28:49 +02:00

support TextMarshaler as map key

This commit is contained in:
Tao Wen 2018-02-18 22:49:06 +08:00
parent 577ddede74
commit d8e64aa825
6 changed files with 147 additions and 131 deletions

View File

@ -358,6 +358,7 @@ func createEncoderOfType(cfg *frozenConfig, prefix string, typ reflect.Type) Val
checkIsEmpty := createCheckIsEmpty(cfg, typ)
var encoder ValEncoder = &directTextMarshalerEncoder{
checkIsEmpty: checkIsEmpty,
stringEncoder: cfg.EncoderOf(reflect.TypeOf("")),
}
return encoder
}
@ -365,14 +366,16 @@ func createEncoderOfType(cfg *frozenConfig, prefix string, typ reflect.Type) Val
checkIsEmpty := createCheckIsEmpty(cfg, typ)
var encoder ValEncoder = &textMarshalerEncoder{
valType: reflect2.Type2(typ),
stringEncoder: cfg.EncoderOf(reflect.TypeOf("")),
checkIsEmpty: checkIsEmpty,
}
return encoder
}
if ptrType.Implements(textMarshalerType) {
if typ.Kind() == reflect.Map && ptrType.Implements(textMarshalerType) {
checkIsEmpty := createCheckIsEmpty(cfg, ptrType)
var encoder ValEncoder = &textMarshalerEncoder{
valType: reflect2.Type2(ptrType),
stringEncoder: cfg.EncoderOf(reflect.TypeOf("")),
checkIsEmpty: checkIsEmpty,
}
return &referenceEncoder{encoder}

View File

@ -2,11 +2,12 @@ package jsoniter
import (
"encoding"
"encoding/json"
"reflect"
"sort"
"strconv"
"unsafe"
"github.com/v2pro/plz/reflect2"
"fmt"
)
func decoderOfMap(cfg *frozenConfig, prefix string, typ reflect.Type) ValDecoder {
@ -16,13 +17,48 @@ func decoderOfMap(cfg *frozenConfig, prefix string, typ reflect.Type) ValDecoder
}
func encoderOfMap(cfg *frozenConfig, prefix string, typ reflect.Type) ValEncoder {
elemType := typ.Elem()
encoder := &emptyInterfaceCodec{}
mapInterface := reflect.New(typ).Elem().Interface()
if cfg.sortMapKeys {
return &sortKeysMapEncoder{typ, elemType, encoder, *((*emptyInterface)(unsafe.Pointer(&mapInterface)))}
return &sortKeysMapEncoder{
mapType: reflect2.Type2(typ).(*reflect2.UnsafeMapType),
keyEncoder: encoderOfMapKey(cfg, prefix+" [mapKey]", typ.Key()),
elemEncoder: encoderOfType(cfg, prefix+" [mapElem]", typ.Elem()),
}
}
return &mapEncoder{
mapType: reflect2.Type2(typ).(*reflect2.UnsafeMapType),
keyEncoder: encoderOfMapKey(cfg, prefix+" [mapKey]", typ.Key()),
elemEncoder: encoderOfType(cfg, prefix+" [mapElem]", typ.Elem()),
}
}
func encoderOfMapKey(cfg *frozenConfig, prefix string, typ reflect.Type) ValEncoder {
switch typ.Kind() {
case reflect.String:
return encoderOfType(cfg, prefix, reflect2.DefaultTypeOfKind(reflect.String).Type1())
case reflect.Bool,
reflect.Uint8, reflect.Int8,
reflect.Uint16, reflect.Int16,
reflect.Uint32, reflect.Int32,
reflect.Uint64, reflect.Int64,
reflect.Uint, reflect.Int,
reflect.Float32, reflect.Float64,
reflect.Uintptr:
typ = reflect2.DefaultTypeOfKind(typ.Kind()).Type1()
return &numericMapKeyEncoder{encoderOfType(cfg, prefix, typ)}
default:
if typ == textMarshalerType {
return &directTextMarshalerEncoder{
stringEncoder: cfg.EncoderOf(reflect.TypeOf("")),
}
}
if typ.Implements(textMarshalerType) {
return &textMarshalerEncoder{
valType: reflect2.Type2(typ),
stringEncoder: cfg.EncoderOf(reflect.TypeOf("")),
}
}
return &lazyErrorEncoder{err: fmt.Errorf("unsupported map key type: %v", typ)}
}
return &mapEncoder{typ, elemType, encoder, *((*emptyInterface)(unsafe.Pointer(&mapInterface)))}
}
type mapDecoder struct {
@ -99,159 +135,108 @@ func (decoder *mapDecoder) Decode(ptr unsafe.Pointer, iter *Iterator) {
})
}
type numericMapKeyEncoder struct {
encoder ValEncoder
}
func (encoder *numericMapKeyEncoder) Encode(ptr unsafe.Pointer, stream *Stream) {
stream.writeByte('"')
encoder.encoder.Encode(ptr, stream)
stream.writeByte('"')
}
func (encoder *numericMapKeyEncoder) IsEmpty(ptr unsafe.Pointer) bool {
return false
}
type mapEncoder struct {
mapType reflect.Type
elemType reflect.Type
elemEncoder ValEncoder
mapInterface emptyInterface
mapType *reflect2.UnsafeMapType
keyEncoder ValEncoder
elemEncoder ValEncoder
}
func (encoder *mapEncoder) Encode(ptr unsafe.Pointer, stream *Stream) {
mapInterface := encoder.mapInterface
mapInterface.word = ptr
realInterface := (*interface{})(unsafe.Pointer(&mapInterface))
realVal := reflect.ValueOf(*realInterface)
stream.WriteObjectStart()
for i, key := range realVal.MapKeys() {
iter := encoder.mapType.UnsafeIterate(ptr)
for i := 0; iter.HasNext(); i++ {
if i != 0 {
stream.WriteMore()
}
encodeMapKey(key, stream)
key, elem := iter.UnsafeNext()
encoder.keyEncoder.Encode(key, stream)
if stream.indention > 0 {
stream.writeTwoBytes(byte(':'), byte(' '))
} else {
stream.writeByte(':')
}
val := realVal.MapIndex(key).Interface()
encoder.elemEncoder.Encode(unsafe.Pointer(&val), stream)
encoder.elemEncoder.Encode(elem, stream)
}
stream.WriteObjectEnd()
}
func encodeMapKey(key reflect.Value, stream *Stream) {
if key.Kind() == reflect.String {
stream.WriteString(key.String())
return
}
if tm, ok := key.Interface().(encoding.TextMarshaler); ok {
buf, err := tm.MarshalText()
if err != nil {
stream.Error = err
return
}
stream.writeByte('"')
stream.Write(buf)
stream.writeByte('"')
return
}
switch key.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
stream.writeByte('"')
stream.WriteInt64(key.Int())
stream.writeByte('"')
return
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
stream.writeByte('"')
stream.WriteUint64(key.Uint())
stream.writeByte('"')
return
}
stream.Error = &json.UnsupportedTypeError{Type: key.Type()}
}
func (encoder *mapEncoder) IsEmpty(ptr unsafe.Pointer) bool {
mapInterface := encoder.mapInterface
mapInterface.word = ptr
realInterface := (*interface{})(unsafe.Pointer(&mapInterface))
realVal := reflect.ValueOf(*realInterface)
return realVal.Len() == 0
iter := encoder.mapType.UnsafeIterate(ptr)
return !iter.HasNext()
}
type sortKeysMapEncoder struct {
mapType reflect.Type
elemType reflect.Type
elemEncoder ValEncoder
mapInterface emptyInterface
mapType *reflect2.UnsafeMapType
keyEncoder ValEncoder
elemEncoder ValEncoder
}
func (encoder *sortKeysMapEncoder) Encode(ptr unsafe.Pointer, stream *Stream) {
ptr = *(*unsafe.Pointer)(ptr)
if ptr == nil {
if *(*unsafe.Pointer)(ptr) == nil {
stream.WriteNil()
return
}
mapInterface := encoder.mapInterface
mapInterface.word = ptr
realInterface := (*interface{})(unsafe.Pointer(&mapInterface))
realVal := reflect.ValueOf(*realInterface)
// Extract and sort the keys.
keys := realVal.MapKeys()
sv := stringValues(make([]reflectWithString, len(keys)))
for i, v := range keys {
sv[i].v = v
if err := sv[i].resolve(); err != nil {
stream.Error = err
return
}
}
sort.Sort(sv)
stream.WriteObjectStart()
for i, key := range sv {
mapIter := encoder.mapType.UnsafeIterate(ptr)
subStream := stream.cfg.BorrowStream(nil)
subIter := stream.cfg.BorrowIterator(nil)
keyValues := encodedKeyValues{}
for mapIter.HasNext() {
subStream.buf = make([]byte, 0, 64)
key, elem := mapIter.UnsafeNext()
encoder.keyEncoder.Encode(key, subStream)
encodedKey := subStream.Buffer()
subIter.ResetBytes(encodedKey)
decodedKey := subIter.ReadString()
if stream.indention > 0 {
subStream.writeTwoBytes(byte(':'), byte(' '))
} else {
subStream.writeByte(':')
}
encoder.elemEncoder.Encode(elem, subStream)
keyValues = append(keyValues, encodedKV{
key: decodedKey,
keyValue: subStream.Buffer(),
})
}
sort.Sort(keyValues)
for i, keyValue := range keyValues {
if i != 0 {
stream.WriteMore()
}
stream.WriteVal(key.s) // might need html escape, so can not WriteString directly
if stream.indention > 0 {
stream.writeTwoBytes(byte(':'), byte(' '))
} else {
stream.writeByte(':')
}
val := realVal.MapIndex(key.v).Interface()
encoder.elemEncoder.Encode(unsafe.Pointer(&val), stream)
stream.Write(keyValue.keyValue)
}
stream.WriteObjectEnd()
stream.cfg.ReturnStream(subStream)
stream.cfg.ReturnIterator(subIter)
}
// stringValues is a slice of reflect.Value holding *reflect.StringValue.
// It implements the methods to sort by string.
type stringValues []reflectWithString
type reflectWithString struct {
v reflect.Value
s string
}
func (w *reflectWithString) resolve() error {
if w.v.Kind() == reflect.String {
w.s = w.v.String()
return nil
}
if tm, ok := w.v.Interface().(encoding.TextMarshaler); ok {
buf, err := tm.MarshalText()
w.s = string(buf)
return err
}
switch w.v.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
w.s = strconv.FormatInt(w.v.Int(), 10)
return nil
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
w.s = strconv.FormatUint(w.v.Uint(), 10)
return nil
}
return &json.UnsupportedTypeError{Type: w.v.Type()}
}
func (sv stringValues) Len() int { return len(sv) }
func (sv stringValues) Swap(i, j int) { sv[i], sv[j] = sv[j], sv[i] }
func (sv stringValues) Less(i, j int) bool { return sv[i].s < sv[j].s }
func (encoder *sortKeysMapEncoder) IsEmpty(ptr unsafe.Pointer) bool {
mapInterface := encoder.mapInterface
mapInterface.word = ptr
realInterface := (*interface{})(unsafe.Pointer(&mapInterface))
realVal := reflect.ValueOf(*realInterface)
return realVal.Len() == 0
iter := encoder.mapType.UnsafeIterate(ptr)
return !iter.HasNext()
}
type encodedKeyValues []encodedKV
type encodedKV struct {
key string
keyValue []byte
}
func (sv encodedKeyValues) Len() int { return len(sv) }
func (sv encodedKeyValues) Swap(i, j int) { sv[i], sv[j] = sv[j], sv[i] }
func (sv encodedKeyValues) Less(i, j int) bool { return sv[i].key < sv[j].key }

View File

@ -55,6 +55,7 @@ func (encoder *directMarshalerEncoder) IsEmpty(ptr unsafe.Pointer) bool {
type textMarshalerEncoder struct {
valType reflect2.Type
stringEncoder ValEncoder
checkIsEmpty checkIsEmpty
}
@ -69,7 +70,8 @@ func (encoder *textMarshalerEncoder) Encode(ptr unsafe.Pointer, stream *Stream)
if err != nil {
stream.Error = err
} else {
stream.WriteString(string(bytes))
str := string(bytes)
encoder.stringEncoder.Encode(unsafe.Pointer(&str), stream)
}
}
@ -78,6 +80,7 @@ func (encoder *textMarshalerEncoder) IsEmpty(ptr unsafe.Pointer) bool {
}
type directTextMarshalerEncoder struct {
stringEncoder ValEncoder
checkIsEmpty checkIsEmpty
}
@ -91,7 +94,8 @@ func (encoder *directTextMarshalerEncoder) Encode(ptr unsafe.Pointer, stream *St
if err != nil {
stream.Error = err
} else {
stream.WriteString(string(bytes))
str := string(bytes)
encoder.stringEncoder.Encode(unsafe.Pointer(&str), stream)
}
}

View File

@ -74,7 +74,7 @@ func init() {
(*[]jsonMarshaler)(nil),
(*[]jsonMarshalerMap)(nil),
(*[]textMarshaler)(nil),
(*[]textMarshalerMap)(nil),
selectedSymmetricCase{(*[]textMarshalerMap)(nil)},
)
}

View File

@ -8,11 +8,23 @@ import (
func init() {
jsonMarshaler := json.Marshaler(fakeJsonMarshaler{})
textMarshaler := encoding.TextMarshaler(fakeTextMarshaler{})
textMarshaler2 := encoding.TextMarshaler(&fakeTextMarshaler2{})
marshalCases = append(marshalCases,
fakeJsonMarshaler{},
&jsonMarshaler,
fakeTextMarshaler{},
&textMarshaler,
fakeTextMarshaler2{},
&textMarshaler2,
map[fakeTextMarshaler]int{
fakeTextMarshaler{}: 100,
},
map[*fakeTextMarshaler]int{
&fakeTextMarshaler{}: 100,
},
map[encoding.TextMarshaler]int{
textMarshaler: 100,
},
)
}
@ -40,3 +52,15 @@ func (q fakeTextMarshaler) MarshalText() ([]byte, error) {
func (q *fakeTextMarshaler) UnmarshalText(value []byte) error {
return nil
}
type fakeTextMarshaler2 struct {
Field2 int
}
func (q *fakeTextMarshaler2) MarshalText() ([]byte, error) {
return []byte(`"abc"`), nil
}
func (q *fakeTextMarshaler2) UnmarshalText(value []byte) error {
return nil
}

View File

@ -54,9 +54,9 @@ func Test_marshal(t *testing.T) {
t.Run(name, func(t *testing.T) {
should := require.New(t)
output1, err1 := json.Marshal(testCase)
should.NoError(err1)
should.NoError(err1, "json")
output2, err2 := jsoniter.ConfigCompatibleWithStandardLibrary.Marshal(testCase)
should.NoError(err2)
should.NoError(err2, "jsoniter")
should.Equal(string(output1), string(output2))
})
}