diff --git a/feature_reflect.go b/feature_reflect.go index 8458d30..5711a5a 100644 --- a/feature_reflect.go +++ b/feature_reflect.go @@ -82,19 +82,22 @@ func (stream *Stream) WriteVal(val interface{}) { encoder.Encode(reflect2.PtrOf(val), stream) } -func decoderOfType(cfg *frozenConfig, prefix string, typ reflect.Type) ValDecoder { +func (cfg *frozenConfig) DecoderOf(typ reflect.Type) ValDecoder { cacheKey := typ decoder := cfg.getDecoderFromCache(cacheKey) if decoder != nil { return decoder } - decoder = getTypeDecoderFromExtension(cfg, typ) + decoder = decoderOfType(cfg, "", typ) + cfg.addDecoderToCache(cacheKey, decoder) + return decoder +} + +func decoderOfType(cfg *frozenConfig, prefix string, typ reflect.Type) ValDecoder { + decoder := getTypeDecoderFromExtension(cfg, typ) if decoder != nil { - cfg.addDecoderToCache(cacheKey, decoder) return decoder } - decoder = &placeholderDecoder{cfg: cfg, cacheKey: cacheKey} - cfg.addDecoderToCache(cacheKey, decoder) decoder = createDecoderOfType(cfg, prefix, typ) for _, extension := range extensions { decoder = extension.DecorateDecoder(typ, decoder) @@ -102,7 +105,6 @@ func decoderOfType(cfg *frozenConfig, prefix string, typ reflect.Type) ValDecode for _, extension := range cfg.extensions { decoder = extension.DecorateDecoder(typ, decoder) } - cfg.addDecoderToCache(cacheKey, decoder) return decoder } @@ -120,30 +122,8 @@ func createDecoderOfType(cfg *frozenConfig, prefix string, typ reflect.Type) Val if typ.AssignableTo(jsoniterNumberType) { return &jsoniterNumberCodec{} } - if typ.Implements(unmarshalerType) { - templateInterface := reflect.New(typ).Elem().Interface() - var decoder ValDecoder = &unmarshalerDecoder{extractInterface(templateInterface)} - if typ.Kind() == reflect.Ptr { - decoder = &OptionalDecoder{typ.Elem(), decoder} - } - return decoder - } - if reflect.PtrTo(typ).Implements(unmarshalerType) { - templateInterface := reflect.New(typ).Interface() - var decoder ValDecoder = &unmarshalerDecoder{extractInterface(templateInterface)} - return decoder - } - if typ.Implements(textUnmarshalerType) { - templateInterface := reflect.New(typ).Elem().Interface() - var decoder ValDecoder = &textUnmarshalerDecoder{extractInterface(templateInterface)} - if typ.Kind() == reflect.Ptr { - decoder = &OptionalDecoder{typ.Elem(), decoder} - } - return decoder - } - if reflect.PtrTo(typ).Implements(textUnmarshalerType) { - templateInterface := reflect.New(typ).Interface() - var decoder ValDecoder = &textUnmarshalerDecoder{extractInterface(templateInterface)} + decoder := createDecoderOfMarshaler(cfg, prefix, typ) + if decoder != nil { return decoder } if typ.Kind() == reflect.Slice && typ.Elem().Kind() == reflect.Uint8 { diff --git a/feature_reflect_map.go b/feature_reflect_map.go index b8b57be..8e34353 100644 --- a/feature_reflect_map.go +++ b/feature_reflect_map.go @@ -1,19 +1,24 @@ package jsoniter import ( - "encoding" "reflect" "sort" - "strconv" "unsafe" "github.com/v2pro/plz/reflect2" "fmt" ) func decoderOfMap(cfg *frozenConfig, prefix string, typ reflect.Type) ValDecoder { - decoder := decoderOfType(cfg, prefix+"[map]->", typ.Elem()) - mapInterface := reflect.New(typ).Interface() - return &mapDecoder{typ, typ.Key(), typ.Elem(), decoder, extractInterface(mapInterface)} + keyDecoder := decoderOfMapKey(cfg, prefix+" [mapKey]", typ.Key()) + elemDecoder := decoderOfType(cfg, prefix+" [mapElem]", typ.Elem()) + mapType := reflect2.Type2(typ).(*reflect2.UnsafeMapType) + return &mapDecoder{ + mapType: mapType, + keyType: mapType.Key(), + elemType: mapType.Elem(), + keyDecoder: keyDecoder, + elemDecoder: elemDecoder, + } } func encoderOfMap(cfg *frozenConfig, prefix string, typ reflect.Type) ValEncoder { @@ -31,6 +36,38 @@ func encoderOfMap(cfg *frozenConfig, prefix string, typ reflect.Type) ValEncoder } } +func decoderOfMapKey(cfg *frozenConfig, prefix string, typ reflect.Type) ValDecoder { + switch typ.Kind() { + case reflect.String: + return decoderOfType(cfg, prefix, reflect2.DefaultTypeOfKind(reflect.String).Type1()) + case reflect.Bool, + reflect.Uint8, reflect.Int8, + reflect.Uint16, reflect.Int16, + reflect.Uint32, reflect.Int32, + reflect.Uint64, reflect.Int64, + reflect.Uint, reflect.Int, + reflect.Float32, reflect.Float64, + reflect.Uintptr: + typ = reflect2.DefaultTypeOfKind(typ.Kind()).Type1() + return &numericMapKeyDecoder{decoderOfType(cfg, prefix, typ)} + default: + ptrType := reflect.PtrTo(typ) + if ptrType.Implements(textMarshalerType) { + return &referenceDecoder{ + &textUnmarshalerDecoder{ + valType: reflect2.Type2(ptrType), + }, + } + } + if typ.Implements(textMarshalerType) { + return &textUnmarshalerDecoder{ + valType: reflect2.Type2(typ), + } + } + return &lazyErrorDecoder{err: fmt.Errorf("unsupported map key type: %v", typ)} + } +} + func encoderOfMapKey(cfg *frozenConfig, prefix string, typ reflect.Type) ValEncoder { switch typ.Kind() { case reflect.String: @@ -53,7 +90,7 @@ func encoderOfMapKey(cfg *frozenConfig, prefix string, typ reflect.Type) ValEnco } if typ.Implements(textMarshalerType) { return &textMarshalerEncoder{ - valType: reflect2.Type2(typ), + valType: reflect2.Type2(typ), stringEncoder: cfg.EncoderOf(reflect.TypeOf("")), } } @@ -62,77 +99,81 @@ func encoderOfMapKey(cfg *frozenConfig, prefix string, typ reflect.Type) ValEnco } type mapDecoder struct { - mapType reflect.Type - keyType reflect.Type - elemType reflect.Type - elemDecoder ValDecoder - mapInterface emptyInterface + mapType *reflect2.UnsafeMapType + keyType reflect2.Type + elemType reflect2.Type + keyDecoder ValDecoder + elemDecoder ValDecoder } func (decoder *mapDecoder) Decode(ptr unsafe.Pointer, iter *Iterator) { - // dark magic to cast unsafe.Pointer back to interface{} using reflect.Type - mapInterface := decoder.mapInterface - mapInterface.word = ptr - realInterface := (*interface{})(unsafe.Pointer(&mapInterface)) - realVal := reflect.ValueOf(*realInterface).Elem() - if iter.ReadNil() { - realVal.Set(reflect.Zero(decoder.mapType)) + mapType := decoder.mapType + c := iter.nextToken() + if c == 'n' { + iter.skipThreeBytes('u', 'l', 'l') + *(*unsafe.Pointer)(ptr) = nil + mapType.UnsafeSet(ptr, mapType.UnsafeNew()) return } - if realVal.IsNil() { - realVal.Set(reflect.MakeMap(realVal.Type())) + if mapType.UnsafeIsNil(ptr) { + mapType.UnsafeSet(ptr, mapType.UnsafeMakeMap(0)) } - iter.ReadMapCB(func(iter *Iterator, keyStr string) bool { - elem := reflect.New(decoder.elemType) - decoder.elemDecoder.Decode(extractInterface(elem.Interface()).word, iter) - // to put into map, we have to use reflection - keyType := decoder.keyType - // TODO: remove this from loop - switch { - case keyType.Kind() == reflect.String: - realVal.SetMapIndex(reflect.ValueOf(keyStr).Convert(keyType), elem.Elem()) - return true - case keyType.Implements(textUnmarshalerType): - textUnmarshaler := reflect.New(keyType.Elem()).Interface().(encoding.TextUnmarshaler) - err := textUnmarshaler.UnmarshalText([]byte(keyStr)) - if err != nil { - iter.ReportError("read map key as TextUnmarshaler", err.Error()) - return false - } - realVal.SetMapIndex(reflect.ValueOf(textUnmarshaler), elem.Elem()) - return true - case reflect.PtrTo(keyType).Implements(textUnmarshalerType): - textUnmarshaler := reflect.New(keyType).Interface().(encoding.TextUnmarshaler) - err := textUnmarshaler.UnmarshalText([]byte(keyStr)) - if err != nil { - iter.ReportError("read map key as TextUnmarshaler", err.Error()) - return false - } - realVal.SetMapIndex(reflect.ValueOf(textUnmarshaler).Elem(), elem.Elem()) - return true - default: - switch keyType.Kind() { - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - n, err := strconv.ParseInt(keyStr, 10, 64) - if err != nil || reflect.Zero(keyType).OverflowInt(n) { - iter.ReportError("read map key as int64", "read int64 failed") - return false - } - realVal.SetMapIndex(reflect.ValueOf(n).Convert(keyType), elem.Elem()) - return true - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: - n, err := strconv.ParseUint(keyStr, 10, 64) - if err != nil || reflect.Zero(keyType).OverflowUint(n) { - iter.ReportError("read map key as uint64", "read uint64 failed") - return false - } - realVal.SetMapIndex(reflect.ValueOf(n).Convert(keyType), elem.Elem()) - return true - } + if c != '{' { + iter.ReportError("ReadMapCB", `expect { or n, but found `+string([]byte{c})) + return + } + c = iter.nextToken() + if c == '}' { + return + } + if c != '"' { + iter.ReportError("ReadMapCB", `expect " after }, but found `+string([]byte{c})) + return + } + iter.unreadByte() + key := decoder.keyType.UnsafeNew() + decoder.keyDecoder.Decode(key, iter) + c = iter.nextToken() + if c != ':' { + iter.ReportError("ReadMapCB", "expect : after object field, but found "+string([]byte{c})) + return + } + elem := decoder.elemType.UnsafeNew() + decoder.elemDecoder.Decode(elem, iter) + decoder.mapType.UnsafeSetIndex(ptr, key, elem) + for c = iter.nextToken(); c == ','; c = iter.nextToken() { + key := decoder.keyType.UnsafeNew() + decoder.keyDecoder.Decode(key, iter) + c = iter.nextToken() + if c != ':' { + iter.ReportError("ReadMapCB", "expect : after object field, but found "+string([]byte{c})) + return } - iter.ReportError("read map key", "unexpected map key type "+keyType.String()) - return true - }) + elem := decoder.elemType.UnsafeNew() + decoder.elemDecoder.Decode(elem, iter) + decoder.mapType.UnsafeSetIndex(ptr, key, elem) + } + if c != '}' { + iter.ReportError("ReadMapCB", `expect }, but found `+string([]byte{c})) + } +} + +type numericMapKeyDecoder struct { + decoder ValDecoder +} + +func (decoder *numericMapKeyDecoder) Decode(ptr unsafe.Pointer, iter *Iterator) { + c := iter.nextToken() + if c != '"' { + iter.ReportError("ReadMapCB", `expect ", but found `+string([]byte{c})) + return + } + decoder.decoder.Decode(ptr, iter) + c = iter.nextToken() + if c != '"' { + iter.ReportError("ReadMapCB", `expect ", but found `+string([]byte{c})) + return + } } type numericMapKeyEncoder struct { diff --git a/feature_reflect_marshaler.go b/feature_reflect_marshaler.go index b3a1dcf..000e985 100644 --- a/feature_reflect_marshaler.go +++ b/feature_reflect_marshaler.go @@ -8,6 +8,21 @@ import ( "reflect" ) +func createDecoderOfMarshaler(cfg *frozenConfig, prefix string, typ reflect.Type) ValDecoder { + ptrType := reflect.PtrTo(typ) + if ptrType.Implements(unmarshalerType) { + return &referenceDecoder{ + &unmarshalerDecoder{reflect2.Type2(ptrType)}, + } + } + if ptrType.Implements(textUnmarshalerType) { + return &referenceDecoder{ + &textUnmarshalerDecoder{reflect2.Type2(ptrType)}, + } + } + return nil +} + func createEncoderOfMarshaler(cfg *frozenConfig, prefix string, typ reflect.Type) ValEncoder { if typ == marshalerType { checkIsEmpty := createCheckIsEmpty(cfg, typ) @@ -160,14 +175,13 @@ func (encoder *directTextMarshalerEncoder) IsEmpty(ptr unsafe.Pointer) bool { } type unmarshalerDecoder struct { - templateInterface emptyInterface + valType reflect2.Type } func (decoder *unmarshalerDecoder) Decode(ptr unsafe.Pointer, iter *Iterator) { - templateInterface := decoder.templateInterface - templateInterface.word = ptr - realInterface := (*interface{})(unsafe.Pointer(&templateInterface)) - unmarshaler := (*realInterface).(json.Unmarshaler) + valType := decoder.valType + obj := valType.UnsafeIndirect(ptr) + unmarshaler := obj.(json.Unmarshaler) iter.nextToken() iter.unreadByte() // skip spaces bytes := iter.SkipAndReturnBytes() @@ -178,14 +192,20 @@ func (decoder *unmarshalerDecoder) Decode(ptr unsafe.Pointer, iter *Iterator) { } type textUnmarshalerDecoder struct { - templateInterface emptyInterface + valType reflect2.Type } func (decoder *textUnmarshalerDecoder) Decode(ptr unsafe.Pointer, iter *Iterator) { - templateInterface := decoder.templateInterface - templateInterface.word = ptr - realInterface := (*interface{})(unsafe.Pointer(&templateInterface)) - unmarshaler := (*realInterface).(encoding.TextUnmarshaler) + valType := decoder.valType + obj := valType.UnsafeIndirect(ptr) + if reflect2.IsNil(obj) { + ptrType := valType.(*reflect2.UnsafePtrType) + elemType := ptrType.Elem() + elem := elemType.UnsafeNew() + ptrType.UnsafeSet(ptr, unsafe.Pointer(&elem)) + obj = valType.UnsafeIndirect(ptr) + } + unmarshaler := (obj).(encoding.TextUnmarshaler) str := iter.ReadString() err := unmarshaler.UnmarshalText([]byte(str)) if err != nil { diff --git a/feature_reflect_optional.go b/feature_reflect_optional.go index 0f7fdb4..4b1156d 100644 --- a/feature_reflect_optional.go +++ b/feature_reflect_optional.go @@ -118,4 +118,12 @@ func (encoder *referenceEncoder) Encode(ptr unsafe.Pointer, stream *Stream) { func (encoder *referenceEncoder) IsEmpty(ptr unsafe.Pointer) bool { return encoder.encoder.IsEmpty(unsafe.Pointer(&ptr)) +} + +type referenceDecoder struct { + decoder ValDecoder +} + +func (decoder *referenceDecoder) Decode(ptr unsafe.Pointer, iter *Iterator) { + decoder.decoder.Decode(unsafe.Pointer(&ptr), iter) } \ No newline at end of file diff --git a/value_tests/marshaler_test.go b/value_tests/marshaler_test.go index b6baa00..90e244e 100644 --- a/value_tests/marshaler_test.go +++ b/value_tests/marshaler_test.go @@ -6,61 +6,80 @@ import ( ) func init() { - jsonMarshaler := json.Marshaler(fakeJsonMarshaler{}) - textMarshaler := encoding.TextMarshaler(fakeTextMarshaler{}) - textMarshaler2 := encoding.TextMarshaler(&fakeTextMarshaler2{}) + jm := json.Marshaler(jmOfStruct{}) + tm1 := encoding.TextMarshaler(tmOfStruct{}) + tm2 := encoding.TextMarshaler(&tmOfStructInt{}) marshalCases = append(marshalCases, - fakeJsonMarshaler{}, - &jsonMarshaler, - fakeTextMarshaler{}, - &textMarshaler, - fakeTextMarshaler2{}, - &textMarshaler2, - map[fakeTextMarshaler]int{ - fakeTextMarshaler{}: 100, + jmOfStruct{}, + &jm, + tmOfStruct{}, + &tm1, + tmOfStructInt{}, + &tm2, + map[tmOfStruct]int{ + tmOfStruct{}: 100, }, - map[*fakeTextMarshaler]int{ - &fakeTextMarshaler{}: 100, + map[*tmOfStruct]int{ + &tmOfStruct{}: 100, }, map[encoding.TextMarshaler]int{ - textMarshaler: 100, + tm1: 100, }, ) + unmarshalCases = append(unmarshalCases, unmarshalCase{ + ptr: (*tmOfMap)(nil), + input: `"{1:2}"`, + }, unmarshalCase{ + ptr: (*tmOfMapPtr)(nil), + input: `"{1:2}"`, + }) } -type fakeJsonMarshaler struct { +type jmOfStruct struct { F2 chan []byte } -func (q fakeJsonMarshaler) MarshalJSON() ([]byte, error) { +func (q jmOfStruct) MarshalJSON() ([]byte, error) { return []byte(`""`), nil } -func (q *fakeJsonMarshaler) UnmarshalJSON(value []byte) error { +func (q *jmOfStruct) UnmarshalJSON(value []byte) error { return nil } -type fakeTextMarshaler struct { +type tmOfStruct struct { F2 chan []byte } -func (q fakeTextMarshaler) MarshalText() ([]byte, error) { +func (q tmOfStruct) MarshalText() ([]byte, error) { return []byte(`""`), nil } -func (q *fakeTextMarshaler) UnmarshalText(value []byte) error { +func (q *tmOfStruct) UnmarshalText(value []byte) error { return nil } -type fakeTextMarshaler2 struct { +type tmOfStructInt struct { Field2 int } -func (q *fakeTextMarshaler2) MarshalText() ([]byte, error) { +func (q *tmOfStructInt) MarshalText() ([]byte, error) { return []byte(`"abc"`), nil } -func (q *fakeTextMarshaler2) UnmarshalText(value []byte) error { +func (q *tmOfStructInt) UnmarshalText(value []byte) error { + return nil +} + +type tmOfMap map[int]int + +func (q tmOfMap) UnmarshalText(value []byte) error { + return nil +} + +type tmOfMapPtr map[int]int + +func (q *tmOfMapPtr) UnmarshalText(value []byte) error { return nil }