mirror of
https://github.com/DATA-DOG/go-sqlmock.git
synced 2025-03-27 21:28:51 +02:00
implements next rows result set support
This commit is contained in:
parent
42ab7c33d0
commit
128bf5c539
@ -144,13 +144,6 @@ func (e *ExpectedQuery) WillReturnError(err error) *ExpectedQuery {
|
||||
return e
|
||||
}
|
||||
|
||||
// WillReturnRows specifies the set of resulting rows that will be returned
|
||||
// by the triggered query
|
||||
func (e *ExpectedQuery) WillReturnRows(rows driver.Rows) *ExpectedQuery {
|
||||
e.rows = rows
|
||||
return e
|
||||
}
|
||||
|
||||
// WillDelayFor allows to specify duration for which it will delay
|
||||
// result. May be used together with Context
|
||||
func (e *ExpectedQuery) WillDelayFor(duration time.Duration) *ExpectedQuery {
|
||||
@ -175,9 +168,11 @@ func (e *ExpectedQuery) String() string {
|
||||
|
||||
if e.rows != nil {
|
||||
msg += "\n - should return rows:\n"
|
||||
rs, _ := e.rows.(*rows)
|
||||
for i, row := range rs.rows {
|
||||
msg += fmt.Sprintf(" %d - %+v\n", i, row)
|
||||
rs, _ := e.rows.(*rowSets)
|
||||
for _, set := range rs.sets {
|
||||
for i, row := range set.rows {
|
||||
msg += fmt.Sprintf(" %d - %+v\n", i, row)
|
||||
}
|
||||
}
|
||||
msg = strings.TrimSpace(msg)
|
||||
}
|
||||
|
@ -8,6 +8,13 @@ import (
|
||||
"reflect"
|
||||
)
|
||||
|
||||
// WillReturnRows specifies the set of resulting rows that will be returned
|
||||
// by the triggered query
|
||||
func (e *ExpectedQuery) WillReturnRows(rows *Rows) *ExpectedQuery {
|
||||
e.rows = &rowSets{sets: []*Rows{rows}}
|
||||
return e
|
||||
}
|
||||
|
||||
func (e *queryBasedExpectation) argsMatches(args []namedValue) error {
|
||||
if nil == e.args {
|
||||
return nil
|
@ -9,6 +9,17 @@ import (
|
||||
"reflect"
|
||||
)
|
||||
|
||||
// WillReturnRows specifies the set of resulting rows that will be returned
|
||||
// by the triggered query
|
||||
func (e *ExpectedQuery) WillReturnRows(rows ...*Rows) *ExpectedQuery {
|
||||
sets := make([]*Rows, len(rows))
|
||||
for i, r := range rows {
|
||||
sets[i] = r
|
||||
}
|
||||
e.rows = &rowSets{sets: sets}
|
||||
return e
|
||||
}
|
||||
|
||||
func (e *queryBasedExpectation) argsMatches(args []namedValue) error {
|
||||
if nil == e.args {
|
||||
return nil
|
93
rows.go
93
rows.go
@ -18,57 +18,22 @@ var CSVColumnParser = func(s string) []byte {
|
||||
return []byte(s)
|
||||
}
|
||||
|
||||
// Rows interface allows to construct rows
|
||||
// which also satisfies database/sql/driver.Rows interface
|
||||
type Rows interface {
|
||||
// composed interface, supports sql driver.Rows
|
||||
driver.Rows
|
||||
|
||||
// AddRow composed from database driver.Value slice
|
||||
// return the same instance to perform subsequent actions.
|
||||
// Note that the number of values must match the number
|
||||
// of columns
|
||||
AddRow(columns ...driver.Value) Rows
|
||||
|
||||
// FromCSVString build rows from csv string.
|
||||
// return the same instance to perform subsequent actions.
|
||||
// Note that the number of values must match the number
|
||||
// of columns
|
||||
FromCSVString(s string) Rows
|
||||
|
||||
// RowError allows to set an error
|
||||
// which will be returned when a given
|
||||
// row number is read
|
||||
RowError(row int, err error) Rows
|
||||
|
||||
// CloseError allows to set an error
|
||||
// which will be returned by rows.Close
|
||||
// function.
|
||||
//
|
||||
// The close error will be triggered only in cases
|
||||
// when rows.Next() EOF was not yet reached, that is
|
||||
// a default sql library behavior
|
||||
CloseError(err error) Rows
|
||||
type rowSets struct {
|
||||
sets []*Rows
|
||||
pos int
|
||||
}
|
||||
|
||||
type rows struct {
|
||||
cols []string
|
||||
rows [][]driver.Value
|
||||
pos int
|
||||
nextErr map[int]error
|
||||
closeErr error
|
||||
func (rs *rowSets) Columns() []string {
|
||||
return rs.sets[rs.pos].cols
|
||||
}
|
||||
|
||||
func (r *rows) Columns() []string {
|
||||
return r.cols
|
||||
}
|
||||
|
||||
func (r *rows) Close() error {
|
||||
return r.closeErr
|
||||
func (rs *rowSets) Close() error {
|
||||
return rs.sets[rs.pos].closeErr
|
||||
}
|
||||
|
||||
// advances to next row
|
||||
func (r *rows) Next(dest []driver.Value) error {
|
||||
func (rs *rowSets) Next(dest []driver.Value) error {
|
||||
r := rs.sets[rs.pos]
|
||||
r.pos++
|
||||
if r.pos > len(r.rows) {
|
||||
return io.EOF // per interface spec
|
||||
@ -81,24 +46,48 @@ func (r *rows) Next(dest []driver.Value) error {
|
||||
return r.nextErr[r.pos-1]
|
||||
}
|
||||
|
||||
// Rows is a mocked collection of rows to
|
||||
// return for Query result
|
||||
type Rows struct {
|
||||
cols []string
|
||||
rows [][]driver.Value
|
||||
pos int
|
||||
nextErr map[int]error
|
||||
closeErr error
|
||||
}
|
||||
|
||||
// NewRows allows Rows to be created from a
|
||||
// sql driver.Value slice or from the CSV string and
|
||||
// to be used as sql driver.Rows
|
||||
func NewRows(columns []string) Rows {
|
||||
return &rows{cols: columns, nextErr: make(map[int]error)}
|
||||
func NewRows(columns []string) *Rows {
|
||||
return &Rows{cols: columns, nextErr: make(map[int]error)}
|
||||
}
|
||||
|
||||
func (r *rows) CloseError(err error) Rows {
|
||||
// CloseError allows to set an error
|
||||
// which will be returned by rows.Close
|
||||
// function.
|
||||
//
|
||||
// The close error will be triggered only in cases
|
||||
// when rows.Next() EOF was not yet reached, that is
|
||||
// a default sql library behavior
|
||||
func (r *Rows) CloseError(err error) *Rows {
|
||||
r.closeErr = err
|
||||
return r
|
||||
}
|
||||
|
||||
func (r *rows) RowError(row int, err error) Rows {
|
||||
// RowError allows to set an error
|
||||
// which will be returned when a given
|
||||
// row number is read
|
||||
func (r *Rows) RowError(row int, err error) *Rows {
|
||||
r.nextErr[row] = err
|
||||
return r
|
||||
}
|
||||
|
||||
func (r *rows) AddRow(values ...driver.Value) Rows {
|
||||
// AddRow composed from database driver.Value slice
|
||||
// return the same instance to perform subsequent actions.
|
||||
// Note that the number of values must match the number
|
||||
// of columns
|
||||
func (r *Rows) AddRow(values ...driver.Value) *Rows {
|
||||
if len(values) != len(r.cols) {
|
||||
panic("Expected number of values to match number of columns")
|
||||
}
|
||||
@ -112,7 +101,11 @@ func (r *rows) AddRow(values ...driver.Value) Rows {
|
||||
return r
|
||||
}
|
||||
|
||||
func (r *rows) FromCSVString(s string) Rows {
|
||||
// FromCSVString build rows from csv string.
|
||||
// return the same instance to perform subsequent actions.
|
||||
// Note that the number of values must match the number
|
||||
// of columns
|
||||
func (r *Rows) FromCSVString(s string) *Rows {
|
||||
res := strings.NewReader(strings.TrimSpace(s))
|
||||
csvReader := csv.NewReader(res)
|
||||
|
||||
|
20
rows_go18.go
Normal file
20
rows_go18.go
Normal file
@ -0,0 +1,20 @@
|
||||
// +build go1.8
|
||||
|
||||
package sqlmock
|
||||
|
||||
import "io"
|
||||
|
||||
// Implement the "RowsNextResultSet" interface
|
||||
func (rs *rowSets) HasNextResultSet() bool {
|
||||
return rs.pos+1 < len(rs.sets)
|
||||
}
|
||||
|
||||
// Implement the "RowsNextResultSet" interface
|
||||
func (rs *rowSets) NextResultSet() error {
|
||||
if !rs.HasNextResultSet() {
|
||||
return io.EOF
|
||||
}
|
||||
|
||||
rs.pos++
|
||||
return nil
|
||||
}
|
92
rows_go18_test.go
Normal file
92
rows_go18_test.go
Normal file
@ -0,0 +1,92 @@
|
||||
// +build go1.8
|
||||
|
||||
package sqlmock
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestQueryMultiRows(t *testing.T) {
|
||||
t.Parallel()
|
||||
db, mock, err := New()
|
||||
if err != nil {
|
||||
t.Errorf("an error '%s' was not expected when opening a stub database connection", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
rs1 := NewRows([]string{"id", "title"}).AddRow(5, "hello world")
|
||||
rs2 := NewRows([]string{"name"}).AddRow("gopher").AddRow("john").AddRow("jane").RowError(2, fmt.Errorf("error"))
|
||||
|
||||
mock.ExpectQuery("SELECT (.+) FROM articles WHERE id = \\?;SELECT name FROM users").
|
||||
WithArgs(5).
|
||||
WillReturnRows(rs1, rs2)
|
||||
|
||||
rows, err := db.Query("SELECT id, title FROM articles WHERE id = ?;SELECT name FROM users", 5)
|
||||
if err != nil {
|
||||
t.Errorf("error was not expected, but got: %v", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
if !rows.Next() {
|
||||
t.Error("expected a row to be available in first result set")
|
||||
}
|
||||
|
||||
var id int
|
||||
var name string
|
||||
|
||||
err = rows.Scan(&id, &name)
|
||||
if err != nil {
|
||||
t.Errorf("error was not expected, but got: %v", err)
|
||||
}
|
||||
|
||||
if id != 5 || name != "hello world" {
|
||||
t.Errorf("unexpected row values id: %v name: %v", id, name)
|
||||
}
|
||||
|
||||
if rows.Next() {
|
||||
t.Error("was not expecting next row in first result set")
|
||||
}
|
||||
|
||||
if !rows.NextResultSet() {
|
||||
t.Error("had to have next result set")
|
||||
}
|
||||
|
||||
if !rows.Next() {
|
||||
t.Error("expected a row to be available in second result set")
|
||||
}
|
||||
|
||||
err = rows.Scan(&name)
|
||||
if err != nil {
|
||||
t.Errorf("error was not expected, but got: %v", err)
|
||||
}
|
||||
|
||||
if name != "gopher" {
|
||||
t.Errorf("unexpected row name: %v", name)
|
||||
}
|
||||
|
||||
if !rows.Next() {
|
||||
t.Error("expected a row to be available in second result set")
|
||||
}
|
||||
|
||||
err = rows.Scan(&name)
|
||||
if err != nil {
|
||||
t.Errorf("error was not expected, but got: %v", err)
|
||||
}
|
||||
|
||||
if name != "john" {
|
||||
t.Errorf("unexpected row name: %v", name)
|
||||
}
|
||||
|
||||
if rows.Next() {
|
||||
t.Error("expected next row to produce error")
|
||||
}
|
||||
|
||||
if rows.Err() == nil {
|
||||
t.Error("expected an error, but there was none")
|
||||
}
|
||||
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("there were unfulfilled expections: %s", err)
|
||||
}
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user