diff --git a/connection.go b/connection.go index 171fdbe..493cf42 100644 --- a/connection.go +++ b/connection.go @@ -85,6 +85,22 @@ func (c *conn) Exec(query string, args []driver.Value) (res driver.Result, err e } func (c *conn) Prepare(query string) (driver.Stmt, error) { + e := c.next() + + // for backwards compatibility, ignore when Prepare not expected + if e == nil { + return &statement{mock.conn, stripQuery(query)}, nil + } + eq, ok := e.(*expectedPrepare) + if !ok { + return &statement{mock.conn, stripQuery(query)}, nil + } + + eq.triggered = true + if eq.err != nil { + return nil, eq.err // mocked to return error + } + return &statement{mock.conn, stripQuery(query)}, nil } diff --git a/expectations.go b/expectations.go index 360338b..d778afd 100644 --- a/expectations.go +++ b/expectations.go @@ -117,3 +117,10 @@ type expectedExec struct { result driver.Result } + +// Prepare expectation +type expectedPrepare struct { + commonExpectation + + statement driver.Stmt +} diff --git a/sqlmock.go b/sqlmock.go index 2386363..a1698d0 100644 --- a/sqlmock.go +++ b/sqlmock.go @@ -122,6 +122,14 @@ func ExpectRollback() Mock { return mock.conn } +// ExpectPrepare expects Query to be prepared +func ExpectPrepare() Mock { + e := &expectedPrepare{} + mock.conn.expectations = append(mock.conn.expectations, e) + mock.conn.active = e + return mock.conn +} + // WillReturnError the expectation will return an error func (c *conn) WillReturnError(err error) Mock { c.active.setError(err)