From d8e64aa8254448fc7932765afd49f3070f48dbcc Mon Sep 17 00:00:00 2001 From: Tao Wen Date: Sun, 18 Feb 2018 22:49:06 +0800 Subject: [PATCH] support TextMarshaler as map key --- feature_reflect.go | 5 +- feature_reflect_map.go | 235 ++++++++++++++++------------------ feature_reflect_marshaler.go | 8 +- type_tests/slice_test.go | 2 +- value_tests/marshaler_test.go | 24 ++++ value_tests/value_test.go | 4 +- 6 files changed, 147 insertions(+), 131 deletions(-) diff --git a/feature_reflect.go b/feature_reflect.go index 070d65a..56f77e0 100644 --- a/feature_reflect.go +++ b/feature_reflect.go @@ -358,6 +358,7 @@ func createEncoderOfType(cfg *frozenConfig, prefix string, typ reflect.Type) Val checkIsEmpty := createCheckIsEmpty(cfg, typ) var encoder ValEncoder = &directTextMarshalerEncoder{ checkIsEmpty: checkIsEmpty, + stringEncoder: cfg.EncoderOf(reflect.TypeOf("")), } return encoder } @@ -365,14 +366,16 @@ func createEncoderOfType(cfg *frozenConfig, prefix string, typ reflect.Type) Val checkIsEmpty := createCheckIsEmpty(cfg, typ) var encoder ValEncoder = &textMarshalerEncoder{ valType: reflect2.Type2(typ), + stringEncoder: cfg.EncoderOf(reflect.TypeOf("")), checkIsEmpty: checkIsEmpty, } return encoder } - if ptrType.Implements(textMarshalerType) { + if typ.Kind() == reflect.Map && ptrType.Implements(textMarshalerType) { checkIsEmpty := createCheckIsEmpty(cfg, ptrType) var encoder ValEncoder = &textMarshalerEncoder{ valType: reflect2.Type2(ptrType), + stringEncoder: cfg.EncoderOf(reflect.TypeOf("")), checkIsEmpty: checkIsEmpty, } return &referenceEncoder{encoder} diff --git a/feature_reflect_map.go b/feature_reflect_map.go index 47559ac..b8b57be 100644 --- a/feature_reflect_map.go +++ b/feature_reflect_map.go @@ -2,11 +2,12 @@ package jsoniter import ( "encoding" - "encoding/json" "reflect" "sort" "strconv" "unsafe" + "github.com/v2pro/plz/reflect2" + "fmt" ) func decoderOfMap(cfg *frozenConfig, prefix string, typ reflect.Type) ValDecoder { @@ -16,13 +17,48 @@ func decoderOfMap(cfg *frozenConfig, prefix string, typ reflect.Type) ValDecoder } func encoderOfMap(cfg *frozenConfig, prefix string, typ reflect.Type) ValEncoder { - elemType := typ.Elem() - encoder := &emptyInterfaceCodec{} - mapInterface := reflect.New(typ).Elem().Interface() if cfg.sortMapKeys { - return &sortKeysMapEncoder{typ, elemType, encoder, *((*emptyInterface)(unsafe.Pointer(&mapInterface)))} + return &sortKeysMapEncoder{ + mapType: reflect2.Type2(typ).(*reflect2.UnsafeMapType), + keyEncoder: encoderOfMapKey(cfg, prefix+" [mapKey]", typ.Key()), + elemEncoder: encoderOfType(cfg, prefix+" [mapElem]", typ.Elem()), + } + } + return &mapEncoder{ + mapType: reflect2.Type2(typ).(*reflect2.UnsafeMapType), + keyEncoder: encoderOfMapKey(cfg, prefix+" [mapKey]", typ.Key()), + elemEncoder: encoderOfType(cfg, prefix+" [mapElem]", typ.Elem()), + } +} + +func encoderOfMapKey(cfg *frozenConfig, prefix string, typ reflect.Type) ValEncoder { + switch typ.Kind() { + case reflect.String: + return encoderOfType(cfg, prefix, reflect2.DefaultTypeOfKind(reflect.String).Type1()) + case reflect.Bool, + reflect.Uint8, reflect.Int8, + reflect.Uint16, reflect.Int16, + reflect.Uint32, reflect.Int32, + reflect.Uint64, reflect.Int64, + reflect.Uint, reflect.Int, + reflect.Float32, reflect.Float64, + reflect.Uintptr: + typ = reflect2.DefaultTypeOfKind(typ.Kind()).Type1() + return &numericMapKeyEncoder{encoderOfType(cfg, prefix, typ)} + default: + if typ == textMarshalerType { + return &directTextMarshalerEncoder{ + stringEncoder: cfg.EncoderOf(reflect.TypeOf("")), + } + } + if typ.Implements(textMarshalerType) { + return &textMarshalerEncoder{ + valType: reflect2.Type2(typ), + stringEncoder: cfg.EncoderOf(reflect.TypeOf("")), + } + } + return &lazyErrorEncoder{err: fmt.Errorf("unsupported map key type: %v", typ)} } - return &mapEncoder{typ, elemType, encoder, *((*emptyInterface)(unsafe.Pointer(&mapInterface)))} } type mapDecoder struct { @@ -99,159 +135,108 @@ func (decoder *mapDecoder) Decode(ptr unsafe.Pointer, iter *Iterator) { }) } +type numericMapKeyEncoder struct { + encoder ValEncoder +} + +func (encoder *numericMapKeyEncoder) Encode(ptr unsafe.Pointer, stream *Stream) { + stream.writeByte('"') + encoder.encoder.Encode(ptr, stream) + stream.writeByte('"') +} + +func (encoder *numericMapKeyEncoder) IsEmpty(ptr unsafe.Pointer) bool { + return false +} + type mapEncoder struct { - mapType reflect.Type - elemType reflect.Type - elemEncoder ValEncoder - mapInterface emptyInterface + mapType *reflect2.UnsafeMapType + keyEncoder ValEncoder + elemEncoder ValEncoder } func (encoder *mapEncoder) Encode(ptr unsafe.Pointer, stream *Stream) { - mapInterface := encoder.mapInterface - mapInterface.word = ptr - realInterface := (*interface{})(unsafe.Pointer(&mapInterface)) - realVal := reflect.ValueOf(*realInterface) stream.WriteObjectStart() - for i, key := range realVal.MapKeys() { + iter := encoder.mapType.UnsafeIterate(ptr) + for i := 0; iter.HasNext(); i++ { if i != 0 { stream.WriteMore() } - encodeMapKey(key, stream) + key, elem := iter.UnsafeNext() + encoder.keyEncoder.Encode(key, stream) if stream.indention > 0 { stream.writeTwoBytes(byte(':'), byte(' ')) } else { stream.writeByte(':') } - val := realVal.MapIndex(key).Interface() - encoder.elemEncoder.Encode(unsafe.Pointer(&val), stream) + encoder.elemEncoder.Encode(elem, stream) } stream.WriteObjectEnd() } -func encodeMapKey(key reflect.Value, stream *Stream) { - if key.Kind() == reflect.String { - stream.WriteString(key.String()) - return - } - if tm, ok := key.Interface().(encoding.TextMarshaler); ok { - buf, err := tm.MarshalText() - if err != nil { - stream.Error = err - return - } - stream.writeByte('"') - stream.Write(buf) - stream.writeByte('"') - return - } - switch key.Kind() { - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - stream.writeByte('"') - stream.WriteInt64(key.Int()) - stream.writeByte('"') - return - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: - stream.writeByte('"') - stream.WriteUint64(key.Uint()) - stream.writeByte('"') - return - } - stream.Error = &json.UnsupportedTypeError{Type: key.Type()} -} - func (encoder *mapEncoder) IsEmpty(ptr unsafe.Pointer) bool { - mapInterface := encoder.mapInterface - mapInterface.word = ptr - realInterface := (*interface{})(unsafe.Pointer(&mapInterface)) - realVal := reflect.ValueOf(*realInterface) - return realVal.Len() == 0 + iter := encoder.mapType.UnsafeIterate(ptr) + return !iter.HasNext() } type sortKeysMapEncoder struct { - mapType reflect.Type - elemType reflect.Type - elemEncoder ValEncoder - mapInterface emptyInterface + mapType *reflect2.UnsafeMapType + keyEncoder ValEncoder + elemEncoder ValEncoder } func (encoder *sortKeysMapEncoder) Encode(ptr unsafe.Pointer, stream *Stream) { - ptr = *(*unsafe.Pointer)(ptr) - if ptr == nil { + if *(*unsafe.Pointer)(ptr) == nil { stream.WriteNil() return } - mapInterface := encoder.mapInterface - mapInterface.word = ptr - realInterface := (*interface{})(unsafe.Pointer(&mapInterface)) - realVal := reflect.ValueOf(*realInterface) - - // Extract and sort the keys. - keys := realVal.MapKeys() - sv := stringValues(make([]reflectWithString, len(keys))) - for i, v := range keys { - sv[i].v = v - if err := sv[i].resolve(); err != nil { - stream.Error = err - return - } - } - sort.Sort(sv) - stream.WriteObjectStart() - for i, key := range sv { + mapIter := encoder.mapType.UnsafeIterate(ptr) + subStream := stream.cfg.BorrowStream(nil) + subIter := stream.cfg.BorrowIterator(nil) + keyValues := encodedKeyValues{} + for mapIter.HasNext() { + subStream.buf = make([]byte, 0, 64) + key, elem := mapIter.UnsafeNext() + encoder.keyEncoder.Encode(key, subStream) + encodedKey := subStream.Buffer() + subIter.ResetBytes(encodedKey) + decodedKey := subIter.ReadString() + if stream.indention > 0 { + subStream.writeTwoBytes(byte(':'), byte(' ')) + } else { + subStream.writeByte(':') + } + encoder.elemEncoder.Encode(elem, subStream) + keyValues = append(keyValues, encodedKV{ + key: decodedKey, + keyValue: subStream.Buffer(), + }) + } + sort.Sort(keyValues) + for i, keyValue := range keyValues { if i != 0 { stream.WriteMore() } - stream.WriteVal(key.s) // might need html escape, so can not WriteString directly - if stream.indention > 0 { - stream.writeTwoBytes(byte(':'), byte(' ')) - } else { - stream.writeByte(':') - } - val := realVal.MapIndex(key.v).Interface() - encoder.elemEncoder.Encode(unsafe.Pointer(&val), stream) + stream.Write(keyValue.keyValue) } stream.WriteObjectEnd() + stream.cfg.ReturnStream(subStream) + stream.cfg.ReturnIterator(subIter) } -// stringValues is a slice of reflect.Value holding *reflect.StringValue. -// It implements the methods to sort by string. -type stringValues []reflectWithString - -type reflectWithString struct { - v reflect.Value - s string -} - -func (w *reflectWithString) resolve() error { - if w.v.Kind() == reflect.String { - w.s = w.v.String() - return nil - } - if tm, ok := w.v.Interface().(encoding.TextMarshaler); ok { - buf, err := tm.MarshalText() - w.s = string(buf) - return err - } - switch w.v.Kind() { - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - w.s = strconv.FormatInt(w.v.Int(), 10) - return nil - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: - w.s = strconv.FormatUint(w.v.Uint(), 10) - return nil - } - return &json.UnsupportedTypeError{Type: w.v.Type()} -} - -func (sv stringValues) Len() int { return len(sv) } -func (sv stringValues) Swap(i, j int) { sv[i], sv[j] = sv[j], sv[i] } -func (sv stringValues) Less(i, j int) bool { return sv[i].s < sv[j].s } - func (encoder *sortKeysMapEncoder) IsEmpty(ptr unsafe.Pointer) bool { - mapInterface := encoder.mapInterface - mapInterface.word = ptr - realInterface := (*interface{})(unsafe.Pointer(&mapInterface)) - realVal := reflect.ValueOf(*realInterface) - return realVal.Len() == 0 + iter := encoder.mapType.UnsafeIterate(ptr) + return !iter.HasNext() } + +type encodedKeyValues []encodedKV + +type encodedKV struct { + key string + keyValue []byte +} + +func (sv encodedKeyValues) Len() int { return len(sv) } +func (sv encodedKeyValues) Swap(i, j int) { sv[i], sv[j] = sv[j], sv[i] } +func (sv encodedKeyValues) Less(i, j int) bool { return sv[i].key < sv[j].key } diff --git a/feature_reflect_marshaler.go b/feature_reflect_marshaler.go index b262b48..56c93d3 100644 --- a/feature_reflect_marshaler.go +++ b/feature_reflect_marshaler.go @@ -55,6 +55,7 @@ func (encoder *directMarshalerEncoder) IsEmpty(ptr unsafe.Pointer) bool { type textMarshalerEncoder struct { valType reflect2.Type + stringEncoder ValEncoder checkIsEmpty checkIsEmpty } @@ -69,7 +70,8 @@ func (encoder *textMarshalerEncoder) Encode(ptr unsafe.Pointer, stream *Stream) if err != nil { stream.Error = err } else { - stream.WriteString(string(bytes)) + str := string(bytes) + encoder.stringEncoder.Encode(unsafe.Pointer(&str), stream) } } @@ -78,6 +80,7 @@ func (encoder *textMarshalerEncoder) IsEmpty(ptr unsafe.Pointer) bool { } type directTextMarshalerEncoder struct { + stringEncoder ValEncoder checkIsEmpty checkIsEmpty } @@ -91,7 +94,8 @@ func (encoder *directTextMarshalerEncoder) Encode(ptr unsafe.Pointer, stream *St if err != nil { stream.Error = err } else { - stream.WriteString(string(bytes)) + str := string(bytes) + encoder.stringEncoder.Encode(unsafe.Pointer(&str), stream) } } diff --git a/type_tests/slice_test.go b/type_tests/slice_test.go index d63726e..cb1d67f 100644 --- a/type_tests/slice_test.go +++ b/type_tests/slice_test.go @@ -74,7 +74,7 @@ func init() { (*[]jsonMarshaler)(nil), (*[]jsonMarshalerMap)(nil), (*[]textMarshaler)(nil), - (*[]textMarshalerMap)(nil), + selectedSymmetricCase{(*[]textMarshalerMap)(nil)}, ) } diff --git a/value_tests/marshaler_test.go b/value_tests/marshaler_test.go index e3f6546..b6baa00 100644 --- a/value_tests/marshaler_test.go +++ b/value_tests/marshaler_test.go @@ -8,11 +8,23 @@ import ( func init() { jsonMarshaler := json.Marshaler(fakeJsonMarshaler{}) textMarshaler := encoding.TextMarshaler(fakeTextMarshaler{}) + textMarshaler2 := encoding.TextMarshaler(&fakeTextMarshaler2{}) marshalCases = append(marshalCases, fakeJsonMarshaler{}, &jsonMarshaler, fakeTextMarshaler{}, &textMarshaler, + fakeTextMarshaler2{}, + &textMarshaler2, + map[fakeTextMarshaler]int{ + fakeTextMarshaler{}: 100, + }, + map[*fakeTextMarshaler]int{ + &fakeTextMarshaler{}: 100, + }, + map[encoding.TextMarshaler]int{ + textMarshaler: 100, + }, ) } @@ -40,3 +52,15 @@ func (q fakeTextMarshaler) MarshalText() ([]byte, error) { func (q *fakeTextMarshaler) UnmarshalText(value []byte) error { return nil } + +type fakeTextMarshaler2 struct { + Field2 int +} + +func (q *fakeTextMarshaler2) MarshalText() ([]byte, error) { + return []byte(`"abc"`), nil +} + +func (q *fakeTextMarshaler2) UnmarshalText(value []byte) error { + return nil +} diff --git a/value_tests/value_test.go b/value_tests/value_test.go index ce42ae4..72c83ce 100644 --- a/value_tests/value_test.go +++ b/value_tests/value_test.go @@ -54,9 +54,9 @@ func Test_marshal(t *testing.T) { t.Run(name, func(t *testing.T) { should := require.New(t) output1, err1 := json.Marshal(testCase) - should.NoError(err1) + should.NoError(err1, "json") output2, err2 := jsoniter.ConfigCompatibleWithStandardLibrary.Marshal(testCase) - should.NoError(err2) + should.NoError(err2, "jsoniter") should.Equal(string(output1), string(output2)) }) }