mirror of
https://github.com/DATA-DOG/go-sqlmock.git
synced 2025-04-21 11:56:50 +02:00
tests Context sql driver extensions
This commit is contained in:
parent
965003de80
commit
cfb2877c66
45
arg_matcher_before_go18.go
Normal file
45
arg_matcher_before_go18.go
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
// +build !go1.8
|
||||||
|
|
||||||
|
package sqlmock
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql/driver"
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (e *queryBasedExpectation) argsMatches(args []namedValue) error {
|
||||||
|
if nil == e.args {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if len(args) != len(e.args) {
|
||||||
|
return fmt.Errorf("expected %d, but got %d arguments", len(e.args), len(args))
|
||||||
|
}
|
||||||
|
for k, v := range args {
|
||||||
|
// custom argument matcher
|
||||||
|
matcher, ok := e.args[k].(Argument)
|
||||||
|
if ok {
|
||||||
|
// @TODO: does it make sense to pass value instead of named value?
|
||||||
|
if !matcher.Match(v.Value) {
|
||||||
|
return fmt.Errorf("matcher %T could not match %d argument %T - %+v", matcher, k, args[k], args[k])
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
dval := e.args[k]
|
||||||
|
// convert to driver converter
|
||||||
|
darg, err := driver.DefaultParameterConverter.ConvertValue(dval)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("could not convert %d argument %T - %+v to driver value: %s", k, e.args[k], e.args[k], err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !driver.IsValue(darg) {
|
||||||
|
return fmt.Errorf("argument %d: non-subset type %T returned from Value", k, darg)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(darg, v.Value) {
|
||||||
|
return fmt.Errorf("argument %d expected [%T - %+v] does not match actual [%T - %+v]", k, darg, darg, v.Value, v.Value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
54
arg_matcher_go18.go
Normal file
54
arg_matcher_go18.go
Normal file
@ -0,0 +1,54 @@
|
|||||||
|
// +build go1.8
|
||||||
|
|
||||||
|
package sqlmock
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"database/sql/driver"
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (e *queryBasedExpectation) argsMatches(args []namedValue) error {
|
||||||
|
if nil == e.args {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if len(args) != len(e.args) {
|
||||||
|
return fmt.Errorf("expected %d, but got %d arguments", len(e.args), len(args))
|
||||||
|
}
|
||||||
|
// @TODO should we assert either all args are named or ordinal?
|
||||||
|
for k, v := range args {
|
||||||
|
// custom argument matcher
|
||||||
|
matcher, ok := e.args[k].(Argument)
|
||||||
|
if ok {
|
||||||
|
// @TODO: does it make sense to pass value instead of named value?
|
||||||
|
if !matcher.Match(v.Value) {
|
||||||
|
return fmt.Errorf("matcher %T could not match %d argument %T - %+v", matcher, k, args[k], args[k])
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
dval := e.args[k]
|
||||||
|
if named, isNamed := dval.(sql.NamedArg); isNamed {
|
||||||
|
dval = named.Value
|
||||||
|
if v.Name != named.Name {
|
||||||
|
return fmt.Errorf("named argument %d: name: \"%s\" does not match expected: \"%s\"", k, v.Name, named.Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// convert to driver converter
|
||||||
|
darg, err := driver.DefaultParameterConverter.ConvertValue(dval)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("could not convert %d argument %T - %+v to driver value: %s", k, e.args[k], e.args[k], err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !driver.IsValue(darg) {
|
||||||
|
return fmt.Errorf("argument %d: non-subset type %T returned from Value", k, darg)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(darg, v.Value) {
|
||||||
|
return fmt.Errorf("argument %d expected [%T - %+v] does not match actual [%T - %+v]", k, darg, darg, v.Value, v.Value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
@ -3,7 +3,6 @@ package sqlmock
|
|||||||
import (
|
import (
|
||||||
"database/sql/driver"
|
"database/sql/driver"
|
||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
|
||||||
"regexp"
|
"regexp"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
@ -355,49 +354,3 @@ func (e *queryBasedExpectation) attemptMatch(sql string, args []namedValue) (err
|
|||||||
func (e *queryBasedExpectation) queryMatches(sql string) bool {
|
func (e *queryBasedExpectation) queryMatches(sql string) bool {
|
||||||
return e.sqlRegex.MatchString(sql)
|
return e.sqlRegex.MatchString(sql)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *queryBasedExpectation) argsMatches(args []namedValue) error {
|
|
||||||
if nil == e.args {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
if len(args) != len(e.args) {
|
|
||||||
return fmt.Errorf("expected %d, but got %d arguments", len(e.args), len(args))
|
|
||||||
}
|
|
||||||
for k, v := range args {
|
|
||||||
// custom argument matcher
|
|
||||||
matcher, ok := e.args[k].(Argument)
|
|
||||||
if ok {
|
|
||||||
// @TODO: does it make sense to pass value instead of named value?
|
|
||||||
if !matcher.Match(v.Value) {
|
|
||||||
return fmt.Errorf("matcher %T could not match %d argument %T - %+v", matcher, k, args[k], args[k])
|
|
||||||
}
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
dval := e.args[k]
|
|
||||||
if named, isNamed := dval.(namedValue); isNamed {
|
|
||||||
dval = named.Value
|
|
||||||
if v.Name != named.Name {
|
|
||||||
return fmt.Errorf("named argument %d: name: \"%s\" does not match expected: \"%s\"", k, v.Name, named.Name)
|
|
||||||
}
|
|
||||||
if v.Ordinal != named.Ordinal {
|
|
||||||
return fmt.Errorf("named argument %d: ordinal position: \"%d\" does not match expected: \"%d\"", k, v.Ordinal, named.Ordinal)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// convert to driver converter
|
|
||||||
darg, err := driver.DefaultParameterConverter.ConvertValue(dval)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("could not convert %d argument %T - %+v to driver value: %s", k, e.args[k], e.args[k], err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !driver.IsValue(darg) {
|
|
||||||
return fmt.Errorf("argument %d: non-subset type %T returned from Value", k, darg)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !reflect.DeepEqual(darg, v.Value) {
|
|
||||||
return fmt.Errorf("argument %d expected [%T - %+v] does not match actual [%T - %+v]", k, darg, darg, v.Value, v.Value)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
@ -64,64 +64,6 @@ func TestQueryExpectationArgComparison(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestQueryExpectationNamedArgComparison(t *testing.T) {
|
|
||||||
e := &queryBasedExpectation{}
|
|
||||||
against := []namedValue{{Value: int64(5), Name: "id"}}
|
|
||||||
if err := e.argsMatches(against); err != nil {
|
|
||||||
t.Errorf("arguments should match, since the no expectation was set, but got err: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
e.args = []driver.Value{
|
|
||||||
namedValue{Name: "id", Value: int64(5)},
|
|
||||||
namedValue{Name: "s", Value: "str"},
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := e.argsMatches(against); err == nil {
|
|
||||||
t.Error("arguments should not match, since the size is not the same")
|
|
||||||
}
|
|
||||||
|
|
||||||
against = []namedValue{
|
|
||||||
{Value: int64(5), Name: "id"},
|
|
||||||
{Value: "str", Name: "s"},
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := e.argsMatches(against); err != nil {
|
|
||||||
t.Errorf("arguments should have matched, but it did not: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
against = []namedValue{
|
|
||||||
{Value: int64(5), Name: "id"},
|
|
||||||
{Value: "str", Name: "username"},
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := e.argsMatches(against); err == nil {
|
|
||||||
t.Error("arguments matched, but it should have not due to Name")
|
|
||||||
}
|
|
||||||
|
|
||||||
e.args = []driver.Value{
|
|
||||||
namedValue{Ordinal: 1, Value: int64(5)},
|
|
||||||
namedValue{Ordinal: 2, Value: "str"},
|
|
||||||
}
|
|
||||||
|
|
||||||
against = []namedValue{
|
|
||||||
{Value: int64(5), Ordinal: 0},
|
|
||||||
{Value: "str", Ordinal: 1},
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := e.argsMatches(against); err == nil {
|
|
||||||
t.Error("arguments matched, but it should have not due to wrong Ordinal position")
|
|
||||||
}
|
|
||||||
|
|
||||||
against = []namedValue{
|
|
||||||
{Value: int64(5), Ordinal: 1},
|
|
||||||
{Value: "str", Ordinal: 2},
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := e.argsMatches(against); err != nil {
|
|
||||||
t.Errorf("arguments should have matched, but it did not: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestQueryExpectationArgComparisonBool(t *testing.T) {
|
func TestQueryExpectationArgComparisonBool(t *testing.T) {
|
||||||
var e *queryBasedExpectation
|
var e *queryBasedExpectation
|
||||||
|
|
||||||
|
64
expectations_test_go18.go
Normal file
64
expectations_test_go18.go
Normal file
@ -0,0 +1,64 @@
|
|||||||
|
// +build go1.8
|
||||||
|
|
||||||
|
package sqlmock
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"database/sql/driver"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestQueryExpectationNamedArgComparison(t *testing.T) {
|
||||||
|
e := &queryBasedExpectation{}
|
||||||
|
against := []namedValue{{Value: int64(5), Name: "id"}}
|
||||||
|
if err := e.argsMatches(against); err != nil {
|
||||||
|
t.Errorf("arguments should match, since the no expectation was set, but got err: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
e.args = []driver.Value{
|
||||||
|
sql.Named("id", 5),
|
||||||
|
sql.Named("s", "str"),
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := e.argsMatches(against); err == nil {
|
||||||
|
t.Error("arguments should not match, since the size is not the same")
|
||||||
|
}
|
||||||
|
|
||||||
|
against = []namedValue{
|
||||||
|
{Value: int64(5), Name: "id"},
|
||||||
|
{Value: "str", Name: "s"},
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := e.argsMatches(against); err != nil {
|
||||||
|
t.Errorf("arguments should have matched, but it did not: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
against = []namedValue{
|
||||||
|
{Value: int64(5), Name: "id"},
|
||||||
|
{Value: "str", Name: "username"},
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := e.argsMatches(against); err == nil {
|
||||||
|
t.Error("arguments matched, but it should have not due to Name")
|
||||||
|
}
|
||||||
|
|
||||||
|
e.args = []driver.Value{int64(5), "str"}
|
||||||
|
|
||||||
|
against = []namedValue{
|
||||||
|
{Value: int64(5), Ordinal: 0},
|
||||||
|
{Value: "str", Ordinal: 1},
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := e.argsMatches(against); err == nil {
|
||||||
|
t.Error("arguments matched, but it should have not due to wrong Ordinal position")
|
||||||
|
}
|
||||||
|
|
||||||
|
against = []namedValue{
|
||||||
|
{Value: int64(5), Ordinal: 1},
|
||||||
|
{Value: "str", Ordinal: 2},
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := e.argsMatches(against); err != nil {
|
||||||
|
t.Errorf("arguments should have matched, but it did not: %v", err)
|
||||||
|
}
|
||||||
|
}
|
47
sqlmock.go
47
sqlmock.go
@ -155,20 +155,16 @@ func (c *sqlmock) ExpectationsWereMet() error {
|
|||||||
|
|
||||||
// Begin meets http://golang.org/pkg/database/sql/driver/#Conn interface
|
// Begin meets http://golang.org/pkg/database/sql/driver/#Conn interface
|
||||||
func (c *sqlmock) Begin() (driver.Tx, error) {
|
func (c *sqlmock) Begin() (driver.Tx, error) {
|
||||||
ex, err := c.beginExpectation()
|
ex, err := c.begin()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return c.begin(ex)
|
time.Sleep(ex.delay)
|
||||||
}
|
|
||||||
|
|
||||||
func (c *sqlmock) begin(expected *ExpectedBegin) (driver.Tx, error) {
|
|
||||||
defer time.Sleep(expected.delay)
|
|
||||||
return c, nil
|
return c, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *sqlmock) beginExpectation() (*ExpectedBegin, error) {
|
func (c *sqlmock) begin() (*ExpectedBegin, error) {
|
||||||
var expected *ExpectedBegin
|
var expected *ExpectedBegin
|
||||||
var ok bool
|
var ok bool
|
||||||
var fulfilled int
|
var fulfilled int
|
||||||
@ -219,15 +215,16 @@ func (c *sqlmock) Exec(query string, args []driver.Value) (driver.Result, error)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
ex, err := c.execExpectation(query, namedArgs)
|
ex, err := c.exec(query, namedArgs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return c.exec(ex)
|
time.Sleep(ex.delay)
|
||||||
|
return ex.result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *sqlmock) execExpectation(query string, args []namedValue) (*ExpectedExec, error) {
|
func (c *sqlmock) exec(query string, args []namedValue) (*ExpectedExec, error) {
|
||||||
query = stripQuery(query)
|
query = stripQuery(query)
|
||||||
var expected *ExpectedExec
|
var expected *ExpectedExec
|
||||||
var fulfilled int
|
var fulfilled int
|
||||||
@ -284,11 +281,6 @@ func (c *sqlmock) execExpectation(query string, args []namedValue) (*ExpectedExe
|
|||||||
return expected, nil
|
return expected, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *sqlmock) exec(expected *ExpectedExec) (driver.Result, error) {
|
|
||||||
defer time.Sleep(expected.delay)
|
|
||||||
return expected.result, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *sqlmock) ExpectExec(sqlRegexStr string) *ExpectedExec {
|
func (c *sqlmock) ExpectExec(sqlRegexStr string) *ExpectedExec {
|
||||||
e := &ExpectedExec{}
|
e := &ExpectedExec{}
|
||||||
sqlRegexStr = stripQuery(sqlRegexStr)
|
sqlRegexStr = stripQuery(sqlRegexStr)
|
||||||
@ -299,15 +291,16 @@ func (c *sqlmock) ExpectExec(sqlRegexStr string) *ExpectedExec {
|
|||||||
|
|
||||||
// Prepare meets http://golang.org/pkg/database/sql/driver/#Conn interface
|
// Prepare meets http://golang.org/pkg/database/sql/driver/#Conn interface
|
||||||
func (c *sqlmock) Prepare(query string) (driver.Stmt, error) {
|
func (c *sqlmock) Prepare(query string) (driver.Stmt, error) {
|
||||||
ex, err := c.prepareExpectation(query)
|
ex, err := c.prepare(query)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return c.prepare(ex, query)
|
time.Sleep(ex.delay)
|
||||||
|
return &statement{c, query, ex.closeErr}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *sqlmock) prepareExpectation(query string) (*ExpectedPrepare, error) {
|
func (c *sqlmock) prepare(query string) (*ExpectedPrepare, error) {
|
||||||
var expected *ExpectedPrepare
|
var expected *ExpectedPrepare
|
||||||
var fulfilled int
|
var fulfilled int
|
||||||
var ok bool
|
var ok bool
|
||||||
@ -346,11 +339,6 @@ func (c *sqlmock) prepareExpectation(query string) (*ExpectedPrepare, error) {
|
|||||||
return expected, expected.err
|
return expected, expected.err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *sqlmock) prepare(expected *ExpectedPrepare, query string) (driver.Stmt, error) {
|
|
||||||
defer time.Sleep(expected.delay)
|
|
||||||
return &statement{c, query, expected.closeErr}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *sqlmock) ExpectPrepare(sqlRegexStr string) *ExpectedPrepare {
|
func (c *sqlmock) ExpectPrepare(sqlRegexStr string) *ExpectedPrepare {
|
||||||
sqlRegexStr = stripQuery(sqlRegexStr)
|
sqlRegexStr = stripQuery(sqlRegexStr)
|
||||||
e := &ExpectedPrepare{sqlRegex: regexp.MustCompile(sqlRegexStr), mock: c}
|
e := &ExpectedPrepare{sqlRegex: regexp.MustCompile(sqlRegexStr), mock: c}
|
||||||
@ -374,15 +362,16 @@ func (c *sqlmock) Query(query string, args []driver.Value) (driver.Rows, error)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
ex, err := c.queryExpectation(query, namedArgs)
|
ex, err := c.query(query, namedArgs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return c.query(ex)
|
time.Sleep(ex.delay)
|
||||||
|
return ex.rows, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *sqlmock) queryExpectation(query string, args []namedValue) (*ExpectedQuery, error) {
|
func (c *sqlmock) query(query string, args []namedValue) (*ExpectedQuery, error) {
|
||||||
query = stripQuery(query)
|
query = stripQuery(query)
|
||||||
var expected *ExpectedQuery
|
var expected *ExpectedQuery
|
||||||
var fulfilled int
|
var fulfilled int
|
||||||
@ -440,12 +429,6 @@ func (c *sqlmock) queryExpectation(query string, args []namedValue) (*ExpectedQu
|
|||||||
return expected, nil
|
return expected, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *sqlmock) query(expected *ExpectedQuery) (driver.Rows, error) {
|
|
||||||
defer time.Sleep(expected.delay)
|
|
||||||
|
|
||||||
return expected.rows, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *sqlmock) ExpectQuery(sqlRegexStr string) *ExpectedQuery {
|
func (c *sqlmock) ExpectQuery(sqlRegexStr string) *ExpectedQuery {
|
||||||
e := &ExpectedQuery{}
|
e := &ExpectedQuery{}
|
||||||
sqlRegexStr = stripQuery(sqlRegexStr)
|
sqlRegexStr = stripQuery(sqlRegexStr)
|
||||||
|
@ -5,10 +5,11 @@ package sqlmock
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql/driver"
|
"database/sql/driver"
|
||||||
"fmt"
|
"errors"
|
||||||
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
var CancelledStatementErr = fmt.Errorf("canceling query due to user request")
|
var ErrCancelled = errors.New("canceling query due to user request")
|
||||||
|
|
||||||
// Implement the "QueryerContext" interface
|
// Implement the "QueryerContext" interface
|
||||||
func (c *sqlmock) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
|
func (c *sqlmock) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
|
||||||
@ -17,31 +18,16 @@ func (c *sqlmock) QueryContext(ctx context.Context, query string, args []driver.
|
|||||||
namedArgs[i] = namedValue(nv)
|
namedArgs[i] = namedValue(nv)
|
||||||
}
|
}
|
||||||
|
|
||||||
ex, err := c.queryExpectation(query, namedArgs)
|
ex, err := c.query(query, namedArgs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
type result struct {
|
|
||||||
rows driver.Rows
|
|
||||||
err error
|
|
||||||
}
|
|
||||||
|
|
||||||
exec := make(chan result)
|
|
||||||
defer func() {
|
|
||||||
close(exec)
|
|
||||||
}()
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
rows, err := c.query(ex)
|
|
||||||
exec <- result{rows, err}
|
|
||||||
}()
|
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case res := <-exec:
|
case <-time.After(ex.delay):
|
||||||
return res.rows, res.err
|
return ex.rows, nil
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return nil, CancelledStatementErr
|
return nil, ErrCancelled
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -52,91 +38,46 @@ func (c *sqlmock) ExecContext(ctx context.Context, query string, args []driver.N
|
|||||||
namedArgs[i] = namedValue(nv)
|
namedArgs[i] = namedValue(nv)
|
||||||
}
|
}
|
||||||
|
|
||||||
ex, err := c.execExpectation(query, namedArgs)
|
ex, err := c.exec(query, namedArgs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
type result struct {
|
|
||||||
rs driver.Result
|
|
||||||
err error
|
|
||||||
}
|
|
||||||
|
|
||||||
exec := make(chan result)
|
|
||||||
defer func() {
|
|
||||||
close(exec)
|
|
||||||
}()
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
rs, err := c.exec(ex)
|
|
||||||
exec <- result{rs, err}
|
|
||||||
}()
|
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case res := <-exec:
|
case <-time.After(ex.delay):
|
||||||
return res.rs, res.err
|
return ex.result, nil
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return nil, CancelledStatementErr
|
return nil, ErrCancelled
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Implement the "ConnBeginTx" interface
|
// Implement the "ConnBeginTx" interface
|
||||||
func (c *sqlmock) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
|
func (c *sqlmock) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
|
||||||
ex, err := c.beginExpectation()
|
ex, err := c.begin()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
type result struct {
|
|
||||||
tx driver.Tx
|
|
||||||
err error
|
|
||||||
}
|
|
||||||
|
|
||||||
exec := make(chan result)
|
|
||||||
defer func() {
|
|
||||||
close(exec)
|
|
||||||
}()
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
tx, err := c.begin(ex)
|
|
||||||
exec <- result{tx, err}
|
|
||||||
}()
|
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case res := <-exec:
|
case <-time.After(ex.delay):
|
||||||
return res.tx, res.err
|
return c, nil
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return nil, CancelledStatementErr
|
return nil, ErrCancelled
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Implement the "ConnPrepareContext" interface
|
// Implement the "ConnPrepareContext" interface
|
||||||
func (c *sqlmock) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
|
func (c *sqlmock) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
|
||||||
ex, err := c.prepareExpectation(query)
|
ex, err := c.prepare(query)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
type result struct {
|
|
||||||
stmt driver.Stmt
|
|
||||||
err error
|
|
||||||
}
|
|
||||||
|
|
||||||
exec := make(chan result)
|
|
||||||
defer func() {
|
|
||||||
close(exec)
|
|
||||||
}()
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
stmt, err := c.prepare(ex, query)
|
|
||||||
exec <- result{stmt, err}
|
|
||||||
}()
|
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case res := <-exec:
|
case <-time.After(ex.delay):
|
||||||
return res.stmt, res.err
|
return &statement{c, query, ex.closeErr}, nil
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return nil, CancelledStatementErr
|
return nil, ErrCancelled
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,3 +1,335 @@
|
|||||||
// +build go1.8
|
// +build go1.8
|
||||||
|
|
||||||
package sqlmock
|
package sqlmock
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestContextExecCancel(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
db, mock, err := New()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("an error '%s' was not expected when opening a stub database connection", err)
|
||||||
|
}
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
mock.ExpectExec("DELETE FROM users").
|
||||||
|
WillDelayFor(time.Second).
|
||||||
|
WillReturnResult(NewResult(1, 1))
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
time.Sleep(time.Millisecond * 10)
|
||||||
|
cancel()
|
||||||
|
}()
|
||||||
|
|
||||||
|
_, err = db.ExecContext(ctx, "DELETE FROM users")
|
||||||
|
if err == nil {
|
||||||
|
t.Error("error was expected, but there was none")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != ErrCancelled {
|
||||||
|
t.Errorf("was expecting cancel error, but got: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = db.ExecContext(ctx, "DELETE FROM users")
|
||||||
|
if err != context.Canceled {
|
||||||
|
t.Error("error was expected since context was already done, but there was none")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := mock.ExpectationsWereMet(); err != nil {
|
||||||
|
t.Errorf("there were unfulfilled expections: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestContextExecWithNamedArg(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
db, mock, err := New()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("an error '%s' was not expected when opening a stub database connection", err)
|
||||||
|
}
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
mock.ExpectExec("DELETE FROM users").
|
||||||
|
WithArgs(sql.Named("id", 5)).
|
||||||
|
WillDelayFor(time.Second).
|
||||||
|
WillReturnResult(NewResult(1, 1))
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
time.Sleep(time.Millisecond * 10)
|
||||||
|
cancel()
|
||||||
|
}()
|
||||||
|
|
||||||
|
_, err = db.ExecContext(ctx, "DELETE FROM users WHERE id = :id", sql.Named("id", 5))
|
||||||
|
if err == nil {
|
||||||
|
t.Error("error was expected, but there was none")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != ErrCancelled {
|
||||||
|
t.Errorf("was expecting cancel error, but got: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = db.ExecContext(ctx, "DELETE FROM users WHERE id = :id", sql.Named("id", 5))
|
||||||
|
if err != context.Canceled {
|
||||||
|
t.Error("error was expected since context was already done, but there was none")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := mock.ExpectationsWereMet(); err != nil {
|
||||||
|
t.Errorf("there were unfulfilled expections: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestContextExec(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
db, mock, err := New()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("an error '%s' was not expected when opening a stub database connection", err)
|
||||||
|
}
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
mock.ExpectExec("DELETE FROM users").
|
||||||
|
WillReturnResult(NewResult(1, 1))
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
time.Sleep(time.Millisecond * 10)
|
||||||
|
cancel()
|
||||||
|
}()
|
||||||
|
|
||||||
|
res, err := db.ExecContext(ctx, "DELETE FROM users")
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("error was not expected, but got: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
affected, err := res.RowsAffected()
|
||||||
|
if affected != 1 {
|
||||||
|
t.Errorf("expected affected rows 1, but got %v", affected)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("error was not expected, but got: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := mock.ExpectationsWereMet(); err != nil {
|
||||||
|
t.Errorf("there were unfulfilled expections: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestContextQueryCancel(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
db, mock, err := New()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("an error '%s' was not expected when opening a stub database connection", err)
|
||||||
|
}
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
rs := NewRows([]string{"id", "title"}).AddRow(5, "hello world")
|
||||||
|
|
||||||
|
mock.ExpectQuery("SELECT (.+) FROM articles WHERE id = ?").
|
||||||
|
WithArgs(5).
|
||||||
|
WillDelayFor(time.Second).
|
||||||
|
WillReturnRows(rs)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
time.Sleep(time.Millisecond * 10)
|
||||||
|
cancel()
|
||||||
|
}()
|
||||||
|
|
||||||
|
_, err = db.QueryContext(ctx, "SELECT id, title FROM articles WHERE id = ?", 5)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("error was expected, but there was none")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != ErrCancelled {
|
||||||
|
t.Errorf("was expecting cancel error, but got: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = db.QueryContext(ctx, "SELECT id, title FROM articles WHERE id = ?", 5)
|
||||||
|
if err != context.Canceled {
|
||||||
|
t.Error("error was expected since context was already done, but there was none")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := mock.ExpectationsWereMet(); err != nil {
|
||||||
|
t.Errorf("there were unfulfilled expections: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestContextQuery(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
db, mock, err := New()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("an error '%s' was not expected when opening a stub database connection", err)
|
||||||
|
}
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
rs := NewRows([]string{"id", "title"}).AddRow(5, "hello world")
|
||||||
|
|
||||||
|
mock.ExpectQuery("SELECT (.+) FROM articles WHERE id =").
|
||||||
|
WithArgs(sql.Named("id", 5)).
|
||||||
|
WillDelayFor(time.Millisecond * 3).
|
||||||
|
WillReturnRows(rs)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
time.Sleep(time.Millisecond * 10)
|
||||||
|
cancel()
|
||||||
|
}()
|
||||||
|
|
||||||
|
rows, err := db.QueryContext(ctx, "SELECT id, title FROM articles WHERE id = :id", sql.Named("id", 5))
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("error was not expected, but got: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !rows.Next() {
|
||||||
|
t.Error("expected one row, but there was none")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := mock.ExpectationsWereMet(); err != nil {
|
||||||
|
t.Errorf("there were unfulfilled expections: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestContextBeginCancel(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
db, mock, err := New()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("an error '%s' was not expected when opening a stub database connection", err)
|
||||||
|
}
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
mock.ExpectBegin().WillDelayFor(time.Second)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
time.Sleep(time.Millisecond * 10)
|
||||||
|
cancel()
|
||||||
|
}()
|
||||||
|
|
||||||
|
_, err = db.BeginTx(ctx, nil)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("error was expected, but there was none")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != ErrCancelled {
|
||||||
|
t.Errorf("was expecting cancel error, but got: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = db.BeginTx(ctx, nil)
|
||||||
|
if err != context.Canceled {
|
||||||
|
t.Error("error was expected since context was already done, but there was none")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := mock.ExpectationsWereMet(); err != nil {
|
||||||
|
t.Errorf("there were unfulfilled expections: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestContextBegin(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
db, mock, err := New()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("an error '%s' was not expected when opening a stub database connection", err)
|
||||||
|
}
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
mock.ExpectBegin().WillDelayFor(time.Millisecond * 3)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
time.Sleep(time.Millisecond * 10)
|
||||||
|
cancel()
|
||||||
|
}()
|
||||||
|
|
||||||
|
tx, err := db.BeginTx(ctx, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("error was not expected, but got: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if tx == nil {
|
||||||
|
t.Error("expected tx, but there was nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := mock.ExpectationsWereMet(); err != nil {
|
||||||
|
t.Errorf("there were unfulfilled expections: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestContextPrepareCancel(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
db, mock, err := New()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("an error '%s' was not expected when opening a stub database connection", err)
|
||||||
|
}
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
mock.ExpectPrepare("SELECT").WillDelayFor(time.Second)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
time.Sleep(time.Millisecond * 10)
|
||||||
|
cancel()
|
||||||
|
}()
|
||||||
|
|
||||||
|
_, err = db.PrepareContext(ctx, "SELECT")
|
||||||
|
if err == nil {
|
||||||
|
t.Error("error was expected, but there was none")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != ErrCancelled {
|
||||||
|
t.Errorf("was expecting cancel error, but got: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = db.PrepareContext(ctx, "SELECT")
|
||||||
|
if err != context.Canceled {
|
||||||
|
t.Error("error was expected since context was already done, but there was none")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := mock.ExpectationsWereMet(); err != nil {
|
||||||
|
t.Errorf("there were unfulfilled expections: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestContextPrepare(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
db, mock, err := New()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("an error '%s' was not expected when opening a stub database connection", err)
|
||||||
|
}
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
mock.ExpectPrepare("SELECT").WillDelayFor(time.Millisecond * 3)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
time.Sleep(time.Millisecond * 10)
|
||||||
|
cancel()
|
||||||
|
}()
|
||||||
|
|
||||||
|
stmt, err := db.PrepareContext(ctx, "SELECT")
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("error was not expected, but got: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if stmt == nil {
|
||||||
|
t.Error("expected stmt, but there was nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := mock.ExpectationsWereMet(); err != nil {
|
||||||
|
t.Errorf("there were unfulfilled expections: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user