1
0
mirror of https://github.com/zhashkevych/go-sqlxmock.git synced 2024-11-16 17:41:57 +02:00

Invalidate memory scanned into sql.RawBytes

The intention of sql.RawBytes is for it to hold memory owned by the
 database. When used, it's content is only valid until the `Next`,
 `Scan` or `Close` is called on the `Rows`

To ensure that we meet this behaviour, when `[]byte` is used in a
 column, it's value is copied to a buffer that we keep track of for
 later invalidation. By doing this, incorrect use of `sql.RawBytes`
 values is exposed in tests that use go-sqlmock. Without this, when a
 real database is used and it's driver does share memory, then those
 issues would not be exposed until runtime (and in non-obvious ways)
This commit is contained in:
David Ackroyd 2019-06-21 13:28:10 +10:00
parent 6c8a572d09
commit d5879ee4b7
4 changed files with 486 additions and 0 deletions

35
rows.go
View File

@ -1,6 +1,7 @@
package sqlmock
import (
"bytes"
"database/sql/driver"
"encoding/csv"
"fmt"
@ -8,6 +9,8 @@ import (
"strings"
)
const invalidate = "☠☠☠ MEMORY OVERWRITTEN ☠☠☠ "
// CSVColumnParser is a function which converts trimmed csv
// column string to a []byte representation. Currently
// transforms NULL to nil
@ -23,6 +26,7 @@ type rowSets struct {
sets []*Rows
pos int
ex *ExpectedQuery
raw [][]byte
}
func (rs *rowSets) Columns() []string {
@ -30,6 +34,7 @@ func (rs *rowSets) Columns() []string {
}
func (rs *rowSets) Close() error {
rs.invalidateRaw()
rs.ex.rowsWereClosed = true
return rs.sets[rs.pos].closeErr
}
@ -38,11 +43,17 @@ func (rs *rowSets) Close() error {
func (rs *rowSets) Next(dest []driver.Value) error {
r := rs.sets[rs.pos]
r.pos++
rs.invalidateRaw()
if r.pos > len(r.rows) {
return io.EOF // per interface spec
}
for i, col := range r.rows[r.pos-1] {
if b, ok := rawBytes(col); ok {
rs.raw = append(rs.raw, b)
dest[i] = b
continue
}
dest[i] = col
}
@ -80,6 +91,30 @@ func (rs *rowSets) empty() bool {
return true
}
func rawBytes(col driver.Value) (_ []byte, ok bool) {
val, ok := col.([]byte)
if !ok || len(val) == 0 {
return nil, false
}
// Copy the bytes from the mocked row into a shared raw buffer, which we'll replace the content of later
// This allows scanning into sql.RawBytes to correctly become invalid on subsequent calls to Next(), Scan() or Close()
b := make([]byte, len(val))
copy(b, val)
return b, true
}
// Bytes that could have been scanned as sql.RawBytes are only valid until the next call to Next, Scan or Close.
// If those occur, we must replace their content to simulate the shared memory to expose misuse of sql.RawBytes
func (rs *rowSets) invalidateRaw() {
// Replace the content of slices previously returned
b := []byte(invalidate)
for _, r := range rs.raw {
copy(r, bytes.Repeat(b, len(r)/len(b)+1))
}
// Start with new slices for the next scan
rs.raw = nil
}
// Rows is a mocked collection of rows to
// return for Query result
type Rows struct {

31
rows_go13_test.go Normal file
View File

@ -0,0 +1,31 @@
// +build go1.3
package sqlmock
import (
"database/sql"
"testing"
)
func TestQueryRowBytesNotInvalidatedByNext_stringIntoRawBytes(t *testing.T) {
t.Parallel()
rows := NewRows([]string{"raw"}).
AddRow(`one binary value with some text!`).
AddRow(`two binary value with even more text than the first one`)
scan := func(rs *sql.Rows) ([]byte, error) {
var raw sql.RawBytes
return raw, rs.Scan(&raw)
}
want := [][]byte{[]byte(`one binary value with some text!`), []byte(`two binary value with even more text than the first one`)}
queryRowBytesNotInvalidatedByNext(t, rows, scan, want)
}
func TestQueryRowBytesNotInvalidatedByClose_stringIntoRawBytes(t *testing.T) {
t.Parallel()
rows := NewRows([]string{"raw"}).AddRow(`one binary value with some text!`)
scan := func(rs *sql.Rows) ([]byte, error) {
var raw sql.RawBytes
return raw, rs.Scan(&raw)
}
queryRowBytesNotInvalidatedByClose(t, rows, scan, []byte(`one binary value with some text!`))
}

View File

@ -3,6 +3,8 @@
package sqlmock
import (
"database/sql"
"encoding/json"
"fmt"
"testing"
)
@ -90,3 +92,114 @@ func TestQueryMultiRows(t *testing.T) {
t.Errorf("there were unfulfilled expectations: %s", err)
}
}
func TestQueryRowBytesInvalidatedByNext_jsonRawMessageIntoRawBytes(t *testing.T) {
t.Parallel()
replace := []byte(invalid)
rows := NewRows([]string{"raw"}).
AddRow(json.RawMessage(`{"thing": "one", "thing2": "two"}`)).
AddRow(json.RawMessage(`{"that": "foo", "this": "bar"}`))
scan := func(rs *sql.Rows) ([]byte, error) {
var raw sql.RawBytes
return raw, rs.Scan(&raw)
}
want := []struct {
Initial []byte
Replaced []byte
}{
{Initial: []byte(`{"thing": "one", "thing2": "two"}`), Replaced: replace[:len(replace)-6]},
{Initial: []byte(`{"that": "foo", "this": "bar"}`), Replaced: replace[:len(replace)-9]},
}
queryRowBytesInvalidatedByNext(t, rows, scan, want)
}
func TestQueryRowBytesNotInvalidatedByNext_jsonRawMessageIntoBytes(t *testing.T) {
t.Parallel()
rows := NewRows([]string{"raw"}).
AddRow(json.RawMessage(`{"thing": "one", "thing2": "two"}`)).
AddRow(json.RawMessage(`{"that": "foo", "this": "bar"}`))
scan := func(rs *sql.Rows) ([]byte, error) {
var b []byte
return b, rs.Scan(&b)
}
want := [][]byte{[]byte(`{"thing": "one", "thing2": "two"}`), []byte(`{"that": "foo", "this": "bar"}`)}
queryRowBytesNotInvalidatedByNext(t, rows, scan, want)
}
func TestQueryRowBytesNotInvalidatedByNext_bytesIntoCustomBytes(t *testing.T) {
t.Parallel()
rows := NewRows([]string{"raw"}).
AddRow([]byte(`one binary value with some text!`)).
AddRow([]byte(`two binary value with even more text than the first one`))
scan := func(rs *sql.Rows) ([]byte, error) {
type customBytes []byte
var b customBytes
return b, rs.Scan(&b)
}
want := [][]byte{[]byte(`one binary value with some text!`), []byte(`two binary value with even more text than the first one`)}
queryRowBytesNotInvalidatedByNext(t, rows, scan, want)
}
func TestQueryRowBytesNotInvalidatedByNext_jsonRawMessageIntoCustomBytes(t *testing.T) {
t.Parallel()
rows := NewRows([]string{"raw"}).
AddRow(json.RawMessage(`{"thing": "one", "thing2": "two"}`)).
AddRow(json.RawMessage(`{"that": "foo", "this": "bar"}`))
scan := func(rs *sql.Rows) ([]byte, error) {
type customBytes []byte
var b customBytes
return b, rs.Scan(&b)
}
want := [][]byte{[]byte(`{"thing": "one", "thing2": "two"}`), []byte(`{"that": "foo", "this": "bar"}`)}
queryRowBytesNotInvalidatedByNext(t, rows, scan, want)
}
func TestQueryRowBytesNotInvalidatedByClose_bytesIntoCustomBytes(t *testing.T) {
t.Parallel()
rows := NewRows([]string{"raw"}).AddRow([]byte(`one binary value with some text!`))
scan := func(rs *sql.Rows) ([]byte, error) {
type customBytes []byte
var b customBytes
return b, rs.Scan(&b)
}
queryRowBytesNotInvalidatedByClose(t, rows, scan, []byte(`one binary value with some text!`))
}
func TestQueryRowBytesInvalidatedByClose_jsonRawMessageIntoRawBytes(t *testing.T) {
t.Parallel()
replace := []byte(invalid)
rows := NewRows([]string{"raw"}).AddRow(json.RawMessage(`{"thing": "one", "thing2": "two"}`))
scan := func(rs *sql.Rows) ([]byte, error) {
var raw sql.RawBytes
return raw, rs.Scan(&raw)
}
want := struct {
Initial []byte
Replaced []byte
}{
Initial: []byte(`{"thing": "one", "thing2": "two"}`),
Replaced: replace[:len(replace)-6],
}
queryRowBytesInvalidatedByClose(t, rows, scan, want)
}
func TestQueryRowBytesNotInvalidatedByClose_jsonRawMessageIntoBytes(t *testing.T) {
t.Parallel()
rows := NewRows([]string{"raw"}).AddRow(json.RawMessage(`{"thing": "one", "thing2": "two"}`))
scan := func(rs *sql.Rows) ([]byte, error) {
var b []byte
return b, rs.Scan(&b)
}
queryRowBytesNotInvalidatedByClose(t, rows, scan, []byte(`{"thing": "one", "thing2": "two"}`))
}
func TestQueryRowBytesNotInvalidatedByClose_jsonRawMessageIntoCustomBytes(t *testing.T) {
t.Parallel()
rows := NewRows([]string{"raw"}).AddRow(json.RawMessage(`{"thing": "one", "thing2": "two"}`))
scan := func(rs *sql.Rows) ([]byte, error) {
type customBytes []byte
var b customBytes
return b, rs.Scan(&b)
}
queryRowBytesNotInvalidatedByClose(t, rows, scan, []byte(`{"thing": "one", "thing2": "two"}`))
}

View File

@ -1,11 +1,14 @@
package sqlmock
import (
"bytes"
"database/sql"
"fmt"
"testing"
)
const invalid = `☠☠☠ MEMORY OVERWRITTEN ☠☠☠ `
func ExampleRows() {
db, mock, err := New()
if err != nil {
@ -88,6 +91,52 @@ func ExampleRows_closeError() {
// Output: got error: close error
}
func ExampleRows_rawBytes() {
db, mock, err := New()
if err != nil {
fmt.Println("failed to open sqlmock database:", err)
}
defer db.Close()
rows := NewRows([]string{"id", "binary"}).
AddRow(1, []byte(`one binary value with some text!`)).
AddRow(2, []byte(`two binary value with even more text than the first one`))
mock.ExpectQuery("SELECT").WillReturnRows(rows)
rs, _ := db.Query("SELECT")
defer rs.Close()
type scanned struct {
id int
raw sql.RawBytes
}
fmt.Println("initial read...")
var ss []scanned
for rs.Next() {
var s scanned
rs.Scan(&s.id, &s.raw)
ss = append(ss, s)
fmt.Println("scanned id:", s.id, "and raw:", string(s.raw))
}
if rs.Err() != nil {
fmt.Println("got rows error:", rs.Err())
}
fmt.Println("after reading all...")
for _, s := range ss {
fmt.Println("scanned id:", s.id, "and raw:", string(s.raw))
}
// Output:
// initial read...
// scanned id: 1 and raw: one binary value with some text!
// scanned id: 2 and raw: two binary value with even more text than the first one
// after reading all...
// scanned id: 1 and raw: ☠☠☠ MEMORY OVERWRITTEN ☠
// scanned id: 2 and raw: ☠☠☠ MEMORY OVERWRITTEN ☠☠☠ ☠☠☠ MEMORY
}
func ExampleRows_expectToBeClosed() {
db, mock, err := New()
if err != nil {
@ -260,6 +309,90 @@ func TestQuerySingleRow(t *testing.T) {
}
}
func TestQueryRowBytesInvalidatedByNext_bytesIntoRawBytes(t *testing.T) {
t.Parallel()
replace := []byte(invalid)
rows := NewRows([]string{"raw"}).
AddRow([]byte(`one binary value with some text!`)).
AddRow([]byte(`two binary value with even more text than the first one`))
scan := func(rs *sql.Rows) ([]byte, error) {
var raw sql.RawBytes
return raw, rs.Scan(&raw)
}
want := []struct {
Initial []byte
Replaced []byte
}{
{Initial: []byte(`one binary value with some text!`), Replaced: replace[:len(replace)-7]},
{Initial: []byte(`two binary value with even more text than the first one`), Replaced: bytes.Join([][]byte{replace, replace[:len(replace)-23]}, nil)},
}
queryRowBytesInvalidatedByNext(t, rows, scan, want)
}
func TestQueryRowBytesNotInvalidatedByNext_bytesIntoBytes(t *testing.T) {
t.Parallel()
rows := NewRows([]string{"raw"}).
AddRow([]byte(`one binary value with some text!`)).
AddRow([]byte(`two binary value with even more text than the first one`))
scan := func(rs *sql.Rows) ([]byte, error) {
var b []byte
return b, rs.Scan(&b)
}
want := [][]byte{[]byte(`one binary value with some text!`), []byte(`two binary value with even more text than the first one`)}
queryRowBytesNotInvalidatedByNext(t, rows, scan, want)
}
func TestQueryRowBytesNotInvalidatedByNext_stringIntoBytes(t *testing.T) {
t.Parallel()
rows := NewRows([]string{"raw"}).
AddRow(`one binary value with some text!`).
AddRow(`two binary value with even more text than the first one`)
scan := func(rs *sql.Rows) ([]byte, error) {
var b []byte
return b, rs.Scan(&b)
}
want := [][]byte{[]byte(`one binary value with some text!`), []byte(`two binary value with even more text than the first one`)}
queryRowBytesNotInvalidatedByNext(t, rows, scan, want)
}
func TestQueryRowBytesInvalidatedByClose_bytesIntoRawBytes(t *testing.T) {
t.Parallel()
replace := []byte(invalid)
rows := NewRows([]string{"raw"}).AddRow([]byte(`one binary value with some text!`))
scan := func(rs *sql.Rows) ([]byte, error) {
var raw sql.RawBytes
return raw, rs.Scan(&raw)
}
want := struct {
Initial []byte
Replaced []byte
}{
Initial: []byte(`one binary value with some text!`),
Replaced: replace[:len(replace)-7],
}
queryRowBytesInvalidatedByClose(t, rows, scan, want)
}
func TestQueryRowBytesNotInvalidatedByClose_bytesIntoBytes(t *testing.T) {
t.Parallel()
rows := NewRows([]string{"raw"}).AddRow([]byte(`one binary value with some text!`))
scan := func(rs *sql.Rows) ([]byte, error) {
var b []byte
return b, rs.Scan(&b)
}
queryRowBytesNotInvalidatedByClose(t, rows, scan, []byte(`one binary value with some text!`))
}
func TestQueryRowBytesNotInvalidatedByClose_stringIntoBytes(t *testing.T) {
t.Parallel()
rows := NewRows([]string{"raw"}).AddRow(`one binary value with some text!`)
scan := func(rs *sql.Rows) ([]byte, error) {
var b []byte
return b, rs.Scan(&b)
}
queryRowBytesNotInvalidatedByClose(t, rows, scan, []byte(`one binary value with some text!`))
}
func TestRowsScanError(t *testing.T) {
t.Parallel()
db, mock, err := New()
@ -363,3 +496,177 @@ func TestEmptyRowSets(t *testing.T) {
t.Fatalf("expected rowset 3, to be empty, but it was not")
}
}
func queryRowBytesInvalidatedByNext(t *testing.T, rows *Rows, scan func(*sql.Rows) ([]byte, error), want []struct {
Initial []byte
Replaced []byte
}) {
db, mock, err := New()
if err != nil {
t.Fatalf("an error '%s' was not expected when opening a stub database connection", err)
}
defer db.Close()
mock.ExpectQuery("SELECT").WillReturnRows(rows)
rs, err := db.Query("SELECT")
if err != nil {
t.Fatalf("failed to query rows: %s", err)
}
if !rs.Next() || rs.Err() != nil {
t.Fatal("unexpected error on first row retrieval")
}
var count int
for i := 0; ; i++ {
count++
b, err := scan(rs)
if err != nil {
t.Fatalf("unexpected error scanning row: %s", err)
}
if exp := want[i].Initial; !bytes.Equal(b, exp) {
t.Fatalf("expected raw value to be '%s' (len:%d), but got [%T]:%s (len:%d)", exp, len(exp), b, b, len(b))
}
next := rs.Next()
if exp := want[i].Replaced; !bytes.Equal(b, exp) {
t.Fatalf("expected raw value to be replaced with '%s' (len:%d) after calling Next(), but got [%T]:%s (len:%d)", exp, len(exp), b, b, len(b))
}
if !next {
break
}
}
if err := rs.Err(); err != nil {
t.Fatalf("row iteration failed: %s", err)
}
if exp := len(want); count != exp {
t.Fatalf("incorrect number of rows exp: %d, but got %d", exp, count)
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Fatal(err)
}
}
func queryRowBytesNotInvalidatedByNext(t *testing.T, rows *Rows, scan func(*sql.Rows) ([]byte, error), want [][]byte) {
db, mock, err := New()
if err != nil {
t.Fatalf("an error '%s' was not expected when opening a stub database connection", err)
}
defer db.Close()
mock.ExpectQuery("SELECT").WillReturnRows(rows)
rs, err := db.Query("SELECT")
if err != nil {
t.Fatalf("failed to query rows: %s", err)
}
if !rs.Next() || rs.Err() != nil {
t.Fatal("unexpected error on first row retrieval")
}
var count int
for i := 0; ; i++ {
count++
b, err := scan(rs)
if err != nil {
t.Fatalf("unexpected error scanning row: %s", err)
}
if exp := want[i]; !bytes.Equal(b, exp) {
t.Fatalf("expected raw value to be '%s' (len:%d), but got [%T]:%s (len:%d)", exp, len(exp), b, b, len(b))
}
next := rs.Next()
if exp := want[i]; !bytes.Equal(b, exp) {
t.Fatalf("expected raw value to be replaced with '%s' (len:%d) after calling Next(), but got [%T]:%s (len:%d)", exp, len(exp), b, b, len(b))
}
if !next {
break
}
}
if err := rs.Err(); err != nil {
t.Fatalf("row iteration failed: %s", err)
}
if exp := len(want); count != exp {
t.Fatalf("incorrect number of rows exp: %d, but got %d", exp, count)
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Fatal(err)
}
}
func queryRowBytesInvalidatedByClose(t *testing.T, rows *Rows, scan func(*sql.Rows) ([]byte, error), want struct {
Initial []byte
Replaced []byte
}) {
db, mock, err := New()
if err != nil {
t.Fatalf("an error '%s' was not expected when opening a stub database connection", err)
}
defer db.Close()
mock.ExpectQuery("SELECT").WillReturnRows(rows)
rs, err := db.Query("SELECT")
if err != nil {
t.Fatalf("failed to query rows: %s", err)
}
if !rs.Next() || rs.Err() != nil {
t.Fatal("unexpected error on first row retrieval")
}
b, err := scan(rs)
if err != nil {
t.Fatalf("unexpected error scanning row: %s", err)
}
if !bytes.Equal(b, want.Initial) {
t.Fatalf("expected raw value to be '%s' (len:%d), but got [%T]:%s (len:%d)", want.Initial, len(want.Initial), b, b, len(b))
}
if err := rs.Close(); err != nil {
t.Fatalf("unexpected error closing rows: %s", err)
}
if !bytes.Equal(b, want.Replaced) {
t.Fatalf("expected raw value to be replaced with '%s' (len:%d) after calling Next(), but got [%T]:%s (len:%d)", want.Replaced, len(want.Replaced), b, b, len(b))
}
if err := rs.Err(); err != nil {
t.Fatalf("row iteration failed: %s", err)
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Fatal(err)
}
}
func queryRowBytesNotInvalidatedByClose(t *testing.T, rows *Rows, scan func(*sql.Rows) ([]byte, error), want []byte) {
db, mock, err := New()
if err != nil {
t.Fatalf("an error '%s' was not expected when opening a stub database connection", err)
}
defer db.Close()
mock.ExpectQuery("SELECT").WillReturnRows(rows)
rs, err := db.Query("SELECT")
if err != nil {
t.Fatalf("failed to query rows: %s", err)
}
if !rs.Next() || rs.Err() != nil {
t.Fatal("unexpected error on first row retrieval")
}
b, err := scan(rs)
if err != nil {
t.Fatalf("unexpected error scanning row: %s", err)
}
if !bytes.Equal(b, want) {
t.Fatalf("expected raw value to be '%s' (len:%d), but got [%T]:%s (len:%d)", want, len(want), b, b, len(b))
}
if err := rs.Close(); err != nil {
t.Fatalf("unexpected error closing rows: %s", err)
}
if !bytes.Equal(b, want) {
t.Fatalf("expected raw value to be replaced with '%s' (len:%d) after calling Next(), but got [%T]:%s (len:%d)", want, len(want), b, b, len(b))
}
if err := rs.Err(); err != nil {
t.Fatalf("row iteration failed: %s", err)
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Fatal(err)
}
}