diff --git a/iter.go b/iter.go index 95ae54f..29b31cf 100644 --- a/iter.go +++ b/iter.go @@ -74,6 +74,7 @@ type Iterator struct { buf []byte head int tail int + depth int captureStartedAt int captured []byte Error error @@ -88,6 +89,7 @@ func NewIterator(cfg API) *Iterator { buf: nil, head: 0, tail: 0, + depth: 0, } } @@ -99,6 +101,7 @@ func Parse(cfg API, reader io.Reader, bufSize int) *Iterator { buf: make([]byte, bufSize), head: 0, tail: 0, + depth: 0, } } @@ -110,6 +113,7 @@ func ParseBytes(cfg API, input []byte) *Iterator { buf: input, head: 0, tail: len(input), + depth: 0, } } @@ -128,6 +132,7 @@ func (iter *Iterator) Reset(reader io.Reader) *Iterator { iter.reader = reader iter.head = 0 iter.tail = 0 + iter.depth = 0 return iter } @@ -137,6 +142,7 @@ func (iter *Iterator) ResetBytes(input []byte) *Iterator { iter.buf = input iter.head = 0 iter.tail = len(input) + iter.depth = 0 return iter } @@ -320,3 +326,24 @@ func (iter *Iterator) Read() interface{} { return nil } } + +// limit maximum depth of nesting, as allowed by https://tools.ietf.org/html/rfc7159#section-9 +const maxDepth = 10000 + +func (iter *Iterator) incrementDepth() (success bool) { + iter.depth++ + if iter.depth <= maxDepth { + return true + } + iter.ReportError("incrementDepth", "exceeded max depth") + return false +} + +func (iter *Iterator) decrementDepth() (success bool) { + iter.depth-- + if iter.depth >= 0 { + return true + } + iter.ReportError("decrementDepth", "unexpected negative nesting") + return false +} diff --git a/iter_array.go b/iter_array.go index 6188cb4..204fe0e 100644 --- a/iter_array.go +++ b/iter_array.go @@ -28,26 +28,32 @@ func (iter *Iterator) ReadArray() (ret bool) { func (iter *Iterator) ReadArrayCB(callback func(*Iterator) bool) (ret bool) { c := iter.nextToken() if c == '[' { + if !iter.incrementDepth() { + return false + } c = iter.nextToken() if c != ']' { iter.unreadByte() if !callback(iter) { + iter.decrementDepth() return false } c = iter.nextToken() for c == ',' { if !callback(iter) { + iter.decrementDepth() return false } c = iter.nextToken() } if c != ']' { iter.ReportError("ReadArrayCB", "expect ] in the end, but found "+string([]byte{c})) + iter.decrementDepth() return false } - return true + return iter.decrementDepth() } - return true + return iter.decrementDepth() } if c == 'n' { iter.skipThreeBytes('u', 'l', 'l') diff --git a/iter_object.go b/iter_object.go index 1c57576..b651371 100644 --- a/iter_object.go +++ b/iter_object.go @@ -112,6 +112,9 @@ func (iter *Iterator) ReadObjectCB(callback func(*Iterator, string) bool) bool { c := iter.nextToken() var field string if c == '{' { + if !iter.incrementDepth() { + return false + } c = iter.nextToken() if c == '"' { iter.unreadByte() @@ -121,6 +124,7 @@ func (iter *Iterator) ReadObjectCB(callback func(*Iterator, string) bool) bool { iter.ReportError("ReadObject", "expect : after object field, but found "+string([]byte{c})) } if !callback(iter, field) { + iter.decrementDepth() return false } c = iter.nextToken() @@ -131,20 +135,23 @@ func (iter *Iterator) ReadObjectCB(callback func(*Iterator, string) bool) bool { iter.ReportError("ReadObject", "expect : after object field, but found "+string([]byte{c})) } if !callback(iter, field) { + iter.decrementDepth() return false } c = iter.nextToken() } if c != '}' { iter.ReportError("ReadObjectCB", `object not ended with }`) + iter.decrementDepth() return false } - return true + return iter.decrementDepth() } if c == '}' { - return true + return iter.decrementDepth() } iter.ReportError("ReadObjectCB", `expect " after }, but found `+string([]byte{c})) + iter.decrementDepth() return false } if c == 'n' { @@ -159,15 +166,20 @@ func (iter *Iterator) ReadObjectCB(callback func(*Iterator, string) bool) bool { func (iter *Iterator) ReadMapCB(callback func(*Iterator, string) bool) bool { c := iter.nextToken() if c == '{' { + if !iter.incrementDepth() { + return false + } c = iter.nextToken() if c == '"' { iter.unreadByte() field := iter.ReadString() if iter.nextToken() != ':' { iter.ReportError("ReadMapCB", "expect : after object field, but found "+string([]byte{c})) + iter.decrementDepth() return false } if !callback(iter, field) { + iter.decrementDepth() return false } c = iter.nextToken() @@ -175,23 +187,27 @@ func (iter *Iterator) ReadMapCB(callback func(*Iterator, string) bool) bool { field = iter.ReadString() if iter.nextToken() != ':' { iter.ReportError("ReadMapCB", "expect : after object field, but found "+string([]byte{c})) + iter.decrementDepth() return false } if !callback(iter, field) { + iter.decrementDepth() return false } c = iter.nextToken() } if c != '}' { iter.ReportError("ReadMapCB", `object not ended with }`) + iter.decrementDepth() return false } - return true + return iter.decrementDepth() } if c == '}' { - return true + return iter.decrementDepth() } iter.ReportError("ReadMapCB", `expect " after }, but found `+string([]byte{c})) + iter.decrementDepth() return false } if c == 'n' { diff --git a/iter_skip_sloppy.go b/iter_skip_sloppy.go index 8fcdc3b..9303de4 100644 --- a/iter_skip_sloppy.go +++ b/iter_skip_sloppy.go @@ -22,6 +22,9 @@ func (iter *Iterator) skipNumber() { func (iter *Iterator) skipArray() { level := 1 + if !iter.incrementDepth() { + return + } for { for i := iter.head; i < iter.tail; i++ { switch iter.buf[i] { @@ -31,8 +34,14 @@ func (iter *Iterator) skipArray() { i = iter.head - 1 // it will be i++ soon case '[': // If open symbol, increase level level++ + if !iter.incrementDepth() { + return + } case ']': // If close symbol, increase level level-- + if !iter.decrementDepth() { + return + } // If we have returned to the original level, we're done if level == 0 { @@ -50,6 +59,10 @@ func (iter *Iterator) skipArray() { func (iter *Iterator) skipObject() { level := 1 + if !iter.incrementDepth() { + return + } + for { for i := iter.head; i < iter.tail; i++ { switch iter.buf[i] { @@ -59,8 +72,14 @@ func (iter *Iterator) skipObject() { i = iter.head - 1 // it will be i++ soon case '{': // If open symbol, increase level level++ + if !iter.incrementDepth() { + return + } case '}': // If close symbol, increase level level-- + if !iter.decrementDepth() { + return + } // If we have returned to the original level, we're done if level == 0 { diff --git a/misc_tests/jsoniter_nested_test.go b/misc_tests/jsoniter_nested_test.go index 1e4994a..fb2e60f 100644 --- a/misc_tests/jsoniter_nested_test.go +++ b/misc_tests/jsoniter_nested_test.go @@ -4,6 +4,7 @@ import ( "encoding/json" "github.com/json-iterator/go" "reflect" + "strings" "testing" ) @@ -15,6 +16,243 @@ type Level2 struct { World string } +func Test_deep_nested(t *testing.T) { + type unstructured interface{} + + testcases := []struct { + name string + data []byte + expectError string + }{ + { + name: "array under maxDepth", + data: []byte(`{"a":` + strings.Repeat(`[`, 10000-1) + strings.Repeat(`]`, 10000-1) + `}`), + expectError: "", + }, + { + name: "array over maxDepth", + data: []byte(`{"a":` + strings.Repeat(`[`, 10000) + strings.Repeat(`]`, 10000) + `}`), + expectError: "max depth", + }, + { + name: "object under maxDepth", + data: []byte(`{"a":` + strings.Repeat(`{"a":`, 10000-1) + `0` + strings.Repeat(`}`, 10000-1) + `}`), + expectError: "", + }, + { + name: "object over maxDepth", + data: []byte(`{"a":` + strings.Repeat(`{"a":`, 10000) + `0` + strings.Repeat(`}`, 10000) + `}`), + expectError: "max depth", + }, + } + + targets := []struct { + name string + new func() interface{} + }{ + { + name: "unstructured", + new: func() interface{} { + var v interface{} + return &v + }, + }, + { + name: "typed named field", + new: func() interface{} { + v := struct { + A interface{} `json:"a"` + }{} + return &v + }, + }, + { + name: "typed missing field", + new: func() interface{} { + v := struct { + B interface{} `json:"b"` + }{} + return &v + }, + }, + { + name: "typed 1 field", + new: func() interface{} { + v := struct { + A interface{} `json:"a"` + }{} + return &v + }, + }, + { + name: "typed 2 field", + new: func() interface{} { + v := struct { + A interface{} `json:"a"` + B interface{} `json:"b"` + }{} + return &v + }, + }, + { + name: "typed 3 field", + new: func() interface{} { + v := struct { + A interface{} `json:"a"` + B interface{} `json:"b"` + C interface{} `json:"c"` + }{} + return &v + }, + }, + { + name: "typed 4 field", + new: func() interface{} { + v := struct { + A interface{} `json:"a"` + B interface{} `json:"b"` + C interface{} `json:"c"` + D interface{} `json:"d"` + }{} + return &v + }, + }, + { + name: "typed 5 field", + new: func() interface{} { + v := struct { + A interface{} `json:"a"` + B interface{} `json:"b"` + C interface{} `json:"c"` + D interface{} `json:"d"` + E interface{} `json:"e"` + }{} + return &v + }, + }, + { + name: "typed 6 field", + new: func() interface{} { + v := struct { + A interface{} `json:"a"` + B interface{} `json:"b"` + C interface{} `json:"c"` + D interface{} `json:"d"` + E interface{} `json:"e"` + F interface{} `json:"f"` + }{} + return &v + }, + }, + { + name: "typed 7 field", + new: func() interface{} { + v := struct { + A interface{} `json:"a"` + B interface{} `json:"b"` + C interface{} `json:"c"` + D interface{} `json:"d"` + E interface{} `json:"e"` + F interface{} `json:"f"` + G interface{} `json:"g"` + }{} + return &v + }, + }, + { + name: "typed 8 field", + new: func() interface{} { + v := struct { + A interface{} `json:"a"` + B interface{} `json:"b"` + C interface{} `json:"c"` + D interface{} `json:"d"` + E interface{} `json:"e"` + F interface{} `json:"f"` + G interface{} `json:"g"` + H interface{} `json:"h"` + }{} + return &v + }, + }, + { + name: "typed 9 field", + new: func() interface{} { + v := struct { + A interface{} `json:"a"` + B interface{} `json:"b"` + C interface{} `json:"c"` + D interface{} `json:"d"` + E interface{} `json:"e"` + F interface{} `json:"f"` + G interface{} `json:"g"` + H interface{} `json:"h"` + I interface{} `json:"i"` + }{} + return &v + }, + }, + { + name: "typed 10 field", + new: func() interface{} { + v := struct { + A interface{} `json:"a"` + B interface{} `json:"b"` + C interface{} `json:"c"` + D interface{} `json:"d"` + E interface{} `json:"e"` + F interface{} `json:"f"` + G interface{} `json:"g"` + H interface{} `json:"h"` + I interface{} `json:"i"` + J interface{} `json:"j"` + }{} + return &v + }, + }, + { + name: "typed 11 field", + new: func() interface{} { + v := struct { + A interface{} `json:"a"` + B interface{} `json:"b"` + C interface{} `json:"c"` + D interface{} `json:"d"` + E interface{} `json:"e"` + F interface{} `json:"f"` + G interface{} `json:"g"` + H interface{} `json:"h"` + I interface{} `json:"i"` + J interface{} `json:"j"` + K interface{} `json:"k"` + }{} + return &v + }, + }, + } + + for _, tc := range testcases { + t.Run(tc.name, func(t *testing.T) { + for _, target := range targets { + t.Run(target.name, func(t *testing.T) { + err := jsoniter.Unmarshal(tc.data, target.new()) + if len(tc.expectError) == 0 { + if err != nil { + t.Errorf("unexpected error: %v", err) + } + } else { + if err == nil { + t.Errorf("expected error, got none") + } else if !strings.Contains(err.Error(), tc.expectError) { + t.Errorf("expected error containing '%s', got: %v", tc.expectError, err) + } + } + }) + } + }) + } +} + func Test_nested(t *testing.T) { iter := jsoniter.ParseString(jsoniter.ConfigDefault, `{"hello": [{"world": "value1"}, {"world": "value2"}]}`) l1 := Level1{} diff --git a/reflect.go b/reflect.go index 4459e20..74974ba 100644 --- a/reflect.go +++ b/reflect.go @@ -60,6 +60,7 @@ func (b *ctx) append(prefix string) *ctx { // ReadVal copy the underlying JSON into go interface, same as json.Unmarshal func (iter *Iterator) ReadVal(obj interface{}) { + depth := iter.depth cacheKey := reflect2.RTypeOf(obj) decoder := iter.cfg.getDecoderFromCache(cacheKey) if decoder == nil { @@ -76,6 +77,10 @@ func (iter *Iterator) ReadVal(obj interface{}) { return } decoder.Decode(ptr, iter) + if iter.depth != depth { + iter.ReportError("ReadVal", "unexpected mismatched nesting") + return + } } // WriteVal copy the go interface into underlying JSON, same as json.Marshal diff --git a/reflect_struct_decoder.go b/reflect_struct_decoder.go index 932641a..5ad5cc5 100644 --- a/reflect_struct_decoder.go +++ b/reflect_struct_decoder.go @@ -500,6 +500,9 @@ func (decoder *generalStructDecoder) Decode(ptr unsafe.Pointer, iter *Iterator) if !iter.readObjectStart() { return } + if !iter.incrementDepth() { + return + } var c byte for c = ','; c == ','; c = iter.nextToken() { decoder.decodeOneField(ptr, iter) @@ -510,6 +513,7 @@ func (decoder *generalStructDecoder) Decode(ptr unsafe.Pointer, iter *Iterator) if c != '}' { iter.ReportError("struct Decode", `expect }, but found `+string([]byte{c})) } + iter.decrementDepth() } func (decoder *generalStructDecoder) decodeOneField(ptr unsafe.Pointer, iter *Iterator) { @@ -571,6 +575,9 @@ func (decoder *oneFieldStructDecoder) Decode(ptr unsafe.Pointer, iter *Iterator) if !iter.readObjectStart() { return } + if !iter.incrementDepth() { + return + } for { if iter.readFieldHash() == decoder.fieldHash { decoder.fieldDecoder.Decode(ptr, iter) @@ -584,6 +591,7 @@ func (decoder *oneFieldStructDecoder) Decode(ptr unsafe.Pointer, iter *Iterator) if iter.Error != nil && iter.Error != io.EOF { iter.Error = fmt.Errorf("%v.%s", decoder.typ, iter.Error.Error()) } + iter.decrementDepth() } type twoFieldsStructDecoder struct { @@ -598,6 +606,9 @@ func (decoder *twoFieldsStructDecoder) Decode(ptr unsafe.Pointer, iter *Iterator if !iter.readObjectStart() { return } + if !iter.incrementDepth() { + return + } for { switch iter.readFieldHash() { case decoder.fieldHash1: @@ -614,6 +625,7 @@ func (decoder *twoFieldsStructDecoder) Decode(ptr unsafe.Pointer, iter *Iterator if iter.Error != nil && iter.Error != io.EOF { iter.Error = fmt.Errorf("%v.%s", decoder.typ, iter.Error.Error()) } + iter.decrementDepth() } type threeFieldsStructDecoder struct { @@ -630,6 +642,9 @@ func (decoder *threeFieldsStructDecoder) Decode(ptr unsafe.Pointer, iter *Iterat if !iter.readObjectStart() { return } + if !iter.incrementDepth() { + return + } for { switch iter.readFieldHash() { case decoder.fieldHash1: @@ -648,6 +663,7 @@ func (decoder *threeFieldsStructDecoder) Decode(ptr unsafe.Pointer, iter *Iterat if iter.Error != nil && iter.Error != io.EOF { iter.Error = fmt.Errorf("%v.%s", decoder.typ, iter.Error.Error()) } + iter.decrementDepth() } type fourFieldsStructDecoder struct { @@ -666,6 +682,9 @@ func (decoder *fourFieldsStructDecoder) Decode(ptr unsafe.Pointer, iter *Iterato if !iter.readObjectStart() { return } + if !iter.incrementDepth() { + return + } for { switch iter.readFieldHash() { case decoder.fieldHash1: @@ -686,6 +705,7 @@ func (decoder *fourFieldsStructDecoder) Decode(ptr unsafe.Pointer, iter *Iterato if iter.Error != nil && iter.Error != io.EOF { iter.Error = fmt.Errorf("%v.%s", decoder.typ, iter.Error.Error()) } + iter.decrementDepth() } type fiveFieldsStructDecoder struct { @@ -706,6 +726,9 @@ func (decoder *fiveFieldsStructDecoder) Decode(ptr unsafe.Pointer, iter *Iterato if !iter.readObjectStart() { return } + if !iter.incrementDepth() { + return + } for { switch iter.readFieldHash() { case decoder.fieldHash1: @@ -728,6 +751,7 @@ func (decoder *fiveFieldsStructDecoder) Decode(ptr unsafe.Pointer, iter *Iterato if iter.Error != nil && iter.Error != io.EOF { iter.Error = fmt.Errorf("%v.%s", decoder.typ, iter.Error.Error()) } + iter.decrementDepth() } type sixFieldsStructDecoder struct { @@ -750,6 +774,9 @@ func (decoder *sixFieldsStructDecoder) Decode(ptr unsafe.Pointer, iter *Iterator if !iter.readObjectStart() { return } + if !iter.incrementDepth() { + return + } for { switch iter.readFieldHash() { case decoder.fieldHash1: @@ -774,6 +801,7 @@ func (decoder *sixFieldsStructDecoder) Decode(ptr unsafe.Pointer, iter *Iterator if iter.Error != nil && iter.Error != io.EOF { iter.Error = fmt.Errorf("%v.%s", decoder.typ, iter.Error.Error()) } + iter.decrementDepth() } type sevenFieldsStructDecoder struct { @@ -798,6 +826,9 @@ func (decoder *sevenFieldsStructDecoder) Decode(ptr unsafe.Pointer, iter *Iterat if !iter.readObjectStart() { return } + if !iter.incrementDepth() { + return + } for { switch iter.readFieldHash() { case decoder.fieldHash1: @@ -824,6 +855,7 @@ func (decoder *sevenFieldsStructDecoder) Decode(ptr unsafe.Pointer, iter *Iterat if iter.Error != nil && iter.Error != io.EOF { iter.Error = fmt.Errorf("%v.%s", decoder.typ, iter.Error.Error()) } + iter.decrementDepth() } type eightFieldsStructDecoder struct { @@ -850,6 +882,9 @@ func (decoder *eightFieldsStructDecoder) Decode(ptr unsafe.Pointer, iter *Iterat if !iter.readObjectStart() { return } + if !iter.incrementDepth() { + return + } for { switch iter.readFieldHash() { case decoder.fieldHash1: @@ -878,6 +913,7 @@ func (decoder *eightFieldsStructDecoder) Decode(ptr unsafe.Pointer, iter *Iterat if iter.Error != nil && iter.Error != io.EOF { iter.Error = fmt.Errorf("%v.%s", decoder.typ, iter.Error.Error()) } + iter.decrementDepth() } type nineFieldsStructDecoder struct { @@ -906,6 +942,9 @@ func (decoder *nineFieldsStructDecoder) Decode(ptr unsafe.Pointer, iter *Iterato if !iter.readObjectStart() { return } + if !iter.incrementDepth() { + return + } for { switch iter.readFieldHash() { case decoder.fieldHash1: @@ -936,6 +975,7 @@ func (decoder *nineFieldsStructDecoder) Decode(ptr unsafe.Pointer, iter *Iterato if iter.Error != nil && iter.Error != io.EOF { iter.Error = fmt.Errorf("%v.%s", decoder.typ, iter.Error.Error()) } + iter.decrementDepth() } type tenFieldsStructDecoder struct { @@ -966,6 +1006,9 @@ func (decoder *tenFieldsStructDecoder) Decode(ptr unsafe.Pointer, iter *Iterator if !iter.readObjectStart() { return } + if !iter.incrementDepth() { + return + } for { switch iter.readFieldHash() { case decoder.fieldHash1: @@ -998,6 +1041,7 @@ func (decoder *tenFieldsStructDecoder) Decode(ptr unsafe.Pointer, iter *Iterator if iter.Error != nil && iter.Error != io.EOF { iter.Error = fmt.Errorf("%v.%s", decoder.typ, iter.Error.Error()) } + iter.decrementDepth() } type structFieldDecoder struct {