1
0
mirror of https://github.com/DATA-DOG/go-sqlmock.git synced 2025-01-24 03:16:17 +02:00

allow unordered expectation matching, support for goroutines

* 1778939 take care of locks for goroutine based matching
This commit is contained in:
gedi 2015-08-26 14:28:01 +03:00
parent 5a740a6373
commit 566ca54083
4 changed files with 276 additions and 103 deletions

View File

@ -44,7 +44,7 @@ func New() (db *sql.DB, mock *Sqlmock, err error) {
dsn := fmt.Sprintf("sqlmock_db_%d", pool.counter) dsn := fmt.Sprintf("sqlmock_db_%d", pool.counter)
pool.counter++ pool.counter++
mock = &Sqlmock{dsn: dsn, drv: pool} mock = &Sqlmock{dsn: dsn, drv: pool, MatchExpectationsInOrder: true}
pool.conns[dsn] = mock pool.conns[dsn] = mock
pool.Unlock() pool.Unlock()

View File

@ -4,6 +4,7 @@ import (
"database/sql/driver" "database/sql/driver"
"reflect" "reflect"
"regexp" "regexp"
"sync"
) )
// Argument interface allows to match // Argument interface allows to match
@ -16,11 +17,14 @@ type Argument interface {
// an expectation interface // an expectation interface
type expectation interface { type expectation interface {
fulfilled() bool fulfilled() bool
Lock()
Unlock()
} }
// common expectation struct // common expectation struct
// satisfies the expectation interface // satisfies the expectation interface
type commonExpectation struct { type commonExpectation struct {
sync.Mutex
triggered bool triggered bool
err error err error
} }
@ -184,6 +188,19 @@ type queryBasedExpectation struct {
args []driver.Value args []driver.Value
} }
func (e *queryBasedExpectation) attemptMatch(sql string, args []driver.Value) (ret bool) {
if !e.queryMatches(sql) {
return
}
defer recover() // ignore panic since we attempt a match
if e.argsMatches(args) {
return true
}
return
}
func (e *queryBasedExpectation) queryMatches(sql string) bool { func (e *queryBasedExpectation) queryMatches(sql string) bool {
return e.sqlRegex.MatchString(sql) return e.sqlRegex.MatchString(sql)
} }

View File

@ -22,6 +22,15 @@ import (
// create expectations for any kind of database action // create expectations for any kind of database action
// in order to mock and test real database behavior. // in order to mock and test real database behavior.
type Sqlmock struct { type Sqlmock struct {
// MatchExpectationsInOrder gives an option whether to match all
// expectations in the order they were set or not.
//
// By default it is set to - true. But if you use goroutines
// to parallelize your query executation, that option may
// be handy.
MatchExpectationsInOrder bool
dsn string dsn string
opened int opened int
drv *mockDriver drv *mockDriver
@ -29,15 +38,6 @@ type Sqlmock struct {
expected []expectation expected []expectation
} }
func (c *Sqlmock) next() (e expectation) {
for _, e = range c.expected {
if !e.fulfilled() {
return
}
}
return nil // all expectations were fulfilled
}
// ExpectClose queues an expectation for this database // ExpectClose queues an expectation for this database
// action to be triggered. the *ExpectedClose allows // action to be triggered. the *ExpectedClose allows
// to mock database response // to mock database response
@ -59,17 +59,32 @@ func (c *Sqlmock) Close() error {
if c.opened == 0 { if c.opened == 0 {
delete(c.drv.conns, c.dsn) delete(c.drv.conns, c.dsn)
} }
e := c.next()
if e == nil { var expected *ExpectedClose
var ok bool
for _, next := range c.expected {
next.Lock()
if next.fulfilled() {
next.Unlock()
continue
}
if expected, ok = next.(*ExpectedClose); ok {
break
}
next.Unlock()
if c.MatchExpectationsInOrder {
return fmt.Errorf("call to database Close, was not expected, next expectation is %T as %+v", next, next)
}
}
if expected == nil {
return fmt.Errorf("all expectations were already fulfilled, call to database Close was not expected") return fmt.Errorf("all expectations were already fulfilled, call to database Close was not expected")
} }
t, ok := e.(*ExpectedClose) expected.triggered = true
if !ok { expected.Unlock()
return fmt.Errorf("call to database Close, was not expected, next expectation is %T as %+v", e, e) return expected.err
}
t.triggered = true
return t.err
} }
// ExpectationsWereMet checks whether all queued expectations // ExpectationsWereMet checks whether all queued expectations
@ -85,17 +100,31 @@ func (c *Sqlmock) ExpectationsWereMet() error {
// Begin meets http://golang.org/pkg/database/sql/driver/#Conn interface // Begin meets http://golang.org/pkg/database/sql/driver/#Conn interface
func (c *Sqlmock) Begin() (driver.Tx, error) { func (c *Sqlmock) Begin() (driver.Tx, error) {
e := c.next() var expected *ExpectedBegin
if e == nil { var ok bool
for _, next := range c.expected {
next.Lock()
if next.fulfilled() {
next.Unlock()
continue
}
if expected, ok = next.(*ExpectedBegin); ok {
break
}
next.Unlock()
if c.MatchExpectationsInOrder {
return nil, fmt.Errorf("call to begin transaction, was not expected, next expectation is %T as %+v", next, next)
}
}
if expected == nil {
return nil, fmt.Errorf("all expectations were already fulfilled, call to begin transaction was not expected") return nil, fmt.Errorf("all expectations were already fulfilled, call to begin transaction was not expected")
} }
t, ok := e.(*ExpectedBegin) expected.triggered = true
if !ok { expected.Unlock()
return nil, fmt.Errorf("call to begin transaction, was not expected, next expectation is %T as %+v", e, e) return c, expected.err
}
t.triggered = true
return c, t.err
} }
// ExpectBegin expects *sql.DB.Begin to be called. // ExpectBegin expects *sql.DB.Begin to be called.
@ -108,37 +137,65 @@ func (c *Sqlmock) ExpectBegin() *ExpectedBegin {
// Exec meets http://golang.org/pkg/database/sql/driver/#Execer // Exec meets http://golang.org/pkg/database/sql/driver/#Execer
func (c *Sqlmock) Exec(query string, args []driver.Value) (res driver.Result, err error) { func (c *Sqlmock) Exec(query string, args []driver.Value) (res driver.Result, err error) {
e := c.next()
query = stripQuery(query) query = stripQuery(query)
if e == nil { var expected *ExpectedExec
var ok bool
for _, next := range c.expected {
next.Lock()
if next.fulfilled() {
next.Unlock()
continue
}
if c.MatchExpectationsInOrder {
if expected, ok = next.(*ExpectedExec); ok {
break
}
next.Unlock()
return nil, fmt.Errorf("call to exec query '%s' with args %+v, was not expected, next expectation is %T as %+v", query, args, next, next)
}
if exec, ok := next.(*ExpectedExec); ok {
if exec.attemptMatch(query, args) {
expected = exec
break
}
}
next.Unlock()
}
if expected == nil {
return nil, fmt.Errorf("all expectations were already fulfilled, call to exec '%s' query with args %+v was not expected", query, args) return nil, fmt.Errorf("all expectations were already fulfilled, call to exec '%s' query with args %+v was not expected", query, args)
} }
t, ok := e.(*ExpectedExec) defer expected.Unlock()
if !ok { expected.triggered = true
return nil, fmt.Errorf("call to exec query '%s' with args %+v, was not expected, next expectation is %T as %+v", query, args, e, e) // converts panic to error in case of reflect value type mismatch
defer func(errp *error, exp *ExpectedExec, q string, a []driver.Value) {
if e := recover(); e != nil {
if se, ok := e.(*reflect.ValueError); ok { // catch reflect error, failed type conversion
msg := "exec query \"%s\", args \"%+v\" failed to match expected arguments \"%+v\", reason %s"
*errp = fmt.Errorf(msg, q, a, exp.args, se)
} else {
panic(e) // overwise if unknown error panic
}
}
}(&err, expected, query, args)
if !expected.queryMatches(query) {
return nil, fmt.Errorf("exec query '%s', does not match regex '%s'", query, expected.sqlRegex.String())
} }
t.triggered = true if !expected.argsMatches(args) {
if t.err != nil { return nil, fmt.Errorf("exec query '%s', args %+v does not match expected %+v", query, args, expected.args)
return nil, t.err // mocked to return error
} }
if t.result == nil { if expected.err != nil {
return nil, fmt.Errorf("exec query '%s' with args %+v, must return a database/sql/driver.result, but it was not set for expectation %T as %+v", query, args, t, t) return nil, expected.err // mocked to return error
} }
defer argMatcherErrorHandler(&err) // converts panic to error in case of reflect value type mismatch if expected.result == nil {
return nil, fmt.Errorf("exec query '%s' with args %+v, must return a database/sql/driver.result, but it was not set for expectation %T as %+v", query, args, expected, expected)
if !t.queryMatches(query) {
return nil, fmt.Errorf("exec query '%s', does not match regex '%s'", query, t.sqlRegex.String())
} }
return expected.result, err
if !t.argsMatches(args) {
return nil, fmt.Errorf("exec query '%s', args %+v does not match expected %+v", query, args, t.args)
}
return t.result, err
} }
// ExpectExec expects Exec() to be called with sql query // ExpectExec expects Exec() to be called with sql query
@ -153,23 +210,33 @@ func (c *Sqlmock) ExpectExec(sqlRegexStr string) *ExpectedExec {
// Prepare meets http://golang.org/pkg/database/sql/driver/#Conn interface // Prepare meets http://golang.org/pkg/database/sql/driver/#Conn interface
func (c *Sqlmock) Prepare(query string) (driver.Stmt, error) { func (c *Sqlmock) Prepare(query string) (driver.Stmt, error) {
e := c.next() var expected *ExpectedPrepare
var ok bool
for _, next := range c.expected {
next.Lock()
if next.fulfilled() {
next.Unlock()
continue
}
if expected, ok = next.(*ExpectedPrepare); ok {
break
}
next.Unlock()
if c.MatchExpectationsInOrder {
return nil, fmt.Errorf("call to Prepare stetement with query '%s', was not expected, next expectation is %T as %+v", query, next, next)
}
}
query = stripQuery(query) query = stripQuery(query)
if e == nil { if expected == nil {
return nil, fmt.Errorf("all expectations were already fulfilled, call to Prepare '%s' query was not expected", query) return nil, fmt.Errorf("all expectations were already fulfilled, call to Prepare '%s' query was not expected", query)
} }
t, ok := e.(*ExpectedPrepare)
if !ok {
return nil, fmt.Errorf("call to Prepare stetement with query '%s', was not expected, next expectation is %T as %+v", query, e, e)
}
t.triggered = true expected.triggered = true
if t.err != nil { expected.Unlock()
return nil, t.err // mocked to return error return &statement{c, query, expected.closeErr}, expected.err
}
return &statement{c, query, t.closeErr}, nil
} }
// ExpectPrepare expects Prepare() to be called with sql query // ExpectPrepare expects Prepare() to be called with sql query
@ -185,37 +252,66 @@ func (c *Sqlmock) ExpectPrepare(sqlRegexStr string) *ExpectedPrepare {
// Query meets http://golang.org/pkg/database/sql/driver/#Queryer // Query meets http://golang.org/pkg/database/sql/driver/#Queryer
func (c *Sqlmock) Query(query string, args []driver.Value) (rw driver.Rows, err error) { func (c *Sqlmock) Query(query string, args []driver.Value) (rw driver.Rows, err error) {
e := c.next()
query = stripQuery(query) query = stripQuery(query)
if e == nil { var expected *ExpectedQuery
var ok bool
for _, next := range c.expected {
next.Lock()
if next.fulfilled() {
next.Unlock()
continue
}
if c.MatchExpectationsInOrder {
if expected, ok = next.(*ExpectedQuery); ok {
break
}
next.Unlock()
return nil, fmt.Errorf("call to query '%s' with args %+v, was not expected, next expectation is %T as %+v", query, args, next, next)
}
if qr, ok := next.(*ExpectedQuery); ok {
if qr.attemptMatch(query, args) {
expected = qr
break
}
}
next.Unlock()
}
if expected == nil {
return nil, fmt.Errorf("all expectations were already fulfilled, call to query '%s' with args %+v was not expected", query, args) return nil, fmt.Errorf("all expectations were already fulfilled, call to query '%s' with args %+v was not expected", query, args)
} }
t, ok := e.(*ExpectedQuery) defer expected.Unlock()
if !ok { expected.triggered = true
return nil, fmt.Errorf("call to query '%s' with args %+v, was not expected, next expectation is %T as %+v", query, args, e, e) // converts panic to error in case of reflect value type mismatch
defer func(errp *error, exp *ExpectedQuery, q string, a []driver.Value) {
if e := recover(); e != nil {
if se, ok := e.(*reflect.ValueError); ok { // catch reflect error, failed type conversion
msg := "query \"%s\", args \"%+v\" failed to match expected arguments \"%+v\", reason %s"
*errp = fmt.Errorf(msg, q, a, exp.args, se)
} else {
panic(e) // overwise if unknown error panic
}
}
}(&err, expected, query, args)
if !expected.queryMatches(query) {
return nil, fmt.Errorf("query '%s', does not match regex [%s]", query, expected.sqlRegex.String())
} }
t.triggered = true if !expected.argsMatches(args) {
if t.err != nil { return nil, fmt.Errorf("query '%s', args %+v does not match expected %+v", query, args, expected.args)
return nil, t.err // mocked to return error
} }
if t.rows == nil { if expected.err != nil {
return nil, fmt.Errorf("query '%s' with args %+v, must return a database/sql/driver.rows, but it was not set for expectation %T as %+v", query, args, t, t) return nil, expected.err // mocked to return error
} }
defer argMatcherErrorHandler(&err) // converts panic to error in case of reflect value type mismatch if expected.rows == nil {
return nil, fmt.Errorf("query '%s' with args %+v, must return a database/sql/driver.rows, but it was not set for expectation %T as %+v", query, args, expected, expected)
if !t.queryMatches(query) {
return nil, fmt.Errorf("query '%s', does not match regex [%s]", query, t.sqlRegex.String())
} }
if !t.argsMatches(args) { return expected.rows, err
return nil, fmt.Errorf("query '%s', args %+v does not match expected %+v", query, args, t.args)
}
return t.rows, err
} }
// ExpectQuery expects Query() or QueryRow() to be called with sql query // ExpectQuery expects Query() or QueryRow() to be called with sql query
@ -246,40 +342,58 @@ func (c *Sqlmock) ExpectRollback() *ExpectedRollback {
// Commit meets http://golang.org/pkg/database/sql/driver/#Tx // Commit meets http://golang.org/pkg/database/sql/driver/#Tx
func (c *Sqlmock) Commit() error { func (c *Sqlmock) Commit() error {
e := c.next() var expected *ExpectedCommit
if e == nil { var ok bool
for _, next := range c.expected {
next.Lock()
if next.fulfilled() {
next.Unlock()
continue
}
if expected, ok = next.(*ExpectedCommit); ok {
break
}
next.Unlock()
if c.MatchExpectationsInOrder {
return fmt.Errorf("call to commit transaction, was not expected, next expectation is %T as %+v", next, next)
}
}
if expected == nil {
return fmt.Errorf("all expectations were already fulfilled, call to commit transaction was not expected") return fmt.Errorf("all expectations were already fulfilled, call to commit transaction was not expected")
} }
t, ok := e.(*ExpectedCommit) expected.triggered = true
if !ok { expected.Unlock()
return fmt.Errorf("call to commit transaction, was not expected, next expectation was %v", e) return expected.err
}
t.triggered = true
return t.err
} }
// Rollback meets http://golang.org/pkg/database/sql/driver/#Tx // Rollback meets http://golang.org/pkg/database/sql/driver/#Tx
func (c *Sqlmock) Rollback() error { func (c *Sqlmock) Rollback() error {
e := c.next() var expected *ExpectedRollback
if e == nil { var ok bool
for _, next := range c.expected {
next.Lock()
if next.fulfilled() {
next.Unlock()
continue
}
if expected, ok = next.(*ExpectedRollback); ok {
break
}
next.Unlock()
if c.MatchExpectationsInOrder {
return fmt.Errorf("call to rollback transaction, was not expected, next expectation is %T as %+v", next, next)
}
}
if expected == nil {
return fmt.Errorf("all expectations were already fulfilled, call to rollback transaction was not expected") return fmt.Errorf("all expectations were already fulfilled, call to rollback transaction was not expected")
} }
t, ok := e.(*ExpectedRollback) expected.triggered = true
if !ok { expected.Unlock()
return fmt.Errorf("call to rollback transaction, was not expected, next expectation was %v", e) return expected.err
}
t.triggered = true
return t.err
}
func argMatcherErrorHandler(errp *error) {
if e := recover(); e != nil {
if se, ok := e.(*reflect.ValueError); ok { // catch reflect error, failed type conversion
*errp = fmt.Errorf("Failed to compare query arguments: %s", se)
} else {
panic(e) // overwise panic
}
}
} }

View File

@ -3,6 +3,7 @@ package sqlmock
import ( import (
"database/sql" "database/sql"
"fmt" "fmt"
"sync"
"testing" "testing"
"time" "time"
) )
@ -575,3 +576,44 @@ func TestArgumentReflectValueTypeError(t *testing.T) {
t.Error("Expected error, but got none") t.Error("Expected error, but got none")
} }
} }
func TestGoroutineExecutionWithUnorderedExpectationMatching(t *testing.T) {
t.Parallel()
db, mock, err := New()
if err != nil {
t.Errorf("an error '%s' was not expected when opening a stub database connection", err)
}
defer db.Close()
// note this line is important for unordered expectation matching
mock.MatchExpectationsInOrder = false
result := NewResult(1, 1)
mock.ExpectExec("^UPDATE one").WithArgs("one").WillReturnResult(result)
mock.ExpectExec("^UPDATE two").WithArgs("one", "two").WillReturnResult(result)
mock.ExpectExec("^UPDATE three").WithArgs("one", "two", "three").WillReturnResult(result)
var wg sync.WaitGroup
queries := map[string][]interface{}{
"one": []interface{}{"one"},
"two": []interface{}{"one", "two"},
"three": []interface{}{"one", "two", "three"},
}
wg.Add(len(queries))
for table, args := range queries {
go func(tbl string, a []interface{}) {
if _, err := db.Exec("UPDATE "+tbl, a...); err != nil {
t.Errorf("error was not expected: %s", err)
}
wg.Done()
}(table, args)
}
wg.Wait()
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("there were unfulfilled expections: %s", err)
}
}