diff --git a/README.md b/README.md index ef401fe..16a825f 100644 --- a/README.md +++ b/README.md @@ -190,6 +190,8 @@ It only asserts that argument is of `time.Time` type. ## Change Log +- **2017-09-01** - it is now possible to expect that prepared statement will be closed, + using **ExpectedPrepare.WillBeClosed**. - **2017-02-09** - implemented support for **go1.8** features. **Rows** interface was changed to struct but contains all methods as before and should maintain backwards compatibility. **ExpectedQuery.WillReturnRows** may now accept multiple row sets. diff --git a/expectations.go b/expectations.go index c00e6c6..6ff9a65 100644 --- a/expectations.go +++ b/expectations.go @@ -252,11 +252,13 @@ func (e *ExpectedExec) WillReturnResult(result driver.Result) *ExpectedExec { // Returned by *Sqlmock.ExpectPrepare. type ExpectedPrepare struct { commonExpectation - mock *sqlmock - sqlRegex *regexp.Regexp - statement driver.Stmt - closeErr error - delay time.Duration + mock *sqlmock + sqlRegex *regexp.Regexp + statement driver.Stmt + closeErr error + mustBeClosed bool + wasClosed bool + delay time.Duration } // WillReturnError allows to set an error for the expected *sql.DB.Prepare or *sql.Tx.Prepare action. @@ -278,6 +280,13 @@ func (e *ExpectedPrepare) WillDelayFor(duration time.Duration) *ExpectedPrepare return e } +// WillBeClosed expects this prepared statement to +// be closed. +func (e *ExpectedPrepare) WillBeClosed() *ExpectedPrepare { + e.mustBeClosed = true + return e +} + // ExpectQuery allows to expect Query() or QueryRow() on this prepared statement. // this method is convenient in order to prevent duplicating sql query string matching. func (e *ExpectedPrepare) ExpectQuery() *ExpectedQuery { diff --git a/sqlmock.go b/sqlmock.go index 0f1572f..fa7f624 100644 --- a/sqlmock.go +++ b/sqlmock.go @@ -154,6 +154,13 @@ func (c *sqlmock) ExpectationsWereMet() error { if !e.fulfilled() { return fmt.Errorf("there is a remaining expectation which was not matched: %s", e) } + + // for expected prepared statement check whether it was closed if expected + if prep, ok := e.(*ExpectedPrepare); ok { + if prep.mustBeClosed && !prep.wasClosed { + return fmt.Errorf("expected prepared statement to be closed, but it was not: %s", prep) + } + } } return nil } @@ -302,7 +309,7 @@ func (c *sqlmock) Prepare(query string) (driver.Stmt, error) { } time.Sleep(ex.delay) - return &statement{c, query, ex.closeErr}, nil + return &statement{c, ex, query}, nil } func (c *sqlmock) prepare(query string) (*ExpectedPrepare, error) { diff --git a/sqlmock_test.go b/sqlmock_test.go index fa2934d..9c48d3d 100644 --- a/sqlmock_test.go +++ b/sqlmock_test.go @@ -1033,3 +1033,33 @@ func TestExpectedBeginOrder(t *testing.T) { t.Error("an error was expected when calling close, but got none") } } + +func TestPreparedStatementCloseExpectation(t *testing.T) { + // Open new mock database + db, mock, err := New() + if err != nil { + fmt.Println("error creating mock database") + return + } + defer db.Close() + + ep := mock.ExpectPrepare("INSERT INTO ORDERS").WillBeClosed() + ep.ExpectExec().WillReturnResult(NewResult(1, 1)) + + stmt, err := db.Prepare("INSERT INTO ORDERS(ID, STATUS) VALUES (?, ?)") + if err != nil { + t.Fatal(err) + } + + if _, err := stmt.Exec(1, "Hello"); err != nil { + t.Fatal(err) + } + + if err := stmt.Close(); err != nil { + t.Fatal(err) + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expections: %s", err) + } +} diff --git a/statement.go b/statement.go index df73740..570efd9 100644 --- a/statement.go +++ b/statement.go @@ -6,12 +6,13 @@ import ( type statement struct { conn *sqlmock + ex *ExpectedPrepare query string - err error } func (stmt *statement) Close() error { - return stmt.err + stmt.ex.wasClosed = true + return stmt.ex.closeErr } func (stmt *statement) NumInput() int {