diff --git a/sqlmock.go b/sqlmock.go index e82a2a3..e3ed3e4 100644 --- a/sqlmock.go +++ b/sqlmock.go @@ -309,6 +309,9 @@ func (c *sqlmock) prepare(query string) (*ExpectedPrepare, error) { var expected *ExpectedPrepare var fulfilled int var ok bool + + query = stripQuery(query) + for _, next := range c.expected { next.Lock() if next.fulfilled() { @@ -317,17 +320,24 @@ func (c *sqlmock) prepare(query string) (*ExpectedPrepare, error) { continue } - if expected, ok = next.(*ExpectedPrepare); ok { - break - } - - next.Unlock() if c.ordered { + if expected, ok = next.(*ExpectedPrepare); ok { + break + } + + next.Unlock() return nil, fmt.Errorf("call to Prepare statement with query '%s', was not expected, next expectation is: %s", query, next) } + + if pr, ok := next.(*ExpectedPrepare); ok { + if pr.sqlRegex.MatchString(query) { + expected = pr + break + } + } + next.Unlock() } - query = stripQuery(query) if expected == nil { msg := "call to Prepare '%s' query was not expected" if fulfilled == len(c.expected) { diff --git a/sqlmock_test.go b/sqlmock_test.go index fe4924b..ffaa824 100644 --- a/sqlmock_test.go +++ b/sqlmock_test.go @@ -354,6 +354,48 @@ func TestPreparedQueryExecutions(t *testing.T) { } } + +func TestUnorderedPreparedQueryExecutions(t *testing.T) { + t.Parallel() + db, mock, err := New() + if err != nil { + t.Errorf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + + mock.MatchExpectationsInOrder(false) + + mock.ExpectPrepare("SELECT (.+) FROM articles WHERE id = ?"). + ExpectQuery(). + WithArgs(5). + WillReturnRows( + NewRows([]string{"id", "title"}).FromCSVString("5,The quick brown fox"), + ) + mock.ExpectPrepare("SELECT (.+) FROM authors WHERE id = ?"). + ExpectQuery(). + WithArgs(1). + WillReturnRows( + NewRows([]string{"id", "title"}).FromCSVString("1,Betty B."), + ) + + var id int + var name string + + stmt, err := db.Prepare("SELECT id, name FROM authors WHERE id = ?") + if err != nil { + t.Errorf("error '%s' was not expected while creating a prepared statement", err) + } + + err = stmt.QueryRow(1).Scan(&id, &name) + if err != nil { + t.Errorf("error '%s' was not expected querying row from statement and scanning", err) + } + + if name != "Betty B." { + t.Errorf("expected mocked name to be 'Betty B.', but got '%s' instead", name) + } +} + func TestUnexpectedOperations(t *testing.T) { t.Parallel() db, mock, err := New()