diff --git a/README.md b/README.md index eee8aa9..58d4c63 100644 --- a/README.md +++ b/README.md @@ -315,6 +315,13 @@ rs := sqlmock.NewRows([]string{"column1", "column2"}). AddRow("three", 3) ``` +**Prepare** will ignore other expectations if ExpectPrepare not set. When set, can expect normal result or simulate an error: + +``` go +rs := sqlmock.ExpectPrepare(). + WillReturnError(fmt.Errorf("Query prepare failed")) +``` + ## Run tests go test diff --git a/connection.go b/connection.go index 171fdbe..493cf42 100644 --- a/connection.go +++ b/connection.go @@ -85,6 +85,22 @@ func (c *conn) Exec(query string, args []driver.Value) (res driver.Result, err e } func (c *conn) Prepare(query string) (driver.Stmt, error) { + e := c.next() + + // for backwards compatibility, ignore when Prepare not expected + if e == nil { + return &statement{mock.conn, stripQuery(query)}, nil + } + eq, ok := e.(*expectedPrepare) + if !ok { + return &statement{mock.conn, stripQuery(query)}, nil + } + + eq.triggered = true + if eq.err != nil { + return nil, eq.err // mocked to return error + } + return &statement{mock.conn, stripQuery(query)}, nil } diff --git a/expectations.go b/expectations.go index 360338b..d778afd 100644 --- a/expectations.go +++ b/expectations.go @@ -117,3 +117,10 @@ type expectedExec struct { result driver.Result } + +// Prepare expectation +type expectedPrepare struct { + commonExpectation + + statement driver.Stmt +} diff --git a/sqlmock.go b/sqlmock.go index 2386363..a1698d0 100644 --- a/sqlmock.go +++ b/sqlmock.go @@ -122,6 +122,14 @@ func ExpectRollback() Mock { return mock.conn } +// ExpectPrepare expects Query to be prepared +func ExpectPrepare() Mock { + e := &expectedPrepare{} + mock.conn.expectations = append(mock.conn.expectations, e) + mock.conn.active = e + return mock.conn +} + // WillReturnError the expectation will return an error func (c *conn) WillReturnError(err error) Mock { c.active.setError(err) diff --git a/sqlmock_test.go b/sqlmock_test.go index ff75017..c956c45 100644 --- a/sqlmock_test.go +++ b/sqlmock_test.go @@ -163,6 +163,68 @@ func TestTransactionExpectations(t *testing.T) { } } +func TestPrepareExpectations(t *testing.T) { + db, err := sql.Open("mock", "") + if err != nil { + t.Errorf("an error '%s' was not expected when opening a stub database connection", err) + } + + // no expectations, w/o ExpectPrepare() + stmt, err := db.Prepare("SELECT (.+) FROM articles WHERE id = ?") + if err != nil { + t.Errorf("error '%s' was not expected while creating a prepared statement", err) + } + if stmt == nil { + t.Errorf("stmt was expected while creating a prepared statement") + } + + // expect something else, w/o ExpectPrepare() + var id int + var title string + rs := NewRows([]string{"id", "title"}).FromCSVString("5,hello world") + + ExpectQuery("SELECT (.+) FROM articles WHERE id = ?"). + WithArgs(5). + WillReturnRows(rs) + + stmt, err = db.Prepare("SELECT (.+) FROM articles WHERE id = ?") + if err != nil { + t.Errorf("error '%s' was not expected while creating a prepared statement", err) + } + if stmt == nil { + t.Errorf("stmt was expected while creating a prepared statement") + } + + err = stmt.QueryRow(5).Scan(&id, &title) + if err != nil { + t.Errorf("error '%s' was not expected while retrieving mock rows", err) + } + + // expect normal result + ExpectPrepare() + stmt, err = db.Prepare("SELECT (.+) FROM articles WHERE id = ?") + if err != nil { + t.Errorf("error '%s' was not expected while creating a prepared statement", err) + } + if stmt == nil { + t.Errorf("stmt was expected while creating a prepared statement") + } + + // expect error result + ExpectPrepare().WillReturnError(fmt.Errorf("Some DB error occurred")) + stmt, err = db.Prepare("SELECT (.+) FROM articles WHERE id = ?") + if err == nil { + t.Error("error was expected while creating a prepared statement") + } + if stmt != nil { + t.Errorf("stmt was not expected while creating a prepared statement returning error") + } + + if err = db.Close(); err != nil { + t.Errorf("error '%s' was not expected while closing the database", err) + } +} + func TestPreparedQueryExecutions(t *testing.T) { db, err := sql.Open("mock", "") if err != nil {