diff --git a/feature_stream.go b/feature_stream.go index c770783..765ef70 100644 --- a/feature_stream.go +++ b/feature_stream.go @@ -80,6 +80,19 @@ func (b *Stream) writeByte(c byte) error { return nil } +func (b *Stream) writeTwoBytes(c1 byte, c2 byte) error { + if b.Error != nil { + return b.Error + } + if b.Available() <= 1 && b.Flush() != nil { + return b.Error + } + b.buf[b.n] = c1 + b.buf[b.n + 1] = c2 + b.n += 2 + return nil +} + // Flush writes any buffered data to the underlying io.Writer. func (b *Stream) Flush() error { if b.Error != nil { @@ -118,20 +131,67 @@ func (b *Stream) WriteRaw(s string) { b.n += n } -func (b *Stream) WriteString(s string) { - b.writeByte('"') - for len(s) > b.Available() && b.Error == nil { - n := copy(b.buf[b.n:], s) - b.n += n - s = s[n:] - b.Flush() +func (stream *Stream) WriteString(s string) { + valLen := len(s) + toWriteLen := valLen + bufLengthMinusTwo := len(stream.buf) - 2 // make room for the quotes + if stream.n + toWriteLen > bufLengthMinusTwo { + toWriteLen = bufLengthMinusTwo - stream.n } - if b.Error != nil { + if toWriteLen < 0 { + stream.Flush() + if stream.n + toWriteLen > bufLengthMinusTwo { + toWriteLen = bufLengthMinusTwo - stream.n + } + } + n := stream.n + stream.buf[n] = '"' + n++ + // write string, the fast path, without utf8 and escape support + i := 0 + for ; i < toWriteLen; i++ { + c := s[i] + if c > 31 && c != '"' && c != '\\' { + stream.buf[n] = c + n++ + } else { + break; + } + } + if i == valLen { + stream.buf[n] = '"' + n++ + stream.n = n return } - n := copy(b.buf[b.n:], s) - b.n += n - b.writeByte('"') + stream.n = n + // for the remaining parts, we process them char by char + stream.writeStringSlowPath(s, i, valLen); + stream.writeByte('"') +} + +func (stream *Stream) writeStringSlowPath(s string, i int, valLen int) { + for ; i < valLen; i++ { + c := s[i] + switch (c) { + case '"': + stream.writeTwoBytes('\\', '"') + case '\\': + stream.writeTwoBytes('\\', '\\') + case '\b': + stream.writeTwoBytes('\\', 'b') + case '\f': + stream.writeTwoBytes('\\', 'f') + case '\n': + stream.writeTwoBytes('\\', 'n') + case '\r': + stream.writeTwoBytes('\\', 'r') + case '\t': + stream.writeTwoBytes('\\', 't') + default: + stream.writeByte(c); + } + } } func (stream *Stream) WriteNil() { diff --git a/jsoniter_string_test.go b/jsoniter_string_test.go index 45eff8c..33fbd85 100644 --- a/jsoniter_string_test.go +++ b/jsoniter_string_test.go @@ -67,12 +67,12 @@ func Test_read_string_via_read(t *testing.T) { func Test_write_string(t *testing.T) { should := require.New(t) - buf := &bytes.Buffer{} - stream := NewStream(buf, 4096) - stream.WriteString("hello") - stream.Flush() - should.Nil(stream.Error) - should.Equal(`"hello"`, buf.String()) + str, err := MarshalToString("hello") + should.Equal(`"hello"`, str) + should.Nil(err) + str, err = MarshalToString(`hel"lo`) + should.Equal(`"hel\"lo"`, str) + should.Nil(err) } func Test_write_val_string(t *testing.T) {