1
0
mirror of https://github.com/DATA-DOG/go-sqlmock.git synced 2024-11-28 08:49:01 +02:00

implements Context based sql driver extensions

This commit is contained in:
gedi 2017-02-07 12:20:08 +02:00
parent d11f623794
commit 965003de80
3 changed files with 205 additions and 26 deletions

View File

@ -155,6 +155,20 @@ func (c *sqlmock) ExpectationsWereMet() error {
// Begin meets http://golang.org/pkg/database/sql/driver/#Conn interface
func (c *sqlmock) Begin() (driver.Tx, error) {
ex, err := c.beginExpectation()
if err != nil {
return nil, err
}
return c.begin(ex)
}
func (c *sqlmock) begin(expected *ExpectedBegin) (driver.Tx, error) {
defer time.Sleep(expected.delay)
return c, nil
}
func (c *sqlmock) beginExpectation() (*ExpectedBegin, error) {
var expected *ExpectedBegin
var ok bool
var fulfilled int
@ -185,8 +199,8 @@ func (c *sqlmock) Begin() (driver.Tx, error) {
expected.triggered = true
expected.Unlock()
defer time.Sleep(expected.delay)
return c, expected.err
return expected, expected.err
}
func (c *sqlmock) ExpectBegin() *ExpectedBegin {
@ -204,10 +218,16 @@ func (c *sqlmock) Exec(query string, args []driver.Value) (driver.Result, error)
Value: v,
}
}
return c.exec(nil, query, namedArgs)
ex, err := c.execExpectation(query, namedArgs)
if err != nil {
return nil, err
}
return c.exec(ex)
}
func (c *sqlmock) exec(ctx interface{}, query string, args []namedValue) (res driver.Result, err error) {
func (c *sqlmock) execExpectation(query string, args []namedValue) (*ExpectedExec, error) {
query = stripQuery(query)
var expected *ExpectedExec
var fulfilled int
@ -242,21 +262,17 @@ func (c *sqlmock) exec(ctx interface{}, query string, args []namedValue) (res dr
}
return nil, fmt.Errorf(msg, query, args)
}
defer expected.Unlock()
if !expected.queryMatches(query) {
expected.Unlock()
return nil, fmt.Errorf("exec query '%s', does not match regex '%s'", query, expected.sqlRegex.String())
}
if err := expected.argsMatches(args); err != nil {
expected.Unlock()
return nil, fmt.Errorf("exec query '%s', arguments do not match: %s", query, err)
}
expected.triggered = true
defer time.Sleep(expected.delay)
defer expected.Unlock()
if expected.err != nil {
return nil, expected.err // mocked to return error
}
@ -265,7 +281,12 @@ func (c *sqlmock) exec(ctx interface{}, query string, args []namedValue) (res dr
return nil, fmt.Errorf("exec query '%s' with args %+v, must return a database/sql/driver.result, but it was not set for expectation %T as %+v", query, args, expected, expected)
}
return expected.result, err
return expected, nil
}
func (c *sqlmock) exec(expected *ExpectedExec) (driver.Result, error) {
defer time.Sleep(expected.delay)
return expected.result, nil
}
func (c *sqlmock) ExpectExec(sqlRegexStr string) *ExpectedExec {
@ -278,6 +299,15 @@ func (c *sqlmock) ExpectExec(sqlRegexStr string) *ExpectedExec {
// Prepare meets http://golang.org/pkg/database/sql/driver/#Conn interface
func (c *sqlmock) Prepare(query string) (driver.Stmt, error) {
ex, err := c.prepareExpectation(query)
if err != nil {
return nil, err
}
return c.prepare(ex, query)
}
func (c *sqlmock) prepareExpectation(query string) (*ExpectedPrepare, error) {
var expected *ExpectedPrepare
var fulfilled int
var ok bool
@ -307,15 +337,18 @@ func (c *sqlmock) Prepare(query string) (driver.Stmt, error) {
}
return nil, fmt.Errorf(msg, query)
}
defer expected.Unlock()
if !expected.sqlRegex.MatchString(query) {
expected.Unlock()
return nil, fmt.Errorf("query '%s', does not match regex [%s]", query, expected.sqlRegex.String())
}
expected.triggered = true
return expected, expected.err
}
func (c *sqlmock) prepare(expected *ExpectedPrepare, query string) (driver.Stmt, error) {
defer time.Sleep(expected.delay)
defer expected.Unlock()
return &statement{c, query, expected.closeErr}, expected.err
return &statement{c, query, expected.closeErr}, nil
}
func (c *sqlmock) ExpectPrepare(sqlRegexStr string) *ExpectedPrepare {
@ -332,7 +365,7 @@ type namedValue struct {
}
// Query meets http://golang.org/pkg/database/sql/driver/#Queryer
func (c *sqlmock) Query(query string, args []driver.Value) (rw driver.Rows, err error) {
func (c *sqlmock) Query(query string, args []driver.Value) (driver.Rows, error) {
namedArgs := make([]namedValue, len(args))
for i, v := range args {
namedArgs[i] = namedValue{
@ -340,12 +373,16 @@ func (c *sqlmock) Query(query string, args []driver.Value) (rw driver.Rows, err
Value: v,
}
}
return c.query(nil, query, namedArgs)
ex, err := c.queryExpectation(query, namedArgs)
if err != nil {
return nil, err
}
return c.query(ex)
}
// in order to prevent dependencies, we use Context as a plain interface
// since it is only related to internal implementation
func (c *sqlmock) query(ctx interface{}, query string, args []namedValue) (rw driver.Rows, err error) {
func (c *sqlmock) queryExpectation(query string, args []namedValue) (*ExpectedQuery, error) {
query = stripQuery(query)
var expected *ExpectedQuery
var fulfilled int
@ -382,21 +419,17 @@ func (c *sqlmock) query(ctx interface{}, query string, args []namedValue) (rw dr
return nil, fmt.Errorf(msg, query, args)
}
defer expected.Unlock()
if !expected.queryMatches(query) {
expected.Unlock()
return nil, fmt.Errorf("query '%s', does not match regex [%s]", query, expected.sqlRegex.String())
}
if err := expected.argsMatches(args); err != nil {
expected.Unlock()
return nil, fmt.Errorf("exec query '%s', arguments do not match: %s", query, err)
}
expected.triggered = true
defer time.Sleep(expected.delay)
defer expected.Unlock()
if expected.err != nil {
return nil, expected.err // mocked to return error
}
@ -404,8 +437,13 @@ func (c *sqlmock) query(ctx interface{}, query string, args []namedValue) (rw dr
if expected.rows == nil {
return nil, fmt.Errorf("query '%s' with args %+v, must return a database/sql/driver.rows, but it was not set for expectation %T as %+v", query, args, expected, expected)
}
return expected, nil
}
return expected.rows, err
func (c *sqlmock) query(expected *ExpectedQuery) (driver.Rows, error) {
defer time.Sleep(expected.delay)
return expected.rows, nil
}
func (c *sqlmock) ExpectQuery(sqlRegexStr string) *ExpectedQuery {

View File

@ -2,4 +2,142 @@
package sqlmock
// @TODO context based extensions
import (
"context"
"database/sql/driver"
"fmt"
)
var CancelledStatementErr = fmt.Errorf("canceling query due to user request")
// Implement the "QueryerContext" interface
func (c *sqlmock) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
namedArgs := make([]namedValue, len(args))
for i, nv := range args {
namedArgs[i] = namedValue(nv)
}
ex, err := c.queryExpectation(query, namedArgs)
if err != nil {
return nil, err
}
type result struct {
rows driver.Rows
err error
}
exec := make(chan result)
defer func() {
close(exec)
}()
go func() {
rows, err := c.query(ex)
exec <- result{rows, err}
}()
select {
case res := <-exec:
return res.rows, res.err
case <-ctx.Done():
return nil, CancelledStatementErr
}
}
// Implement the "ExecerContext" interface
func (c *sqlmock) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
namedArgs := make([]namedValue, len(args))
for i, nv := range args {
namedArgs[i] = namedValue(nv)
}
ex, err := c.execExpectation(query, namedArgs)
if err != nil {
return nil, err
}
type result struct {
rs driver.Result
err error
}
exec := make(chan result)
defer func() {
close(exec)
}()
go func() {
rs, err := c.exec(ex)
exec <- result{rs, err}
}()
select {
case res := <-exec:
return res.rs, res.err
case <-ctx.Done():
return nil, CancelledStatementErr
}
}
// Implement the "ConnBeginTx" interface
func (c *sqlmock) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
ex, err := c.beginExpectation()
if err != nil {
return nil, err
}
type result struct {
tx driver.Tx
err error
}
exec := make(chan result)
defer func() {
close(exec)
}()
go func() {
tx, err := c.begin(ex)
exec <- result{tx, err}
}()
select {
case res := <-exec:
return res.tx, res.err
case <-ctx.Done():
return nil, CancelledStatementErr
}
}
// Implement the "ConnPrepareContext" interface
func (c *sqlmock) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
ex, err := c.prepareExpectation(query)
if err != nil {
return nil, err
}
type result struct {
stmt driver.Stmt
err error
}
exec := make(chan result)
defer func() {
close(exec)
}()
go func() {
stmt, err := c.prepare(ex, query)
exec <- result{stmt, err}
}()
select {
case res := <-exec:
return res.stmt, res.err
case <-ctx.Done():
return nil, CancelledStatementErr
}
}
// @TODO maybe add ExpectedBegin.WithOptions(driver.TxOptions)

3
sqlmock_go18_test.go Normal file
View File

@ -0,0 +1,3 @@
// +build go1.8
package sqlmock