diff --git a/api_tests/decoder_test.go b/api_tests/decoder_test.go new file mode 100644 index 0000000..8213393 --- /dev/null +++ b/api_tests/decoder_test.go @@ -0,0 +1,17 @@ +package test + +import ( + "bytes" + "github.com/json-iterator/go" + "github.com/stretchr/testify/require" + "testing" +) + +func Test_disallowUnknownFields(t *testing.T) { + should := require.New(t) + type TestObject struct{} + var obj TestObject + decoder := jsoniter.NewDecoder(bytes.NewBufferString(`{"field1":100}`)) + decoder.DisallowUnknownFields() + should.Error(decoder.Decode(&obj)) +} diff --git a/feature_adapter.go b/feature_adapter.go index a322fea..1ba32a4 100644 --- a/feature_adapter.go +++ b/feature_adapter.go @@ -95,13 +95,23 @@ func (adapter *Decoder) Buffered() io.Reader { return bytes.NewReader(remaining) } -// UseNumber for number JSON element, use float64 or json.NumberValue (alias of string) +// UseNumber causes the Decoder to unmarshal a number into an interface{} as a +// Number instead of as a float64. func (adapter *Decoder) UseNumber() { cfg := adapter.iter.cfg.configBeforeFrozen cfg.UseNumber = true adapter.iter.cfg = cfg.frozeWithCacheReuse() } +// DisallowUnknownFields causes the Decoder to return an error when the destination +// is a struct and the input contains object keys which do not match any +// non-ignored, exported fields in the destination. +func (adapter *Decoder) DisallowUnknownFields() { + cfg := adapter.iter.cfg.configBeforeFrozen + cfg.DisallowUnknownFields = true + adapter.iter.cfg = cfg.frozeWithCacheReuse() +} + // NewEncoder same as json.NewEncoder func NewEncoder(writer io.Writer) *Encoder { return ConfigDefault.NewEncoder(writer) diff --git a/feature_config.go b/feature_config.go index 4370716..c5604ea 100644 --- a/feature_config.go +++ b/feature_config.go @@ -16,6 +16,7 @@ type Config struct { EscapeHTML bool SortMapKeys bool UseNumber bool + DisallowUnknownFields bool TagKey string OnlyTaggedField bool ValidateJsonRawMessage bool @@ -65,6 +66,7 @@ func (cfg Config) Froze() API { indentionStep: cfg.IndentionStep, objectFieldMustBeSimpleString: cfg.ObjectFieldMustBeSimpleString, onlyTaggedField: cfg.OnlyTaggedField, + disallowUnknownFields: cfg.DisallowUnknownFields, streamPool: make(chan *Stream, 16), iteratorPool: make(chan *Iterator, 16), } diff --git a/feature_config_without_sync_map.go b/feature_config_without_sync_map.go index d91a583..23e5575 100644 --- a/feature_config_without_sync_map.go +++ b/feature_config_without_sync_map.go @@ -13,6 +13,7 @@ type frozenConfig struct { indentionStep int objectFieldMustBeSimpleString bool onlyTaggedField bool + disallowUnknownFields bool cacheLock *sync.RWMutex decoderCache map[reflect.Type]ValDecoder encoderCache map[reflect.Type]ValEncoder diff --git a/feature_reflect_object.go b/feature_reflect_object.go index 036545c..0e801fb 100644 --- a/feature_reflect_object.go +++ b/feature_reflect_object.go @@ -99,7 +99,7 @@ func decoderOfStruct(cfg *frozenConfig, prefix string, typ reflect.Type) ValDeco for k, binding := range bindings { fields[strings.ToLower(k)] = binding.Decoder.(*structFieldDecoder) } - return createStructDecoder(typ, fields) + return createStructDecoder(cfg, typ, fields) } type structFieldEncoder struct { diff --git a/feature_reflect_struct_decoder.go b/feature_reflect_struct_decoder.go index d308048..bbb3caf 100644 --- a/feature_reflect_struct_decoder.go +++ b/feature_reflect_struct_decoder.go @@ -8,7 +8,10 @@ import ( "unsafe" ) -func createStructDecoder(typ reflect.Type, fields map[string]*structFieldDecoder) ValDecoder { +func createStructDecoder(cfg *frozenConfig, typ reflect.Type, fields map[string]*structFieldDecoder) ValDecoder { + if cfg.disallowUnknownFields { + return &generalStructDecoder{typ: typ, fields: fields, disallowUnknownFields: true} + } knownHash := map[int32]struct{}{ 0: {}, } @@ -20,7 +23,7 @@ func createStructDecoder(typ reflect.Type, fields map[string]*structFieldDecoder fieldHash := calcHash(fieldName) _, known := knownHash[fieldHash] if known { - return &generalStructDecoder{typ, fields} + return &generalStructDecoder{typ, fields, false} } knownHash[fieldHash] = struct{}{} return &oneFieldStructDecoder{typ, fieldHash, fieldDecoder} @@ -34,7 +37,7 @@ func createStructDecoder(typ reflect.Type, fields map[string]*structFieldDecoder fieldHash := calcHash(fieldName) _, known := knownHash[fieldHash] if known { - return &generalStructDecoder{typ, fields} + return &generalStructDecoder{typ, fields, false} } knownHash[fieldHash] = struct{}{} if fieldHash1 == 0 { @@ -57,7 +60,7 @@ func createStructDecoder(typ reflect.Type, fields map[string]*structFieldDecoder fieldHash := calcHash(fieldName) _, known := knownHash[fieldHash] if known { - return &generalStructDecoder{typ, fields} + return &generalStructDecoder{typ, fields, false} } knownHash[fieldHash] = struct{}{} if fieldName1 == 0 { @@ -88,7 +91,7 @@ func createStructDecoder(typ reflect.Type, fields map[string]*structFieldDecoder fieldHash := calcHash(fieldName) _, known := knownHash[fieldHash] if known { - return &generalStructDecoder{typ, fields} + return &generalStructDecoder{typ, fields, false} } knownHash[fieldHash] = struct{}{} if fieldName1 == 0 { @@ -125,7 +128,7 @@ func createStructDecoder(typ reflect.Type, fields map[string]*structFieldDecoder fieldHash := calcHash(fieldName) _, known := knownHash[fieldHash] if known { - return &generalStructDecoder{typ, fields} + return &generalStructDecoder{typ, fields, false} } knownHash[fieldHash] = struct{}{} if fieldName1 == 0 { @@ -168,7 +171,7 @@ func createStructDecoder(typ reflect.Type, fields map[string]*structFieldDecoder fieldHash := calcHash(fieldName) _, known := knownHash[fieldHash] if known { - return &generalStructDecoder{typ, fields} + return &generalStructDecoder{typ, fields, false} } knownHash[fieldHash] = struct{}{} if fieldName1 == 0 { @@ -217,7 +220,7 @@ func createStructDecoder(typ reflect.Type, fields map[string]*structFieldDecoder fieldHash := calcHash(fieldName) _, known := knownHash[fieldHash] if known { - return &generalStructDecoder{typ, fields} + return &generalStructDecoder{typ, fields, false} } knownHash[fieldHash] = struct{}{} if fieldName1 == 0 { @@ -272,7 +275,7 @@ func createStructDecoder(typ reflect.Type, fields map[string]*structFieldDecoder fieldHash := calcHash(fieldName) _, known := knownHash[fieldHash] if known { - return &generalStructDecoder{typ, fields} + return &generalStructDecoder{typ, fields, false} } knownHash[fieldHash] = struct{}{} if fieldName1 == 0 { @@ -333,7 +336,7 @@ func createStructDecoder(typ reflect.Type, fields map[string]*structFieldDecoder fieldHash := calcHash(fieldName) _, known := knownHash[fieldHash] if known { - return &generalStructDecoder{typ, fields} + return &generalStructDecoder{typ, fields, false} } knownHash[fieldHash] = struct{}{} if fieldName1 == 0 { @@ -400,7 +403,7 @@ func createStructDecoder(typ reflect.Type, fields map[string]*structFieldDecoder fieldHash := calcHash(fieldName) _, known := knownHash[fieldHash] if known { - return &generalStructDecoder{typ, fields} + return &generalStructDecoder{typ, fields, false} } knownHash[fieldHash] = struct{}{} if fieldName1 == 0 { @@ -447,12 +450,13 @@ func createStructDecoder(typ reflect.Type, fields map[string]*structFieldDecoder fieldName9, fieldDecoder9, fieldName10, fieldDecoder10} } - return &generalStructDecoder{typ, fields} + return &generalStructDecoder{typ, fields, false} } type generalStructDecoder struct { - typ reflect.Type - fields map[string]*structFieldDecoder + typ reflect.Type + fields map[string]*structFieldDecoder + disallowUnknownFields bool } func (decoder *generalStructDecoder) Decode(ptr unsafe.Pointer, iter *Iterator) { @@ -473,6 +477,11 @@ func (decoder *generalStructDecoder) Decode(ptr unsafe.Pointer, iter *Iterator) } fieldDecoder := decoder.fields[strings.ToLower(field)] if fieldDecoder == nil { + if decoder.disallowUnknownFields { + iter.ReportError("ReadObject", "found unknown field: "+field) + iter.Skip() + return + } iter.Skip() } else { fieldDecoder.Decode(ptr, iter) @@ -490,6 +499,11 @@ func (decoder *generalStructDecoder) Decode(ptr unsafe.Pointer, iter *Iterator) } fieldDecoder = decoder.fields[strings.ToLower(field)] if fieldDecoder == nil { + if decoder.disallowUnknownFields { + iter.ReportError("ReadObject", "found unknown field: "+field) + iter.Skip() + return + } iter.Skip() } else { fieldDecoder.Decode(ptr, iter)