mirror of
https://github.com/DATA-DOG/go-sqlmock.git
synced 2025-04-25 12:04:40 +02:00
implements Context based sql driver extensions
This commit is contained in:
parent
d11f623794
commit
965003de80
88
sqlmock.go
88
sqlmock.go
@ -155,6 +155,20 @@ func (c *sqlmock) ExpectationsWereMet() error {
|
|||||||
|
|
||||||
// Begin meets http://golang.org/pkg/database/sql/driver/#Conn interface
|
// Begin meets http://golang.org/pkg/database/sql/driver/#Conn interface
|
||||||
func (c *sqlmock) Begin() (driver.Tx, error) {
|
func (c *sqlmock) Begin() (driver.Tx, error) {
|
||||||
|
ex, err := c.beginExpectation()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.begin(ex)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *sqlmock) begin(expected *ExpectedBegin) (driver.Tx, error) {
|
||||||
|
defer time.Sleep(expected.delay)
|
||||||
|
return c, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *sqlmock) beginExpectation() (*ExpectedBegin, error) {
|
||||||
var expected *ExpectedBegin
|
var expected *ExpectedBegin
|
||||||
var ok bool
|
var ok bool
|
||||||
var fulfilled int
|
var fulfilled int
|
||||||
@ -185,8 +199,8 @@ func (c *sqlmock) Begin() (driver.Tx, error) {
|
|||||||
|
|
||||||
expected.triggered = true
|
expected.triggered = true
|
||||||
expected.Unlock()
|
expected.Unlock()
|
||||||
defer time.Sleep(expected.delay)
|
|
||||||
return c, expected.err
|
return expected, expected.err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *sqlmock) ExpectBegin() *ExpectedBegin {
|
func (c *sqlmock) ExpectBegin() *ExpectedBegin {
|
||||||
@ -204,10 +218,16 @@ func (c *sqlmock) Exec(query string, args []driver.Value) (driver.Result, error)
|
|||||||
Value: v,
|
Value: v,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return c.exec(nil, query, namedArgs)
|
|
||||||
|
ex, err := c.execExpectation(query, namedArgs)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.exec(ex)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *sqlmock) exec(ctx interface{}, query string, args []namedValue) (res driver.Result, err error) {
|
func (c *sqlmock) execExpectation(query string, args []namedValue) (*ExpectedExec, error) {
|
||||||
query = stripQuery(query)
|
query = stripQuery(query)
|
||||||
var expected *ExpectedExec
|
var expected *ExpectedExec
|
||||||
var fulfilled int
|
var fulfilled int
|
||||||
@ -242,21 +262,17 @@ func (c *sqlmock) exec(ctx interface{}, query string, args []namedValue) (res dr
|
|||||||
}
|
}
|
||||||
return nil, fmt.Errorf(msg, query, args)
|
return nil, fmt.Errorf(msg, query, args)
|
||||||
}
|
}
|
||||||
|
defer expected.Unlock()
|
||||||
|
|
||||||
if !expected.queryMatches(query) {
|
if !expected.queryMatches(query) {
|
||||||
expected.Unlock()
|
|
||||||
return nil, fmt.Errorf("exec query '%s', does not match regex '%s'", query, expected.sqlRegex.String())
|
return nil, fmt.Errorf("exec query '%s', does not match regex '%s'", query, expected.sqlRegex.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := expected.argsMatches(args); err != nil {
|
if err := expected.argsMatches(args); err != nil {
|
||||||
expected.Unlock()
|
|
||||||
return nil, fmt.Errorf("exec query '%s', arguments do not match: %s", query, err)
|
return nil, fmt.Errorf("exec query '%s', arguments do not match: %s", query, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
expected.triggered = true
|
expected.triggered = true
|
||||||
defer time.Sleep(expected.delay)
|
|
||||||
defer expected.Unlock()
|
|
||||||
|
|
||||||
if expected.err != nil {
|
if expected.err != nil {
|
||||||
return nil, expected.err // mocked to return error
|
return nil, expected.err // mocked to return error
|
||||||
}
|
}
|
||||||
@ -265,7 +281,12 @@ func (c *sqlmock) exec(ctx interface{}, query string, args []namedValue) (res dr
|
|||||||
return nil, fmt.Errorf("exec query '%s' with args %+v, must return a database/sql/driver.result, but it was not set for expectation %T as %+v", query, args, expected, expected)
|
return nil, fmt.Errorf("exec query '%s' with args %+v, must return a database/sql/driver.result, but it was not set for expectation %T as %+v", query, args, expected, expected)
|
||||||
}
|
}
|
||||||
|
|
||||||
return expected.result, err
|
return expected, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *sqlmock) exec(expected *ExpectedExec) (driver.Result, error) {
|
||||||
|
defer time.Sleep(expected.delay)
|
||||||
|
return expected.result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *sqlmock) ExpectExec(sqlRegexStr string) *ExpectedExec {
|
func (c *sqlmock) ExpectExec(sqlRegexStr string) *ExpectedExec {
|
||||||
@ -278,6 +299,15 @@ func (c *sqlmock) ExpectExec(sqlRegexStr string) *ExpectedExec {
|
|||||||
|
|
||||||
// Prepare meets http://golang.org/pkg/database/sql/driver/#Conn interface
|
// Prepare meets http://golang.org/pkg/database/sql/driver/#Conn interface
|
||||||
func (c *sqlmock) Prepare(query string) (driver.Stmt, error) {
|
func (c *sqlmock) Prepare(query string) (driver.Stmt, error) {
|
||||||
|
ex, err := c.prepareExpectation(query)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.prepare(ex, query)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *sqlmock) prepareExpectation(query string) (*ExpectedPrepare, error) {
|
||||||
var expected *ExpectedPrepare
|
var expected *ExpectedPrepare
|
||||||
var fulfilled int
|
var fulfilled int
|
||||||
var ok bool
|
var ok bool
|
||||||
@ -307,15 +337,18 @@ func (c *sqlmock) Prepare(query string) (driver.Stmt, error) {
|
|||||||
}
|
}
|
||||||
return nil, fmt.Errorf(msg, query)
|
return nil, fmt.Errorf(msg, query)
|
||||||
}
|
}
|
||||||
|
defer expected.Unlock()
|
||||||
if !expected.sqlRegex.MatchString(query) {
|
if !expected.sqlRegex.MatchString(query) {
|
||||||
expected.Unlock()
|
|
||||||
return nil, fmt.Errorf("query '%s', does not match regex [%s]", query, expected.sqlRegex.String())
|
return nil, fmt.Errorf("query '%s', does not match regex [%s]", query, expected.sqlRegex.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
expected.triggered = true
|
expected.triggered = true
|
||||||
|
return expected, expected.err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *sqlmock) prepare(expected *ExpectedPrepare, query string) (driver.Stmt, error) {
|
||||||
defer time.Sleep(expected.delay)
|
defer time.Sleep(expected.delay)
|
||||||
defer expected.Unlock()
|
return &statement{c, query, expected.closeErr}, nil
|
||||||
return &statement{c, query, expected.closeErr}, expected.err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *sqlmock) ExpectPrepare(sqlRegexStr string) *ExpectedPrepare {
|
func (c *sqlmock) ExpectPrepare(sqlRegexStr string) *ExpectedPrepare {
|
||||||
@ -332,7 +365,7 @@ type namedValue struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Query meets http://golang.org/pkg/database/sql/driver/#Queryer
|
// Query meets http://golang.org/pkg/database/sql/driver/#Queryer
|
||||||
func (c *sqlmock) Query(query string, args []driver.Value) (rw driver.Rows, err error) {
|
func (c *sqlmock) Query(query string, args []driver.Value) (driver.Rows, error) {
|
||||||
namedArgs := make([]namedValue, len(args))
|
namedArgs := make([]namedValue, len(args))
|
||||||
for i, v := range args {
|
for i, v := range args {
|
||||||
namedArgs[i] = namedValue{
|
namedArgs[i] = namedValue{
|
||||||
@ -340,12 +373,16 @@ func (c *sqlmock) Query(query string, args []driver.Value) (rw driver.Rows, err
|
|||||||
Value: v,
|
Value: v,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return c.query(nil, query, namedArgs)
|
|
||||||
|
ex, err := c.queryExpectation(query, namedArgs)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.query(ex)
|
||||||
}
|
}
|
||||||
|
|
||||||
// in order to prevent dependencies, we use Context as a plain interface
|
func (c *sqlmock) queryExpectation(query string, args []namedValue) (*ExpectedQuery, error) {
|
||||||
// since it is only related to internal implementation
|
|
||||||
func (c *sqlmock) query(ctx interface{}, query string, args []namedValue) (rw driver.Rows, err error) {
|
|
||||||
query = stripQuery(query)
|
query = stripQuery(query)
|
||||||
var expected *ExpectedQuery
|
var expected *ExpectedQuery
|
||||||
var fulfilled int
|
var fulfilled int
|
||||||
@ -382,21 +419,17 @@ func (c *sqlmock) query(ctx interface{}, query string, args []namedValue) (rw dr
|
|||||||
return nil, fmt.Errorf(msg, query, args)
|
return nil, fmt.Errorf(msg, query, args)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
defer expected.Unlock()
|
||||||
|
|
||||||
if !expected.queryMatches(query) {
|
if !expected.queryMatches(query) {
|
||||||
expected.Unlock()
|
|
||||||
return nil, fmt.Errorf("query '%s', does not match regex [%s]", query, expected.sqlRegex.String())
|
return nil, fmt.Errorf("query '%s', does not match regex [%s]", query, expected.sqlRegex.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := expected.argsMatches(args); err != nil {
|
if err := expected.argsMatches(args); err != nil {
|
||||||
expected.Unlock()
|
|
||||||
return nil, fmt.Errorf("exec query '%s', arguments do not match: %s", query, err)
|
return nil, fmt.Errorf("exec query '%s', arguments do not match: %s", query, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
expected.triggered = true
|
expected.triggered = true
|
||||||
|
|
||||||
defer time.Sleep(expected.delay)
|
|
||||||
defer expected.Unlock()
|
|
||||||
|
|
||||||
if expected.err != nil {
|
if expected.err != nil {
|
||||||
return nil, expected.err // mocked to return error
|
return nil, expected.err // mocked to return error
|
||||||
}
|
}
|
||||||
@ -404,8 +437,13 @@ func (c *sqlmock) query(ctx interface{}, query string, args []namedValue) (rw dr
|
|||||||
if expected.rows == nil {
|
if expected.rows == nil {
|
||||||
return nil, fmt.Errorf("query '%s' with args %+v, must return a database/sql/driver.rows, but it was not set for expectation %T as %+v", query, args, expected, expected)
|
return nil, fmt.Errorf("query '%s' with args %+v, must return a database/sql/driver.rows, but it was not set for expectation %T as %+v", query, args, expected, expected)
|
||||||
}
|
}
|
||||||
|
return expected, nil
|
||||||
|
}
|
||||||
|
|
||||||
return expected.rows, err
|
func (c *sqlmock) query(expected *ExpectedQuery) (driver.Rows, error) {
|
||||||
|
defer time.Sleep(expected.delay)
|
||||||
|
|
||||||
|
return expected.rows, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *sqlmock) ExpectQuery(sqlRegexStr string) *ExpectedQuery {
|
func (c *sqlmock) ExpectQuery(sqlRegexStr string) *ExpectedQuery {
|
||||||
|
140
sqlmock_go18.go
140
sqlmock_go18.go
@ -2,4 +2,142 @@
|
|||||||
|
|
||||||
package sqlmock
|
package sqlmock
|
||||||
|
|
||||||
// @TODO context based extensions
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql/driver"
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
var CancelledStatementErr = fmt.Errorf("canceling query due to user request")
|
||||||
|
|
||||||
|
// Implement the "QueryerContext" interface
|
||||||
|
func (c *sqlmock) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
|
||||||
|
namedArgs := make([]namedValue, len(args))
|
||||||
|
for i, nv := range args {
|
||||||
|
namedArgs[i] = namedValue(nv)
|
||||||
|
}
|
||||||
|
|
||||||
|
ex, err := c.queryExpectation(query, namedArgs)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
type result struct {
|
||||||
|
rows driver.Rows
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
exec := make(chan result)
|
||||||
|
defer func() {
|
||||||
|
close(exec)
|
||||||
|
}()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
rows, err := c.query(ex)
|
||||||
|
exec <- result{rows, err}
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case res := <-exec:
|
||||||
|
return res.rows, res.err
|
||||||
|
case <-ctx.Done():
|
||||||
|
return nil, CancelledStatementErr
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Implement the "ExecerContext" interface
|
||||||
|
func (c *sqlmock) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
|
||||||
|
namedArgs := make([]namedValue, len(args))
|
||||||
|
for i, nv := range args {
|
||||||
|
namedArgs[i] = namedValue(nv)
|
||||||
|
}
|
||||||
|
|
||||||
|
ex, err := c.execExpectation(query, namedArgs)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
type result struct {
|
||||||
|
rs driver.Result
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
exec := make(chan result)
|
||||||
|
defer func() {
|
||||||
|
close(exec)
|
||||||
|
}()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
rs, err := c.exec(ex)
|
||||||
|
exec <- result{rs, err}
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case res := <-exec:
|
||||||
|
return res.rs, res.err
|
||||||
|
case <-ctx.Done():
|
||||||
|
return nil, CancelledStatementErr
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Implement the "ConnBeginTx" interface
|
||||||
|
func (c *sqlmock) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
|
||||||
|
ex, err := c.beginExpectation()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
type result struct {
|
||||||
|
tx driver.Tx
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
exec := make(chan result)
|
||||||
|
defer func() {
|
||||||
|
close(exec)
|
||||||
|
}()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
tx, err := c.begin(ex)
|
||||||
|
exec <- result{tx, err}
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case res := <-exec:
|
||||||
|
return res.tx, res.err
|
||||||
|
case <-ctx.Done():
|
||||||
|
return nil, CancelledStatementErr
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Implement the "ConnPrepareContext" interface
|
||||||
|
func (c *sqlmock) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
|
||||||
|
ex, err := c.prepareExpectation(query)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
type result struct {
|
||||||
|
stmt driver.Stmt
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
exec := make(chan result)
|
||||||
|
defer func() {
|
||||||
|
close(exec)
|
||||||
|
}()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
stmt, err := c.prepare(ex, query)
|
||||||
|
exec <- result{stmt, err}
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case res := <-exec:
|
||||||
|
return res.stmt, res.err
|
||||||
|
case <-ctx.Done():
|
||||||
|
return nil, CancelledStatementErr
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// @TODO maybe add ExpectedBegin.WithOptions(driver.TxOptions)
|
||||||
|
3
sqlmock_go18_test.go
Normal file
3
sqlmock_go18_test.go
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
// +build go1.8
|
||||||
|
|
||||||
|
package sqlmock
|
Loading…
x
Reference in New Issue
Block a user