mirror of
https://github.com/DATA-DOG/go-sqlmock.git
synced 2025-04-04 21:54:20 +02:00
implements next rows result set support
This commit is contained in:
parent
42ab7c33d0
commit
128bf5c539
@ -144,13 +144,6 @@ func (e *ExpectedQuery) WillReturnError(err error) *ExpectedQuery {
|
|||||||
return e
|
return e
|
||||||
}
|
}
|
||||||
|
|
||||||
// WillReturnRows specifies the set of resulting rows that will be returned
|
|
||||||
// by the triggered query
|
|
||||||
func (e *ExpectedQuery) WillReturnRows(rows driver.Rows) *ExpectedQuery {
|
|
||||||
e.rows = rows
|
|
||||||
return e
|
|
||||||
}
|
|
||||||
|
|
||||||
// WillDelayFor allows to specify duration for which it will delay
|
// WillDelayFor allows to specify duration for which it will delay
|
||||||
// result. May be used together with Context
|
// result. May be used together with Context
|
||||||
func (e *ExpectedQuery) WillDelayFor(duration time.Duration) *ExpectedQuery {
|
func (e *ExpectedQuery) WillDelayFor(duration time.Duration) *ExpectedQuery {
|
||||||
@ -175,9 +168,11 @@ func (e *ExpectedQuery) String() string {
|
|||||||
|
|
||||||
if e.rows != nil {
|
if e.rows != nil {
|
||||||
msg += "\n - should return rows:\n"
|
msg += "\n - should return rows:\n"
|
||||||
rs, _ := e.rows.(*rows)
|
rs, _ := e.rows.(*rowSets)
|
||||||
for i, row := range rs.rows {
|
for _, set := range rs.sets {
|
||||||
msg += fmt.Sprintf(" %d - %+v\n", i, row)
|
for i, row := range set.rows {
|
||||||
|
msg += fmt.Sprintf(" %d - %+v\n", i, row)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
msg = strings.TrimSpace(msg)
|
msg = strings.TrimSpace(msg)
|
||||||
}
|
}
|
||||||
|
@ -8,6 +8,13 @@ import (
|
|||||||
"reflect"
|
"reflect"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// WillReturnRows specifies the set of resulting rows that will be returned
|
||||||
|
// by the triggered query
|
||||||
|
func (e *ExpectedQuery) WillReturnRows(rows *Rows) *ExpectedQuery {
|
||||||
|
e.rows = &rowSets{sets: []*Rows{rows}}
|
||||||
|
return e
|
||||||
|
}
|
||||||
|
|
||||||
func (e *queryBasedExpectation) argsMatches(args []namedValue) error {
|
func (e *queryBasedExpectation) argsMatches(args []namedValue) error {
|
||||||
if nil == e.args {
|
if nil == e.args {
|
||||||
return nil
|
return nil
|
@ -9,6 +9,17 @@ import (
|
|||||||
"reflect"
|
"reflect"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// WillReturnRows specifies the set of resulting rows that will be returned
|
||||||
|
// by the triggered query
|
||||||
|
func (e *ExpectedQuery) WillReturnRows(rows ...*Rows) *ExpectedQuery {
|
||||||
|
sets := make([]*Rows, len(rows))
|
||||||
|
for i, r := range rows {
|
||||||
|
sets[i] = r
|
||||||
|
}
|
||||||
|
e.rows = &rowSets{sets: sets}
|
||||||
|
return e
|
||||||
|
}
|
||||||
|
|
||||||
func (e *queryBasedExpectation) argsMatches(args []namedValue) error {
|
func (e *queryBasedExpectation) argsMatches(args []namedValue) error {
|
||||||
if nil == e.args {
|
if nil == e.args {
|
||||||
return nil
|
return nil
|
93
rows.go
93
rows.go
@ -18,57 +18,22 @@ var CSVColumnParser = func(s string) []byte {
|
|||||||
return []byte(s)
|
return []byte(s)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Rows interface allows to construct rows
|
type rowSets struct {
|
||||||
// which also satisfies database/sql/driver.Rows interface
|
sets []*Rows
|
||||||
type Rows interface {
|
pos int
|
||||||
// composed interface, supports sql driver.Rows
|
|
||||||
driver.Rows
|
|
||||||
|
|
||||||
// AddRow composed from database driver.Value slice
|
|
||||||
// return the same instance to perform subsequent actions.
|
|
||||||
// Note that the number of values must match the number
|
|
||||||
// of columns
|
|
||||||
AddRow(columns ...driver.Value) Rows
|
|
||||||
|
|
||||||
// FromCSVString build rows from csv string.
|
|
||||||
// return the same instance to perform subsequent actions.
|
|
||||||
// Note that the number of values must match the number
|
|
||||||
// of columns
|
|
||||||
FromCSVString(s string) Rows
|
|
||||||
|
|
||||||
// RowError allows to set an error
|
|
||||||
// which will be returned when a given
|
|
||||||
// row number is read
|
|
||||||
RowError(row int, err error) Rows
|
|
||||||
|
|
||||||
// CloseError allows to set an error
|
|
||||||
// which will be returned by rows.Close
|
|
||||||
// function.
|
|
||||||
//
|
|
||||||
// The close error will be triggered only in cases
|
|
||||||
// when rows.Next() EOF was not yet reached, that is
|
|
||||||
// a default sql library behavior
|
|
||||||
CloseError(err error) Rows
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type rows struct {
|
func (rs *rowSets) Columns() []string {
|
||||||
cols []string
|
return rs.sets[rs.pos].cols
|
||||||
rows [][]driver.Value
|
|
||||||
pos int
|
|
||||||
nextErr map[int]error
|
|
||||||
closeErr error
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *rows) Columns() []string {
|
func (rs *rowSets) Close() error {
|
||||||
return r.cols
|
return rs.sets[rs.pos].closeErr
|
||||||
}
|
|
||||||
|
|
||||||
func (r *rows) Close() error {
|
|
||||||
return r.closeErr
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// advances to next row
|
// advances to next row
|
||||||
func (r *rows) Next(dest []driver.Value) error {
|
func (rs *rowSets) Next(dest []driver.Value) error {
|
||||||
|
r := rs.sets[rs.pos]
|
||||||
r.pos++
|
r.pos++
|
||||||
if r.pos > len(r.rows) {
|
if r.pos > len(r.rows) {
|
||||||
return io.EOF // per interface spec
|
return io.EOF // per interface spec
|
||||||
@ -81,24 +46,48 @@ func (r *rows) Next(dest []driver.Value) error {
|
|||||||
return r.nextErr[r.pos-1]
|
return r.nextErr[r.pos-1]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Rows is a mocked collection of rows to
|
||||||
|
// return for Query result
|
||||||
|
type Rows struct {
|
||||||
|
cols []string
|
||||||
|
rows [][]driver.Value
|
||||||
|
pos int
|
||||||
|
nextErr map[int]error
|
||||||
|
closeErr error
|
||||||
|
}
|
||||||
|
|
||||||
// NewRows allows Rows to be created from a
|
// NewRows allows Rows to be created from a
|
||||||
// sql driver.Value slice or from the CSV string and
|
// sql driver.Value slice or from the CSV string and
|
||||||
// to be used as sql driver.Rows
|
// to be used as sql driver.Rows
|
||||||
func NewRows(columns []string) Rows {
|
func NewRows(columns []string) *Rows {
|
||||||
return &rows{cols: columns, nextErr: make(map[int]error)}
|
return &Rows{cols: columns, nextErr: make(map[int]error)}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *rows) CloseError(err error) Rows {
|
// CloseError allows to set an error
|
||||||
|
// which will be returned by rows.Close
|
||||||
|
// function.
|
||||||
|
//
|
||||||
|
// The close error will be triggered only in cases
|
||||||
|
// when rows.Next() EOF was not yet reached, that is
|
||||||
|
// a default sql library behavior
|
||||||
|
func (r *Rows) CloseError(err error) *Rows {
|
||||||
r.closeErr = err
|
r.closeErr = err
|
||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *rows) RowError(row int, err error) Rows {
|
// RowError allows to set an error
|
||||||
|
// which will be returned when a given
|
||||||
|
// row number is read
|
||||||
|
func (r *Rows) RowError(row int, err error) *Rows {
|
||||||
r.nextErr[row] = err
|
r.nextErr[row] = err
|
||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *rows) AddRow(values ...driver.Value) Rows {
|
// AddRow composed from database driver.Value slice
|
||||||
|
// return the same instance to perform subsequent actions.
|
||||||
|
// Note that the number of values must match the number
|
||||||
|
// of columns
|
||||||
|
func (r *Rows) AddRow(values ...driver.Value) *Rows {
|
||||||
if len(values) != len(r.cols) {
|
if len(values) != len(r.cols) {
|
||||||
panic("Expected number of values to match number of columns")
|
panic("Expected number of values to match number of columns")
|
||||||
}
|
}
|
||||||
@ -112,7 +101,11 @@ func (r *rows) AddRow(values ...driver.Value) Rows {
|
|||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *rows) FromCSVString(s string) Rows {
|
// FromCSVString build rows from csv string.
|
||||||
|
// return the same instance to perform subsequent actions.
|
||||||
|
// Note that the number of values must match the number
|
||||||
|
// of columns
|
||||||
|
func (r *Rows) FromCSVString(s string) *Rows {
|
||||||
res := strings.NewReader(strings.TrimSpace(s))
|
res := strings.NewReader(strings.TrimSpace(s))
|
||||||
csvReader := csv.NewReader(res)
|
csvReader := csv.NewReader(res)
|
||||||
|
|
||||||
|
20
rows_go18.go
Normal file
20
rows_go18.go
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
// +build go1.8
|
||||||
|
|
||||||
|
package sqlmock
|
||||||
|
|
||||||
|
import "io"
|
||||||
|
|
||||||
|
// Implement the "RowsNextResultSet" interface
|
||||||
|
func (rs *rowSets) HasNextResultSet() bool {
|
||||||
|
return rs.pos+1 < len(rs.sets)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Implement the "RowsNextResultSet" interface
|
||||||
|
func (rs *rowSets) NextResultSet() error {
|
||||||
|
if !rs.HasNextResultSet() {
|
||||||
|
return io.EOF
|
||||||
|
}
|
||||||
|
|
||||||
|
rs.pos++
|
||||||
|
return nil
|
||||||
|
}
|
92
rows_go18_test.go
Normal file
92
rows_go18_test.go
Normal file
@ -0,0 +1,92 @@
|
|||||||
|
// +build go1.8
|
||||||
|
|
||||||
|
package sqlmock
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestQueryMultiRows(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
db, mock, err := New()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("an error '%s' was not expected when opening a stub database connection", err)
|
||||||
|
}
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
rs1 := NewRows([]string{"id", "title"}).AddRow(5, "hello world")
|
||||||
|
rs2 := NewRows([]string{"name"}).AddRow("gopher").AddRow("john").AddRow("jane").RowError(2, fmt.Errorf("error"))
|
||||||
|
|
||||||
|
mock.ExpectQuery("SELECT (.+) FROM articles WHERE id = \\?;SELECT name FROM users").
|
||||||
|
WithArgs(5).
|
||||||
|
WillReturnRows(rs1, rs2)
|
||||||
|
|
||||||
|
rows, err := db.Query("SELECT id, title FROM articles WHERE id = ?;SELECT name FROM users", 5)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("error was not expected, but got: %v", err)
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
if !rows.Next() {
|
||||||
|
t.Error("expected a row to be available in first result set")
|
||||||
|
}
|
||||||
|
|
||||||
|
var id int
|
||||||
|
var name string
|
||||||
|
|
||||||
|
err = rows.Scan(&id, &name)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("error was not expected, but got: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if id != 5 || name != "hello world" {
|
||||||
|
t.Errorf("unexpected row values id: %v name: %v", id, name)
|
||||||
|
}
|
||||||
|
|
||||||
|
if rows.Next() {
|
||||||
|
t.Error("was not expecting next row in first result set")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !rows.NextResultSet() {
|
||||||
|
t.Error("had to have next result set")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !rows.Next() {
|
||||||
|
t.Error("expected a row to be available in second result set")
|
||||||
|
}
|
||||||
|
|
||||||
|
err = rows.Scan(&name)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("error was not expected, but got: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if name != "gopher" {
|
||||||
|
t.Errorf("unexpected row name: %v", name)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !rows.Next() {
|
||||||
|
t.Error("expected a row to be available in second result set")
|
||||||
|
}
|
||||||
|
|
||||||
|
err = rows.Scan(&name)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("error was not expected, but got: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if name != "john" {
|
||||||
|
t.Errorf("unexpected row name: %v", name)
|
||||||
|
}
|
||||||
|
|
||||||
|
if rows.Next() {
|
||||||
|
t.Error("expected next row to produce error")
|
||||||
|
}
|
||||||
|
|
||||||
|
if rows.Err() == nil {
|
||||||
|
t.Error("expected an error, but there was none")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := mock.ExpectationsWereMet(); err != nil {
|
||||||
|
t.Errorf("there were unfulfilled expections: %s", err)
|
||||||
|
}
|
||||||
|
}
|
Loading…
x
Reference in New Issue
Block a user