diff --git a/driver.go b/driver.go index 105a2af..050aeef 100644 --- a/driver.go +++ b/driver.go @@ -44,7 +44,7 @@ func New() (db *sql.DB, mock *Sqlmock, err error) { dsn := fmt.Sprintf("sqlmock_db_%d", pool.counter) pool.counter++ - mock = &Sqlmock{dsn: dsn, drv: pool} + mock = &Sqlmock{dsn: dsn, drv: pool, MatchExpectationsInOrder: true} pool.conns[dsn] = mock pool.Unlock() diff --git a/expectations.go b/expectations.go index 909c499..e4f2a2c 100644 --- a/expectations.go +++ b/expectations.go @@ -4,6 +4,7 @@ import ( "database/sql/driver" "reflect" "regexp" + "sync" ) // Argument interface allows to match @@ -16,11 +17,14 @@ type Argument interface { // an expectation interface type expectation interface { fulfilled() bool + Lock() + Unlock() } // common expectation struct // satisfies the expectation interface type commonExpectation struct { + sync.Mutex triggered bool err error } @@ -184,6 +188,19 @@ type queryBasedExpectation struct { args []driver.Value } +func (e *queryBasedExpectation) attemptMatch(sql string, args []driver.Value) (ret bool) { + if !e.queryMatches(sql) { + return + } + + defer recover() // ignore panic since we attempt a match + + if e.argsMatches(args) { + return true + } + return +} + func (e *queryBasedExpectation) queryMatches(sql string) bool { return e.sqlRegex.MatchString(sql) } diff --git a/sqlmock.go b/sqlmock.go index 4a44673..8490590 100644 --- a/sqlmock.go +++ b/sqlmock.go @@ -22,6 +22,15 @@ import ( // create expectations for any kind of database action // in order to mock and test real database behavior. type Sqlmock struct { + + // MatchExpectationsInOrder gives an option whether to match all + // expectations in the order they were set or not. + // + // By default it is set to - true. But if you use goroutines + // to parallelize your query executation, that option may + // be handy. + MatchExpectationsInOrder bool + dsn string opened int drv *mockDriver @@ -29,15 +38,6 @@ type Sqlmock struct { expected []expectation } -func (c *Sqlmock) next() (e expectation) { - for _, e = range c.expected { - if !e.fulfilled() { - return - } - } - return nil // all expectations were fulfilled -} - // ExpectClose queues an expectation for this database // action to be triggered. the *ExpectedClose allows // to mock database response @@ -59,17 +59,32 @@ func (c *Sqlmock) Close() error { if c.opened == 0 { delete(c.drv.conns, c.dsn) } - e := c.next() - if e == nil { + + var expected *ExpectedClose + var ok bool + for _, next := range c.expected { + next.Lock() + if next.fulfilled() { + next.Unlock() + continue + } + + if expected, ok = next.(*ExpectedClose); ok { + break + } + + next.Unlock() + if c.MatchExpectationsInOrder { + return fmt.Errorf("call to database Close, was not expected, next expectation is %T as %+v", next, next) + } + } + if expected == nil { return fmt.Errorf("all expectations were already fulfilled, call to database Close was not expected") } - t, ok := e.(*ExpectedClose) - if !ok { - return fmt.Errorf("call to database Close, was not expected, next expectation is %T as %+v", e, e) - } - t.triggered = true - return t.err + expected.triggered = true + expected.Unlock() + return expected.err } // ExpectationsWereMet checks whether all queued expectations @@ -85,17 +100,31 @@ func (c *Sqlmock) ExpectationsWereMet() error { // Begin meets http://golang.org/pkg/database/sql/driver/#Conn interface func (c *Sqlmock) Begin() (driver.Tx, error) { - e := c.next() - if e == nil { + var expected *ExpectedBegin + var ok bool + for _, next := range c.expected { + next.Lock() + if next.fulfilled() { + next.Unlock() + continue + } + + if expected, ok = next.(*ExpectedBegin); ok { + break + } + + next.Unlock() + if c.MatchExpectationsInOrder { + return nil, fmt.Errorf("call to begin transaction, was not expected, next expectation is %T as %+v", next, next) + } + } + if expected == nil { return nil, fmt.Errorf("all expectations were already fulfilled, call to begin transaction was not expected") } - t, ok := e.(*ExpectedBegin) - if !ok { - return nil, fmt.Errorf("call to begin transaction, was not expected, next expectation is %T as %+v", e, e) - } - t.triggered = true - return c, t.err + expected.triggered = true + expected.Unlock() + return c, expected.err } // ExpectBegin expects *sql.DB.Begin to be called. @@ -108,37 +137,65 @@ func (c *Sqlmock) ExpectBegin() *ExpectedBegin { // Exec meets http://golang.org/pkg/database/sql/driver/#Execer func (c *Sqlmock) Exec(query string, args []driver.Value) (res driver.Result, err error) { - e := c.next() query = stripQuery(query) - if e == nil { + var expected *ExpectedExec + var ok bool + for _, next := range c.expected { + next.Lock() + if next.fulfilled() { + next.Unlock() + continue + } + + if c.MatchExpectationsInOrder { + if expected, ok = next.(*ExpectedExec); ok { + break + } + next.Unlock() + return nil, fmt.Errorf("call to exec query '%s' with args %+v, was not expected, next expectation is %T as %+v", query, args, next, next) + } + if exec, ok := next.(*ExpectedExec); ok { + if exec.attemptMatch(query, args) { + expected = exec + break + } + } + next.Unlock() + } + if expected == nil { return nil, fmt.Errorf("all expectations were already fulfilled, call to exec '%s' query with args %+v was not expected", query, args) } - t, ok := e.(*ExpectedExec) - if !ok { - return nil, fmt.Errorf("call to exec query '%s' with args %+v, was not expected, next expectation is %T as %+v", query, args, e, e) + defer expected.Unlock() + expected.triggered = true + // converts panic to error in case of reflect value type mismatch + defer func(errp *error, exp *ExpectedExec, q string, a []driver.Value) { + if e := recover(); e != nil { + if se, ok := e.(*reflect.ValueError); ok { // catch reflect error, failed type conversion + msg := "exec query \"%s\", args \"%+v\" failed to match expected arguments \"%+v\", reason %s" + *errp = fmt.Errorf(msg, q, a, exp.args, se) + } else { + panic(e) // overwise if unknown error panic + } + } + }(&err, expected, query, args) + + if !expected.queryMatches(query) { + return nil, fmt.Errorf("exec query '%s', does not match regex '%s'", query, expected.sqlRegex.String()) } - t.triggered = true - if t.err != nil { - return nil, t.err // mocked to return error + if !expected.argsMatches(args) { + return nil, fmt.Errorf("exec query '%s', args %+v does not match expected %+v", query, args, expected.args) } - if t.result == nil { - return nil, fmt.Errorf("exec query '%s' with args %+v, must return a database/sql/driver.result, but it was not set for expectation %T as %+v", query, args, t, t) + if expected.err != nil { + return nil, expected.err // mocked to return error } - defer argMatcherErrorHandler(&err) // converts panic to error in case of reflect value type mismatch - - if !t.queryMatches(query) { - return nil, fmt.Errorf("exec query '%s', does not match regex '%s'", query, t.sqlRegex.String()) + if expected.result == nil { + return nil, fmt.Errorf("exec query '%s' with args %+v, must return a database/sql/driver.result, but it was not set for expectation %T as %+v", query, args, expected, expected) } - - if !t.argsMatches(args) { - return nil, fmt.Errorf("exec query '%s', args %+v does not match expected %+v", query, args, t.args) - } - - return t.result, err + return expected.result, err } // ExpectExec expects Exec() to be called with sql query @@ -153,23 +210,33 @@ func (c *Sqlmock) ExpectExec(sqlRegexStr string) *ExpectedExec { // Prepare meets http://golang.org/pkg/database/sql/driver/#Conn interface func (c *Sqlmock) Prepare(query string) (driver.Stmt, error) { - e := c.next() + var expected *ExpectedPrepare + var ok bool + for _, next := range c.expected { + next.Lock() + if next.fulfilled() { + next.Unlock() + continue + } + + if expected, ok = next.(*ExpectedPrepare); ok { + break + } + + next.Unlock() + if c.MatchExpectationsInOrder { + return nil, fmt.Errorf("call to Prepare stetement with query '%s', was not expected, next expectation is %T as %+v", query, next, next) + } + } query = stripQuery(query) - if e == nil { + if expected == nil { return nil, fmt.Errorf("all expectations were already fulfilled, call to Prepare '%s' query was not expected", query) } - t, ok := e.(*ExpectedPrepare) - if !ok { - return nil, fmt.Errorf("call to Prepare stetement with query '%s', was not expected, next expectation is %T as %+v", query, e, e) - } - t.triggered = true - if t.err != nil { - return nil, t.err // mocked to return error - } - - return &statement{c, query, t.closeErr}, nil + expected.triggered = true + expected.Unlock() + return &statement{c, query, expected.closeErr}, expected.err } // ExpectPrepare expects Prepare() to be called with sql query @@ -185,37 +252,66 @@ func (c *Sqlmock) ExpectPrepare(sqlRegexStr string) *ExpectedPrepare { // Query meets http://golang.org/pkg/database/sql/driver/#Queryer func (c *Sqlmock) Query(query string, args []driver.Value) (rw driver.Rows, err error) { - e := c.next() query = stripQuery(query) - if e == nil { + var expected *ExpectedQuery + var ok bool + for _, next := range c.expected { + next.Lock() + if next.fulfilled() { + next.Unlock() + continue + } + + if c.MatchExpectationsInOrder { + if expected, ok = next.(*ExpectedQuery); ok { + break + } + next.Unlock() + return nil, fmt.Errorf("call to query '%s' with args %+v, was not expected, next expectation is %T as %+v", query, args, next, next) + } + if qr, ok := next.(*ExpectedQuery); ok { + if qr.attemptMatch(query, args) { + expected = qr + break + } + } + next.Unlock() + } + if expected == nil { return nil, fmt.Errorf("all expectations were already fulfilled, call to query '%s' with args %+v was not expected", query, args) } - t, ok := e.(*ExpectedQuery) - if !ok { - return nil, fmt.Errorf("call to query '%s' with args %+v, was not expected, next expectation is %T as %+v", query, args, e, e) + defer expected.Unlock() + expected.triggered = true + // converts panic to error in case of reflect value type mismatch + defer func(errp *error, exp *ExpectedQuery, q string, a []driver.Value) { + if e := recover(); e != nil { + if se, ok := e.(*reflect.ValueError); ok { // catch reflect error, failed type conversion + msg := "query \"%s\", args \"%+v\" failed to match expected arguments \"%+v\", reason %s" + *errp = fmt.Errorf(msg, q, a, exp.args, se) + } else { + panic(e) // overwise if unknown error panic + } + } + }(&err, expected, query, args) + + if !expected.queryMatches(query) { + return nil, fmt.Errorf("query '%s', does not match regex [%s]", query, expected.sqlRegex.String()) } - t.triggered = true - if t.err != nil { - return nil, t.err // mocked to return error + if !expected.argsMatches(args) { + return nil, fmt.Errorf("query '%s', args %+v does not match expected %+v", query, args, expected.args) } - if t.rows == nil { - return nil, fmt.Errorf("query '%s' with args %+v, must return a database/sql/driver.rows, but it was not set for expectation %T as %+v", query, args, t, t) + if expected.err != nil { + return nil, expected.err // mocked to return error } - defer argMatcherErrorHandler(&err) // converts panic to error in case of reflect value type mismatch - - if !t.queryMatches(query) { - return nil, fmt.Errorf("query '%s', does not match regex [%s]", query, t.sqlRegex.String()) + if expected.rows == nil { + return nil, fmt.Errorf("query '%s' with args %+v, must return a database/sql/driver.rows, but it was not set for expectation %T as %+v", query, args, expected, expected) } - if !t.argsMatches(args) { - return nil, fmt.Errorf("query '%s', args %+v does not match expected %+v", query, args, t.args) - } - - return t.rows, err + return expected.rows, err } // ExpectQuery expects Query() or QueryRow() to be called with sql query @@ -246,40 +342,58 @@ func (c *Sqlmock) ExpectRollback() *ExpectedRollback { // Commit meets http://golang.org/pkg/database/sql/driver/#Tx func (c *Sqlmock) Commit() error { - e := c.next() - if e == nil { + var expected *ExpectedCommit + var ok bool + for _, next := range c.expected { + next.Lock() + if next.fulfilled() { + next.Unlock() + continue + } + + if expected, ok = next.(*ExpectedCommit); ok { + break + } + + next.Unlock() + if c.MatchExpectationsInOrder { + return fmt.Errorf("call to commit transaction, was not expected, next expectation is %T as %+v", next, next) + } + } + if expected == nil { return fmt.Errorf("all expectations were already fulfilled, call to commit transaction was not expected") } - t, ok := e.(*ExpectedCommit) - if !ok { - return fmt.Errorf("call to commit transaction, was not expected, next expectation was %v", e) - } - t.triggered = true - return t.err + expected.triggered = true + expected.Unlock() + return expected.err } // Rollback meets http://golang.org/pkg/database/sql/driver/#Tx func (c *Sqlmock) Rollback() error { - e := c.next() - if e == nil { + var expected *ExpectedRollback + var ok bool + for _, next := range c.expected { + next.Lock() + if next.fulfilled() { + next.Unlock() + continue + } + + if expected, ok = next.(*ExpectedRollback); ok { + break + } + + next.Unlock() + if c.MatchExpectationsInOrder { + return fmt.Errorf("call to rollback transaction, was not expected, next expectation is %T as %+v", next, next) + } + } + if expected == nil { return fmt.Errorf("all expectations were already fulfilled, call to rollback transaction was not expected") } - t, ok := e.(*ExpectedRollback) - if !ok { - return fmt.Errorf("call to rollback transaction, was not expected, next expectation was %v", e) - } - t.triggered = true - return t.err -} - -func argMatcherErrorHandler(errp *error) { - if e := recover(); e != nil { - if se, ok := e.(*reflect.ValueError); ok { // catch reflect error, failed type conversion - *errp = fmt.Errorf("Failed to compare query arguments: %s", se) - } else { - panic(e) // overwise panic - } - } + expected.triggered = true + expected.Unlock() + return expected.err } diff --git a/sqlmock_test.go b/sqlmock_test.go index e36a225..deb37ff 100644 --- a/sqlmock_test.go +++ b/sqlmock_test.go @@ -3,6 +3,7 @@ package sqlmock import ( "database/sql" "fmt" + "sync" "testing" "time" ) @@ -575,3 +576,44 @@ func TestArgumentReflectValueTypeError(t *testing.T) { t.Error("Expected error, but got none") } } + +func TestGoroutineExecutionWithUnorderedExpectationMatching(t *testing.T) { + t.Parallel() + db, mock, err := New() + if err != nil { + t.Errorf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + + // note this line is important for unordered expectation matching + mock.MatchExpectationsInOrder = false + + result := NewResult(1, 1) + + mock.ExpectExec("^UPDATE one").WithArgs("one").WillReturnResult(result) + mock.ExpectExec("^UPDATE two").WithArgs("one", "two").WillReturnResult(result) + mock.ExpectExec("^UPDATE three").WithArgs("one", "two", "three").WillReturnResult(result) + + var wg sync.WaitGroup + queries := map[string][]interface{}{ + "one": []interface{}{"one"}, + "two": []interface{}{"one", "two"}, + "three": []interface{}{"one", "two", "three"}, + } + + wg.Add(len(queries)) + for table, args := range queries { + go func(tbl string, a []interface{}) { + if _, err := db.Exec("UPDATE "+tbl, a...); err != nil { + t.Errorf("error was not expected: %s", err) + } + wg.Done() + }(table, args) + } + + wg.Wait() + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expections: %s", err) + } +}