mirror of
https://github.com/json-iterator/go.git
synced 2024-11-27 08:30:57 +02:00
feature #384: Specify custom map key sort function
This commit is contained in:
parent
27518f6661
commit
3e68b32b9a
@ -4,11 +4,10 @@ import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"github.com/json-iterator/go"
|
||||
"testing"
|
||||
"github.com/stretchr/testify/require"
|
||||
"testing"
|
||||
)
|
||||
|
||||
|
||||
type Foo struct {
|
||||
Bar interface{}
|
||||
}
|
||||
@ -19,11 +18,10 @@ func (f Foo) MarshalJSON() ([]byte, error) {
|
||||
return buf.Bytes(), err
|
||||
}
|
||||
|
||||
|
||||
// Standard Encoder has trailing newline.
|
||||
func TestEncodeMarshalJSON(t *testing.T) {
|
||||
|
||||
foo := Foo {
|
||||
foo := Foo{
|
||||
Bar: 123,
|
||||
}
|
||||
should := require.New(t)
|
||||
|
@ -61,6 +61,79 @@ func Test_customize_map_key_encoder(t *testing.T) {
|
||||
should.Equal(map[int]int{1: 2}, m)
|
||||
}
|
||||
|
||||
// Test using custom encoder with sorted map keys.
|
||||
// Keys should be numerically sorted AFTER the custom key encoder runs.
|
||||
func Test_customize_map_key_encoder_with_sorted_keys(t *testing.T) {
|
||||
should := require.New(t)
|
||||
cfg := jsoniter.Config{
|
||||
SortMapKeys: true,
|
||||
}.Froze()
|
||||
cfg.RegisterExtension(&testMapKeyExtension{})
|
||||
m := map[int]int{
|
||||
3: 3,
|
||||
1: 9,
|
||||
}
|
||||
output, err := cfg.MarshalToString(m)
|
||||
should.NoError(err)
|
||||
should.Equal(`{"2":9,"4":3}`, output)
|
||||
m2 := map[int]int{}
|
||||
should.NoError(cfg.UnmarshalFromString(output, &m2))
|
||||
should.Equal(map[int]int{
|
||||
1: 9,
|
||||
3: 3,
|
||||
}, m2)
|
||||
}
|
||||
|
||||
func Test_customize_map_key_sorter(t *testing.T) {
|
||||
should := require.New(t)
|
||||
cfg := jsoniter.Config{
|
||||
SortMapKeys: true,
|
||||
}.Froze()
|
||||
|
||||
cfg.RegisterExtension(&testMapKeySorterExtension{
|
||||
sorter: &testKeySorter{},
|
||||
})
|
||||
|
||||
m := map[string]int{
|
||||
"a": 1,
|
||||
"foo": 2,
|
||||
"b": 3,
|
||||
}
|
||||
output, err := cfg.MarshalToString(m)
|
||||
should.NoError(err)
|
||||
should.Equal(`{"foo":2,"a":1,"b":3}`, output)
|
||||
m = map[string]int{}
|
||||
should.NoError(cfg.UnmarshalFromString(output, &m))
|
||||
should.Equal(map[string]int{
|
||||
"foo": 2,
|
||||
"a": 1,
|
||||
"b": 3,
|
||||
}, m)
|
||||
}
|
||||
|
||||
type testKeySorter struct {
|
||||
}
|
||||
|
||||
func (sorter *testKeySorter) Sort(keyA string, keyB string) bool {
|
||||
// Prioritize "foo" over other keys, otherwise alpha-sort
|
||||
if keyA == "foo" {
|
||||
return true
|
||||
} else if keyB == "foo" {
|
||||
return false
|
||||
} else {
|
||||
return keyA < keyB
|
||||
}
|
||||
}
|
||||
|
||||
type testMapKeySorterExtension struct {
|
||||
jsoniter.DummyExtension
|
||||
sorter jsoniter.MapKeySorter
|
||||
}
|
||||
|
||||
func (extension *testMapKeySorterExtension) CreateMapKeySorter() jsoniter.MapKeySorter {
|
||||
return extension.sorter
|
||||
}
|
||||
|
||||
type testMapKeyExtension struct {
|
||||
jsoniter.DummyExtension
|
||||
}
|
||||
|
@ -30,6 +30,11 @@ type ValEncoder interface {
|
||||
Encode(ptr unsafe.Pointer, stream *Stream)
|
||||
}
|
||||
|
||||
// MapKeySorter is used to define a custom function for sorting the keys of maps
|
||||
type MapKeySorter interface {
|
||||
Sort(keyA string, keyB string) bool
|
||||
}
|
||||
|
||||
type checkIsEmpty interface {
|
||||
IsEmpty(ptr unsafe.Pointer) bool
|
||||
}
|
||||
|
@ -49,6 +49,7 @@ type Extension interface {
|
||||
UpdateStructDescriptor(structDescriptor *StructDescriptor)
|
||||
CreateMapKeyDecoder(typ reflect2.Type) ValDecoder
|
||||
CreateMapKeyEncoder(typ reflect2.Type) ValEncoder
|
||||
CreateMapKeySorter() MapKeySorter
|
||||
CreateDecoder(typ reflect2.Type) ValDecoder
|
||||
CreateEncoder(typ reflect2.Type) ValEncoder
|
||||
DecorateDecoder(typ reflect2.Type, decoder ValDecoder) ValDecoder
|
||||
@ -73,6 +74,11 @@ func (extension *DummyExtension) CreateMapKeyEncoder(typ reflect2.Type) ValEncod
|
||||
return nil
|
||||
}
|
||||
|
||||
// CreateMapKeySorter No-op
|
||||
func (extension *DummyExtension) CreateMapKeySorter() MapKeySorter {
|
||||
return nil
|
||||
}
|
||||
|
||||
// CreateDecoder No-op
|
||||
func (extension *DummyExtension) CreateDecoder(typ reflect2.Type) ValDecoder {
|
||||
return nil
|
||||
@ -119,6 +125,11 @@ func (extension EncoderExtension) CreateMapKeyEncoder(typ reflect2.Type) ValEnco
|
||||
return nil
|
||||
}
|
||||
|
||||
// CreateMapKeySorter No-op
|
||||
func (extension EncoderExtension) CreateMapKeySorter() MapKeySorter {
|
||||
return nil
|
||||
}
|
||||
|
||||
// DecorateDecoder No-op
|
||||
func (extension EncoderExtension) DecorateDecoder(typ reflect2.Type, decoder ValDecoder) ValDecoder {
|
||||
return decoder
|
||||
@ -145,6 +156,11 @@ func (extension DecoderExtension) CreateMapKeyEncoder(typ reflect2.Type) ValEnco
|
||||
return nil
|
||||
}
|
||||
|
||||
// CreateMapKeySorter No-op
|
||||
func (extension DecoderExtension) CreateMapKeySorter() MapKeySorter {
|
||||
return nil
|
||||
}
|
||||
|
||||
// CreateDecoder get decoder from map
|
||||
func (extension DecoderExtension) CreateDecoder(typ reflect2.Type) ValDecoder {
|
||||
return extension[typ]
|
||||
|
@ -28,6 +28,7 @@ func encoderOfMap(ctx *ctx, typ reflect2.Type) ValEncoder {
|
||||
return &sortKeysMapEncoder{
|
||||
mapType: mapType,
|
||||
keyEncoder: encoderOfMapKey(ctx.append("[mapKey]"), mapType.Key()),
|
||||
keySorter: sorterOfMapKey(ctx),
|
||||
elemEncoder: encoderOfType(ctx.append("[mapElem]"), mapType.Elem()),
|
||||
}
|
||||
}
|
||||
@ -38,6 +39,23 @@ func encoderOfMap(ctx *ctx, typ reflect2.Type) ValEncoder {
|
||||
}
|
||||
}
|
||||
|
||||
type defaultMapKeySorter struct {
|
||||
}
|
||||
|
||||
func (sorter defaultMapKeySorter) Sort(keyA string, keyB string) bool {
|
||||
return keyA < keyB
|
||||
}
|
||||
|
||||
func sorterOfMapKey(ctx *ctx) MapKeySorter {
|
||||
for _, extension := range ctx.extraExtensions {
|
||||
sorter := extension.CreateMapKeySorter()
|
||||
if sorter != nil {
|
||||
return sorter
|
||||
}
|
||||
}
|
||||
return defaultMapKeySorter{}
|
||||
}
|
||||
|
||||
func decoderOfMapKey(ctx *ctx, typ reflect2.Type) ValDecoder {
|
||||
decoder := ctx.decoderExtension.CreateMapKeyDecoder(typ)
|
||||
if decoder != nil {
|
||||
@ -275,6 +293,7 @@ func (encoder *mapEncoder) IsEmpty(ptr unsafe.Pointer) bool {
|
||||
type sortKeysMapEncoder struct {
|
||||
mapType *reflect2.UnsafeMapType
|
||||
keyEncoder ValEncoder
|
||||
keySorter MapKeySorter
|
||||
elemEncoder ValEncoder
|
||||
}
|
||||
|
||||
@ -287,7 +306,7 @@ func (encoder *sortKeysMapEncoder) Encode(ptr unsafe.Pointer, stream *Stream) {
|
||||
mapIter := encoder.mapType.UnsafeIterate(ptr)
|
||||
subStream := stream.cfg.BorrowStream(nil)
|
||||
subIter := stream.cfg.BorrowIterator(nil)
|
||||
keyValues := encodedKeyValues{}
|
||||
keyValues := []encodedKV{}
|
||||
for mapIter.HasNext() {
|
||||
subStream.buf = make([]byte, 0, 64)
|
||||
key, elem := mapIter.UnsafeNext()
|
||||
@ -309,7 +328,11 @@ func (encoder *sortKeysMapEncoder) Encode(ptr unsafe.Pointer, stream *Stream) {
|
||||
keyValue: subStream.Buffer(),
|
||||
})
|
||||
}
|
||||
sort.Sort(keyValues)
|
||||
keyValueWrapper := encodedKeyValues{
|
||||
keySorter: encoder.keySorter,
|
||||
values: keyValues,
|
||||
}
|
||||
sort.Sort(keyValueWrapper)
|
||||
for i, keyValue := range keyValues {
|
||||
if i != 0 {
|
||||
stream.WriteMore()
|
||||
@ -326,13 +349,18 @@ func (encoder *sortKeysMapEncoder) IsEmpty(ptr unsafe.Pointer) bool {
|
||||
return !iter.HasNext()
|
||||
}
|
||||
|
||||
type encodedKeyValues []encodedKV
|
||||
type encodedKeyValues struct {
|
||||
keySorter MapKeySorter
|
||||
values []encodedKV
|
||||
}
|
||||
|
||||
type encodedKV struct {
|
||||
key string
|
||||
keyValue []byte
|
||||
}
|
||||
|
||||
func (sv encodedKeyValues) Len() int { return len(sv) }
|
||||
func (sv encodedKeyValues) Swap(i, j int) { sv[i], sv[j] = sv[j], sv[i] }
|
||||
func (sv encodedKeyValues) Less(i, j int) bool { return sv[i].key < sv[j].key }
|
||||
func (sv encodedKeyValues) Len() int { return len(sv.values) }
|
||||
func (sv encodedKeyValues) Swap(i, j int) { sv.values[i], sv.values[j] = sv.values[j], sv.values[i] }
|
||||
func (sv encodedKeyValues) Less(i, j int) bool {
|
||||
return sv.keySorter.Sort(sv.values[i].key, sv.values[j].key)
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user