diff --git a/feature_iter_float.go b/feature_iter_float.go index d4060b9..bf9c3a4 100644 --- a/feature_iter_float.go +++ b/feature_iter_float.go @@ -10,7 +10,6 @@ var floatDigits []int8 const invalidCharForNumber = int8(-1) const endOfNumber = int8(-2) const dotInNumber = int8(-3) -const uint64SafeToMultiple10 = uint64(0xffffffffffffffff) / 10 - 10 func init() { floatDigits = make([]int8, 256) diff --git a/feature_iter_int.go b/feature_iter_int.go index 693c777..067b88b 100644 --- a/feature_iter_int.go +++ b/feature_iter_int.go @@ -1,13 +1,15 @@ package jsoniter - var intDigits []int8 -const maxUint64 = (1<<64 - 1) -const cutoffUint64 = maxUint64/10 + 1 -const maxUint32 = (1<<32 - 1) -const cutoffUint32 = maxUint32/10 + 1 -const int32SafeToMultiply10 = uint32(int32(0x7fffffff)/10 - 10) -const uint32SafeToMultiply10 = uint32(0xffffffff)/10 - 10 + +const int8SafeToMultiply10 = uint32(int8(0x7f) / 10 - 10) +const uint8SafeToMultiply10 = uint32(0xff) / 10 - 10 +const int16SafeToMultiply10 = uint32(int16(0x7fff) / 10 - 10) +const uint16SafeToMultiply10 = uint32(0xffff) / 10 - 10 +const int32SafeToMultiply10 = uint32(int32(0x7fffffff) / 10 - 10) +const uint32SafeToMultiply10 = uint32(0xffffffff) / 10 - 10 +const uint64SafeToMultiple10 = uint64(0xffffffffffffffff) / 10 - 10 +const int64SafeToMultiple10 = uint64(int64(0x7fffffffffffffff) / 10 - 10) func init() { intDigits = make([]int8, 256) @@ -19,97 +21,38 @@ func init() { } } -// ReadUint reads a json object as Uint -func (iter *Iterator) ReadUint() (ret uint) { - val := iter.ReadUint64() - converted := uint(val) - if uint64(converted) != val { - iter.reportError("ReadUint", "int overflow") - return - } - return converted +func (iter *Iterator) ReadUint() uint { + return uint(iter.ReadUint64()) } -// ReadUint8 reads a json object as Uint8 -func (iter *Iterator) ReadUint8() (ret uint8) { - val := iter.ReadUint64() - converted := uint8(val) - if uint64(converted) != val { - iter.reportError("ReadUint8", "int overflow") - return - } - return converted +func (iter *Iterator) ReadInt() int { + return int(iter.ReadInt64()) } -// ReadUint16 reads a json object as Uint16 -func (iter *Iterator) ReadUint16() (ret uint16) { - val := iter.ReadUint64() - converted := uint16(val) - if uint64(converted) != val { - iter.reportError("ReadUint16", "int overflow") - return - } - return converted -} - -// ReadUint64 reads a json object as Uint64 -func (iter *Iterator) ReadUint64() (ret uint64) { +func (iter *Iterator) ReadInt8() int8 { c := iter.nextToken() - v := hexDigits[c] - if v == 0 { - return 0 // single zero + if c == '-' { + return -int8(iter.readUint32(int8SafeToMultiply10, iter.readByte())) + } else { + return int8(iter.readUint32(int8SafeToMultiply10, c)) } - if v == 255 { - iter.reportError("ReadUint64", "unexpected character") - return - } - for { - if ret >= cutoffUint64 { - iter.reportError("ReadUint64", "overflow") - return - } - ret = ret*10 + uint64(v) - c = iter.readByte() - v = hexDigits[c] - if v == 255 { - iter.unreadByte() - break - } - } - return ret } -// ReadInt reads a json object as Int -func (iter *Iterator) ReadInt() (ret int) { - val := iter.ReadInt64() - converted := int(val) - if int64(converted) != val { - iter.reportError("ReadInt", "int overflow") - return - } - return converted +func (iter *Iterator) ReadUint8() (ret uint8) { + return uint8(iter.readUint32(uint8SafeToMultiply10, iter.nextToken())) } -// ReadInt8 reads a json object as Int8 -func (iter *Iterator) ReadInt8() (ret int8) { - val := iter.ReadInt64() - converted := int8(val) - if int64(converted) != val { - iter.reportError("ReadInt8", "int overflow") - return +func (iter *Iterator) ReadInt16() int16 { + c := iter.nextToken() + if c == '-' { + return -int16(iter.readUint32(int16SafeToMultiply10, iter.readByte())) + } else { + return int16(iter.readUint32(int16SafeToMultiply10, c)) } - return converted } -// ReadInt16 reads a json object as Int16 -func (iter *Iterator) ReadInt16() (ret int16) { - val := iter.ReadInt64() - converted := int16(val) - if int64(converted) != val { - iter.reportError("ReadInt16", "int overflow") - return - } - return converted +func (iter *Iterator) ReadUint16() uint16 { + return uint16(iter.readUint32(uint16SafeToMultiply10, iter.nextToken())) } func (iter *Iterator) ReadInt32() int32 { @@ -137,15 +80,21 @@ func (iter *Iterator) readUint32(safeToMultiply10 uint32, c byte) (ret uint32) { value := uint32(ind) for { for i := iter.head; i < iter.tail; i++ { - if value > safeToMultiply10 { - iter.reportError("readUint32", "overflow") - return - } ind = intDigits[iter.buf[i]] if ind == invalidCharForNumber { return value } - value = (value << 3) + (value << 1) + uint32(ind) + if value > safeToMultiply10 { + value2 := (value << 3) + (value << 1) + uint32(ind) + if value2 < safeToMultiply10 * 10 { + iter.reportError("readUint32", "overflow") + return + } else { + value = value2 + continue + } + } + value = (value << 3) + (value << 1) + uint32(ind) } if (!iter.loadMore()) { return value @@ -153,19 +102,49 @@ func (iter *Iterator) readUint32(safeToMultiply10 uint32, c byte) (ret uint32) { } } -// ReadInt64 reads a json object as Int64 -func (iter *Iterator) ReadInt64() (ret int64) { +func (iter *Iterator) ReadInt64() int64 { c := iter.nextToken() - if iter.Error != nil { + if c == '-' { + return -int64(iter.readUint64(int64SafeToMultiple10, iter.readByte())) + } else { + return int64(iter.readUint64(int64SafeToMultiple10, c)) + } +} + +func (iter *Iterator) ReadUint64() uint64 { + return iter.readUint64(uint64SafeToMultiple10, iter.nextToken()) +} + +func (iter *Iterator) readUint64(safeToMultiply10 uint64, c byte) (ret uint64) { + ind := intDigits[c] + if ind == 0 { + return 0 // single zero + } + if ind == invalidCharForNumber { + iter.reportError("readUint64", "unexpected character: " + string([]byte{byte(ind)})) return } - - /* optional leading minus */ - if c == '-' { - n := iter.ReadUint64() - return -int64(n) + value := uint64(ind) + for { + for i := iter.head; i < iter.tail; i++ { + ind = intDigits[iter.buf[i]] + if ind == invalidCharForNumber { + return value + } + if value > safeToMultiply10 { + value2 := (value << 3) + (value << 1) + uint64(ind) + if value2 < safeToMultiply10 * 10 { + iter.reportError("readUint64", "overflow") + return + } else { + value = value2 + continue + } + } + value = (value << 3) + (value << 1) + uint64(ind) + } + if (!iter.loadMore()) { + return value + } } - iter.unreadByte() - n := iter.ReadUint64() - return int64(n) } diff --git a/iterator.go b/iterator.go index 528626b..7282b58 100644 --- a/iterator.go +++ b/iterator.go @@ -331,16 +331,8 @@ func (iter *Iterator) readU4() (ret rune) { return } if c >= '0' && c <= '9' { - if ret >= cutoffUint32 { - iter.reportError("readU4", "overflow") - return - } ret = ret*16 + rune(c-'0') } else if c >= 'a' && c <= 'f' { - if ret >= cutoffUint32 { - iter.reportError("readU4", "overflow") - return - } ret = ret*16 + rune(c-'a'+10) } else { iter.reportError("readU4", "expects 0~9 or a~f") diff --git a/jsoniter_int_test.go b/jsoniter_int_test.go index 71982f1..2dae7b2 100644 --- a/jsoniter_int_test.go +++ b/jsoniter_int_test.go @@ -10,70 +10,16 @@ import ( "io/ioutil" ) -func Test_decode_decode_uint64_0(t *testing.T) { - iter := Parse(bytes.NewBufferString("0"), 4096) - val := iter.ReadUint64() - if iter.Error != nil { - t.Fatal(iter.Error) - } - if val != 0 { - t.Fatal(val) - } -} -func Test_decode_uint64_1(t *testing.T) { - iter := Parse(bytes.NewBufferString("1"), 4096) - val := iter.ReadUint64() - if val != 1 { - t.Fatal(val) - } -} - -func Test_decode_uint64_100(t *testing.T) { - iter := Parse(bytes.NewBufferString("100"), 4096) - val := iter.ReadUint64() - if val != 100 { - t.Fatal(val) - } -} - -func Test_decode_uint64_100_comma(t *testing.T) { - iter := Parse(bytes.NewBufferString("100,"), 4096) - val := iter.ReadUint64() - if iter.Error != nil { - t.Fatal(iter.Error) - } - if val != 100 { - t.Fatal(val) - } -} - -func Test_decode_uint64_invalid(t *testing.T) { - iter := Parse(bytes.NewBufferString(","), 4096) +func Test_read_uint64_invalid(t *testing.T) { + should := require.New(t) + iter := ParseString(",") iter.ReadUint64() - if iter.Error == nil { - t.FailNow() - } -} - -func Test_decode_int64_100(t *testing.T) { - iter := Parse(bytes.NewBufferString("100"), 4096) - val := iter.ReadInt64() - if val != 100 { - t.Fatal(val) - } -} - -func Test_decode_int64_minus_100(t *testing.T) { - iter := Parse(bytes.NewBufferString("-100"), 4096) - val := iter.ReadInt64() - if val != -100 { - t.Fatal(val) - } + should.NotNil(iter.Error) } func Test_read_int32(t *testing.T) { - inputs := []string{`1`, `12`, `123`} + inputs := []string{`1`, `12`, `123`, `1234`, `12345`, `123456`, `2147483647`} for _, input := range inputs { t.Run(fmt.Sprintf("%v", input), func(t *testing.T) { should := require.New(t) @@ -82,6 +28,13 @@ func Test_read_int32(t *testing.T) { should.Nil(err) should.Equal(int32(expected), iter.ReadInt32()) }) + t.Run(fmt.Sprintf("%v", input), func(t *testing.T) { + should := require.New(t) + iter := Parse(bytes.NewBufferString(input), 2) + expected, err := strconv.ParseInt(input, 10, 32) + should.Nil(err) + should.Equal(int32(expected), iter.ReadInt32()) + }) } } @@ -93,6 +46,34 @@ func Test_read_int32_overflow(t *testing.T) { should.NotNil(iter.Error) } +func Test_read_int64(t *testing.T) { + inputs := []string{`1`, `12`, `123`, `1234`, `12345`, `123456`, `9223372036854775807`} + for _, input := range inputs { + t.Run(fmt.Sprintf("%v", input), func(t *testing.T) { + should := require.New(t) + iter := ParseString(input) + expected, err := strconv.ParseInt(input, 10, 64) + should.Nil(err) + should.Equal(expected, iter.ReadInt64()) + }) + t.Run(fmt.Sprintf("%v", input), func(t *testing.T) { + should := require.New(t) + iter := Parse(bytes.NewBufferString(input), 2) + expected, err := strconv.ParseInt(input, 10, 64) + should.Nil(err) + should.Equal(expected, iter.ReadInt64()) + }) + } +} + +func Test_read_int64_overflow(t *testing.T) { + should := require.New(t) + input := "123456789123456789" + iter := ParseString(input) + iter.ReadInt64() + should.NotNil(iter.Error) +} + func Test_write_uint8(t *testing.T) { vals := []uint8{0, 1, 11, 111, 255} for _, val := range vals {