diff --git a/api_tests/marshal_json_test.go b/api_tests/marshal_json_test.go index 635a24e..7f3a0aa 100644 --- a/api_tests/marshal_json_test.go +++ b/api_tests/marshal_json_test.go @@ -4,11 +4,10 @@ import ( "bytes" "encoding/json" "github.com/json-iterator/go" - "testing" "github.com/stretchr/testify/require" + "testing" ) - type Foo struct { Bar interface{} } @@ -19,11 +18,10 @@ func (f Foo) MarshalJSON() ([]byte, error) { return buf.Bytes(), err } - // Standard Encoder has trailing newline. func TestEncodeMarshalJSON(t *testing.T) { - foo := Foo { + foo := Foo{ Bar: 123, } should := require.New(t) diff --git a/extension_tests/extension_test.go b/extension_tests/extension_test.go index 836db5b..0c379ba 100644 --- a/extension_tests/extension_test.go +++ b/extension_tests/extension_test.go @@ -61,6 +61,79 @@ func Test_customize_map_key_encoder(t *testing.T) { should.Equal(map[int]int{1: 2}, m) } +// Test using custom encoder with sorted map keys. +// Keys should be numerically sorted AFTER the custom key encoder runs. +func Test_customize_map_key_encoder_with_sorted_keys(t *testing.T) { + should := require.New(t) + cfg := jsoniter.Config{ + SortMapKeys: true, + }.Froze() + cfg.RegisterExtension(&testMapKeyExtension{}) + m := map[int]int{ + 3: 3, + 1: 9, + } + output, err := cfg.MarshalToString(m) + should.NoError(err) + should.Equal(`{"2":9,"4":3}`, output) + m2 := map[int]int{} + should.NoError(cfg.UnmarshalFromString(output, &m2)) + should.Equal(map[int]int{ + 1: 9, + 3: 3, + }, m2) +} + +func Test_customize_map_key_sorter(t *testing.T) { + should := require.New(t) + cfg := jsoniter.Config{ + SortMapKeys: true, + }.Froze() + + cfg.RegisterExtension(&testMapKeySorterExtension{ + sorter: &testKeySorter{}, + }) + + m := map[string]int{ + "a": 1, + "foo": 2, + "b": 3, + } + output, err := cfg.MarshalToString(m) + should.NoError(err) + should.Equal(`{"foo":2,"a":1,"b":3}`, output) + m = map[string]int{} + should.NoError(cfg.UnmarshalFromString(output, &m)) + should.Equal(map[string]int{ + "foo": 2, + "a": 1, + "b": 3, + }, m) +} + +type testKeySorter struct { +} + +func (sorter *testKeySorter) Sort(keyA string, keyB string) bool { + // Prioritize "foo" over other keys, otherwise alpha-sort + if keyA == "foo" { + return true + } else if keyB == "foo" { + return false + } else { + return keyA < keyB + } +} + +type testMapKeySorterExtension struct { + jsoniter.DummyExtension + sorter jsoniter.MapKeySorter +} + +func (extension *testMapKeySorterExtension) CreateMapKeySorter() jsoniter.MapKeySorter { + return extension.sorter +} + type testMapKeyExtension struct { jsoniter.DummyExtension } diff --git a/reflect.go b/reflect.go index 4459e20..c787e23 100644 --- a/reflect.go +++ b/reflect.go @@ -30,6 +30,11 @@ type ValEncoder interface { Encode(ptr unsafe.Pointer, stream *Stream) } +// MapKeySorter is used to define a custom function for sorting the keys of maps +type MapKeySorter interface { + Sort(keyA string, keyB string) bool +} + type checkIsEmpty interface { IsEmpty(ptr unsafe.Pointer) bool } diff --git a/reflect_extension.go b/reflect_extension.go index 05e8fbf..5ad3cc8 100644 --- a/reflect_extension.go +++ b/reflect_extension.go @@ -49,6 +49,7 @@ type Extension interface { UpdateStructDescriptor(structDescriptor *StructDescriptor) CreateMapKeyDecoder(typ reflect2.Type) ValDecoder CreateMapKeyEncoder(typ reflect2.Type) ValEncoder + CreateMapKeySorter() MapKeySorter CreateDecoder(typ reflect2.Type) ValDecoder CreateEncoder(typ reflect2.Type) ValEncoder DecorateDecoder(typ reflect2.Type, decoder ValDecoder) ValDecoder @@ -73,6 +74,11 @@ func (extension *DummyExtension) CreateMapKeyEncoder(typ reflect2.Type) ValEncod return nil } +// CreateMapKeySorter No-op +func (extension *DummyExtension) CreateMapKeySorter() MapKeySorter { + return nil +} + // CreateDecoder No-op func (extension *DummyExtension) CreateDecoder(typ reflect2.Type) ValDecoder { return nil @@ -119,6 +125,11 @@ func (extension EncoderExtension) CreateMapKeyEncoder(typ reflect2.Type) ValEnco return nil } +// CreateMapKeySorter No-op +func (extension EncoderExtension) CreateMapKeySorter() MapKeySorter { + return nil +} + // DecorateDecoder No-op func (extension EncoderExtension) DecorateDecoder(typ reflect2.Type, decoder ValDecoder) ValDecoder { return decoder @@ -145,6 +156,11 @@ func (extension DecoderExtension) CreateMapKeyEncoder(typ reflect2.Type) ValEnco return nil } +// CreateMapKeySorter No-op +func (extension DecoderExtension) CreateMapKeySorter() MapKeySorter { + return nil +} + // CreateDecoder get decoder from map func (extension DecoderExtension) CreateDecoder(typ reflect2.Type) ValDecoder { return extension[typ] diff --git a/reflect_map.go b/reflect_map.go index 547b442..45739f6 100644 --- a/reflect_map.go +++ b/reflect_map.go @@ -28,6 +28,7 @@ func encoderOfMap(ctx *ctx, typ reflect2.Type) ValEncoder { return &sortKeysMapEncoder{ mapType: mapType, keyEncoder: encoderOfMapKey(ctx.append("[mapKey]"), mapType.Key()), + keySorter: sorterOfMapKey(ctx), elemEncoder: encoderOfType(ctx.append("[mapElem]"), mapType.Elem()), } } @@ -38,6 +39,23 @@ func encoderOfMap(ctx *ctx, typ reflect2.Type) ValEncoder { } } +type defaultMapKeySorter struct { +} + +func (sorter defaultMapKeySorter) Sort(keyA string, keyB string) bool { + return keyA < keyB +} + +func sorterOfMapKey(ctx *ctx) MapKeySorter { + for _, extension := range ctx.extraExtensions { + sorter := extension.CreateMapKeySorter() + if sorter != nil { + return sorter + } + } + return defaultMapKeySorter{} +} + func decoderOfMapKey(ctx *ctx, typ reflect2.Type) ValDecoder { decoder := ctx.decoderExtension.CreateMapKeyDecoder(typ) if decoder != nil { @@ -275,6 +293,7 @@ func (encoder *mapEncoder) IsEmpty(ptr unsafe.Pointer) bool { type sortKeysMapEncoder struct { mapType *reflect2.UnsafeMapType keyEncoder ValEncoder + keySorter MapKeySorter elemEncoder ValEncoder } @@ -287,7 +306,7 @@ func (encoder *sortKeysMapEncoder) Encode(ptr unsafe.Pointer, stream *Stream) { mapIter := encoder.mapType.UnsafeIterate(ptr) subStream := stream.cfg.BorrowStream(nil) subIter := stream.cfg.BorrowIterator(nil) - keyValues := encodedKeyValues{} + keyValues := []encodedKV{} for mapIter.HasNext() { subStream.buf = make([]byte, 0, 64) key, elem := mapIter.UnsafeNext() @@ -309,7 +328,11 @@ func (encoder *sortKeysMapEncoder) Encode(ptr unsafe.Pointer, stream *Stream) { keyValue: subStream.Buffer(), }) } - sort.Sort(keyValues) + keyValueWrapper := encodedKeyValues{ + keySorter: encoder.keySorter, + values: keyValues, + } + sort.Sort(keyValueWrapper) for i, keyValue := range keyValues { if i != 0 { stream.WriteMore() @@ -326,13 +349,18 @@ func (encoder *sortKeysMapEncoder) IsEmpty(ptr unsafe.Pointer) bool { return !iter.HasNext() } -type encodedKeyValues []encodedKV +type encodedKeyValues struct { + keySorter MapKeySorter + values []encodedKV +} type encodedKV struct { key string keyValue []byte } -func (sv encodedKeyValues) Len() int { return len(sv) } -func (sv encodedKeyValues) Swap(i, j int) { sv[i], sv[j] = sv[j], sv[i] } -func (sv encodedKeyValues) Less(i, j int) bool { return sv[i].key < sv[j].key } +func (sv encodedKeyValues) Len() int { return len(sv.values) } +func (sv encodedKeyValues) Swap(i, j int) { sv.values[i], sv.values[j] = sv.values[j], sv.values[i] } +func (sv encodedKeyValues) Less(i, j int) bool { + return sv.keySorter.Sort(sv.values[i].key, sv.values[j].key) +}