1
0
mirror of https://github.com/DATA-DOG/go-sqlmock.git synced 2025-05-13 21:56:39 +02:00

Merge pull request from DATA-DOG/go1.8

Support for go 1.8 SQL features
This commit is contained in:
Gediminas Morkevicius 2017-02-21 17:59:18 +02:00 committed by GitHub
commit b983233bc0
14 changed files with 1030 additions and 125 deletions

@ -7,6 +7,7 @@ go:
- 1.5 - 1.5
- 1.6 - 1.6
- 1.7 - 1.7
- 1.8
- tip - tip
script: go test -race -coverprofile=coverage.txt -covermode=atomic script: go test -race -coverprofile=coverage.txt -covermode=atomic

@ -1,6 +1,6 @@
The three clause BSD license (http://en.wikipedia.org/wiki/BSD_licenses) The three clause BSD license (http://en.wikipedia.org/wiki/BSD_licenses)
Copyright (c) 2013-2016, DATA-DOG team Copyright (c) 2013-2017, DATA-DOG team
All rights reserved. All rights reserved.
Redistribution and use in source and binary forms, with or without Redistribution and use in source and binary forms, with or without

@ -10,19 +10,16 @@ maintain correct **TDD** workflow.
- this library is now complete and stable. (you may not find new changes for this reason) - this library is now complete and stable. (you may not find new changes for this reason)
- supports concurrency and multiple connections. - supports concurrency and multiple connections.
- supports **go1.8** Context related feature mocking and Named sql parameters.
- does not require any modifications to your source code. - does not require any modifications to your source code.
- the driver allows to mock any sql driver method behavior. - the driver allows to mock any sql driver method behavior.
- has strict by default expectation order matching. - has strict by default expectation order matching.
- has no vendor dependencies. - has no third party dependencies.
## Install ## Install
go get gopkg.in/DATA-DOG/go-sqlmock.v1 go get gopkg.in/DATA-DOG/go-sqlmock.v1
If you need an old version, checkout **go-sqlmock** at gopkg.in:
go get gopkg.in/DATA-DOG/go-sqlmock.v0
## Documentation and Examples ## Documentation and Examples
Visit [godoc](http://godoc.org/github.com/DATA-DOG/go-sqlmock) for general examples and public api reference. Visit [godoc](http://godoc.org/github.com/DATA-DOG/go-sqlmock) for general examples and public api reference.
@ -187,8 +184,11 @@ It only asserts that argument is of `time.Time` type.
go test -race go test -race
## Changes ## Change Log
- **2017-02-09** - implemented support for **go1.8** features. **Rows** interface was changed to struct
but contains all methods as before and should maintain backwards compatibility. **ExpectedQuery.WillReturnRows** may now
accept multiple row sets.
- **2016-11-02** - `db.Prepare()` was not validating expected prepare SQL - **2016-11-02** - `db.Prepare()` was not validating expected prepare SQL
query. It should still be validated even if Exec or Query is not query. It should still be validated even if Exec or Query is not
executed on that prepared statement. executed on that prepared statement.

@ -3,10 +3,10 @@ package sqlmock
import ( import (
"database/sql/driver" "database/sql/driver"
"fmt" "fmt"
"reflect"
"regexp" "regexp"
"strings" "strings"
"sync" "sync"
"time"
) )
// an expectation interface // an expectation interface
@ -54,6 +54,7 @@ func (e *ExpectedClose) String() string {
// returned by *Sqlmock.ExpectBegin. // returned by *Sqlmock.ExpectBegin.
type ExpectedBegin struct { type ExpectedBegin struct {
commonExpectation commonExpectation
delay time.Duration
} }
// WillReturnError allows to set an error for *sql.DB.Begin action // WillReturnError allows to set an error for *sql.DB.Begin action
@ -71,6 +72,13 @@ func (e *ExpectedBegin) String() string {
return msg return msg
} }
// WillDelayFor allows to specify duration for which it will delay
// result. May be used together with Context
func (e *ExpectedBegin) WillDelayFor(duration time.Duration) *ExpectedBegin {
e.delay = duration
return e
}
// ExpectedCommit is used to manage *sql.Tx.Commit expectation // ExpectedCommit is used to manage *sql.Tx.Commit expectation
// returned by *Sqlmock.ExpectCommit. // returned by *Sqlmock.ExpectCommit.
type ExpectedCommit struct { type ExpectedCommit struct {
@ -118,7 +126,8 @@ func (e *ExpectedRollback) String() string {
// Returned by *Sqlmock.ExpectQuery. // Returned by *Sqlmock.ExpectQuery.
type ExpectedQuery struct { type ExpectedQuery struct {
queryBasedExpectation queryBasedExpectation
rows driver.Rows rows driver.Rows
delay time.Duration
} }
// WithArgs will match given expected args to actual database query arguments. // WithArgs will match given expected args to actual database query arguments.
@ -135,10 +144,10 @@ func (e *ExpectedQuery) WillReturnError(err error) *ExpectedQuery {
return e return e
} }
// WillReturnRows specifies the set of resulting rows that will be returned // WillDelayFor allows to specify duration for which it will delay
// by the triggered query // result. May be used together with Context
func (e *ExpectedQuery) WillReturnRows(rows driver.Rows) *ExpectedQuery { func (e *ExpectedQuery) WillDelayFor(duration time.Duration) *ExpectedQuery {
e.rows = rows e.delay = duration
return e return e
} }
@ -158,12 +167,7 @@ func (e *ExpectedQuery) String() string {
} }
if e.rows != nil { if e.rows != nil {
msg += "\n - should return rows:\n" msg += fmt.Sprintf("\n - %s", e.rows)
rs, _ := e.rows.(*rows)
for i, row := range rs.rows {
msg += fmt.Sprintf(" %d - %+v\n", i, row)
}
msg = strings.TrimSpace(msg)
} }
if e.err != nil { if e.err != nil {
@ -178,6 +182,7 @@ func (e *ExpectedQuery) String() string {
type ExpectedExec struct { type ExpectedExec struct {
queryBasedExpectation queryBasedExpectation
result driver.Result result driver.Result
delay time.Duration
} }
// WithArgs will match given expected args to actual database exec operation arguments. // WithArgs will match given expected args to actual database exec operation arguments.
@ -194,6 +199,13 @@ func (e *ExpectedExec) WillReturnError(err error) *ExpectedExec {
return e return e
} }
// WillDelayFor allows to specify duration for which it will delay
// result. May be used together with Context
func (e *ExpectedExec) WillDelayFor(duration time.Duration) *ExpectedExec {
e.delay = duration
return e
}
// String returns string representation // String returns string representation
func (e *ExpectedExec) String() string { func (e *ExpectedExec) String() string {
msg := "ExpectedExec => expecting Exec which:" msg := "ExpectedExec => expecting Exec which:"
@ -244,6 +256,7 @@ type ExpectedPrepare struct {
sqlRegex *regexp.Regexp sqlRegex *regexp.Regexp
statement driver.Stmt statement driver.Stmt
closeErr error closeErr error
delay time.Duration
} }
// WillReturnError allows to set an error for the expected *sql.DB.Prepare or *sql.Tx.Prepare action. // WillReturnError allows to set an error for the expected *sql.DB.Prepare or *sql.Tx.Prepare action.
@ -258,6 +271,13 @@ func (e *ExpectedPrepare) WillReturnCloseError(err error) *ExpectedPrepare {
return e return e
} }
// WillDelayFor allows to specify duration for which it will delay
// result. May be used together with Context
func (e *ExpectedPrepare) WillDelayFor(duration time.Duration) *ExpectedPrepare {
e.delay = duration
return e
}
// ExpectQuery allows to expect Query() or QueryRow() on this prepared statement. // ExpectQuery allows to expect Query() or QueryRow() on this prepared statement.
// this method is convenient in order to prevent duplicating sql query string matching. // this method is convenient in order to prevent duplicating sql query string matching.
func (e *ExpectedPrepare) ExpectQuery() *ExpectedQuery { func (e *ExpectedPrepare) ExpectQuery() *ExpectedQuery {
@ -300,7 +320,7 @@ type queryBasedExpectation struct {
args []driver.Value args []driver.Value
} }
func (e *queryBasedExpectation) attemptMatch(sql string, args []driver.Value) (err error) { func (e *queryBasedExpectation) attemptMatch(sql string, args []namedValue) (err error) {
if !e.queryMatches(sql) { if !e.queryMatches(sql) {
return fmt.Errorf(`could not match sql: "%s" with expected regexp "%s"`, sql, e.sqlRegex.String()) return fmt.Errorf(`could not match sql: "%s" with expected regexp "%s"`, sql, e.sqlRegex.String())
} }
@ -322,37 +342,3 @@ func (e *queryBasedExpectation) attemptMatch(sql string, args []driver.Value) (e
func (e *queryBasedExpectation) queryMatches(sql string) bool { func (e *queryBasedExpectation) queryMatches(sql string) bool {
return e.sqlRegex.MatchString(sql) return e.sqlRegex.MatchString(sql)
} }
func (e *queryBasedExpectation) argsMatches(args []driver.Value) error {
if nil == e.args {
return nil
}
if len(args) != len(e.args) {
return fmt.Errorf("expected %d, but got %d arguments", len(e.args), len(args))
}
for k, v := range args {
// custom argument matcher
matcher, ok := e.args[k].(Argument)
if ok {
if !matcher.Match(v) {
return fmt.Errorf("matcher %T could not match %d argument %T - %+v", matcher, k, args[k], args[k])
}
continue
}
// convert to driver converter
darg, err := driver.DefaultParameterConverter.ConvertValue(e.args[k])
if err != nil {
return fmt.Errorf("could not convert %d argument %T - %+v to driver value: %s", k, e.args[k], e.args[k], err)
}
if !driver.IsValue(darg) {
return fmt.Errorf("argument %d: non-subset type %T returned from Value", k, darg)
}
if !reflect.DeepEqual(darg, args[k]) {
return fmt.Errorf("argument %d expected [%T - %+v] does not match actual [%T - %+v]", k, darg, darg, args[k], args[k])
}
}
return nil
}

@ -0,0 +1,52 @@
// +build !go1.8
package sqlmock
import (
"database/sql/driver"
"fmt"
"reflect"
)
// WillReturnRows specifies the set of resulting rows that will be returned
// by the triggered query
func (e *ExpectedQuery) WillReturnRows(rows *Rows) *ExpectedQuery {
e.rows = &rowSets{sets: []*Rows{rows}}
return e
}
func (e *queryBasedExpectation) argsMatches(args []namedValue) error {
if nil == e.args {
return nil
}
if len(args) != len(e.args) {
return fmt.Errorf("expected %d, but got %d arguments", len(e.args), len(args))
}
for k, v := range args {
// custom argument matcher
matcher, ok := e.args[k].(Argument)
if ok {
// @TODO: does it make sense to pass value instead of named value?
if !matcher.Match(v.Value) {
return fmt.Errorf("matcher %T could not match %d argument %T - %+v", matcher, k, args[k], args[k])
}
continue
}
dval := e.args[k]
// convert to driver converter
darg, err := driver.DefaultParameterConverter.ConvertValue(dval)
if err != nil {
return fmt.Errorf("could not convert %d argument %T - %+v to driver value: %s", k, e.args[k], e.args[k], err)
}
if !driver.IsValue(darg) {
return fmt.Errorf("argument %d: non-subset type %T returned from Value", k, darg)
}
if !reflect.DeepEqual(darg, v.Value) {
return fmt.Errorf("argument %d expected [%T - %+v] does not match actual [%T - %+v]", k, darg, darg, v.Value, v.Value)
}
}
return nil
}

66
expectations_go18.go Normal file

@ -0,0 +1,66 @@
// +build go1.8
package sqlmock
import (
"database/sql"
"database/sql/driver"
"fmt"
"reflect"
)
// WillReturnRows specifies the set of resulting rows that will be returned
// by the triggered query
func (e *ExpectedQuery) WillReturnRows(rows ...*Rows) *ExpectedQuery {
sets := make([]*Rows, len(rows))
for i, r := range rows {
sets[i] = r
}
e.rows = &rowSets{sets: sets}
return e
}
func (e *queryBasedExpectation) argsMatches(args []namedValue) error {
if nil == e.args {
return nil
}
if len(args) != len(e.args) {
return fmt.Errorf("expected %d, but got %d arguments", len(e.args), len(args))
}
// @TODO should we assert either all args are named or ordinal?
for k, v := range args {
// custom argument matcher
matcher, ok := e.args[k].(Argument)
if ok {
if !matcher.Match(v.Value) {
return fmt.Errorf("matcher %T could not match %d argument %T - %+v", matcher, k, args[k], args[k])
}
continue
}
dval := e.args[k]
if named, isNamed := dval.(sql.NamedArg); isNamed {
dval = named.Value
if v.Name != named.Name {
return fmt.Errorf("named argument %d: name: \"%s\" does not match expected: \"%s\"", k, v.Name, named.Name)
}
} else if k+1 != v.Ordinal {
return fmt.Errorf("argument %d: ordinal position: %d does not match expected: %d", k, k+1, v.Ordinal)
}
// convert to driver converter
darg, err := driver.DefaultParameterConverter.ConvertValue(dval)
if err != nil {
return fmt.Errorf("could not convert %d argument %T - %+v to driver value: %s", k, e.args[k], e.args[k], err)
}
if !driver.IsValue(darg) {
return fmt.Errorf("argument %d: non-subset type %T returned from Value", k, darg)
}
if !reflect.DeepEqual(darg, v.Value) {
return fmt.Errorf("argument %d expected [%T - %+v] does not match actual [%T - %+v]", k, darg, darg, v.Value, v.Value)
}
}
return nil
}

64
expectations_go18_test.go Normal file

@ -0,0 +1,64 @@
// +build go1.8
package sqlmock
import (
"database/sql"
"database/sql/driver"
"testing"
)
func TestQueryExpectationNamedArgComparison(t *testing.T) {
e := &queryBasedExpectation{}
against := []namedValue{{Value: int64(5), Name: "id"}}
if err := e.argsMatches(against); err != nil {
t.Errorf("arguments should match, since the no expectation was set, but got err: %s", err)
}
e.args = []driver.Value{
sql.Named("id", 5),
sql.Named("s", "str"),
}
if err := e.argsMatches(against); err == nil {
t.Error("arguments should not match, since the size is not the same")
}
against = []namedValue{
{Value: int64(5), Name: "id"},
{Value: "str", Name: "s"},
}
if err := e.argsMatches(against); err != nil {
t.Errorf("arguments should have matched, but it did not: %v", err)
}
against = []namedValue{
{Value: int64(5), Name: "id"},
{Value: "str", Name: "username"},
}
if err := e.argsMatches(against); err == nil {
t.Error("arguments matched, but it should have not due to Name")
}
e.args = []driver.Value{int64(5), "str"}
against = []namedValue{
{Value: int64(5), Ordinal: 0},
{Value: "str", Ordinal: 1},
}
if err := e.argsMatches(against); err == nil {
t.Error("arguments matched, but it should have not due to wrong Ordinal position")
}
against = []namedValue{
{Value: int64(5), Ordinal: 1},
{Value: "str", Ordinal: 2},
}
if err := e.argsMatches(against); err != nil {
t.Errorf("arguments should have matched, but it did not: %v", err)
}
}

@ -10,29 +10,38 @@ import (
func TestQueryExpectationArgComparison(t *testing.T) { func TestQueryExpectationArgComparison(t *testing.T) {
e := &queryBasedExpectation{} e := &queryBasedExpectation{}
against := []driver.Value{int64(5)} against := []namedValue{{Value: int64(5), Ordinal: 1}}
if err := e.argsMatches(against); err != nil { if err := e.argsMatches(against); err != nil {
t.Errorf("arguments should match, since the no expectation was set, but got err: %s", err) t.Errorf("arguments should match, since the no expectation was set, but got err: %s", err)
} }
e.args = []driver.Value{5, "str"} e.args = []driver.Value{5, "str"}
against = []driver.Value{int64(5)} against = []namedValue{{Value: int64(5), Ordinal: 1}}
if err := e.argsMatches(against); err == nil { if err := e.argsMatches(against); err == nil {
t.Error("arguments should not match, since the size is not the same") t.Error("arguments should not match, since the size is not the same")
} }
against = []driver.Value{int64(3), "str"} against = []namedValue{
{Value: int64(3), Ordinal: 1},
{Value: "str", Ordinal: 2},
}
if err := e.argsMatches(against); err == nil { if err := e.argsMatches(against); err == nil {
t.Error("arguments should not match, since the first argument (int value) is different") t.Error("arguments should not match, since the first argument (int value) is different")
} }
against = []driver.Value{int64(5), "st"} against = []namedValue{
{Value: int64(5), Ordinal: 1},
{Value: "st", Ordinal: 2},
}
if err := e.argsMatches(against); err == nil { if err := e.argsMatches(against); err == nil {
t.Error("arguments should not match, since the second argument (string value) is different") t.Error("arguments should not match, since the second argument (string value) is different")
} }
against = []driver.Value{int64(5), "str"} against = []namedValue{
{Value: int64(5), Ordinal: 1},
{Value: "str", Ordinal: 2},
}
if err := e.argsMatches(against); err != nil { if err := e.argsMatches(against); err != nil {
t.Errorf("arguments should match, but it did not: %s", err) t.Errorf("arguments should match, but it did not: %s", err)
} }
@ -41,7 +50,10 @@ func TestQueryExpectationArgComparison(t *testing.T) {
tm, _ := time.Parse(longForm, "Feb 3, 2013 at 7:54pm (PST)") tm, _ := time.Parse(longForm, "Feb 3, 2013 at 7:54pm (PST)")
e.args = []driver.Value{5, tm} e.args = []driver.Value{5, tm}
against = []driver.Value{int64(5), tm} against = []namedValue{
{Value: int64(5), Ordinal: 1},
{Value: tm, Ordinal: 2},
}
if err := e.argsMatches(against); err != nil { if err := e.argsMatches(against); err != nil {
t.Error("arguments should match, but it did not") t.Error("arguments should match, but it did not")
} }
@ -56,25 +68,33 @@ func TestQueryExpectationArgComparisonBool(t *testing.T) {
var e *queryBasedExpectation var e *queryBasedExpectation
e = &queryBasedExpectation{args: []driver.Value{true}} e = &queryBasedExpectation{args: []driver.Value{true}}
against := []driver.Value{true} against := []namedValue{
{Value: true, Ordinal: 1},
}
if err := e.argsMatches(against); err != nil { if err := e.argsMatches(against); err != nil {
t.Error("arguments should match, since arguments are the same") t.Error("arguments should match, since arguments are the same")
} }
e = &queryBasedExpectation{args: []driver.Value{false}} e = &queryBasedExpectation{args: []driver.Value{false}}
against = []driver.Value{false} against = []namedValue{
{Value: false, Ordinal: 1},
}
if err := e.argsMatches(against); err != nil { if err := e.argsMatches(against); err != nil {
t.Error("arguments should match, since argument are the same") t.Error("arguments should match, since argument are the same")
} }
e = &queryBasedExpectation{args: []driver.Value{true}} e = &queryBasedExpectation{args: []driver.Value{true}}
against = []driver.Value{false} against = []namedValue{
{Value: false, Ordinal: 1},
}
if err := e.argsMatches(against); err == nil { if err := e.argsMatches(against); err == nil {
t.Error("arguments should not match, since argument is different") t.Error("arguments should not match, since argument is different")
} }
e = &queryBasedExpectation{args: []driver.Value{false}} e = &queryBasedExpectation{args: []driver.Value{false}}
against = []driver.Value{true} against = []namedValue{
{Value: true, Ordinal: 1},
}
if err := e.argsMatches(against); err == nil { if err := e.argsMatches(against); err == nil {
t.Error("arguments should not match, since argument is different") t.Error("arguments should not match, since argument is different")
} }
@ -117,7 +137,7 @@ func TestBuildQuery(t *testing.T) {
name = 'John' name = 'John'
and and
address = 'Jakarta' address = 'Jakarta'
` `
mock.ExpectQuery(query) mock.ExpectQuery(query)

112
rows.go

@ -3,6 +3,7 @@ package sqlmock
import ( import (
"database/sql/driver" "database/sql/driver"
"encoding/csv" "encoding/csv"
"fmt"
"io" "io"
"strings" "strings"
) )
@ -18,57 +19,22 @@ var CSVColumnParser = func(s string) []byte {
return []byte(s) return []byte(s)
} }
// Rows interface allows to construct rows type rowSets struct {
// which also satisfies database/sql/driver.Rows interface sets []*Rows
type Rows interface { pos int
// composed interface, supports sql driver.Rows
driver.Rows
// AddRow composed from database driver.Value slice
// return the same instance to perform subsequent actions.
// Note that the number of values must match the number
// of columns
AddRow(columns ...driver.Value) Rows
// FromCSVString build rows from csv string.
// return the same instance to perform subsequent actions.
// Note that the number of values must match the number
// of columns
FromCSVString(s string) Rows
// RowError allows to set an error
// which will be returned when a given
// row number is read
RowError(row int, err error) Rows
// CloseError allows to set an error
// which will be returned by rows.Close
// function.
//
// The close error will be triggered only in cases
// when rows.Next() EOF was not yet reached, that is
// a default sql library behavior
CloseError(err error) Rows
} }
type rows struct { func (rs *rowSets) Columns() []string {
cols []string return rs.sets[rs.pos].cols
rows [][]driver.Value
pos int
nextErr map[int]error
closeErr error
} }
func (r *rows) Columns() []string { func (rs *rowSets) Close() error {
return r.cols return rs.sets[rs.pos].closeErr
}
func (r *rows) Close() error {
return r.closeErr
} }
// advances to next row // advances to next row
func (r *rows) Next(dest []driver.Value) error { func (rs *rowSets) Next(dest []driver.Value) error {
r := rs.sets[rs.pos]
r.pos++ r.pos++
if r.pos > len(r.rows) { if r.pos > len(r.rows) {
return io.EOF // per interface spec return io.EOF // per interface spec
@ -81,24 +47,66 @@ func (r *rows) Next(dest []driver.Value) error {
return r.nextErr[r.pos-1] return r.nextErr[r.pos-1]
} }
// transforms to debuggable printable string
func (rs *rowSets) String() string {
msg := "should return rows:\n"
if len(rs.sets) == 1 {
for n, row := range rs.sets[0].rows {
msg += fmt.Sprintf(" row %d - %+v\n", n, row)
}
return strings.TrimSpace(msg)
}
for i, set := range rs.sets {
msg += fmt.Sprintf(" result set: %d\n", i)
for n, row := range set.rows {
msg += fmt.Sprintf(" row %d - %+v\n", n, row)
}
}
return strings.TrimSpace(msg)
}
// Rows is a mocked collection of rows to
// return for Query result
type Rows struct {
cols []string
rows [][]driver.Value
pos int
nextErr map[int]error
closeErr error
}
// NewRows allows Rows to be created from a // NewRows allows Rows to be created from a
// sql driver.Value slice or from the CSV string and // sql driver.Value slice or from the CSV string and
// to be used as sql driver.Rows // to be used as sql driver.Rows
func NewRows(columns []string) Rows { func NewRows(columns []string) *Rows {
return &rows{cols: columns, nextErr: make(map[int]error)} return &Rows{cols: columns, nextErr: make(map[int]error)}
} }
func (r *rows) CloseError(err error) Rows { // CloseError allows to set an error
// which will be returned by rows.Close
// function.
//
// The close error will be triggered only in cases
// when rows.Next() EOF was not yet reached, that is
// a default sql library behavior
func (r *Rows) CloseError(err error) *Rows {
r.closeErr = err r.closeErr = err
return r return r
} }
func (r *rows) RowError(row int, err error) Rows { // RowError allows to set an error
// which will be returned when a given
// row number is read
func (r *Rows) RowError(row int, err error) *Rows {
r.nextErr[row] = err r.nextErr[row] = err
return r return r
} }
func (r *rows) AddRow(values ...driver.Value) Rows { // AddRow composed from database driver.Value slice
// return the same instance to perform subsequent actions.
// Note that the number of values must match the number
// of columns
func (r *Rows) AddRow(values ...driver.Value) *Rows {
if len(values) != len(r.cols) { if len(values) != len(r.cols) {
panic("Expected number of values to match number of columns") panic("Expected number of values to match number of columns")
} }
@ -112,7 +120,11 @@ func (r *rows) AddRow(values ...driver.Value) Rows {
return r return r
} }
func (r *rows) FromCSVString(s string) Rows { // FromCSVString build rows from csv string.
// return the same instance to perform subsequent actions.
// Note that the number of values must match the number
// of columns
func (r *Rows) FromCSVString(s string) *Rows {
res := strings.NewReader(strings.TrimSpace(s)) res := strings.NewReader(strings.TrimSpace(s))
csvReader := csv.NewReader(res) csvReader := csv.NewReader(res)

20
rows_go18.go Normal file

@ -0,0 +1,20 @@
// +build go1.8
package sqlmock
import "io"
// Implement the "RowsNextResultSet" interface
func (rs *rowSets) HasNextResultSet() bool {
return rs.pos+1 < len(rs.sets)
}
// Implement the "RowsNextResultSet" interface
func (rs *rowSets) NextResultSet() error {
if !rs.HasNextResultSet() {
return io.EOF
}
rs.pos++
return nil
}

92
rows_go18_test.go Normal file

@ -0,0 +1,92 @@
// +build go1.8
package sqlmock
import (
"fmt"
"testing"
)
func TestQueryMultiRows(t *testing.T) {
t.Parallel()
db, mock, err := New()
if err != nil {
t.Errorf("an error '%s' was not expected when opening a stub database connection", err)
}
defer db.Close()
rs1 := NewRows([]string{"id", "title"}).AddRow(5, "hello world")
rs2 := NewRows([]string{"name"}).AddRow("gopher").AddRow("john").AddRow("jane").RowError(2, fmt.Errorf("error"))
mock.ExpectQuery("SELECT (.+) FROM articles WHERE id = \\?;SELECT name FROM users").
WithArgs(5).
WillReturnRows(rs1, rs2)
rows, err := db.Query("SELECT id, title FROM articles WHERE id = ?;SELECT name FROM users", 5)
if err != nil {
t.Errorf("error was not expected, but got: %v", err)
}
defer rows.Close()
if !rows.Next() {
t.Error("expected a row to be available in first result set")
}
var id int
var name string
err = rows.Scan(&id, &name)
if err != nil {
t.Errorf("error was not expected, but got: %v", err)
}
if id != 5 || name != "hello world" {
t.Errorf("unexpected row values id: %v name: %v", id, name)
}
if rows.Next() {
t.Error("was not expecting next row in first result set")
}
if !rows.NextResultSet() {
t.Error("had to have next result set")
}
if !rows.Next() {
t.Error("expected a row to be available in second result set")
}
err = rows.Scan(&name)
if err != nil {
t.Errorf("error was not expected, but got: %v", err)
}
if name != "gopher" {
t.Errorf("unexpected row name: %v", name)
}
if !rows.Next() {
t.Error("expected a row to be available in second result set")
}
err = rows.Scan(&name)
if err != nil {
t.Errorf("error was not expected, but got: %v", err)
}
if name != "john" {
t.Errorf("unexpected row name: %v", name)
}
if rows.Next() {
t.Error("expected next row to produce error")
}
if rows.Err() == nil {
t.Error("expected an error, but there was none")
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("there were unfulfilled expections: %s", err)
}
}

@ -15,6 +15,7 @@ import (
"database/sql/driver" "database/sql/driver"
"fmt" "fmt"
"regexp" "regexp"
"time"
) )
// Sqlmock interface serves to create expectations // Sqlmock interface serves to create expectations
@ -66,6 +67,11 @@ type Sqlmock interface {
// By default it is set to - true. But if you use goroutines // By default it is set to - true. But if you use goroutines
// to parallelize your query executation, that option may // to parallelize your query executation, that option may
// be handy. // be handy.
//
// This option may be turned on anytime during tests. As soon
// as it is switched to false, expectations will be matched
// in any order. Or otherwise if switched to true, any unmatched
// expectations will be expected in order
MatchExpectationsInOrder(bool) MatchExpectationsInOrder(bool)
} }
@ -154,6 +160,16 @@ func (c *sqlmock) ExpectationsWereMet() error {
// Begin meets http://golang.org/pkg/database/sql/driver/#Conn interface // Begin meets http://golang.org/pkg/database/sql/driver/#Conn interface
func (c *sqlmock) Begin() (driver.Tx, error) { func (c *sqlmock) Begin() (driver.Tx, error) {
ex, err := c.begin()
if err != nil {
return nil, err
}
time.Sleep(ex.delay)
return c, nil
}
func (c *sqlmock) begin() (*ExpectedBegin, error) {
var expected *ExpectedBegin var expected *ExpectedBegin
var ok bool var ok bool
var fulfilled int var fulfilled int
@ -184,7 +200,8 @@ func (c *sqlmock) Begin() (driver.Tx, error) {
expected.triggered = true expected.triggered = true
expected.Unlock() expected.Unlock()
return c, expected.err
return expected, expected.err
} }
func (c *sqlmock) ExpectBegin() *ExpectedBegin { func (c *sqlmock) ExpectBegin() *ExpectedBegin {
@ -194,7 +211,25 @@ func (c *sqlmock) ExpectBegin() *ExpectedBegin {
} }
// Exec meets http://golang.org/pkg/database/sql/driver/#Execer // Exec meets http://golang.org/pkg/database/sql/driver/#Execer
func (c *sqlmock) Exec(query string, args []driver.Value) (res driver.Result, err error) { func (c *sqlmock) Exec(query string, args []driver.Value) (driver.Result, error) {
namedArgs := make([]namedValue, len(args))
for i, v := range args {
namedArgs[i] = namedValue{
Ordinal: i + 1,
Value: v,
}
}
ex, err := c.exec(query, namedArgs)
if err != nil {
return nil, err
}
time.Sleep(ex.delay)
return ex.result, nil
}
func (c *sqlmock) exec(query string, args []namedValue) (*ExpectedExec, error) {
query = stripQuery(query) query = stripQuery(query)
var expected *ExpectedExec var expected *ExpectedExec
var fulfilled int var fulfilled int
@ -229,7 +264,6 @@ func (c *sqlmock) Exec(query string, args []driver.Value) (res driver.Result, er
} }
return nil, fmt.Errorf(msg, query, args) return nil, fmt.Errorf(msg, query, args)
} }
defer expected.Unlock() defer expected.Unlock()
if !expected.queryMatches(query) { if !expected.queryMatches(query) {
@ -241,7 +275,6 @@ func (c *sqlmock) Exec(query string, args []driver.Value) (res driver.Result, er
} }
expected.triggered = true expected.triggered = true
if expected.err != nil { if expected.err != nil {
return nil, expected.err // mocked to return error return nil, expected.err // mocked to return error
} }
@ -250,7 +283,7 @@ func (c *sqlmock) Exec(query string, args []driver.Value) (res driver.Result, er
return nil, fmt.Errorf("exec query '%s' with args %+v, must return a database/sql/driver.result, but it was not set for expectation %T as %+v", query, args, expected, expected) return nil, fmt.Errorf("exec query '%s' with args %+v, must return a database/sql/driver.result, but it was not set for expectation %T as %+v", query, args, expected, expected)
} }
return expected.result, err return expected, nil
} }
func (c *sqlmock) ExpectExec(sqlRegexStr string) *ExpectedExec { func (c *sqlmock) ExpectExec(sqlRegexStr string) *ExpectedExec {
@ -263,6 +296,16 @@ func (c *sqlmock) ExpectExec(sqlRegexStr string) *ExpectedExec {
// Prepare meets http://golang.org/pkg/database/sql/driver/#Conn interface // Prepare meets http://golang.org/pkg/database/sql/driver/#Conn interface
func (c *sqlmock) Prepare(query string) (driver.Stmt, error) { func (c *sqlmock) Prepare(query string) (driver.Stmt, error) {
ex, err := c.prepare(query)
if err != nil {
return nil, err
}
time.Sleep(ex.delay)
return &statement{c, query, ex.closeErr}, nil
}
func (c *sqlmock) prepare(query string) (*ExpectedPrepare, error) {
var expected *ExpectedPrepare var expected *ExpectedPrepare
var fulfilled int var fulfilled int
var ok bool var ok bool
@ -298,7 +341,7 @@ func (c *sqlmock) Prepare(query string) (driver.Stmt, error) {
} }
expected.triggered = true expected.triggered = true
return &statement{c, query, expected.closeErr}, expected.err return expected, expected.err
} }
func (c *sqlmock) ExpectPrepare(sqlRegexStr string) *ExpectedPrepare { func (c *sqlmock) ExpectPrepare(sqlRegexStr string) *ExpectedPrepare {
@ -308,8 +351,32 @@ func (c *sqlmock) ExpectPrepare(sqlRegexStr string) *ExpectedPrepare {
return e return e
} }
type namedValue struct {
Name string
Ordinal int
Value driver.Value
}
// Query meets http://golang.org/pkg/database/sql/driver/#Queryer // Query meets http://golang.org/pkg/database/sql/driver/#Queryer
func (c *sqlmock) Query(query string, args []driver.Value) (rw driver.Rows, err error) { func (c *sqlmock) Query(query string, args []driver.Value) (driver.Rows, error) {
namedArgs := make([]namedValue, len(args))
for i, v := range args {
namedArgs[i] = namedValue{
Ordinal: i + 1,
Value: v,
}
}
ex, err := c.query(query, namedArgs)
if err != nil {
return nil, err
}
time.Sleep(ex.delay)
return ex.rows, nil
}
func (c *sqlmock) query(query string, args []namedValue) (*ExpectedQuery, error) {
query = stripQuery(query) query = stripQuery(query)
var expected *ExpectedQuery var expected *ExpectedQuery
var fulfilled int var fulfilled int
@ -357,7 +424,6 @@ func (c *sqlmock) Query(query string, args []driver.Value) (rw driver.Rows, err
} }
expected.triggered = true expected.triggered = true
if expected.err != nil { if expected.err != nil {
return nil, expected.err // mocked to return error return nil, expected.err // mocked to return error
} }
@ -365,8 +431,7 @@ func (c *sqlmock) Query(query string, args []driver.Value) (rw driver.Rows, err
if expected.rows == nil { if expected.rows == nil {
return nil, fmt.Errorf("query '%s' with args %+v, must return a database/sql/driver.rows, but it was not set for expectation %T as %+v", query, args, expected, expected) return nil, fmt.Errorf("query '%s' with args %+v, must return a database/sql/driver.rows, but it was not set for expectation %T as %+v", query, args, expected, expected)
} }
return expected, nil
return expected.rows, err
} }
func (c *sqlmock) ExpectQuery(sqlRegexStr string) *ExpectedQuery { func (c *sqlmock) ExpectQuery(sqlRegexStr string) *ExpectedQuery {

101
sqlmock_go18.go Normal file

@ -0,0 +1,101 @@
// +build go1.8
package sqlmock
import (
"context"
"database/sql/driver"
"errors"
"time"
)
var ErrCancelled = errors.New("canceling query due to user request")
// Implement the "QueryerContext" interface
func (c *sqlmock) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
namedArgs := make([]namedValue, len(args))
for i, nv := range args {
namedArgs[i] = namedValue(nv)
}
ex, err := c.query(query, namedArgs)
if err != nil {
return nil, err
}
select {
case <-time.After(ex.delay):
return ex.rows, nil
case <-ctx.Done():
return nil, ErrCancelled
}
}
// Implement the "ExecerContext" interface
func (c *sqlmock) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
namedArgs := make([]namedValue, len(args))
for i, nv := range args {
namedArgs[i] = namedValue(nv)
}
ex, err := c.exec(query, namedArgs)
if err != nil {
return nil, err
}
select {
case <-time.After(ex.delay):
return ex.result, nil
case <-ctx.Done():
return nil, ErrCancelled
}
}
// Implement the "ConnBeginTx" interface
func (c *sqlmock) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
ex, err := c.begin()
if err != nil {
return nil, err
}
select {
case <-time.After(ex.delay):
return c, nil
case <-ctx.Done():
return nil, ErrCancelled
}
}
// Implement the "ConnPrepareContext" interface
func (c *sqlmock) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
ex, err := c.prepare(query)
if err != nil {
return nil, err
}
select {
case <-time.After(ex.delay):
return &statement{c, query, ex.closeErr}, nil
case <-ctx.Done():
return nil, ErrCancelled
}
}
// Implement the "Pinger" interface
// for now we do not have a Ping expectation
// may be something for the future
func (c *sqlmock) Ping(ctx context.Context) error {
return nil
}
// Implement the "StmtExecContext" interface
func (stmt *statement) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
return stmt.conn.ExecContext(ctx, stmt.query, args)
}
// Implement the "StmtQueryContext" interface
func (stmt *statement) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
return stmt.conn.QueryContext(ctx, stmt.query, args)
}
// @TODO maybe add ExpectedBegin.WithOptions(driver.TxOptions)

426
sqlmock_go18_test.go Normal file

@ -0,0 +1,426 @@
// +build go1.8
package sqlmock
import (
"context"
"database/sql"
"testing"
"time"
)
func TestContextExecCancel(t *testing.T) {
t.Parallel()
db, mock, err := New()
if err != nil {
t.Errorf("an error '%s' was not expected when opening a stub database connection", err)
}
defer db.Close()
mock.ExpectExec("DELETE FROM users").
WillDelayFor(time.Second).
WillReturnResult(NewResult(1, 1))
ctx, cancel := context.WithCancel(context.Background())
go func() {
time.Sleep(time.Millisecond * 10)
cancel()
}()
_, err = db.ExecContext(ctx, "DELETE FROM users")
if err == nil {
t.Error("error was expected, but there was none")
}
if err != ErrCancelled {
t.Errorf("was expecting cancel error, but got: %v", err)
}
_, err = db.ExecContext(ctx, "DELETE FROM users")
if err != context.Canceled {
t.Error("error was expected since context was already done, but there was none")
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("there were unfulfilled expections: %s", err)
}
}
func TestPreparedStatementContextExecCancel(t *testing.T) {
t.Parallel()
db, mock, err := New()
if err != nil {
t.Errorf("an error '%s' was not expected when opening a stub database connection", err)
}
defer db.Close()
mock.ExpectPrepare("DELETE FROM users").
ExpectExec().
WillDelayFor(time.Second).
WillReturnResult(NewResult(1, 1))
ctx, cancel := context.WithCancel(context.Background())
go func() {
time.Sleep(time.Millisecond * 10)
cancel()
}()
stmt, err := db.Prepare("DELETE FROM users")
if err != nil {
t.Errorf("error was not expected, but got: %v", err)
}
_, err = stmt.ExecContext(ctx)
if err == nil {
t.Error("error was expected, but there was none")
}
if err != ErrCancelled {
t.Errorf("was expecting cancel error, but got: %v", err)
}
_, err = stmt.ExecContext(ctx)
if err != context.Canceled {
t.Error("error was expected since context was already done, but there was none")
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("there were unfulfilled expections: %s", err)
}
}
func TestContextExecWithNamedArg(t *testing.T) {
t.Parallel()
db, mock, err := New()
if err != nil {
t.Errorf("an error '%s' was not expected when opening a stub database connection", err)
}
defer db.Close()
mock.ExpectExec("DELETE FROM users").
WithArgs(sql.Named("id", 5)).
WillDelayFor(time.Second).
WillReturnResult(NewResult(1, 1))
ctx, cancel := context.WithCancel(context.Background())
go func() {
time.Sleep(time.Millisecond * 10)
cancel()
}()
_, err = db.ExecContext(ctx, "DELETE FROM users WHERE id = :id", sql.Named("id", 5))
if err == nil {
t.Error("error was expected, but there was none")
}
if err != ErrCancelled {
t.Errorf("was expecting cancel error, but got: %v", err)
}
_, err = db.ExecContext(ctx, "DELETE FROM users WHERE id = :id", sql.Named("id", 5))
if err != context.Canceled {
t.Error("error was expected since context was already done, but there was none")
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("there were unfulfilled expections: %s", err)
}
}
func TestContextExec(t *testing.T) {
t.Parallel()
db, mock, err := New()
if err != nil {
t.Errorf("an error '%s' was not expected when opening a stub database connection", err)
}
defer db.Close()
mock.ExpectExec("DELETE FROM users").
WillReturnResult(NewResult(1, 1))
ctx, cancel := context.WithCancel(context.Background())
go func() {
time.Sleep(time.Millisecond * 10)
cancel()
}()
res, err := db.ExecContext(ctx, "DELETE FROM users")
if err != nil {
t.Errorf("error was not expected, but got: %v", err)
}
affected, err := res.RowsAffected()
if affected != 1 {
t.Errorf("expected affected rows 1, but got %v", affected)
}
if err != nil {
t.Errorf("error was not expected, but got: %v", err)
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("there were unfulfilled expections: %s", err)
}
}
func TestContextQueryCancel(t *testing.T) {
t.Parallel()
db, mock, err := New()
if err != nil {
t.Errorf("an error '%s' was not expected when opening a stub database connection", err)
}
defer db.Close()
rs := NewRows([]string{"id", "title"}).AddRow(5, "hello world")
mock.ExpectQuery("SELECT (.+) FROM articles WHERE id = ?").
WithArgs(5).
WillDelayFor(time.Second).
WillReturnRows(rs)
ctx, cancel := context.WithCancel(context.Background())
go func() {
time.Sleep(time.Millisecond * 10)
cancel()
}()
_, err = db.QueryContext(ctx, "SELECT id, title FROM articles WHERE id = ?", 5)
if err == nil {
t.Error("error was expected, but there was none")
}
if err != ErrCancelled {
t.Errorf("was expecting cancel error, but got: %v", err)
}
_, err = db.QueryContext(ctx, "SELECT id, title FROM articles WHERE id = ?", 5)
if err != context.Canceled {
t.Error("error was expected since context was already done, but there was none")
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("there were unfulfilled expections: %s", err)
}
}
func TestPreparedStatementContextQueryCancel(t *testing.T) {
t.Parallel()
db, mock, err := New()
if err != nil {
t.Errorf("an error '%s' was not expected when opening a stub database connection", err)
}
defer db.Close()
rs := NewRows([]string{"id", "title"}).AddRow(5, "hello world")
mock.ExpectPrepare("SELECT (.+) FROM articles WHERE id = ?").
ExpectQuery().
WithArgs(5).
WillDelayFor(time.Second).
WillReturnRows(rs)
ctx, cancel := context.WithCancel(context.Background())
go func() {
time.Sleep(time.Millisecond * 10)
cancel()
}()
stmt, err := db.Prepare("SELECT id, title FROM articles WHERE id = ?")
if err != nil {
t.Errorf("error was not expected, but got: %v", err)
}
_, err = stmt.QueryContext(ctx, 5)
if err == nil {
t.Error("error was expected, but there was none")
}
if err != ErrCancelled {
t.Errorf("was expecting cancel error, but got: %v", err)
}
_, err = stmt.QueryContext(ctx, 5)
if err != context.Canceled {
t.Error("error was expected since context was already done, but there was none")
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("there were unfulfilled expections: %s", err)
}
}
func TestContextQuery(t *testing.T) {
t.Parallel()
db, mock, err := New()
if err != nil {
t.Errorf("an error '%s' was not expected when opening a stub database connection", err)
}
defer db.Close()
rs := NewRows([]string{"id", "title"}).AddRow(5, "hello world")
mock.ExpectQuery("SELECT (.+) FROM articles WHERE id =").
WithArgs(sql.Named("id", 5)).
WillDelayFor(time.Millisecond * 3).
WillReturnRows(rs)
ctx, cancel := context.WithCancel(context.Background())
go func() {
time.Sleep(time.Millisecond * 10)
cancel()
}()
rows, err := db.QueryContext(ctx, "SELECT id, title FROM articles WHERE id = :id", sql.Named("id", 5))
if err != nil {
t.Errorf("error was not expected, but got: %v", err)
}
if !rows.Next() {
t.Error("expected one row, but there was none")
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("there were unfulfilled expections: %s", err)
}
}
func TestContextBeginCancel(t *testing.T) {
t.Parallel()
db, mock, err := New()
if err != nil {
t.Errorf("an error '%s' was not expected when opening a stub database connection", err)
}
defer db.Close()
mock.ExpectBegin().WillDelayFor(time.Second)
ctx, cancel := context.WithCancel(context.Background())
go func() {
time.Sleep(time.Millisecond * 10)
cancel()
}()
_, err = db.BeginTx(ctx, nil)
if err == nil {
t.Error("error was expected, but there was none")
}
if err != ErrCancelled {
t.Errorf("was expecting cancel error, but got: %v", err)
}
_, err = db.BeginTx(ctx, nil)
if err != context.Canceled {
t.Error("error was expected since context was already done, but there was none")
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("there were unfulfilled expections: %s", err)
}
}
func TestContextBegin(t *testing.T) {
t.Parallel()
db, mock, err := New()
if err != nil {
t.Errorf("an error '%s' was not expected when opening a stub database connection", err)
}
defer db.Close()
mock.ExpectBegin().WillDelayFor(time.Millisecond * 3)
ctx, cancel := context.WithCancel(context.Background())
go func() {
time.Sleep(time.Millisecond * 10)
cancel()
}()
tx, err := db.BeginTx(ctx, nil)
if err != nil {
t.Errorf("error was not expected, but got: %v", err)
}
if tx == nil {
t.Error("expected tx, but there was nil")
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("there were unfulfilled expections: %s", err)
}
}
func TestContextPrepareCancel(t *testing.T) {
t.Parallel()
db, mock, err := New()
if err != nil {
t.Errorf("an error '%s' was not expected when opening a stub database connection", err)
}
defer db.Close()
mock.ExpectPrepare("SELECT").WillDelayFor(time.Second)
ctx, cancel := context.WithCancel(context.Background())
go func() {
time.Sleep(time.Millisecond * 10)
cancel()
}()
_, err = db.PrepareContext(ctx, "SELECT")
if err == nil {
t.Error("error was expected, but there was none")
}
if err != ErrCancelled {
t.Errorf("was expecting cancel error, but got: %v", err)
}
_, err = db.PrepareContext(ctx, "SELECT")
if err != context.Canceled {
t.Error("error was expected since context was already done, but there was none")
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("there were unfulfilled expections: %s", err)
}
}
func TestContextPrepare(t *testing.T) {
t.Parallel()
db, mock, err := New()
if err != nil {
t.Errorf("an error '%s' was not expected when opening a stub database connection", err)
}
defer db.Close()
mock.ExpectPrepare("SELECT").WillDelayFor(time.Millisecond * 3)
ctx, cancel := context.WithCancel(context.Background())
go func() {
time.Sleep(time.Millisecond * 10)
cancel()
}()
stmt, err := db.PrepareContext(ctx, "SELECT")
if err != nil {
t.Errorf("error was not expected, but got: %v", err)
}
if stmt == nil {
t.Error("expected stmt, but there was nil")
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("there were unfulfilled expections: %s", err)
}
}