diff --git a/sqlmock.go b/sqlmock.go index 500b5c6..536fa13 100644 --- a/sqlmock.go +++ b/sqlmock.go @@ -155,6 +155,20 @@ func (c *sqlmock) ExpectationsWereMet() error { // Begin meets http://golang.org/pkg/database/sql/driver/#Conn interface func (c *sqlmock) Begin() (driver.Tx, error) { + ex, err := c.beginExpectation() + if err != nil { + return nil, err + } + + return c.begin(ex) +} + +func (c *sqlmock) begin(expected *ExpectedBegin) (driver.Tx, error) { + defer time.Sleep(expected.delay) + return c, nil +} + +func (c *sqlmock) beginExpectation() (*ExpectedBegin, error) { var expected *ExpectedBegin var ok bool var fulfilled int @@ -185,8 +199,8 @@ func (c *sqlmock) Begin() (driver.Tx, error) { expected.triggered = true expected.Unlock() - defer time.Sleep(expected.delay) - return c, expected.err + + return expected, expected.err } func (c *sqlmock) ExpectBegin() *ExpectedBegin { @@ -204,10 +218,16 @@ func (c *sqlmock) Exec(query string, args []driver.Value) (driver.Result, error) Value: v, } } - return c.exec(nil, query, namedArgs) + + ex, err := c.execExpectation(query, namedArgs) + if err != nil { + return nil, err + } + + return c.exec(ex) } -func (c *sqlmock) exec(ctx interface{}, query string, args []namedValue) (res driver.Result, err error) { +func (c *sqlmock) execExpectation(query string, args []namedValue) (*ExpectedExec, error) { query = stripQuery(query) var expected *ExpectedExec var fulfilled int @@ -242,21 +262,17 @@ func (c *sqlmock) exec(ctx interface{}, query string, args []namedValue) (res dr } return nil, fmt.Errorf(msg, query, args) } + defer expected.Unlock() if !expected.queryMatches(query) { - expected.Unlock() return nil, fmt.Errorf("exec query '%s', does not match regex '%s'", query, expected.sqlRegex.String()) } if err := expected.argsMatches(args); err != nil { - expected.Unlock() return nil, fmt.Errorf("exec query '%s', arguments do not match: %s", query, err) } expected.triggered = true - defer time.Sleep(expected.delay) - defer expected.Unlock() - if expected.err != nil { return nil, expected.err // mocked to return error } @@ -265,7 +281,12 @@ func (c *sqlmock) exec(ctx interface{}, query string, args []namedValue) (res dr return nil, fmt.Errorf("exec query '%s' with args %+v, must return a database/sql/driver.result, but it was not set for expectation %T as %+v", query, args, expected, expected) } - return expected.result, err + return expected, nil +} + +func (c *sqlmock) exec(expected *ExpectedExec) (driver.Result, error) { + defer time.Sleep(expected.delay) + return expected.result, nil } func (c *sqlmock) ExpectExec(sqlRegexStr string) *ExpectedExec { @@ -278,6 +299,15 @@ func (c *sqlmock) ExpectExec(sqlRegexStr string) *ExpectedExec { // Prepare meets http://golang.org/pkg/database/sql/driver/#Conn interface func (c *sqlmock) Prepare(query string) (driver.Stmt, error) { + ex, err := c.prepareExpectation(query) + if err != nil { + return nil, err + } + + return c.prepare(ex, query) +} + +func (c *sqlmock) prepareExpectation(query string) (*ExpectedPrepare, error) { var expected *ExpectedPrepare var fulfilled int var ok bool @@ -307,15 +337,18 @@ func (c *sqlmock) Prepare(query string) (driver.Stmt, error) { } return nil, fmt.Errorf(msg, query) } + defer expected.Unlock() if !expected.sqlRegex.MatchString(query) { - expected.Unlock() return nil, fmt.Errorf("query '%s', does not match regex [%s]", query, expected.sqlRegex.String()) } expected.triggered = true + return expected, expected.err +} + +func (c *sqlmock) prepare(expected *ExpectedPrepare, query string) (driver.Stmt, error) { defer time.Sleep(expected.delay) - defer expected.Unlock() - return &statement{c, query, expected.closeErr}, expected.err + return &statement{c, query, expected.closeErr}, nil } func (c *sqlmock) ExpectPrepare(sqlRegexStr string) *ExpectedPrepare { @@ -332,7 +365,7 @@ type namedValue struct { } // Query meets http://golang.org/pkg/database/sql/driver/#Queryer -func (c *sqlmock) Query(query string, args []driver.Value) (rw driver.Rows, err error) { +func (c *sqlmock) Query(query string, args []driver.Value) (driver.Rows, error) { namedArgs := make([]namedValue, len(args)) for i, v := range args { namedArgs[i] = namedValue{ @@ -340,12 +373,16 @@ func (c *sqlmock) Query(query string, args []driver.Value) (rw driver.Rows, err Value: v, } } - return c.query(nil, query, namedArgs) + + ex, err := c.queryExpectation(query, namedArgs) + if err != nil { + return nil, err + } + + return c.query(ex) } -// in order to prevent dependencies, we use Context as a plain interface -// since it is only related to internal implementation -func (c *sqlmock) query(ctx interface{}, query string, args []namedValue) (rw driver.Rows, err error) { +func (c *sqlmock) queryExpectation(query string, args []namedValue) (*ExpectedQuery, error) { query = stripQuery(query) var expected *ExpectedQuery var fulfilled int @@ -382,21 +419,17 @@ func (c *sqlmock) query(ctx interface{}, query string, args []namedValue) (rw dr return nil, fmt.Errorf(msg, query, args) } + defer expected.Unlock() + if !expected.queryMatches(query) { - expected.Unlock() return nil, fmt.Errorf("query '%s', does not match regex [%s]", query, expected.sqlRegex.String()) } if err := expected.argsMatches(args); err != nil { - expected.Unlock() return nil, fmt.Errorf("exec query '%s', arguments do not match: %s", query, err) } expected.triggered = true - - defer time.Sleep(expected.delay) - defer expected.Unlock() - if expected.err != nil { return nil, expected.err // mocked to return error } @@ -404,8 +437,13 @@ func (c *sqlmock) query(ctx interface{}, query string, args []namedValue) (rw dr if expected.rows == nil { return nil, fmt.Errorf("query '%s' with args %+v, must return a database/sql/driver.rows, but it was not set for expectation %T as %+v", query, args, expected, expected) } + return expected, nil +} - return expected.rows, err +func (c *sqlmock) query(expected *ExpectedQuery) (driver.Rows, error) { + defer time.Sleep(expected.delay) + + return expected.rows, nil } func (c *sqlmock) ExpectQuery(sqlRegexStr string) *ExpectedQuery { diff --git a/sqlmock_go18.go b/sqlmock_go18.go index 2021007..0c82a3b 100644 --- a/sqlmock_go18.go +++ b/sqlmock_go18.go @@ -2,4 +2,142 @@ package sqlmock -// @TODO context based extensions +import ( + "context" + "database/sql/driver" + "fmt" +) + +var CancelledStatementErr = fmt.Errorf("canceling query due to user request") + +// Implement the "QueryerContext" interface +func (c *sqlmock) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { + namedArgs := make([]namedValue, len(args)) + for i, nv := range args { + namedArgs[i] = namedValue(nv) + } + + ex, err := c.queryExpectation(query, namedArgs) + if err != nil { + return nil, err + } + + type result struct { + rows driver.Rows + err error + } + + exec := make(chan result) + defer func() { + close(exec) + }() + + go func() { + rows, err := c.query(ex) + exec <- result{rows, err} + }() + + select { + case res := <-exec: + return res.rows, res.err + case <-ctx.Done(): + return nil, CancelledStatementErr + } +} + +// Implement the "ExecerContext" interface +func (c *sqlmock) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { + namedArgs := make([]namedValue, len(args)) + for i, nv := range args { + namedArgs[i] = namedValue(nv) + } + + ex, err := c.execExpectation(query, namedArgs) + if err != nil { + return nil, err + } + + type result struct { + rs driver.Result + err error + } + + exec := make(chan result) + defer func() { + close(exec) + }() + + go func() { + rs, err := c.exec(ex) + exec <- result{rs, err} + }() + + select { + case res := <-exec: + return res.rs, res.err + case <-ctx.Done(): + return nil, CancelledStatementErr + } +} + +// Implement the "ConnBeginTx" interface +func (c *sqlmock) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { + ex, err := c.beginExpectation() + if err != nil { + return nil, err + } + + type result struct { + tx driver.Tx + err error + } + + exec := make(chan result) + defer func() { + close(exec) + }() + + go func() { + tx, err := c.begin(ex) + exec <- result{tx, err} + }() + + select { + case res := <-exec: + return res.tx, res.err + case <-ctx.Done(): + return nil, CancelledStatementErr + } +} + +// Implement the "ConnPrepareContext" interface +func (c *sqlmock) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { + ex, err := c.prepareExpectation(query) + if err != nil { + return nil, err + } + + type result struct { + stmt driver.Stmt + err error + } + + exec := make(chan result) + defer func() { + close(exec) + }() + + go func() { + stmt, err := c.prepare(ex, query) + exec <- result{stmt, err} + }() + + select { + case res := <-exec: + return res.stmt, res.err + case <-ctx.Done(): + return nil, CancelledStatementErr + } +} + +// @TODO maybe add ExpectedBegin.WithOptions(driver.TxOptions) diff --git a/sqlmock_go18_test.go b/sqlmock_go18_test.go new file mode 100644 index 0000000..713ed0f --- /dev/null +++ b/sqlmock_go18_test.go @@ -0,0 +1,3 @@ +// +build go1.8 + +package sqlmock