You've already forked go-sqlxmock
							
							
				mirror of
				https://github.com/zhashkevych/go-sqlxmock.git
				synced 2025-10-30 23:27:38 +02:00 
			
		
		
		
	Invalidate memory scanned into sql.RawBytes
The intention of sql.RawBytes is for it to hold memory owned by the database. When used, it's content is only valid until the `Next`, `Scan` or `Close` is called on the `Rows` To ensure that we meet this behaviour, when `[]byte` is used in a column, it's value is copied to a buffer that we keep track of for later invalidation. By doing this, incorrect use of `sql.RawBytes` values is exposed in tests that use go-sqlmock. Without this, when a real database is used and it's driver does share memory, then those issues would not be exposed until runtime (and in non-obvious ways)
This commit is contained in:
		
							
								
								
									
										35
									
								
								rows.go
									
									
									
									
									
								
							
							
						
						
									
										35
									
								
								rows.go
									
									
									
									
									
								
							| @@ -1,6 +1,7 @@ | ||||
| package sqlmock | ||||
|  | ||||
| import ( | ||||
| 	"bytes" | ||||
| 	"database/sql/driver" | ||||
| 	"encoding/csv" | ||||
| 	"fmt" | ||||
| @@ -8,6 +9,8 @@ import ( | ||||
| 	"strings" | ||||
| ) | ||||
|  | ||||
| const invalidate = "☠☠☠ MEMORY OVERWRITTEN ☠☠☠ " | ||||
|  | ||||
| // CSVColumnParser is a function which converts trimmed csv | ||||
| // column string to a []byte representation. Currently | ||||
| // transforms NULL to nil | ||||
| @@ -23,6 +26,7 @@ type rowSets struct { | ||||
| 	sets []*Rows | ||||
| 	pos  int | ||||
| 	ex   *ExpectedQuery | ||||
| 	raw  [][]byte | ||||
| } | ||||
|  | ||||
| func (rs *rowSets) Columns() []string { | ||||
| @@ -30,6 +34,7 @@ func (rs *rowSets) Columns() []string { | ||||
| } | ||||
|  | ||||
| func (rs *rowSets) Close() error { | ||||
| 	rs.invalidateRaw() | ||||
| 	rs.ex.rowsWereClosed = true | ||||
| 	return rs.sets[rs.pos].closeErr | ||||
| } | ||||
| @@ -38,11 +43,17 @@ func (rs *rowSets) Close() error { | ||||
| func (rs *rowSets) Next(dest []driver.Value) error { | ||||
| 	r := rs.sets[rs.pos] | ||||
| 	r.pos++ | ||||
| 	rs.invalidateRaw() | ||||
| 	if r.pos > len(r.rows) { | ||||
| 		return io.EOF // per interface spec | ||||
| 	} | ||||
|  | ||||
| 	for i, col := range r.rows[r.pos-1] { | ||||
| 		if b, ok := rawBytes(col); ok { | ||||
| 			rs.raw = append(rs.raw, b) | ||||
| 			dest[i] = b | ||||
| 			continue | ||||
| 		} | ||||
| 		dest[i] = col | ||||
| 	} | ||||
|  | ||||
| @@ -80,6 +91,30 @@ func (rs *rowSets) empty() bool { | ||||
| 	return true | ||||
| } | ||||
|  | ||||
| func rawBytes(col driver.Value) (_ []byte, ok bool) { | ||||
| 	val, ok := col.([]byte) | ||||
| 	if !ok || len(val) == 0 { | ||||
| 		return nil, false | ||||
| 	} | ||||
| 	// Copy the bytes from the mocked row into a shared raw buffer, which we'll replace the content of later | ||||
| 	// This allows scanning into sql.RawBytes to correctly become invalid on subsequent calls to Next(), Scan() or Close() | ||||
| 	b := make([]byte, len(val)) | ||||
| 	copy(b, val) | ||||
| 	return b, true | ||||
| } | ||||
|  | ||||
| // Bytes that could have been scanned as sql.RawBytes are only valid until the next call to Next, Scan or Close. | ||||
| // If those occur, we must replace their content to simulate the shared memory to expose misuse of sql.RawBytes | ||||
| func (rs *rowSets) invalidateRaw() { | ||||
| 	// Replace the content of slices previously returned | ||||
| 	b := []byte(invalidate) | ||||
| 	for _, r := range rs.raw { | ||||
| 		copy(r, bytes.Repeat(b, len(r)/len(b)+1)) | ||||
| 	} | ||||
| 	// Start with new slices for the next scan | ||||
| 	rs.raw = nil | ||||
| } | ||||
|  | ||||
| // Rows is a mocked collection of rows to | ||||
| // return for Query result | ||||
| type Rows struct { | ||||
|   | ||||
							
								
								
									
										31
									
								
								rows_go13_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										31
									
								
								rows_go13_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,31 @@ | ||||
| // +build go1.3 | ||||
|  | ||||
| package sqlmock | ||||
|  | ||||
| import ( | ||||
| 	"database/sql" | ||||
| 	"testing" | ||||
| ) | ||||
|  | ||||
| func TestQueryRowBytesNotInvalidatedByNext_stringIntoRawBytes(t *testing.T) { | ||||
| 	t.Parallel() | ||||
| 	rows := NewRows([]string{"raw"}). | ||||
| 		AddRow(`one binary value with some text!`). | ||||
| 		AddRow(`two binary value with even more text than the first one`) | ||||
| 	scan := func(rs *sql.Rows) ([]byte, error) { | ||||
| 		var raw sql.RawBytes | ||||
| 		return raw, rs.Scan(&raw) | ||||
| 	} | ||||
| 	want := [][]byte{[]byte(`one binary value with some text!`), []byte(`two binary value with even more text than the first one`)} | ||||
| 	queryRowBytesNotInvalidatedByNext(t, rows, scan, want) | ||||
| } | ||||
|  | ||||
| func TestQueryRowBytesNotInvalidatedByClose_stringIntoRawBytes(t *testing.T) { | ||||
| 	t.Parallel() | ||||
| 	rows := NewRows([]string{"raw"}).AddRow(`one binary value with some text!`) | ||||
| 	scan := func(rs *sql.Rows) ([]byte, error) { | ||||
| 		var raw sql.RawBytes | ||||
| 		return raw, rs.Scan(&raw) | ||||
| 	} | ||||
| 	queryRowBytesNotInvalidatedByClose(t, rows, scan, []byte(`one binary value with some text!`)) | ||||
| } | ||||
| @@ -3,6 +3,8 @@ | ||||
| package sqlmock | ||||
|  | ||||
| import ( | ||||
| 	"database/sql" | ||||
| 	"encoding/json" | ||||
| 	"fmt" | ||||
| 	"testing" | ||||
| ) | ||||
| @@ -90,3 +92,114 @@ func TestQueryMultiRows(t *testing.T) { | ||||
| 		t.Errorf("there were unfulfilled expectations: %s", err) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestQueryRowBytesInvalidatedByNext_jsonRawMessageIntoRawBytes(t *testing.T) { | ||||
| 	t.Parallel() | ||||
| 	replace := []byte(invalid) | ||||
| 	rows := NewRows([]string{"raw"}). | ||||
| 		AddRow(json.RawMessage(`{"thing": "one", "thing2": "two"}`)). | ||||
| 		AddRow(json.RawMessage(`{"that": "foo", "this": "bar"}`)) | ||||
| 	scan := func(rs *sql.Rows) ([]byte, error) { | ||||
| 		var raw sql.RawBytes | ||||
| 		return raw, rs.Scan(&raw) | ||||
| 	} | ||||
| 	want := []struct { | ||||
| 		Initial  []byte | ||||
| 		Replaced []byte | ||||
| 	}{ | ||||
| 		{Initial: []byte(`{"thing": "one", "thing2": "two"}`), Replaced: replace[:len(replace)-6]}, | ||||
| 		{Initial: []byte(`{"that": "foo", "this": "bar"}`), Replaced: replace[:len(replace)-9]}, | ||||
| 	} | ||||
| 	queryRowBytesInvalidatedByNext(t, rows, scan, want) | ||||
| } | ||||
|  | ||||
| func TestQueryRowBytesNotInvalidatedByNext_jsonRawMessageIntoBytes(t *testing.T) { | ||||
| 	t.Parallel() | ||||
| 	rows := NewRows([]string{"raw"}). | ||||
| 		AddRow(json.RawMessage(`{"thing": "one", "thing2": "two"}`)). | ||||
| 		AddRow(json.RawMessage(`{"that": "foo", "this": "bar"}`)) | ||||
| 	scan := func(rs *sql.Rows) ([]byte, error) { | ||||
| 		var b []byte | ||||
| 		return b, rs.Scan(&b) | ||||
| 	} | ||||
| 	want := [][]byte{[]byte(`{"thing": "one", "thing2": "two"}`), []byte(`{"that": "foo", "this": "bar"}`)} | ||||
| 	queryRowBytesNotInvalidatedByNext(t, rows, scan, want) | ||||
| } | ||||
|  | ||||
| func TestQueryRowBytesNotInvalidatedByNext_bytesIntoCustomBytes(t *testing.T) { | ||||
| 	t.Parallel() | ||||
| 	rows := NewRows([]string{"raw"}). | ||||
| 		AddRow([]byte(`one binary value with some text!`)). | ||||
| 		AddRow([]byte(`two binary value with even more text than the first one`)) | ||||
| 	scan := func(rs *sql.Rows) ([]byte, error) { | ||||
| 		type customBytes []byte | ||||
| 		var b customBytes | ||||
| 		return b, rs.Scan(&b) | ||||
| 	} | ||||
| 	want := [][]byte{[]byte(`one binary value with some text!`), []byte(`two binary value with even more text than the first one`)} | ||||
| 	queryRowBytesNotInvalidatedByNext(t, rows, scan, want) | ||||
| } | ||||
|  | ||||
| func TestQueryRowBytesNotInvalidatedByNext_jsonRawMessageIntoCustomBytes(t *testing.T) { | ||||
| 	t.Parallel() | ||||
| 	rows := NewRows([]string{"raw"}). | ||||
| 		AddRow(json.RawMessage(`{"thing": "one", "thing2": "two"}`)). | ||||
| 		AddRow(json.RawMessage(`{"that": "foo", "this": "bar"}`)) | ||||
| 	scan := func(rs *sql.Rows) ([]byte, error) { | ||||
| 		type customBytes []byte | ||||
| 		var b customBytes | ||||
| 		return b, rs.Scan(&b) | ||||
| 	} | ||||
| 	want := [][]byte{[]byte(`{"thing": "one", "thing2": "two"}`), []byte(`{"that": "foo", "this": "bar"}`)} | ||||
| 	queryRowBytesNotInvalidatedByNext(t, rows, scan, want) | ||||
| } | ||||
|  | ||||
| func TestQueryRowBytesNotInvalidatedByClose_bytesIntoCustomBytes(t *testing.T) { | ||||
| 	t.Parallel() | ||||
| 	rows := NewRows([]string{"raw"}).AddRow([]byte(`one binary value with some text!`)) | ||||
| 	scan := func(rs *sql.Rows) ([]byte, error) { | ||||
| 		type customBytes []byte | ||||
| 		var b customBytes | ||||
| 		return b, rs.Scan(&b) | ||||
| 	} | ||||
| 	queryRowBytesNotInvalidatedByClose(t, rows, scan, []byte(`one binary value with some text!`)) | ||||
| } | ||||
|  | ||||
| func TestQueryRowBytesInvalidatedByClose_jsonRawMessageIntoRawBytes(t *testing.T) { | ||||
| 	t.Parallel() | ||||
| 	replace := []byte(invalid) | ||||
| 	rows := NewRows([]string{"raw"}).AddRow(json.RawMessage(`{"thing": "one", "thing2": "two"}`)) | ||||
| 	scan := func(rs *sql.Rows) ([]byte, error) { | ||||
| 		var raw sql.RawBytes | ||||
| 		return raw, rs.Scan(&raw) | ||||
| 	} | ||||
| 	want := struct { | ||||
| 		Initial  []byte | ||||
| 		Replaced []byte | ||||
| 	}{ | ||||
| 		Initial:  []byte(`{"thing": "one", "thing2": "two"}`), | ||||
| 		Replaced: replace[:len(replace)-6], | ||||
| 	} | ||||
| 	queryRowBytesInvalidatedByClose(t, rows, scan, want) | ||||
| } | ||||
|  | ||||
| func TestQueryRowBytesNotInvalidatedByClose_jsonRawMessageIntoBytes(t *testing.T) { | ||||
| 	t.Parallel() | ||||
| 	rows := NewRows([]string{"raw"}).AddRow(json.RawMessage(`{"thing": "one", "thing2": "two"}`)) | ||||
| 	scan := func(rs *sql.Rows) ([]byte, error) { | ||||
| 		var b []byte | ||||
| 		return b, rs.Scan(&b) | ||||
| 	} | ||||
| 	queryRowBytesNotInvalidatedByClose(t, rows, scan, []byte(`{"thing": "one", "thing2": "two"}`)) | ||||
| } | ||||
|  | ||||
| func TestQueryRowBytesNotInvalidatedByClose_jsonRawMessageIntoCustomBytes(t *testing.T) { | ||||
| 	t.Parallel() | ||||
| 	rows := NewRows([]string{"raw"}).AddRow(json.RawMessage(`{"thing": "one", "thing2": "two"}`)) | ||||
| 	scan := func(rs *sql.Rows) ([]byte, error) { | ||||
| 		type customBytes []byte | ||||
| 		var b customBytes | ||||
| 		return b, rs.Scan(&b) | ||||
| 	} | ||||
| 	queryRowBytesNotInvalidatedByClose(t, rows, scan, []byte(`{"thing": "one", "thing2": "two"}`)) | ||||
| } | ||||
|   | ||||
							
								
								
									
										307
									
								
								rows_test.go
									
									
									
									
									
								
							
							
						
						
									
										307
									
								
								rows_test.go
									
									
									
									
									
								
							| @@ -1,11 +1,14 @@ | ||||
| package sqlmock | ||||
|  | ||||
| import ( | ||||
| 	"bytes" | ||||
| 	"database/sql" | ||||
| 	"fmt" | ||||
| 	"testing" | ||||
| ) | ||||
|  | ||||
| const invalid = `☠☠☠ MEMORY OVERWRITTEN ☠☠☠ ` | ||||
|  | ||||
| func ExampleRows() { | ||||
| 	db, mock, err := New() | ||||
| 	if err != nil { | ||||
| @@ -88,6 +91,52 @@ func ExampleRows_closeError() { | ||||
| 	// Output: got error: close error | ||||
| } | ||||
|  | ||||
| func ExampleRows_rawBytes() { | ||||
| 	db, mock, err := New() | ||||
| 	if err != nil { | ||||
| 		fmt.Println("failed to open sqlmock database:", err) | ||||
| 	} | ||||
| 	defer db.Close() | ||||
|  | ||||
| 	rows := NewRows([]string{"id", "binary"}). | ||||
| 		AddRow(1, []byte(`one binary value with some text!`)). | ||||
| 		AddRow(2, []byte(`two binary value with even more text than the first one`)) | ||||
|  | ||||
| 	mock.ExpectQuery("SELECT").WillReturnRows(rows) | ||||
|  | ||||
| 	rs, _ := db.Query("SELECT") | ||||
| 	defer rs.Close() | ||||
|  | ||||
| 	type scanned struct { | ||||
| 		id  int | ||||
| 		raw sql.RawBytes | ||||
| 	} | ||||
| 	fmt.Println("initial read...") | ||||
| 	var ss []scanned | ||||
| 	for rs.Next() { | ||||
| 		var s scanned | ||||
| 		rs.Scan(&s.id, &s.raw) | ||||
| 		ss = append(ss, s) | ||||
| 		fmt.Println("scanned id:", s.id, "and raw:", string(s.raw)) | ||||
| 	} | ||||
|  | ||||
| 	if rs.Err() != nil { | ||||
| 		fmt.Println("got rows error:", rs.Err()) | ||||
| 	} | ||||
|  | ||||
| 	fmt.Println("after reading all...") | ||||
| 	for _, s := range ss { | ||||
| 		fmt.Println("scanned id:", s.id, "and raw:", string(s.raw)) | ||||
| 	} | ||||
| 	// Output: | ||||
| 	// initial read... | ||||
| 	// scanned id: 1 and raw: one binary value with some text! | ||||
| 	// scanned id: 2 and raw: two binary value with even more text than the first one | ||||
| 	// after reading all... | ||||
| 	// scanned id: 1 and raw: ☠☠☠ MEMORY OVERWRITTEN ☠ | ||||
| 	// scanned id: 2 and raw: ☠☠☠ MEMORY OVERWRITTEN ☠☠☠ ☠☠☠ MEMORY | ||||
| } | ||||
|  | ||||
| func ExampleRows_expectToBeClosed() { | ||||
| 	db, mock, err := New() | ||||
| 	if err != nil { | ||||
| @@ -260,6 +309,90 @@ func TestQuerySingleRow(t *testing.T) { | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestQueryRowBytesInvalidatedByNext_bytesIntoRawBytes(t *testing.T) { | ||||
| 	t.Parallel() | ||||
| 	replace := []byte(invalid) | ||||
| 	rows := NewRows([]string{"raw"}). | ||||
| 		AddRow([]byte(`one binary value with some text!`)). | ||||
| 		AddRow([]byte(`two binary value with even more text than the first one`)) | ||||
| 	scan := func(rs *sql.Rows) ([]byte, error) { | ||||
| 		var raw sql.RawBytes | ||||
| 		return raw, rs.Scan(&raw) | ||||
| 	} | ||||
| 	want := []struct { | ||||
| 		Initial  []byte | ||||
| 		Replaced []byte | ||||
| 	}{ | ||||
| 		{Initial: []byte(`one binary value with some text!`), Replaced: replace[:len(replace)-7]}, | ||||
| 		{Initial: []byte(`two binary value with even more text than the first one`), Replaced: bytes.Join([][]byte{replace, replace[:len(replace)-23]}, nil)}, | ||||
| 	} | ||||
| 	queryRowBytesInvalidatedByNext(t, rows, scan, want) | ||||
| } | ||||
|  | ||||
| func TestQueryRowBytesNotInvalidatedByNext_bytesIntoBytes(t *testing.T) { | ||||
| 	t.Parallel() | ||||
| 	rows := NewRows([]string{"raw"}). | ||||
| 		AddRow([]byte(`one binary value with some text!`)). | ||||
| 		AddRow([]byte(`two binary value with even more text than the first one`)) | ||||
| 	scan := func(rs *sql.Rows) ([]byte, error) { | ||||
| 		var b []byte | ||||
| 		return b, rs.Scan(&b) | ||||
| 	} | ||||
| 	want := [][]byte{[]byte(`one binary value with some text!`), []byte(`two binary value with even more text than the first one`)} | ||||
| 	queryRowBytesNotInvalidatedByNext(t, rows, scan, want) | ||||
| } | ||||
|  | ||||
| func TestQueryRowBytesNotInvalidatedByNext_stringIntoBytes(t *testing.T) { | ||||
| 	t.Parallel() | ||||
| 	rows := NewRows([]string{"raw"}). | ||||
| 		AddRow(`one binary value with some text!`). | ||||
| 		AddRow(`two binary value with even more text than the first one`) | ||||
| 	scan := func(rs *sql.Rows) ([]byte, error) { | ||||
| 		var b []byte | ||||
| 		return b, rs.Scan(&b) | ||||
| 	} | ||||
| 	want := [][]byte{[]byte(`one binary value with some text!`), []byte(`two binary value with even more text than the first one`)} | ||||
| 	queryRowBytesNotInvalidatedByNext(t, rows, scan, want) | ||||
| } | ||||
|  | ||||
| func TestQueryRowBytesInvalidatedByClose_bytesIntoRawBytes(t *testing.T) { | ||||
| 	t.Parallel() | ||||
| 	replace := []byte(invalid) | ||||
| 	rows := NewRows([]string{"raw"}).AddRow([]byte(`one binary value with some text!`)) | ||||
| 	scan := func(rs *sql.Rows) ([]byte, error) { | ||||
| 		var raw sql.RawBytes | ||||
| 		return raw, rs.Scan(&raw) | ||||
| 	} | ||||
| 	want := struct { | ||||
| 		Initial  []byte | ||||
| 		Replaced []byte | ||||
| 	}{ | ||||
| 		Initial:  []byte(`one binary value with some text!`), | ||||
| 		Replaced: replace[:len(replace)-7], | ||||
| 	} | ||||
| 	queryRowBytesInvalidatedByClose(t, rows, scan, want) | ||||
| } | ||||
|  | ||||
| func TestQueryRowBytesNotInvalidatedByClose_bytesIntoBytes(t *testing.T) { | ||||
| 	t.Parallel() | ||||
| 	rows := NewRows([]string{"raw"}).AddRow([]byte(`one binary value with some text!`)) | ||||
| 	scan := func(rs *sql.Rows) ([]byte, error) { | ||||
| 		var b []byte | ||||
| 		return b, rs.Scan(&b) | ||||
| 	} | ||||
| 	queryRowBytesNotInvalidatedByClose(t, rows, scan, []byte(`one binary value with some text!`)) | ||||
| } | ||||
|  | ||||
| func TestQueryRowBytesNotInvalidatedByClose_stringIntoBytes(t *testing.T) { | ||||
| 	t.Parallel() | ||||
| 	rows := NewRows([]string{"raw"}).AddRow(`one binary value with some text!`) | ||||
| 	scan := func(rs *sql.Rows) ([]byte, error) { | ||||
| 		var b []byte | ||||
| 		return b, rs.Scan(&b) | ||||
| 	} | ||||
| 	queryRowBytesNotInvalidatedByClose(t, rows, scan, []byte(`one binary value with some text!`)) | ||||
| } | ||||
|  | ||||
| func TestRowsScanError(t *testing.T) { | ||||
| 	t.Parallel() | ||||
| 	db, mock, err := New() | ||||
| @@ -363,3 +496,177 @@ func TestEmptyRowSets(t *testing.T) { | ||||
| 		t.Fatalf("expected rowset 3, to be empty, but it was not") | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func queryRowBytesInvalidatedByNext(t *testing.T, rows *Rows, scan func(*sql.Rows) ([]byte, error), want []struct { | ||||
| 	Initial  []byte | ||||
| 	Replaced []byte | ||||
| }) { | ||||
| 	db, mock, err := New() | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) | ||||
| 	} | ||||
| 	defer db.Close() | ||||
| 	mock.ExpectQuery("SELECT").WillReturnRows(rows) | ||||
|  | ||||
| 	rs, err := db.Query("SELECT") | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("failed to query rows: %s", err) | ||||
| 	} | ||||
|  | ||||
| 	if !rs.Next() || rs.Err() != nil { | ||||
| 		t.Fatal("unexpected error on first row retrieval") | ||||
| 	} | ||||
| 	var count int | ||||
| 	for i := 0; ; i++ { | ||||
| 		count++ | ||||
| 		b, err := scan(rs) | ||||
| 		if err != nil { | ||||
| 			t.Fatalf("unexpected error scanning row: %s", err) | ||||
| 		} | ||||
| 		if exp := want[i].Initial; !bytes.Equal(b, exp) { | ||||
| 			t.Fatalf("expected raw value to be '%s' (len:%d), but got [%T]:%s (len:%d)", exp, len(exp), b, b, len(b)) | ||||
| 		} | ||||
| 		next := rs.Next() | ||||
| 		if exp := want[i].Replaced; !bytes.Equal(b, exp) { | ||||
| 			t.Fatalf("expected raw value to be replaced with '%s' (len:%d) after calling Next(), but got [%T]:%s (len:%d)", exp, len(exp), b, b, len(b)) | ||||
| 		} | ||||
| 		if !next { | ||||
| 			break | ||||
| 		} | ||||
| 	} | ||||
| 	if err := rs.Err(); err != nil { | ||||
| 		t.Fatalf("row iteration failed: %s", err) | ||||
| 	} | ||||
| 	if exp := len(want); count != exp { | ||||
| 		t.Fatalf("incorrect number of rows exp: %d, but got %d", exp, count) | ||||
| 	} | ||||
|  | ||||
| 	if err := mock.ExpectationsWereMet(); err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func queryRowBytesNotInvalidatedByNext(t *testing.T, rows *Rows, scan func(*sql.Rows) ([]byte, error), want [][]byte) { | ||||
| 	db, mock, err := New() | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) | ||||
| 	} | ||||
| 	defer db.Close() | ||||
| 	mock.ExpectQuery("SELECT").WillReturnRows(rows) | ||||
|  | ||||
| 	rs, err := db.Query("SELECT") | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("failed to query rows: %s", err) | ||||
| 	} | ||||
|  | ||||
| 	if !rs.Next() || rs.Err() != nil { | ||||
| 		t.Fatal("unexpected error on first row retrieval") | ||||
| 	} | ||||
| 	var count int | ||||
| 	for i := 0; ; i++ { | ||||
| 		count++ | ||||
| 		b, err := scan(rs) | ||||
| 		if err != nil { | ||||
| 			t.Fatalf("unexpected error scanning row: %s", err) | ||||
| 		} | ||||
| 		if exp := want[i]; !bytes.Equal(b, exp) { | ||||
| 			t.Fatalf("expected raw value to be '%s' (len:%d), but got [%T]:%s (len:%d)", exp, len(exp), b, b, len(b)) | ||||
| 		} | ||||
| 		next := rs.Next() | ||||
| 		if exp := want[i]; !bytes.Equal(b, exp) { | ||||
| 			t.Fatalf("expected raw value to be replaced with '%s' (len:%d) after calling Next(), but got [%T]:%s (len:%d)", exp, len(exp), b, b, len(b)) | ||||
| 		} | ||||
| 		if !next { | ||||
| 			break | ||||
| 		} | ||||
| 	} | ||||
| 	if err := rs.Err(); err != nil { | ||||
| 		t.Fatalf("row iteration failed: %s", err) | ||||
| 	} | ||||
| 	if exp := len(want); count != exp { | ||||
| 		t.Fatalf("incorrect number of rows exp: %d, but got %d", exp, count) | ||||
| 	} | ||||
|  | ||||
| 	if err := mock.ExpectationsWereMet(); err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func queryRowBytesInvalidatedByClose(t *testing.T, rows *Rows, scan func(*sql.Rows) ([]byte, error), want struct { | ||||
| 	Initial  []byte | ||||
| 	Replaced []byte | ||||
| }) { | ||||
| 	db, mock, err := New() | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) | ||||
| 	} | ||||
| 	defer db.Close() | ||||
| 	mock.ExpectQuery("SELECT").WillReturnRows(rows) | ||||
|  | ||||
| 	rs, err := db.Query("SELECT") | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("failed to query rows: %s", err) | ||||
| 	} | ||||
|  | ||||
| 	if !rs.Next() || rs.Err() != nil { | ||||
| 		t.Fatal("unexpected error on first row retrieval") | ||||
| 	} | ||||
| 	b, err := scan(rs) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("unexpected error scanning row: %s", err) | ||||
| 	} | ||||
| 	if !bytes.Equal(b, want.Initial) { | ||||
| 		t.Fatalf("expected raw value to be '%s' (len:%d), but got [%T]:%s (len:%d)", want.Initial, len(want.Initial), b, b, len(b)) | ||||
| 	} | ||||
| 	if err := rs.Close(); err != nil { | ||||
| 		t.Fatalf("unexpected error closing rows: %s", err) | ||||
| 	} | ||||
| 	if !bytes.Equal(b, want.Replaced) { | ||||
| 		t.Fatalf("expected raw value to be replaced with '%s' (len:%d) after calling Next(), but got [%T]:%s (len:%d)", want.Replaced, len(want.Replaced), b, b, len(b)) | ||||
| 	} | ||||
| 	if err := rs.Err(); err != nil { | ||||
| 		t.Fatalf("row iteration failed: %s", err) | ||||
| 	} | ||||
|  | ||||
| 	if err := mock.ExpectationsWereMet(); err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func queryRowBytesNotInvalidatedByClose(t *testing.T, rows *Rows, scan func(*sql.Rows) ([]byte, error), want []byte) { | ||||
| 	db, mock, err := New() | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) | ||||
| 	} | ||||
| 	defer db.Close() | ||||
| 	mock.ExpectQuery("SELECT").WillReturnRows(rows) | ||||
|  | ||||
| 	rs, err := db.Query("SELECT") | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("failed to query rows: %s", err) | ||||
| 	} | ||||
|  | ||||
| 	if !rs.Next() || rs.Err() != nil { | ||||
| 		t.Fatal("unexpected error on first row retrieval") | ||||
| 	} | ||||
| 	b, err := scan(rs) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("unexpected error scanning row: %s", err) | ||||
| 	} | ||||
| 	if !bytes.Equal(b, want) { | ||||
| 		t.Fatalf("expected raw value to be '%s' (len:%d), but got [%T]:%s (len:%d)", want, len(want), b, b, len(b)) | ||||
| 	} | ||||
| 	if err := rs.Close(); err != nil { | ||||
| 		t.Fatalf("unexpected error closing rows: %s", err) | ||||
| 	} | ||||
| 	if !bytes.Equal(b, want) { | ||||
| 		t.Fatalf("expected raw value to be replaced with '%s' (len:%d) after calling Next(), but got [%T]:%s (len:%d)", want, len(want), b, b, len(b)) | ||||
| 	} | ||||
| 	if err := rs.Err(); err != nil { | ||||
| 		t.Fatalf("row iteration failed: %s", err) | ||||
| 	} | ||||
|  | ||||
| 	if err := mock.ExpectationsWereMet(); err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
| } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user