mirror of
https://github.com/DATA-DOG/go-sqlmock.git
synced 2024-11-28 08:49:01 +02:00
implements Context based sql driver extensions
This commit is contained in:
parent
d11f623794
commit
965003de80
88
sqlmock.go
88
sqlmock.go
@ -155,6 +155,20 @@ func (c *sqlmock) ExpectationsWereMet() error {
|
||||
|
||||
// Begin meets http://golang.org/pkg/database/sql/driver/#Conn interface
|
||||
func (c *sqlmock) Begin() (driver.Tx, error) {
|
||||
ex, err := c.beginExpectation()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return c.begin(ex)
|
||||
}
|
||||
|
||||
func (c *sqlmock) begin(expected *ExpectedBegin) (driver.Tx, error) {
|
||||
defer time.Sleep(expected.delay)
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (c *sqlmock) beginExpectation() (*ExpectedBegin, error) {
|
||||
var expected *ExpectedBegin
|
||||
var ok bool
|
||||
var fulfilled int
|
||||
@ -185,8 +199,8 @@ func (c *sqlmock) Begin() (driver.Tx, error) {
|
||||
|
||||
expected.triggered = true
|
||||
expected.Unlock()
|
||||
defer time.Sleep(expected.delay)
|
||||
return c, expected.err
|
||||
|
||||
return expected, expected.err
|
||||
}
|
||||
|
||||
func (c *sqlmock) ExpectBegin() *ExpectedBegin {
|
||||
@ -204,10 +218,16 @@ func (c *sqlmock) Exec(query string, args []driver.Value) (driver.Result, error)
|
||||
Value: v,
|
||||
}
|
||||
}
|
||||
return c.exec(nil, query, namedArgs)
|
||||
|
||||
ex, err := c.execExpectation(query, namedArgs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return c.exec(ex)
|
||||
}
|
||||
|
||||
func (c *sqlmock) exec(ctx interface{}, query string, args []namedValue) (res driver.Result, err error) {
|
||||
func (c *sqlmock) execExpectation(query string, args []namedValue) (*ExpectedExec, error) {
|
||||
query = stripQuery(query)
|
||||
var expected *ExpectedExec
|
||||
var fulfilled int
|
||||
@ -242,21 +262,17 @@ func (c *sqlmock) exec(ctx interface{}, query string, args []namedValue) (res dr
|
||||
}
|
||||
return nil, fmt.Errorf(msg, query, args)
|
||||
}
|
||||
defer expected.Unlock()
|
||||
|
||||
if !expected.queryMatches(query) {
|
||||
expected.Unlock()
|
||||
return nil, fmt.Errorf("exec query '%s', does not match regex '%s'", query, expected.sqlRegex.String())
|
||||
}
|
||||
|
||||
if err := expected.argsMatches(args); err != nil {
|
||||
expected.Unlock()
|
||||
return nil, fmt.Errorf("exec query '%s', arguments do not match: %s", query, err)
|
||||
}
|
||||
|
||||
expected.triggered = true
|
||||
defer time.Sleep(expected.delay)
|
||||
defer expected.Unlock()
|
||||
|
||||
if expected.err != nil {
|
||||
return nil, expected.err // mocked to return error
|
||||
}
|
||||
@ -265,7 +281,12 @@ func (c *sqlmock) exec(ctx interface{}, query string, args []namedValue) (res dr
|
||||
return nil, fmt.Errorf("exec query '%s' with args %+v, must return a database/sql/driver.result, but it was not set for expectation %T as %+v", query, args, expected, expected)
|
||||
}
|
||||
|
||||
return expected.result, err
|
||||
return expected, nil
|
||||
}
|
||||
|
||||
func (c *sqlmock) exec(expected *ExpectedExec) (driver.Result, error) {
|
||||
defer time.Sleep(expected.delay)
|
||||
return expected.result, nil
|
||||
}
|
||||
|
||||
func (c *sqlmock) ExpectExec(sqlRegexStr string) *ExpectedExec {
|
||||
@ -278,6 +299,15 @@ func (c *sqlmock) ExpectExec(sqlRegexStr string) *ExpectedExec {
|
||||
|
||||
// Prepare meets http://golang.org/pkg/database/sql/driver/#Conn interface
|
||||
func (c *sqlmock) Prepare(query string) (driver.Stmt, error) {
|
||||
ex, err := c.prepareExpectation(query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return c.prepare(ex, query)
|
||||
}
|
||||
|
||||
func (c *sqlmock) prepareExpectation(query string) (*ExpectedPrepare, error) {
|
||||
var expected *ExpectedPrepare
|
||||
var fulfilled int
|
||||
var ok bool
|
||||
@ -307,15 +337,18 @@ func (c *sqlmock) Prepare(query string) (driver.Stmt, error) {
|
||||
}
|
||||
return nil, fmt.Errorf(msg, query)
|
||||
}
|
||||
defer expected.Unlock()
|
||||
if !expected.sqlRegex.MatchString(query) {
|
||||
expected.Unlock()
|
||||
return nil, fmt.Errorf("query '%s', does not match regex [%s]", query, expected.sqlRegex.String())
|
||||
}
|
||||
|
||||
expected.triggered = true
|
||||
return expected, expected.err
|
||||
}
|
||||
|
||||
func (c *sqlmock) prepare(expected *ExpectedPrepare, query string) (driver.Stmt, error) {
|
||||
defer time.Sleep(expected.delay)
|
||||
defer expected.Unlock()
|
||||
return &statement{c, query, expected.closeErr}, expected.err
|
||||
return &statement{c, query, expected.closeErr}, nil
|
||||
}
|
||||
|
||||
func (c *sqlmock) ExpectPrepare(sqlRegexStr string) *ExpectedPrepare {
|
||||
@ -332,7 +365,7 @@ type namedValue struct {
|
||||
}
|
||||
|
||||
// Query meets http://golang.org/pkg/database/sql/driver/#Queryer
|
||||
func (c *sqlmock) Query(query string, args []driver.Value) (rw driver.Rows, err error) {
|
||||
func (c *sqlmock) Query(query string, args []driver.Value) (driver.Rows, error) {
|
||||
namedArgs := make([]namedValue, len(args))
|
||||
for i, v := range args {
|
||||
namedArgs[i] = namedValue{
|
||||
@ -340,12 +373,16 @@ func (c *sqlmock) Query(query string, args []driver.Value) (rw driver.Rows, err
|
||||
Value: v,
|
||||
}
|
||||
}
|
||||
return c.query(nil, query, namedArgs)
|
||||
|
||||
ex, err := c.queryExpectation(query, namedArgs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return c.query(ex)
|
||||
}
|
||||
|
||||
// in order to prevent dependencies, we use Context as a plain interface
|
||||
// since it is only related to internal implementation
|
||||
func (c *sqlmock) query(ctx interface{}, query string, args []namedValue) (rw driver.Rows, err error) {
|
||||
func (c *sqlmock) queryExpectation(query string, args []namedValue) (*ExpectedQuery, error) {
|
||||
query = stripQuery(query)
|
||||
var expected *ExpectedQuery
|
||||
var fulfilled int
|
||||
@ -382,21 +419,17 @@ func (c *sqlmock) query(ctx interface{}, query string, args []namedValue) (rw dr
|
||||
return nil, fmt.Errorf(msg, query, args)
|
||||
}
|
||||
|
||||
defer expected.Unlock()
|
||||
|
||||
if !expected.queryMatches(query) {
|
||||
expected.Unlock()
|
||||
return nil, fmt.Errorf("query '%s', does not match regex [%s]", query, expected.sqlRegex.String())
|
||||
}
|
||||
|
||||
if err := expected.argsMatches(args); err != nil {
|
||||
expected.Unlock()
|
||||
return nil, fmt.Errorf("exec query '%s', arguments do not match: %s", query, err)
|
||||
}
|
||||
|
||||
expected.triggered = true
|
||||
|
||||
defer time.Sleep(expected.delay)
|
||||
defer expected.Unlock()
|
||||
|
||||
if expected.err != nil {
|
||||
return nil, expected.err // mocked to return error
|
||||
}
|
||||
@ -404,8 +437,13 @@ func (c *sqlmock) query(ctx interface{}, query string, args []namedValue) (rw dr
|
||||
if expected.rows == nil {
|
||||
return nil, fmt.Errorf("query '%s' with args %+v, must return a database/sql/driver.rows, but it was not set for expectation %T as %+v", query, args, expected, expected)
|
||||
}
|
||||
return expected, nil
|
||||
}
|
||||
|
||||
return expected.rows, err
|
||||
func (c *sqlmock) query(expected *ExpectedQuery) (driver.Rows, error) {
|
||||
defer time.Sleep(expected.delay)
|
||||
|
||||
return expected.rows, nil
|
||||
}
|
||||
|
||||
func (c *sqlmock) ExpectQuery(sqlRegexStr string) *ExpectedQuery {
|
||||
|
140
sqlmock_go18.go
140
sqlmock_go18.go
@ -2,4 +2,142 @@
|
||||
|
||||
package sqlmock
|
||||
|
||||
// @TODO context based extensions
|
||||
import (
|
||||
"context"
|
||||
"database/sql/driver"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
var CancelledStatementErr = fmt.Errorf("canceling query due to user request")
|
||||
|
||||
// Implement the "QueryerContext" interface
|
||||
func (c *sqlmock) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
|
||||
namedArgs := make([]namedValue, len(args))
|
||||
for i, nv := range args {
|
||||
namedArgs[i] = namedValue(nv)
|
||||
}
|
||||
|
||||
ex, err := c.queryExpectation(query, namedArgs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
type result struct {
|
||||
rows driver.Rows
|
||||
err error
|
||||
}
|
||||
|
||||
exec := make(chan result)
|
||||
defer func() {
|
||||
close(exec)
|
||||
}()
|
||||
|
||||
go func() {
|
||||
rows, err := c.query(ex)
|
||||
exec <- result{rows, err}
|
||||
}()
|
||||
|
||||
select {
|
||||
case res := <-exec:
|
||||
return res.rows, res.err
|
||||
case <-ctx.Done():
|
||||
return nil, CancelledStatementErr
|
||||
}
|
||||
}
|
||||
|
||||
// Implement the "ExecerContext" interface
|
||||
func (c *sqlmock) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
|
||||
namedArgs := make([]namedValue, len(args))
|
||||
for i, nv := range args {
|
||||
namedArgs[i] = namedValue(nv)
|
||||
}
|
||||
|
||||
ex, err := c.execExpectation(query, namedArgs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
type result struct {
|
||||
rs driver.Result
|
||||
err error
|
||||
}
|
||||
|
||||
exec := make(chan result)
|
||||
defer func() {
|
||||
close(exec)
|
||||
}()
|
||||
|
||||
go func() {
|
||||
rs, err := c.exec(ex)
|
||||
exec <- result{rs, err}
|
||||
}()
|
||||
|
||||
select {
|
||||
case res := <-exec:
|
||||
return res.rs, res.err
|
||||
case <-ctx.Done():
|
||||
return nil, CancelledStatementErr
|
||||
}
|
||||
}
|
||||
|
||||
// Implement the "ConnBeginTx" interface
|
||||
func (c *sqlmock) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
|
||||
ex, err := c.beginExpectation()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
type result struct {
|
||||
tx driver.Tx
|
||||
err error
|
||||
}
|
||||
|
||||
exec := make(chan result)
|
||||
defer func() {
|
||||
close(exec)
|
||||
}()
|
||||
|
||||
go func() {
|
||||
tx, err := c.begin(ex)
|
||||
exec <- result{tx, err}
|
||||
}()
|
||||
|
||||
select {
|
||||
case res := <-exec:
|
||||
return res.tx, res.err
|
||||
case <-ctx.Done():
|
||||
return nil, CancelledStatementErr
|
||||
}
|
||||
}
|
||||
|
||||
// Implement the "ConnPrepareContext" interface
|
||||
func (c *sqlmock) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
|
||||
ex, err := c.prepareExpectation(query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
type result struct {
|
||||
stmt driver.Stmt
|
||||
err error
|
||||
}
|
||||
|
||||
exec := make(chan result)
|
||||
defer func() {
|
||||
close(exec)
|
||||
}()
|
||||
|
||||
go func() {
|
||||
stmt, err := c.prepare(ex, query)
|
||||
exec <- result{stmt, err}
|
||||
}()
|
||||
|
||||
select {
|
||||
case res := <-exec:
|
||||
return res.stmt, res.err
|
||||
case <-ctx.Done():
|
||||
return nil, CancelledStatementErr
|
||||
}
|
||||
}
|
||||
|
||||
// @TODO maybe add ExpectedBegin.WithOptions(driver.TxOptions)
|
||||
|
3
sqlmock_go18_test.go
Normal file
3
sqlmock_go18_test.go
Normal file
@ -0,0 +1,3 @@
|
||||
// +build go1.8
|
||||
|
||||
package sqlmock
|
Loading…
Reference in New Issue
Block a user