mirror of
https://github.com/DATA-DOG/go-sqlmock.git
synced 2025-01-07 23:01:44 +02:00
allow unordered expectation matching, support for goroutines
* 1778939 take care of locks for goroutine based matching
This commit is contained in:
parent
5a740a6373
commit
566ca54083
@ -44,7 +44,7 @@ func New() (db *sql.DB, mock *Sqlmock, err error) {
|
||||
dsn := fmt.Sprintf("sqlmock_db_%d", pool.counter)
|
||||
pool.counter++
|
||||
|
||||
mock = &Sqlmock{dsn: dsn, drv: pool}
|
||||
mock = &Sqlmock{dsn: dsn, drv: pool, MatchExpectationsInOrder: true}
|
||||
pool.conns[dsn] = mock
|
||||
pool.Unlock()
|
||||
|
||||
|
@ -4,6 +4,7 @@ import (
|
||||
"database/sql/driver"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// Argument interface allows to match
|
||||
@ -16,11 +17,14 @@ type Argument interface {
|
||||
// an expectation interface
|
||||
type expectation interface {
|
||||
fulfilled() bool
|
||||
Lock()
|
||||
Unlock()
|
||||
}
|
||||
|
||||
// common expectation struct
|
||||
// satisfies the expectation interface
|
||||
type commonExpectation struct {
|
||||
sync.Mutex
|
||||
triggered bool
|
||||
err error
|
||||
}
|
||||
@ -184,6 +188,19 @@ type queryBasedExpectation struct {
|
||||
args []driver.Value
|
||||
}
|
||||
|
||||
func (e *queryBasedExpectation) attemptMatch(sql string, args []driver.Value) (ret bool) {
|
||||
if !e.queryMatches(sql) {
|
||||
return
|
||||
}
|
||||
|
||||
defer recover() // ignore panic since we attempt a match
|
||||
|
||||
if e.argsMatches(args) {
|
||||
return true
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (e *queryBasedExpectation) queryMatches(sql string) bool {
|
||||
return e.sqlRegex.MatchString(sql)
|
||||
}
|
||||
|
318
sqlmock.go
318
sqlmock.go
@ -22,6 +22,15 @@ import (
|
||||
// create expectations for any kind of database action
|
||||
// in order to mock and test real database behavior.
|
||||
type Sqlmock struct {
|
||||
|
||||
// MatchExpectationsInOrder gives an option whether to match all
|
||||
// expectations in the order they were set or not.
|
||||
//
|
||||
// By default it is set to - true. But if you use goroutines
|
||||
// to parallelize your query executation, that option may
|
||||
// be handy.
|
||||
MatchExpectationsInOrder bool
|
||||
|
||||
dsn string
|
||||
opened int
|
||||
drv *mockDriver
|
||||
@ -29,15 +38,6 @@ type Sqlmock struct {
|
||||
expected []expectation
|
||||
}
|
||||
|
||||
func (c *Sqlmock) next() (e expectation) {
|
||||
for _, e = range c.expected {
|
||||
if !e.fulfilled() {
|
||||
return
|
||||
}
|
||||
}
|
||||
return nil // all expectations were fulfilled
|
||||
}
|
||||
|
||||
// ExpectClose queues an expectation for this database
|
||||
// action to be triggered. the *ExpectedClose allows
|
||||
// to mock database response
|
||||
@ -59,17 +59,32 @@ func (c *Sqlmock) Close() error {
|
||||
if c.opened == 0 {
|
||||
delete(c.drv.conns, c.dsn)
|
||||
}
|
||||
e := c.next()
|
||||
if e == nil {
|
||||
|
||||
var expected *ExpectedClose
|
||||
var ok bool
|
||||
for _, next := range c.expected {
|
||||
next.Lock()
|
||||
if next.fulfilled() {
|
||||
next.Unlock()
|
||||
continue
|
||||
}
|
||||
|
||||
if expected, ok = next.(*ExpectedClose); ok {
|
||||
break
|
||||
}
|
||||
|
||||
next.Unlock()
|
||||
if c.MatchExpectationsInOrder {
|
||||
return fmt.Errorf("call to database Close, was not expected, next expectation is %T as %+v", next, next)
|
||||
}
|
||||
}
|
||||
if expected == nil {
|
||||
return fmt.Errorf("all expectations were already fulfilled, call to database Close was not expected")
|
||||
}
|
||||
|
||||
t, ok := e.(*ExpectedClose)
|
||||
if !ok {
|
||||
return fmt.Errorf("call to database Close, was not expected, next expectation is %T as %+v", e, e)
|
||||
}
|
||||
t.triggered = true
|
||||
return t.err
|
||||
expected.triggered = true
|
||||
expected.Unlock()
|
||||
return expected.err
|
||||
}
|
||||
|
||||
// ExpectationsWereMet checks whether all queued expectations
|
||||
@ -85,17 +100,31 @@ func (c *Sqlmock) ExpectationsWereMet() error {
|
||||
|
||||
// Begin meets http://golang.org/pkg/database/sql/driver/#Conn interface
|
||||
func (c *Sqlmock) Begin() (driver.Tx, error) {
|
||||
e := c.next()
|
||||
if e == nil {
|
||||
var expected *ExpectedBegin
|
||||
var ok bool
|
||||
for _, next := range c.expected {
|
||||
next.Lock()
|
||||
if next.fulfilled() {
|
||||
next.Unlock()
|
||||
continue
|
||||
}
|
||||
|
||||
if expected, ok = next.(*ExpectedBegin); ok {
|
||||
break
|
||||
}
|
||||
|
||||
next.Unlock()
|
||||
if c.MatchExpectationsInOrder {
|
||||
return nil, fmt.Errorf("call to begin transaction, was not expected, next expectation is %T as %+v", next, next)
|
||||
}
|
||||
}
|
||||
if expected == nil {
|
||||
return nil, fmt.Errorf("all expectations were already fulfilled, call to begin transaction was not expected")
|
||||
}
|
||||
|
||||
t, ok := e.(*ExpectedBegin)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("call to begin transaction, was not expected, next expectation is %T as %+v", e, e)
|
||||
}
|
||||
t.triggered = true
|
||||
return c, t.err
|
||||
expected.triggered = true
|
||||
expected.Unlock()
|
||||
return c, expected.err
|
||||
}
|
||||
|
||||
// ExpectBegin expects *sql.DB.Begin to be called.
|
||||
@ -108,37 +137,65 @@ func (c *Sqlmock) ExpectBegin() *ExpectedBegin {
|
||||
|
||||
// Exec meets http://golang.org/pkg/database/sql/driver/#Execer
|
||||
func (c *Sqlmock) Exec(query string, args []driver.Value) (res driver.Result, err error) {
|
||||
e := c.next()
|
||||
query = stripQuery(query)
|
||||
if e == nil {
|
||||
var expected *ExpectedExec
|
||||
var ok bool
|
||||
for _, next := range c.expected {
|
||||
next.Lock()
|
||||
if next.fulfilled() {
|
||||
next.Unlock()
|
||||
continue
|
||||
}
|
||||
|
||||
if c.MatchExpectationsInOrder {
|
||||
if expected, ok = next.(*ExpectedExec); ok {
|
||||
break
|
||||
}
|
||||
next.Unlock()
|
||||
return nil, fmt.Errorf("call to exec query '%s' with args %+v, was not expected, next expectation is %T as %+v", query, args, next, next)
|
||||
}
|
||||
if exec, ok := next.(*ExpectedExec); ok {
|
||||
if exec.attemptMatch(query, args) {
|
||||
expected = exec
|
||||
break
|
||||
}
|
||||
}
|
||||
next.Unlock()
|
||||
}
|
||||
if expected == nil {
|
||||
return nil, fmt.Errorf("all expectations were already fulfilled, call to exec '%s' query with args %+v was not expected", query, args)
|
||||
}
|
||||
|
||||
t, ok := e.(*ExpectedExec)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("call to exec query '%s' with args %+v, was not expected, next expectation is %T as %+v", query, args, e, e)
|
||||
defer expected.Unlock()
|
||||
expected.triggered = true
|
||||
// converts panic to error in case of reflect value type mismatch
|
||||
defer func(errp *error, exp *ExpectedExec, q string, a []driver.Value) {
|
||||
if e := recover(); e != nil {
|
||||
if se, ok := e.(*reflect.ValueError); ok { // catch reflect error, failed type conversion
|
||||
msg := "exec query \"%s\", args \"%+v\" failed to match expected arguments \"%+v\", reason %s"
|
||||
*errp = fmt.Errorf(msg, q, a, exp.args, se)
|
||||
} else {
|
||||
panic(e) // overwise if unknown error panic
|
||||
}
|
||||
}
|
||||
}(&err, expected, query, args)
|
||||
|
||||
if !expected.queryMatches(query) {
|
||||
return nil, fmt.Errorf("exec query '%s', does not match regex '%s'", query, expected.sqlRegex.String())
|
||||
}
|
||||
|
||||
t.triggered = true
|
||||
if t.err != nil {
|
||||
return nil, t.err // mocked to return error
|
||||
if !expected.argsMatches(args) {
|
||||
return nil, fmt.Errorf("exec query '%s', args %+v does not match expected %+v", query, args, expected.args)
|
||||
}
|
||||
|
||||
if t.result == nil {
|
||||
return nil, fmt.Errorf("exec query '%s' with args %+v, must return a database/sql/driver.result, but it was not set for expectation %T as %+v", query, args, t, t)
|
||||
if expected.err != nil {
|
||||
return nil, expected.err // mocked to return error
|
||||
}
|
||||
|
||||
defer argMatcherErrorHandler(&err) // converts panic to error in case of reflect value type mismatch
|
||||
|
||||
if !t.queryMatches(query) {
|
||||
return nil, fmt.Errorf("exec query '%s', does not match regex '%s'", query, t.sqlRegex.String())
|
||||
if expected.result == nil {
|
||||
return nil, fmt.Errorf("exec query '%s' with args %+v, must return a database/sql/driver.result, but it was not set for expectation %T as %+v", query, args, expected, expected)
|
||||
}
|
||||
|
||||
if !t.argsMatches(args) {
|
||||
return nil, fmt.Errorf("exec query '%s', args %+v does not match expected %+v", query, args, t.args)
|
||||
}
|
||||
|
||||
return t.result, err
|
||||
return expected.result, err
|
||||
}
|
||||
|
||||
// ExpectExec expects Exec() to be called with sql query
|
||||
@ -153,23 +210,33 @@ func (c *Sqlmock) ExpectExec(sqlRegexStr string) *ExpectedExec {
|
||||
|
||||
// Prepare meets http://golang.org/pkg/database/sql/driver/#Conn interface
|
||||
func (c *Sqlmock) Prepare(query string) (driver.Stmt, error) {
|
||||
e := c.next()
|
||||
var expected *ExpectedPrepare
|
||||
var ok bool
|
||||
for _, next := range c.expected {
|
||||
next.Lock()
|
||||
if next.fulfilled() {
|
||||
next.Unlock()
|
||||
continue
|
||||
}
|
||||
|
||||
if expected, ok = next.(*ExpectedPrepare); ok {
|
||||
break
|
||||
}
|
||||
|
||||
next.Unlock()
|
||||
if c.MatchExpectationsInOrder {
|
||||
return nil, fmt.Errorf("call to Prepare stetement with query '%s', was not expected, next expectation is %T as %+v", query, next, next)
|
||||
}
|
||||
}
|
||||
|
||||
query = stripQuery(query)
|
||||
if e == nil {
|
||||
if expected == nil {
|
||||
return nil, fmt.Errorf("all expectations were already fulfilled, call to Prepare '%s' query was not expected", query)
|
||||
}
|
||||
t, ok := e.(*ExpectedPrepare)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("call to Prepare stetement with query '%s', was not expected, next expectation is %T as %+v", query, e, e)
|
||||
}
|
||||
|
||||
t.triggered = true
|
||||
if t.err != nil {
|
||||
return nil, t.err // mocked to return error
|
||||
}
|
||||
|
||||
return &statement{c, query, t.closeErr}, nil
|
||||
expected.triggered = true
|
||||
expected.Unlock()
|
||||
return &statement{c, query, expected.closeErr}, expected.err
|
||||
}
|
||||
|
||||
// ExpectPrepare expects Prepare() to be called with sql query
|
||||
@ -185,37 +252,66 @@ func (c *Sqlmock) ExpectPrepare(sqlRegexStr string) *ExpectedPrepare {
|
||||
|
||||
// Query meets http://golang.org/pkg/database/sql/driver/#Queryer
|
||||
func (c *Sqlmock) Query(query string, args []driver.Value) (rw driver.Rows, err error) {
|
||||
e := c.next()
|
||||
query = stripQuery(query)
|
||||
if e == nil {
|
||||
var expected *ExpectedQuery
|
||||
var ok bool
|
||||
for _, next := range c.expected {
|
||||
next.Lock()
|
||||
if next.fulfilled() {
|
||||
next.Unlock()
|
||||
continue
|
||||
}
|
||||
|
||||
if c.MatchExpectationsInOrder {
|
||||
if expected, ok = next.(*ExpectedQuery); ok {
|
||||
break
|
||||
}
|
||||
next.Unlock()
|
||||
return nil, fmt.Errorf("call to query '%s' with args %+v, was not expected, next expectation is %T as %+v", query, args, next, next)
|
||||
}
|
||||
if qr, ok := next.(*ExpectedQuery); ok {
|
||||
if qr.attemptMatch(query, args) {
|
||||
expected = qr
|
||||
break
|
||||
}
|
||||
}
|
||||
next.Unlock()
|
||||
}
|
||||
if expected == nil {
|
||||
return nil, fmt.Errorf("all expectations were already fulfilled, call to query '%s' with args %+v was not expected", query, args)
|
||||
}
|
||||
|
||||
t, ok := e.(*ExpectedQuery)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("call to query '%s' with args %+v, was not expected, next expectation is %T as %+v", query, args, e, e)
|
||||
defer expected.Unlock()
|
||||
expected.triggered = true
|
||||
// converts panic to error in case of reflect value type mismatch
|
||||
defer func(errp *error, exp *ExpectedQuery, q string, a []driver.Value) {
|
||||
if e := recover(); e != nil {
|
||||
if se, ok := e.(*reflect.ValueError); ok { // catch reflect error, failed type conversion
|
||||
msg := "query \"%s\", args \"%+v\" failed to match expected arguments \"%+v\", reason %s"
|
||||
*errp = fmt.Errorf(msg, q, a, exp.args, se)
|
||||
} else {
|
||||
panic(e) // overwise if unknown error panic
|
||||
}
|
||||
}
|
||||
}(&err, expected, query, args)
|
||||
|
||||
if !expected.queryMatches(query) {
|
||||
return nil, fmt.Errorf("query '%s', does not match regex [%s]", query, expected.sqlRegex.String())
|
||||
}
|
||||
|
||||
t.triggered = true
|
||||
if t.err != nil {
|
||||
return nil, t.err // mocked to return error
|
||||
if !expected.argsMatches(args) {
|
||||
return nil, fmt.Errorf("query '%s', args %+v does not match expected %+v", query, args, expected.args)
|
||||
}
|
||||
|
||||
if t.rows == nil {
|
||||
return nil, fmt.Errorf("query '%s' with args %+v, must return a database/sql/driver.rows, but it was not set for expectation %T as %+v", query, args, t, t)
|
||||
if expected.err != nil {
|
||||
return nil, expected.err // mocked to return error
|
||||
}
|
||||
|
||||
defer argMatcherErrorHandler(&err) // converts panic to error in case of reflect value type mismatch
|
||||
|
||||
if !t.queryMatches(query) {
|
||||
return nil, fmt.Errorf("query '%s', does not match regex [%s]", query, t.sqlRegex.String())
|
||||
if expected.rows == nil {
|
||||
return nil, fmt.Errorf("query '%s' with args %+v, must return a database/sql/driver.rows, but it was not set for expectation %T as %+v", query, args, expected, expected)
|
||||
}
|
||||
|
||||
if !t.argsMatches(args) {
|
||||
return nil, fmt.Errorf("query '%s', args %+v does not match expected %+v", query, args, t.args)
|
||||
}
|
||||
|
||||
return t.rows, err
|
||||
return expected.rows, err
|
||||
}
|
||||
|
||||
// ExpectQuery expects Query() or QueryRow() to be called with sql query
|
||||
@ -246,40 +342,58 @@ func (c *Sqlmock) ExpectRollback() *ExpectedRollback {
|
||||
|
||||
// Commit meets http://golang.org/pkg/database/sql/driver/#Tx
|
||||
func (c *Sqlmock) Commit() error {
|
||||
e := c.next()
|
||||
if e == nil {
|
||||
var expected *ExpectedCommit
|
||||
var ok bool
|
||||
for _, next := range c.expected {
|
||||
next.Lock()
|
||||
if next.fulfilled() {
|
||||
next.Unlock()
|
||||
continue
|
||||
}
|
||||
|
||||
if expected, ok = next.(*ExpectedCommit); ok {
|
||||
break
|
||||
}
|
||||
|
||||
next.Unlock()
|
||||
if c.MatchExpectationsInOrder {
|
||||
return fmt.Errorf("call to commit transaction, was not expected, next expectation is %T as %+v", next, next)
|
||||
}
|
||||
}
|
||||
if expected == nil {
|
||||
return fmt.Errorf("all expectations were already fulfilled, call to commit transaction was not expected")
|
||||
}
|
||||
|
||||
t, ok := e.(*ExpectedCommit)
|
||||
if !ok {
|
||||
return fmt.Errorf("call to commit transaction, was not expected, next expectation was %v", e)
|
||||
}
|
||||
t.triggered = true
|
||||
return t.err
|
||||
expected.triggered = true
|
||||
expected.Unlock()
|
||||
return expected.err
|
||||
}
|
||||
|
||||
// Rollback meets http://golang.org/pkg/database/sql/driver/#Tx
|
||||
func (c *Sqlmock) Rollback() error {
|
||||
e := c.next()
|
||||
if e == nil {
|
||||
var expected *ExpectedRollback
|
||||
var ok bool
|
||||
for _, next := range c.expected {
|
||||
next.Lock()
|
||||
if next.fulfilled() {
|
||||
next.Unlock()
|
||||
continue
|
||||
}
|
||||
|
||||
if expected, ok = next.(*ExpectedRollback); ok {
|
||||
break
|
||||
}
|
||||
|
||||
next.Unlock()
|
||||
if c.MatchExpectationsInOrder {
|
||||
return fmt.Errorf("call to rollback transaction, was not expected, next expectation is %T as %+v", next, next)
|
||||
}
|
||||
}
|
||||
if expected == nil {
|
||||
return fmt.Errorf("all expectations were already fulfilled, call to rollback transaction was not expected")
|
||||
}
|
||||
|
||||
t, ok := e.(*ExpectedRollback)
|
||||
if !ok {
|
||||
return fmt.Errorf("call to rollback transaction, was not expected, next expectation was %v", e)
|
||||
}
|
||||
t.triggered = true
|
||||
return t.err
|
||||
}
|
||||
|
||||
func argMatcherErrorHandler(errp *error) {
|
||||
if e := recover(); e != nil {
|
||||
if se, ok := e.(*reflect.ValueError); ok { // catch reflect error, failed type conversion
|
||||
*errp = fmt.Errorf("Failed to compare query arguments: %s", se)
|
||||
} else {
|
||||
panic(e) // overwise panic
|
||||
}
|
||||
}
|
||||
expected.triggered = true
|
||||
expected.Unlock()
|
||||
return expected.err
|
||||
}
|
||||
|
@ -3,6 +3,7 @@ package sqlmock
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
@ -575,3 +576,44 @@ func TestArgumentReflectValueTypeError(t *testing.T) {
|
||||
t.Error("Expected error, but got none")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGoroutineExecutionWithUnorderedExpectationMatching(t *testing.T) {
|
||||
t.Parallel()
|
||||
db, mock, err := New()
|
||||
if err != nil {
|
||||
t.Errorf("an error '%s' was not expected when opening a stub database connection", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
// note this line is important for unordered expectation matching
|
||||
mock.MatchExpectationsInOrder = false
|
||||
|
||||
result := NewResult(1, 1)
|
||||
|
||||
mock.ExpectExec("^UPDATE one").WithArgs("one").WillReturnResult(result)
|
||||
mock.ExpectExec("^UPDATE two").WithArgs("one", "two").WillReturnResult(result)
|
||||
mock.ExpectExec("^UPDATE three").WithArgs("one", "two", "three").WillReturnResult(result)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
queries := map[string][]interface{}{
|
||||
"one": []interface{}{"one"},
|
||||
"two": []interface{}{"one", "two"},
|
||||
"three": []interface{}{"one", "two", "three"},
|
||||
}
|
||||
|
||||
wg.Add(len(queries))
|
||||
for table, args := range queries {
|
||||
go func(tbl string, a []interface{}) {
|
||||
if _, err := db.Exec("UPDATE "+tbl, a...); err != nil {
|
||||
t.Errorf("error was not expected: %s", err)
|
||||
}
|
||||
wg.Done()
|
||||
}(table, args)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("there were unfulfilled expections: %s", err)
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user