From d5879ee4b77f10ee15c5b986e8a99124d031250b Mon Sep 17 00:00:00 2001 From: David Ackroyd Date: Fri, 21 Jun 2019 13:28:10 +1000 Subject: [PATCH] Invalidate memory scanned into sql.RawBytes The intention of sql.RawBytes is for it to hold memory owned by the database. When used, it's content is only valid until the `Next`, `Scan` or `Close` is called on the `Rows` To ensure that we meet this behaviour, when `[]byte` is used in a column, it's value is copied to a buffer that we keep track of for later invalidation. By doing this, incorrect use of `sql.RawBytes` values is exposed in tests that use go-sqlmock. Without this, when a real database is used and it's driver does share memory, then those issues would not be exposed until runtime (and in non-obvious ways) --- rows.go | 35 ++++++ rows_go13_test.go | 31 +++++ rows_go18_test.go | 113 +++++++++++++++++ rows_test.go | 307 ++++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 486 insertions(+) create mode 100644 rows_go13_test.go diff --git a/rows.go b/rows.go index 4dcd65c..5f11c78 100644 --- a/rows.go +++ b/rows.go @@ -1,6 +1,7 @@ package sqlmock import ( + "bytes" "database/sql/driver" "encoding/csv" "fmt" @@ -8,6 +9,8 @@ import ( "strings" ) +const invalidate = "☠☠☠ MEMORY OVERWRITTEN ☠☠☠ " + // CSVColumnParser is a function which converts trimmed csv // column string to a []byte representation. Currently // transforms NULL to nil @@ -23,6 +26,7 @@ type rowSets struct { sets []*Rows pos int ex *ExpectedQuery + raw [][]byte } func (rs *rowSets) Columns() []string { @@ -30,6 +34,7 @@ func (rs *rowSets) Columns() []string { } func (rs *rowSets) Close() error { + rs.invalidateRaw() rs.ex.rowsWereClosed = true return rs.sets[rs.pos].closeErr } @@ -38,11 +43,17 @@ func (rs *rowSets) Close() error { func (rs *rowSets) Next(dest []driver.Value) error { r := rs.sets[rs.pos] r.pos++ + rs.invalidateRaw() if r.pos > len(r.rows) { return io.EOF // per interface spec } for i, col := range r.rows[r.pos-1] { + if b, ok := rawBytes(col); ok { + rs.raw = append(rs.raw, b) + dest[i] = b + continue + } dest[i] = col } @@ -80,6 +91,30 @@ func (rs *rowSets) empty() bool { return true } +func rawBytes(col driver.Value) (_ []byte, ok bool) { + val, ok := col.([]byte) + if !ok || len(val) == 0 { + return nil, false + } + // Copy the bytes from the mocked row into a shared raw buffer, which we'll replace the content of later + // This allows scanning into sql.RawBytes to correctly become invalid on subsequent calls to Next(), Scan() or Close() + b := make([]byte, len(val)) + copy(b, val) + return b, true +} + +// Bytes that could have been scanned as sql.RawBytes are only valid until the next call to Next, Scan or Close. +// If those occur, we must replace their content to simulate the shared memory to expose misuse of sql.RawBytes +func (rs *rowSets) invalidateRaw() { + // Replace the content of slices previously returned + b := []byte(invalidate) + for _, r := range rs.raw { + copy(r, bytes.Repeat(b, len(r)/len(b)+1)) + } + // Start with new slices for the next scan + rs.raw = nil +} + // Rows is a mocked collection of rows to // return for Query result type Rows struct { diff --git a/rows_go13_test.go b/rows_go13_test.go new file mode 100644 index 0000000..5c9038c --- /dev/null +++ b/rows_go13_test.go @@ -0,0 +1,31 @@ +// +build go1.3 + +package sqlmock + +import ( + "database/sql" + "testing" +) + +func TestQueryRowBytesNotInvalidatedByNext_stringIntoRawBytes(t *testing.T) { + t.Parallel() + rows := NewRows([]string{"raw"}). + AddRow(`one binary value with some text!`). + AddRow(`two binary value with even more text than the first one`) + scan := func(rs *sql.Rows) ([]byte, error) { + var raw sql.RawBytes + return raw, rs.Scan(&raw) + } + want := [][]byte{[]byte(`one binary value with some text!`), []byte(`two binary value with even more text than the first one`)} + queryRowBytesNotInvalidatedByNext(t, rows, scan, want) +} + +func TestQueryRowBytesNotInvalidatedByClose_stringIntoRawBytes(t *testing.T) { + t.Parallel() + rows := NewRows([]string{"raw"}).AddRow(`one binary value with some text!`) + scan := func(rs *sql.Rows) ([]byte, error) { + var raw sql.RawBytes + return raw, rs.Scan(&raw) + } + queryRowBytesNotInvalidatedByClose(t, rows, scan, []byte(`one binary value with some text!`)) +} diff --git a/rows_go18_test.go b/rows_go18_test.go index c776def..b29a2c5 100644 --- a/rows_go18_test.go +++ b/rows_go18_test.go @@ -3,6 +3,8 @@ package sqlmock import ( + "database/sql" + "encoding/json" "fmt" "testing" ) @@ -90,3 +92,114 @@ func TestQueryMultiRows(t *testing.T) { t.Errorf("there were unfulfilled expectations: %s", err) } } + +func TestQueryRowBytesInvalidatedByNext_jsonRawMessageIntoRawBytes(t *testing.T) { + t.Parallel() + replace := []byte(invalid) + rows := NewRows([]string{"raw"}). + AddRow(json.RawMessage(`{"thing": "one", "thing2": "two"}`)). + AddRow(json.RawMessage(`{"that": "foo", "this": "bar"}`)) + scan := func(rs *sql.Rows) ([]byte, error) { + var raw sql.RawBytes + return raw, rs.Scan(&raw) + } + want := []struct { + Initial []byte + Replaced []byte + }{ + {Initial: []byte(`{"thing": "one", "thing2": "two"}`), Replaced: replace[:len(replace)-6]}, + {Initial: []byte(`{"that": "foo", "this": "bar"}`), Replaced: replace[:len(replace)-9]}, + } + queryRowBytesInvalidatedByNext(t, rows, scan, want) +} + +func TestQueryRowBytesNotInvalidatedByNext_jsonRawMessageIntoBytes(t *testing.T) { + t.Parallel() + rows := NewRows([]string{"raw"}). + AddRow(json.RawMessage(`{"thing": "one", "thing2": "two"}`)). + AddRow(json.RawMessage(`{"that": "foo", "this": "bar"}`)) + scan := func(rs *sql.Rows) ([]byte, error) { + var b []byte + return b, rs.Scan(&b) + } + want := [][]byte{[]byte(`{"thing": "one", "thing2": "two"}`), []byte(`{"that": "foo", "this": "bar"}`)} + queryRowBytesNotInvalidatedByNext(t, rows, scan, want) +} + +func TestQueryRowBytesNotInvalidatedByNext_bytesIntoCustomBytes(t *testing.T) { + t.Parallel() + rows := NewRows([]string{"raw"}). + AddRow([]byte(`one binary value with some text!`)). + AddRow([]byte(`two binary value with even more text than the first one`)) + scan := func(rs *sql.Rows) ([]byte, error) { + type customBytes []byte + var b customBytes + return b, rs.Scan(&b) + } + want := [][]byte{[]byte(`one binary value with some text!`), []byte(`two binary value with even more text than the first one`)} + queryRowBytesNotInvalidatedByNext(t, rows, scan, want) +} + +func TestQueryRowBytesNotInvalidatedByNext_jsonRawMessageIntoCustomBytes(t *testing.T) { + t.Parallel() + rows := NewRows([]string{"raw"}). + AddRow(json.RawMessage(`{"thing": "one", "thing2": "two"}`)). + AddRow(json.RawMessage(`{"that": "foo", "this": "bar"}`)) + scan := func(rs *sql.Rows) ([]byte, error) { + type customBytes []byte + var b customBytes + return b, rs.Scan(&b) + } + want := [][]byte{[]byte(`{"thing": "one", "thing2": "two"}`), []byte(`{"that": "foo", "this": "bar"}`)} + queryRowBytesNotInvalidatedByNext(t, rows, scan, want) +} + +func TestQueryRowBytesNotInvalidatedByClose_bytesIntoCustomBytes(t *testing.T) { + t.Parallel() + rows := NewRows([]string{"raw"}).AddRow([]byte(`one binary value with some text!`)) + scan := func(rs *sql.Rows) ([]byte, error) { + type customBytes []byte + var b customBytes + return b, rs.Scan(&b) + } + queryRowBytesNotInvalidatedByClose(t, rows, scan, []byte(`one binary value with some text!`)) +} + +func TestQueryRowBytesInvalidatedByClose_jsonRawMessageIntoRawBytes(t *testing.T) { + t.Parallel() + replace := []byte(invalid) + rows := NewRows([]string{"raw"}).AddRow(json.RawMessage(`{"thing": "one", "thing2": "two"}`)) + scan := func(rs *sql.Rows) ([]byte, error) { + var raw sql.RawBytes + return raw, rs.Scan(&raw) + } + want := struct { + Initial []byte + Replaced []byte + }{ + Initial: []byte(`{"thing": "one", "thing2": "two"}`), + Replaced: replace[:len(replace)-6], + } + queryRowBytesInvalidatedByClose(t, rows, scan, want) +} + +func TestQueryRowBytesNotInvalidatedByClose_jsonRawMessageIntoBytes(t *testing.T) { + t.Parallel() + rows := NewRows([]string{"raw"}).AddRow(json.RawMessage(`{"thing": "one", "thing2": "two"}`)) + scan := func(rs *sql.Rows) ([]byte, error) { + var b []byte + return b, rs.Scan(&b) + } + queryRowBytesNotInvalidatedByClose(t, rows, scan, []byte(`{"thing": "one", "thing2": "two"}`)) +} + +func TestQueryRowBytesNotInvalidatedByClose_jsonRawMessageIntoCustomBytes(t *testing.T) { + t.Parallel() + rows := NewRows([]string{"raw"}).AddRow(json.RawMessage(`{"thing": "one", "thing2": "two"}`)) + scan := func(rs *sql.Rows) ([]byte, error) { + type customBytes []byte + var b customBytes + return b, rs.Scan(&b) + } + queryRowBytesNotInvalidatedByClose(t, rows, scan, []byte(`{"thing": "one", "thing2": "two"}`)) +} diff --git a/rows_test.go b/rows_test.go index ff80147..15cdbee 100644 --- a/rows_test.go +++ b/rows_test.go @@ -1,11 +1,14 @@ package sqlmock import ( + "bytes" "database/sql" "fmt" "testing" ) +const invalid = `☠☠☠ MEMORY OVERWRITTEN ☠☠☠ ` + func ExampleRows() { db, mock, err := New() if err != nil { @@ -88,6 +91,52 @@ func ExampleRows_closeError() { // Output: got error: close error } +func ExampleRows_rawBytes() { + db, mock, err := New() + if err != nil { + fmt.Println("failed to open sqlmock database:", err) + } + defer db.Close() + + rows := NewRows([]string{"id", "binary"}). + AddRow(1, []byte(`one binary value with some text!`)). + AddRow(2, []byte(`two binary value with even more text than the first one`)) + + mock.ExpectQuery("SELECT").WillReturnRows(rows) + + rs, _ := db.Query("SELECT") + defer rs.Close() + + type scanned struct { + id int + raw sql.RawBytes + } + fmt.Println("initial read...") + var ss []scanned + for rs.Next() { + var s scanned + rs.Scan(&s.id, &s.raw) + ss = append(ss, s) + fmt.Println("scanned id:", s.id, "and raw:", string(s.raw)) + } + + if rs.Err() != nil { + fmt.Println("got rows error:", rs.Err()) + } + + fmt.Println("after reading all...") + for _, s := range ss { + fmt.Println("scanned id:", s.id, "and raw:", string(s.raw)) + } + // Output: + // initial read... + // scanned id: 1 and raw: one binary value with some text! + // scanned id: 2 and raw: two binary value with even more text than the first one + // after reading all... + // scanned id: 1 and raw: ☠☠☠ MEMORY OVERWRITTEN ☠ + // scanned id: 2 and raw: ☠☠☠ MEMORY OVERWRITTEN ☠☠☠ ☠☠☠ MEMORY +} + func ExampleRows_expectToBeClosed() { db, mock, err := New() if err != nil { @@ -260,6 +309,90 @@ func TestQuerySingleRow(t *testing.T) { } } +func TestQueryRowBytesInvalidatedByNext_bytesIntoRawBytes(t *testing.T) { + t.Parallel() + replace := []byte(invalid) + rows := NewRows([]string{"raw"}). + AddRow([]byte(`one binary value with some text!`)). + AddRow([]byte(`two binary value with even more text than the first one`)) + scan := func(rs *sql.Rows) ([]byte, error) { + var raw sql.RawBytes + return raw, rs.Scan(&raw) + } + want := []struct { + Initial []byte + Replaced []byte + }{ + {Initial: []byte(`one binary value with some text!`), Replaced: replace[:len(replace)-7]}, + {Initial: []byte(`two binary value with even more text than the first one`), Replaced: bytes.Join([][]byte{replace, replace[:len(replace)-23]}, nil)}, + } + queryRowBytesInvalidatedByNext(t, rows, scan, want) +} + +func TestQueryRowBytesNotInvalidatedByNext_bytesIntoBytes(t *testing.T) { + t.Parallel() + rows := NewRows([]string{"raw"}). + AddRow([]byte(`one binary value with some text!`)). + AddRow([]byte(`two binary value with even more text than the first one`)) + scan := func(rs *sql.Rows) ([]byte, error) { + var b []byte + return b, rs.Scan(&b) + } + want := [][]byte{[]byte(`one binary value with some text!`), []byte(`two binary value with even more text than the first one`)} + queryRowBytesNotInvalidatedByNext(t, rows, scan, want) +} + +func TestQueryRowBytesNotInvalidatedByNext_stringIntoBytes(t *testing.T) { + t.Parallel() + rows := NewRows([]string{"raw"}). + AddRow(`one binary value with some text!`). + AddRow(`two binary value with even more text than the first one`) + scan := func(rs *sql.Rows) ([]byte, error) { + var b []byte + return b, rs.Scan(&b) + } + want := [][]byte{[]byte(`one binary value with some text!`), []byte(`two binary value with even more text than the first one`)} + queryRowBytesNotInvalidatedByNext(t, rows, scan, want) +} + +func TestQueryRowBytesInvalidatedByClose_bytesIntoRawBytes(t *testing.T) { + t.Parallel() + replace := []byte(invalid) + rows := NewRows([]string{"raw"}).AddRow([]byte(`one binary value with some text!`)) + scan := func(rs *sql.Rows) ([]byte, error) { + var raw sql.RawBytes + return raw, rs.Scan(&raw) + } + want := struct { + Initial []byte + Replaced []byte + }{ + Initial: []byte(`one binary value with some text!`), + Replaced: replace[:len(replace)-7], + } + queryRowBytesInvalidatedByClose(t, rows, scan, want) +} + +func TestQueryRowBytesNotInvalidatedByClose_bytesIntoBytes(t *testing.T) { + t.Parallel() + rows := NewRows([]string{"raw"}).AddRow([]byte(`one binary value with some text!`)) + scan := func(rs *sql.Rows) ([]byte, error) { + var b []byte + return b, rs.Scan(&b) + } + queryRowBytesNotInvalidatedByClose(t, rows, scan, []byte(`one binary value with some text!`)) +} + +func TestQueryRowBytesNotInvalidatedByClose_stringIntoBytes(t *testing.T) { + t.Parallel() + rows := NewRows([]string{"raw"}).AddRow(`one binary value with some text!`) + scan := func(rs *sql.Rows) ([]byte, error) { + var b []byte + return b, rs.Scan(&b) + } + queryRowBytesNotInvalidatedByClose(t, rows, scan, []byte(`one binary value with some text!`)) +} + func TestRowsScanError(t *testing.T) { t.Parallel() db, mock, err := New() @@ -363,3 +496,177 @@ func TestEmptyRowSets(t *testing.T) { t.Fatalf("expected rowset 3, to be empty, but it was not") } } + +func queryRowBytesInvalidatedByNext(t *testing.T, rows *Rows, scan func(*sql.Rows) ([]byte, error), want []struct { + Initial []byte + Replaced []byte +}) { + db, mock, err := New() + if err != nil { + t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + mock.ExpectQuery("SELECT").WillReturnRows(rows) + + rs, err := db.Query("SELECT") + if err != nil { + t.Fatalf("failed to query rows: %s", err) + } + + if !rs.Next() || rs.Err() != nil { + t.Fatal("unexpected error on first row retrieval") + } + var count int + for i := 0; ; i++ { + count++ + b, err := scan(rs) + if err != nil { + t.Fatalf("unexpected error scanning row: %s", err) + } + if exp := want[i].Initial; !bytes.Equal(b, exp) { + t.Fatalf("expected raw value to be '%s' (len:%d), but got [%T]:%s (len:%d)", exp, len(exp), b, b, len(b)) + } + next := rs.Next() + if exp := want[i].Replaced; !bytes.Equal(b, exp) { + t.Fatalf("expected raw value to be replaced with '%s' (len:%d) after calling Next(), but got [%T]:%s (len:%d)", exp, len(exp), b, b, len(b)) + } + if !next { + break + } + } + if err := rs.Err(); err != nil { + t.Fatalf("row iteration failed: %s", err) + } + if exp := len(want); count != exp { + t.Fatalf("incorrect number of rows exp: %d, but got %d", exp, count) + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Fatal(err) + } +} + +func queryRowBytesNotInvalidatedByNext(t *testing.T, rows *Rows, scan func(*sql.Rows) ([]byte, error), want [][]byte) { + db, mock, err := New() + if err != nil { + t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + mock.ExpectQuery("SELECT").WillReturnRows(rows) + + rs, err := db.Query("SELECT") + if err != nil { + t.Fatalf("failed to query rows: %s", err) + } + + if !rs.Next() || rs.Err() != nil { + t.Fatal("unexpected error on first row retrieval") + } + var count int + for i := 0; ; i++ { + count++ + b, err := scan(rs) + if err != nil { + t.Fatalf("unexpected error scanning row: %s", err) + } + if exp := want[i]; !bytes.Equal(b, exp) { + t.Fatalf("expected raw value to be '%s' (len:%d), but got [%T]:%s (len:%d)", exp, len(exp), b, b, len(b)) + } + next := rs.Next() + if exp := want[i]; !bytes.Equal(b, exp) { + t.Fatalf("expected raw value to be replaced with '%s' (len:%d) after calling Next(), but got [%T]:%s (len:%d)", exp, len(exp), b, b, len(b)) + } + if !next { + break + } + } + if err := rs.Err(); err != nil { + t.Fatalf("row iteration failed: %s", err) + } + if exp := len(want); count != exp { + t.Fatalf("incorrect number of rows exp: %d, but got %d", exp, count) + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Fatal(err) + } +} + +func queryRowBytesInvalidatedByClose(t *testing.T, rows *Rows, scan func(*sql.Rows) ([]byte, error), want struct { + Initial []byte + Replaced []byte +}) { + db, mock, err := New() + if err != nil { + t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + mock.ExpectQuery("SELECT").WillReturnRows(rows) + + rs, err := db.Query("SELECT") + if err != nil { + t.Fatalf("failed to query rows: %s", err) + } + + if !rs.Next() || rs.Err() != nil { + t.Fatal("unexpected error on first row retrieval") + } + b, err := scan(rs) + if err != nil { + t.Fatalf("unexpected error scanning row: %s", err) + } + if !bytes.Equal(b, want.Initial) { + t.Fatalf("expected raw value to be '%s' (len:%d), but got [%T]:%s (len:%d)", want.Initial, len(want.Initial), b, b, len(b)) + } + if err := rs.Close(); err != nil { + t.Fatalf("unexpected error closing rows: %s", err) + } + if !bytes.Equal(b, want.Replaced) { + t.Fatalf("expected raw value to be replaced with '%s' (len:%d) after calling Next(), but got [%T]:%s (len:%d)", want.Replaced, len(want.Replaced), b, b, len(b)) + } + if err := rs.Err(); err != nil { + t.Fatalf("row iteration failed: %s", err) + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Fatal(err) + } +} + +func queryRowBytesNotInvalidatedByClose(t *testing.T, rows *Rows, scan func(*sql.Rows) ([]byte, error), want []byte) { + db, mock, err := New() + if err != nil { + t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + mock.ExpectQuery("SELECT").WillReturnRows(rows) + + rs, err := db.Query("SELECT") + if err != nil { + t.Fatalf("failed to query rows: %s", err) + } + + if !rs.Next() || rs.Err() != nil { + t.Fatal("unexpected error on first row retrieval") + } + b, err := scan(rs) + if err != nil { + t.Fatalf("unexpected error scanning row: %s", err) + } + if !bytes.Equal(b, want) { + t.Fatalf("expected raw value to be '%s' (len:%d), but got [%T]:%s (len:%d)", want, len(want), b, b, len(b)) + } + if err := rs.Close(); err != nil { + t.Fatalf("unexpected error closing rows: %s", err) + } + if !bytes.Equal(b, want) { + t.Fatalf("expected raw value to be replaced with '%s' (len:%d) after calling Next(), but got [%T]:%s (len:%d)", want, len(want), b, b, len(b)) + } + if err := rs.Err(); err != nil { + t.Fatalf("row iteration failed: %s", err) + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Fatal(err) + } +}