You've already forked go-sqlxmock
mirror of
https://github.com/zhashkevych/go-sqlxmock.git
synced 2025-06-12 21:47:29 +02:00
Invalidate memory scanned into sql.RawBytes
The intention of sql.RawBytes is for it to hold memory owned by the database. When used, it's content is only valid until the `Next`, `Scan` or `Close` is called on the `Rows` To ensure that we meet this behaviour, when `[]byte` is used in a column, it's value is copied to a buffer that we keep track of for later invalidation. By doing this, incorrect use of `sql.RawBytes` values is exposed in tests that use go-sqlmock. Without this, when a real database is used and it's driver does share memory, then those issues would not be exposed until runtime (and in non-obvious ways)
This commit is contained in:
307
rows_test.go
307
rows_test.go
@ -1,11 +1,14 @@
|
||||
package sqlmock
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"testing"
|
||||
)
|
||||
|
||||
const invalid = `☠☠☠ MEMORY OVERWRITTEN ☠☠☠ `
|
||||
|
||||
func ExampleRows() {
|
||||
db, mock, err := New()
|
||||
if err != nil {
|
||||
@ -88,6 +91,52 @@ func ExampleRows_closeError() {
|
||||
// Output: got error: close error
|
||||
}
|
||||
|
||||
func ExampleRows_rawBytes() {
|
||||
db, mock, err := New()
|
||||
if err != nil {
|
||||
fmt.Println("failed to open sqlmock database:", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
rows := NewRows([]string{"id", "binary"}).
|
||||
AddRow(1, []byte(`one binary value with some text!`)).
|
||||
AddRow(2, []byte(`two binary value with even more text than the first one`))
|
||||
|
||||
mock.ExpectQuery("SELECT").WillReturnRows(rows)
|
||||
|
||||
rs, _ := db.Query("SELECT")
|
||||
defer rs.Close()
|
||||
|
||||
type scanned struct {
|
||||
id int
|
||||
raw sql.RawBytes
|
||||
}
|
||||
fmt.Println("initial read...")
|
||||
var ss []scanned
|
||||
for rs.Next() {
|
||||
var s scanned
|
||||
rs.Scan(&s.id, &s.raw)
|
||||
ss = append(ss, s)
|
||||
fmt.Println("scanned id:", s.id, "and raw:", string(s.raw))
|
||||
}
|
||||
|
||||
if rs.Err() != nil {
|
||||
fmt.Println("got rows error:", rs.Err())
|
||||
}
|
||||
|
||||
fmt.Println("after reading all...")
|
||||
for _, s := range ss {
|
||||
fmt.Println("scanned id:", s.id, "and raw:", string(s.raw))
|
||||
}
|
||||
// Output:
|
||||
// initial read...
|
||||
// scanned id: 1 and raw: one binary value with some text!
|
||||
// scanned id: 2 and raw: two binary value with even more text than the first one
|
||||
// after reading all...
|
||||
// scanned id: 1 and raw: ☠☠☠ MEMORY OVERWRITTEN ☠
|
||||
// scanned id: 2 and raw: ☠☠☠ MEMORY OVERWRITTEN ☠☠☠ ☠☠☠ MEMORY
|
||||
}
|
||||
|
||||
func ExampleRows_expectToBeClosed() {
|
||||
db, mock, err := New()
|
||||
if err != nil {
|
||||
@ -260,6 +309,90 @@ func TestQuerySingleRow(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestQueryRowBytesInvalidatedByNext_bytesIntoRawBytes(t *testing.T) {
|
||||
t.Parallel()
|
||||
replace := []byte(invalid)
|
||||
rows := NewRows([]string{"raw"}).
|
||||
AddRow([]byte(`one binary value with some text!`)).
|
||||
AddRow([]byte(`two binary value with even more text than the first one`))
|
||||
scan := func(rs *sql.Rows) ([]byte, error) {
|
||||
var raw sql.RawBytes
|
||||
return raw, rs.Scan(&raw)
|
||||
}
|
||||
want := []struct {
|
||||
Initial []byte
|
||||
Replaced []byte
|
||||
}{
|
||||
{Initial: []byte(`one binary value with some text!`), Replaced: replace[:len(replace)-7]},
|
||||
{Initial: []byte(`two binary value with even more text than the first one`), Replaced: bytes.Join([][]byte{replace, replace[:len(replace)-23]}, nil)},
|
||||
}
|
||||
queryRowBytesInvalidatedByNext(t, rows, scan, want)
|
||||
}
|
||||
|
||||
func TestQueryRowBytesNotInvalidatedByNext_bytesIntoBytes(t *testing.T) {
|
||||
t.Parallel()
|
||||
rows := NewRows([]string{"raw"}).
|
||||
AddRow([]byte(`one binary value with some text!`)).
|
||||
AddRow([]byte(`two binary value with even more text than the first one`))
|
||||
scan := func(rs *sql.Rows) ([]byte, error) {
|
||||
var b []byte
|
||||
return b, rs.Scan(&b)
|
||||
}
|
||||
want := [][]byte{[]byte(`one binary value with some text!`), []byte(`two binary value with even more text than the first one`)}
|
||||
queryRowBytesNotInvalidatedByNext(t, rows, scan, want)
|
||||
}
|
||||
|
||||
func TestQueryRowBytesNotInvalidatedByNext_stringIntoBytes(t *testing.T) {
|
||||
t.Parallel()
|
||||
rows := NewRows([]string{"raw"}).
|
||||
AddRow(`one binary value with some text!`).
|
||||
AddRow(`two binary value with even more text than the first one`)
|
||||
scan := func(rs *sql.Rows) ([]byte, error) {
|
||||
var b []byte
|
||||
return b, rs.Scan(&b)
|
||||
}
|
||||
want := [][]byte{[]byte(`one binary value with some text!`), []byte(`two binary value with even more text than the first one`)}
|
||||
queryRowBytesNotInvalidatedByNext(t, rows, scan, want)
|
||||
}
|
||||
|
||||
func TestQueryRowBytesInvalidatedByClose_bytesIntoRawBytes(t *testing.T) {
|
||||
t.Parallel()
|
||||
replace := []byte(invalid)
|
||||
rows := NewRows([]string{"raw"}).AddRow([]byte(`one binary value with some text!`))
|
||||
scan := func(rs *sql.Rows) ([]byte, error) {
|
||||
var raw sql.RawBytes
|
||||
return raw, rs.Scan(&raw)
|
||||
}
|
||||
want := struct {
|
||||
Initial []byte
|
||||
Replaced []byte
|
||||
}{
|
||||
Initial: []byte(`one binary value with some text!`),
|
||||
Replaced: replace[:len(replace)-7],
|
||||
}
|
||||
queryRowBytesInvalidatedByClose(t, rows, scan, want)
|
||||
}
|
||||
|
||||
func TestQueryRowBytesNotInvalidatedByClose_bytesIntoBytes(t *testing.T) {
|
||||
t.Parallel()
|
||||
rows := NewRows([]string{"raw"}).AddRow([]byte(`one binary value with some text!`))
|
||||
scan := func(rs *sql.Rows) ([]byte, error) {
|
||||
var b []byte
|
||||
return b, rs.Scan(&b)
|
||||
}
|
||||
queryRowBytesNotInvalidatedByClose(t, rows, scan, []byte(`one binary value with some text!`))
|
||||
}
|
||||
|
||||
func TestQueryRowBytesNotInvalidatedByClose_stringIntoBytes(t *testing.T) {
|
||||
t.Parallel()
|
||||
rows := NewRows([]string{"raw"}).AddRow(`one binary value with some text!`)
|
||||
scan := func(rs *sql.Rows) ([]byte, error) {
|
||||
var b []byte
|
||||
return b, rs.Scan(&b)
|
||||
}
|
||||
queryRowBytesNotInvalidatedByClose(t, rows, scan, []byte(`one binary value with some text!`))
|
||||
}
|
||||
|
||||
func TestRowsScanError(t *testing.T) {
|
||||
t.Parallel()
|
||||
db, mock, err := New()
|
||||
@ -363,3 +496,177 @@ func TestEmptyRowSets(t *testing.T) {
|
||||
t.Fatalf("expected rowset 3, to be empty, but it was not")
|
||||
}
|
||||
}
|
||||
|
||||
func queryRowBytesInvalidatedByNext(t *testing.T, rows *Rows, scan func(*sql.Rows) ([]byte, error), want []struct {
|
||||
Initial []byte
|
||||
Replaced []byte
|
||||
}) {
|
||||
db, mock, err := New()
|
||||
if err != nil {
|
||||
t.Fatalf("an error '%s' was not expected when opening a stub database connection", err)
|
||||
}
|
||||
defer db.Close()
|
||||
mock.ExpectQuery("SELECT").WillReturnRows(rows)
|
||||
|
||||
rs, err := db.Query("SELECT")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to query rows: %s", err)
|
||||
}
|
||||
|
||||
if !rs.Next() || rs.Err() != nil {
|
||||
t.Fatal("unexpected error on first row retrieval")
|
||||
}
|
||||
var count int
|
||||
for i := 0; ; i++ {
|
||||
count++
|
||||
b, err := scan(rs)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error scanning row: %s", err)
|
||||
}
|
||||
if exp := want[i].Initial; !bytes.Equal(b, exp) {
|
||||
t.Fatalf("expected raw value to be '%s' (len:%d), but got [%T]:%s (len:%d)", exp, len(exp), b, b, len(b))
|
||||
}
|
||||
next := rs.Next()
|
||||
if exp := want[i].Replaced; !bytes.Equal(b, exp) {
|
||||
t.Fatalf("expected raw value to be replaced with '%s' (len:%d) after calling Next(), but got [%T]:%s (len:%d)", exp, len(exp), b, b, len(b))
|
||||
}
|
||||
if !next {
|
||||
break
|
||||
}
|
||||
}
|
||||
if err := rs.Err(); err != nil {
|
||||
t.Fatalf("row iteration failed: %s", err)
|
||||
}
|
||||
if exp := len(want); count != exp {
|
||||
t.Fatalf("incorrect number of rows exp: %d, but got %d", exp, count)
|
||||
}
|
||||
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func queryRowBytesNotInvalidatedByNext(t *testing.T, rows *Rows, scan func(*sql.Rows) ([]byte, error), want [][]byte) {
|
||||
db, mock, err := New()
|
||||
if err != nil {
|
||||
t.Fatalf("an error '%s' was not expected when opening a stub database connection", err)
|
||||
}
|
||||
defer db.Close()
|
||||
mock.ExpectQuery("SELECT").WillReturnRows(rows)
|
||||
|
||||
rs, err := db.Query("SELECT")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to query rows: %s", err)
|
||||
}
|
||||
|
||||
if !rs.Next() || rs.Err() != nil {
|
||||
t.Fatal("unexpected error on first row retrieval")
|
||||
}
|
||||
var count int
|
||||
for i := 0; ; i++ {
|
||||
count++
|
||||
b, err := scan(rs)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error scanning row: %s", err)
|
||||
}
|
||||
if exp := want[i]; !bytes.Equal(b, exp) {
|
||||
t.Fatalf("expected raw value to be '%s' (len:%d), but got [%T]:%s (len:%d)", exp, len(exp), b, b, len(b))
|
||||
}
|
||||
next := rs.Next()
|
||||
if exp := want[i]; !bytes.Equal(b, exp) {
|
||||
t.Fatalf("expected raw value to be replaced with '%s' (len:%d) after calling Next(), but got [%T]:%s (len:%d)", exp, len(exp), b, b, len(b))
|
||||
}
|
||||
if !next {
|
||||
break
|
||||
}
|
||||
}
|
||||
if err := rs.Err(); err != nil {
|
||||
t.Fatalf("row iteration failed: %s", err)
|
||||
}
|
||||
if exp := len(want); count != exp {
|
||||
t.Fatalf("incorrect number of rows exp: %d, but got %d", exp, count)
|
||||
}
|
||||
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func queryRowBytesInvalidatedByClose(t *testing.T, rows *Rows, scan func(*sql.Rows) ([]byte, error), want struct {
|
||||
Initial []byte
|
||||
Replaced []byte
|
||||
}) {
|
||||
db, mock, err := New()
|
||||
if err != nil {
|
||||
t.Fatalf("an error '%s' was not expected when opening a stub database connection", err)
|
||||
}
|
||||
defer db.Close()
|
||||
mock.ExpectQuery("SELECT").WillReturnRows(rows)
|
||||
|
||||
rs, err := db.Query("SELECT")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to query rows: %s", err)
|
||||
}
|
||||
|
||||
if !rs.Next() || rs.Err() != nil {
|
||||
t.Fatal("unexpected error on first row retrieval")
|
||||
}
|
||||
b, err := scan(rs)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error scanning row: %s", err)
|
||||
}
|
||||
if !bytes.Equal(b, want.Initial) {
|
||||
t.Fatalf("expected raw value to be '%s' (len:%d), but got [%T]:%s (len:%d)", want.Initial, len(want.Initial), b, b, len(b))
|
||||
}
|
||||
if err := rs.Close(); err != nil {
|
||||
t.Fatalf("unexpected error closing rows: %s", err)
|
||||
}
|
||||
if !bytes.Equal(b, want.Replaced) {
|
||||
t.Fatalf("expected raw value to be replaced with '%s' (len:%d) after calling Next(), but got [%T]:%s (len:%d)", want.Replaced, len(want.Replaced), b, b, len(b))
|
||||
}
|
||||
if err := rs.Err(); err != nil {
|
||||
t.Fatalf("row iteration failed: %s", err)
|
||||
}
|
||||
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func queryRowBytesNotInvalidatedByClose(t *testing.T, rows *Rows, scan func(*sql.Rows) ([]byte, error), want []byte) {
|
||||
db, mock, err := New()
|
||||
if err != nil {
|
||||
t.Fatalf("an error '%s' was not expected when opening a stub database connection", err)
|
||||
}
|
||||
defer db.Close()
|
||||
mock.ExpectQuery("SELECT").WillReturnRows(rows)
|
||||
|
||||
rs, err := db.Query("SELECT")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to query rows: %s", err)
|
||||
}
|
||||
|
||||
if !rs.Next() || rs.Err() != nil {
|
||||
t.Fatal("unexpected error on first row retrieval")
|
||||
}
|
||||
b, err := scan(rs)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error scanning row: %s", err)
|
||||
}
|
||||
if !bytes.Equal(b, want) {
|
||||
t.Fatalf("expected raw value to be '%s' (len:%d), but got [%T]:%s (len:%d)", want, len(want), b, b, len(b))
|
||||
}
|
||||
if err := rs.Close(); err != nil {
|
||||
t.Fatalf("unexpected error closing rows: %s", err)
|
||||
}
|
||||
if !bytes.Equal(b, want) {
|
||||
t.Fatalf("expected raw value to be replaced with '%s' (len:%d) after calling Next(), but got [%T]:%s (len:%d)", want, len(want), b, b, len(b))
|
||||
}
|
||||
if err := rs.Err(); err != nil {
|
||||
t.Fatalf("row iteration failed: %s", err)
|
||||
}
|
||||
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user