From fbf1c7c325dc032c8ad35a69cf32159b0ec9f96b Mon Sep 17 00:00:00 2001 From: Michael MacDonald Date: Tue, 6 Jun 2017 11:17:58 -0400 Subject: [PATCH] Adding feature to allow repeatable expectations --- sqlmock.go | 68 +++++++++++++++++++++++++++++++++++++++---------- sqlmock_test.go | 51 +++++++++++++++++++++++++++++++++++++ 2 files changed, 105 insertions(+), 14 deletions(-) diff --git a/sqlmock.go b/sqlmock.go index fa7f624..e013164 100644 --- a/sqlmock.go +++ b/sqlmock.go @@ -73,13 +73,26 @@ type Sqlmock interface { // in any order. Or otherwise if switched to true, any unmatched // expectations will be expected in order MatchExpectationsInOrder(bool) + + // AllowRepeatedExpectationMatching gives an option whether or not to + // allow expectations to be matched more than once. + // + // By default it is set to - false. + // + // This option may be turned on anytime during tests. As soon + // as it is switched to true, expectations will be allowed to match + // regardless of it has been previously matched against. + // + // When setting this true, consider if you will need to set MatchExpectationsInOrder(false) + AllowRepeatedExpectationMatching(bool) } type sqlmock struct { - ordered bool - dsn string - opened int - drv *mockDriver + ordered bool + repeatable bool + dsn string + opened int + drv *mockDriver expected []expectation } @@ -102,6 +115,10 @@ func (c *sqlmock) MatchExpectationsInOrder(b bool) { c.ordered = b } +func (c *sqlmock) AllowRepeatedExpectationMatching(b bool) { + c.repeatable = b +} + // Close a mock database driver connection. It may or may not // be called depending on the sircumstances, but if it is called // there must be an *ExpectedClose expectation satisfied. @@ -121,9 +138,11 @@ func (c *sqlmock) Close() error { for _, next := range c.expected { next.Lock() if next.fulfilled() { - next.Unlock() fulfilled++ - continue + if !c.repeatable { + next.Unlock() + continue + } } if expected, ok = next.(*ExpectedClose); ok { @@ -185,7 +204,9 @@ func (c *sqlmock) begin() (*ExpectedBegin, error) { if next.fulfilled() { next.Unlock() fulfilled++ - continue + if !c.repeatable { + continue + } } if expected, ok = next.(*ExpectedBegin); ok { @@ -246,7 +267,9 @@ func (c *sqlmock) exec(query string, args []namedValue) (*ExpectedExec, error) { if next.fulfilled() { next.Unlock() fulfilled++ - continue + if !c.repeatable { + continue + } } if c.ordered { @@ -322,9 +345,12 @@ func (c *sqlmock) prepare(query string) (*ExpectedPrepare, error) { for _, next := range c.expected { next.Lock() if next.fulfilled() { - next.Unlock() fulfilled++ - continue + + if !c.repeatable { + next.Unlock() + continue + } } if c.ordered { @@ -401,9 +427,11 @@ func (c *sqlmock) query(query string, args []namedValue) (*ExpectedQuery, error) for _, next := range c.expected { next.Lock() if next.fulfilled() { - next.Unlock() fulfilled++ - continue + if !c.repeatable { + next.Unlock() + continue + } } if c.ordered { @@ -448,6 +476,14 @@ func (c *sqlmock) query(query string, args []namedValue) (*ExpectedQuery, error) if expected.rows == nil { return nil, fmt.Errorf("Query '%s' with args %+v, must return a database/sql/driver.Rows, but it was not set for expectation %T as %+v", query, args, expected, expected) } + + // reset rows for next use if allowed + if rs, ok := expected.rows.(*rowSets); ok { + rs.pos = 0 + for _, set := range rs.sets { + set.pos = 0 + } + } return expected, nil } @@ -481,7 +517,9 @@ func (c *sqlmock) Commit() error { if next.fulfilled() { next.Unlock() fulfilled++ - continue + if !c.repeatable { + continue + } } if expected, ok = next.(*ExpectedCommit); ok { @@ -516,7 +554,9 @@ func (c *sqlmock) Rollback() error { if next.fulfilled() { next.Unlock() fulfilled++ - continue + if !c.repeatable { + continue + } } if expected, ok = next.(*ExpectedRollback); ok { diff --git a/sqlmock_test.go b/sqlmock_test.go index 9c48d3d..51ce90f 100644 --- a/sqlmock_test.go +++ b/sqlmock_test.go @@ -354,6 +354,57 @@ func TestPreparedQueryExecutions(t *testing.T) { } } +func TestPreparedQueryMultipleExecutions(t *testing.T) { + t.Parallel() + db, mock, err := New() + if err != nil { + t.Errorf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + + mock.MatchExpectationsInOrder(false) + mock.AllowRepeatedExpectationMatching(true) + + rs := NewRows([]string{"id", "title"}).FromCSVString("5,hello world") + mock.ExpectPrepare("SELECT (.+) FROM articles WHERE id = ?").ExpectQuery(). + WithArgs(5). + WillReturnRows(rs) + + stmt, err := db.Prepare("SELECT id, title FROM articles WHERE id = ?") + if err != nil { + t.Errorf("error '%s' was not expected while creating a prepared statement", err) + } + + stmt2, err := db.Prepare("SELECT id, title FROM articles WHERE id = ?") + if err != nil { + t.Errorf("error '%s' was not expected while creating a prepared statement", err) + } + + var id1, id2 int + var title1, title2 string + err = stmt.QueryRow(5).Scan(&id1, &title1) + if err != nil { + t.Errorf("error '%s' was not expected querying row from statement and scanning", err) + } + + err = stmt2.QueryRow(5).Scan(&id2, &title2) + if err != nil { + t.Errorf("error '%s' was not expected querying row from statement and scanning", err) + } + + if id1 != 5 || id2 != 5 { + t.Errorf("expected mocked id to be 5, but got %d instead", id2) + } + + if title1 != "hello world" || title2 != "hello world" { + t.Errorf("expected mocked title to be 'hello world', but got '%s' instead", title2) + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expections: %s", err) + } +} + func TestUnorderedPreparedQueryExecutions(t *testing.T) { t.Parallel() db, mock, err := New()