1
0
mirror of https://github.com/DATA-DOG/go-sqlmock.git synced 2025-04-04 21:54:20 +02:00

Adding feature to allow repeatable expectations

This commit is contained in:
Michael MacDonald 2017-06-06 11:17:58 -04:00 committed by gedi
parent d76b18b42f
commit fbf1c7c325
No known key found for this signature in database
GPG Key ID: 56604CDCCC201556
2 changed files with 105 additions and 14 deletions

View File

@ -73,10 +73,23 @@ type Sqlmock interface {
// in any order. Or otherwise if switched to true, any unmatched // in any order. Or otherwise if switched to true, any unmatched
// expectations will be expected in order // expectations will be expected in order
MatchExpectationsInOrder(bool) MatchExpectationsInOrder(bool)
// AllowRepeatedExpectationMatching gives an option whether or not to
// allow expectations to be matched more than once.
//
// By default it is set to - false.
//
// This option may be turned on anytime during tests. As soon
// as it is switched to true, expectations will be allowed to match
// regardless of it has been previously matched against.
//
// When setting this true, consider if you will need to set MatchExpectationsInOrder(false)
AllowRepeatedExpectationMatching(bool)
} }
type sqlmock struct { type sqlmock struct {
ordered bool ordered bool
repeatable bool
dsn string dsn string
opened int opened int
drv *mockDriver drv *mockDriver
@ -102,6 +115,10 @@ func (c *sqlmock) MatchExpectationsInOrder(b bool) {
c.ordered = b c.ordered = b
} }
func (c *sqlmock) AllowRepeatedExpectationMatching(b bool) {
c.repeatable = b
}
// Close a mock database driver connection. It may or may not // Close a mock database driver connection. It may or may not
// be called depending on the sircumstances, but if it is called // be called depending on the sircumstances, but if it is called
// there must be an *ExpectedClose expectation satisfied. // there must be an *ExpectedClose expectation satisfied.
@ -121,10 +138,12 @@ func (c *sqlmock) Close() error {
for _, next := range c.expected { for _, next := range c.expected {
next.Lock() next.Lock()
if next.fulfilled() { if next.fulfilled() {
next.Unlock()
fulfilled++ fulfilled++
if !c.repeatable {
next.Unlock()
continue continue
} }
}
if expected, ok = next.(*ExpectedClose); ok { if expected, ok = next.(*ExpectedClose); ok {
break break
@ -185,8 +204,10 @@ func (c *sqlmock) begin() (*ExpectedBegin, error) {
if next.fulfilled() { if next.fulfilled() {
next.Unlock() next.Unlock()
fulfilled++ fulfilled++
if !c.repeatable {
continue continue
} }
}
if expected, ok = next.(*ExpectedBegin); ok { if expected, ok = next.(*ExpectedBegin); ok {
break break
@ -246,8 +267,10 @@ func (c *sqlmock) exec(query string, args []namedValue) (*ExpectedExec, error) {
if next.fulfilled() { if next.fulfilled() {
next.Unlock() next.Unlock()
fulfilled++ fulfilled++
if !c.repeatable {
continue continue
} }
}
if c.ordered { if c.ordered {
if expected, ok = next.(*ExpectedExec); ok { if expected, ok = next.(*ExpectedExec); ok {
@ -322,10 +345,13 @@ func (c *sqlmock) prepare(query string) (*ExpectedPrepare, error) {
for _, next := range c.expected { for _, next := range c.expected {
next.Lock() next.Lock()
if next.fulfilled() { if next.fulfilled() {
next.Unlock()
fulfilled++ fulfilled++
if !c.repeatable {
next.Unlock()
continue continue
} }
}
if c.ordered { if c.ordered {
if expected, ok = next.(*ExpectedPrepare); ok { if expected, ok = next.(*ExpectedPrepare); ok {
@ -401,10 +427,12 @@ func (c *sqlmock) query(query string, args []namedValue) (*ExpectedQuery, error)
for _, next := range c.expected { for _, next := range c.expected {
next.Lock() next.Lock()
if next.fulfilled() { if next.fulfilled() {
next.Unlock()
fulfilled++ fulfilled++
if !c.repeatable {
next.Unlock()
continue continue
} }
}
if c.ordered { if c.ordered {
if expected, ok = next.(*ExpectedQuery); ok { if expected, ok = next.(*ExpectedQuery); ok {
@ -448,6 +476,14 @@ func (c *sqlmock) query(query string, args []namedValue) (*ExpectedQuery, error)
if expected.rows == nil { if expected.rows == nil {
return nil, fmt.Errorf("Query '%s' with args %+v, must return a database/sql/driver.Rows, but it was not set for expectation %T as %+v", query, args, expected, expected) return nil, fmt.Errorf("Query '%s' with args %+v, must return a database/sql/driver.Rows, but it was not set for expectation %T as %+v", query, args, expected, expected)
} }
// reset rows for next use if allowed
if rs, ok := expected.rows.(*rowSets); ok {
rs.pos = 0
for _, set := range rs.sets {
set.pos = 0
}
}
return expected, nil return expected, nil
} }
@ -481,8 +517,10 @@ func (c *sqlmock) Commit() error {
if next.fulfilled() { if next.fulfilled() {
next.Unlock() next.Unlock()
fulfilled++ fulfilled++
if !c.repeatable {
continue continue
} }
}
if expected, ok = next.(*ExpectedCommit); ok { if expected, ok = next.(*ExpectedCommit); ok {
break break
@ -516,8 +554,10 @@ func (c *sqlmock) Rollback() error {
if next.fulfilled() { if next.fulfilled() {
next.Unlock() next.Unlock()
fulfilled++ fulfilled++
if !c.repeatable {
continue continue
} }
}
if expected, ok = next.(*ExpectedRollback); ok { if expected, ok = next.(*ExpectedRollback); ok {
break break

View File

@ -354,6 +354,57 @@ func TestPreparedQueryExecutions(t *testing.T) {
} }
} }
func TestPreparedQueryMultipleExecutions(t *testing.T) {
t.Parallel()
db, mock, err := New()
if err != nil {
t.Errorf("an error '%s' was not expected when opening a stub database connection", err)
}
defer db.Close()
mock.MatchExpectationsInOrder(false)
mock.AllowRepeatedExpectationMatching(true)
rs := NewRows([]string{"id", "title"}).FromCSVString("5,hello world")
mock.ExpectPrepare("SELECT (.+) FROM articles WHERE id = ?").ExpectQuery().
WithArgs(5).
WillReturnRows(rs)
stmt, err := db.Prepare("SELECT id, title FROM articles WHERE id = ?")
if err != nil {
t.Errorf("error '%s' was not expected while creating a prepared statement", err)
}
stmt2, err := db.Prepare("SELECT id, title FROM articles WHERE id = ?")
if err != nil {
t.Errorf("error '%s' was not expected while creating a prepared statement", err)
}
var id1, id2 int
var title1, title2 string
err = stmt.QueryRow(5).Scan(&id1, &title1)
if err != nil {
t.Errorf("error '%s' was not expected querying row from statement and scanning", err)
}
err = stmt2.QueryRow(5).Scan(&id2, &title2)
if err != nil {
t.Errorf("error '%s' was not expected querying row from statement and scanning", err)
}
if id1 != 5 || id2 != 5 {
t.Errorf("expected mocked id to be 5, but got %d instead", id2)
}
if title1 != "hello world" || title2 != "hello world" {
t.Errorf("expected mocked title to be 'hello world', but got '%s' instead", title2)
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("there were unfulfilled expections: %s", err)
}
}
func TestUnorderedPreparedQueryExecutions(t *testing.T) { func TestUnorderedPreparedQueryExecutions(t *testing.T) {
t.Parallel() t.Parallel()
db, mock, err := New() db, mock, err := New()