From cfb2877c66e3438409cdc4c6fe478c36e56ff274 Mon Sep 17 00:00:00 2001 From: gedi Date: Tue, 7 Feb 2017 15:03:05 +0200 Subject: [PATCH] tests Context sql driver extensions --- arg_matcher_before_go18.go | 45 +++++ arg_matcher_go18.go | 54 ++++++ expectations.go | 47 ------ expectations_test.go | 58 ------- expectations_test_go18.go | 64 +++++++ sqlmock.go | 47 ++---- sqlmock_go18.go | 97 +++-------- sqlmock_go18_test.go | 332 +++++++++++++++++++++++++++++++++++++ 8 files changed, 529 insertions(+), 215 deletions(-) create mode 100644 arg_matcher_before_go18.go create mode 100644 arg_matcher_go18.go create mode 100644 expectations_test_go18.go diff --git a/arg_matcher_before_go18.go b/arg_matcher_before_go18.go new file mode 100644 index 0000000..52eb369 --- /dev/null +++ b/arg_matcher_before_go18.go @@ -0,0 +1,45 @@ +// +build !go1.8 + +package sqlmock + +import ( + "database/sql/driver" + "fmt" + "reflect" +) + +func (e *queryBasedExpectation) argsMatches(args []namedValue) error { + if nil == e.args { + return nil + } + if len(args) != len(e.args) { + return fmt.Errorf("expected %d, but got %d arguments", len(e.args), len(args)) + } + for k, v := range args { + // custom argument matcher + matcher, ok := e.args[k].(Argument) + if ok { + // @TODO: does it make sense to pass value instead of named value? + if !matcher.Match(v.Value) { + return fmt.Errorf("matcher %T could not match %d argument %T - %+v", matcher, k, args[k], args[k]) + } + continue + } + + dval := e.args[k] + // convert to driver converter + darg, err := driver.DefaultParameterConverter.ConvertValue(dval) + if err != nil { + return fmt.Errorf("could not convert %d argument %T - %+v to driver value: %s", k, e.args[k], e.args[k], err) + } + + if !driver.IsValue(darg) { + return fmt.Errorf("argument %d: non-subset type %T returned from Value", k, darg) + } + + if !reflect.DeepEqual(darg, v.Value) { + return fmt.Errorf("argument %d expected [%T - %+v] does not match actual [%T - %+v]", k, darg, darg, v.Value, v.Value) + } + } + return nil +} diff --git a/arg_matcher_go18.go b/arg_matcher_go18.go new file mode 100644 index 0000000..610eac3 --- /dev/null +++ b/arg_matcher_go18.go @@ -0,0 +1,54 @@ +// +build go1.8 + +package sqlmock + +import ( + "database/sql" + "database/sql/driver" + "fmt" + "reflect" +) + +func (e *queryBasedExpectation) argsMatches(args []namedValue) error { + if nil == e.args { + return nil + } + if len(args) != len(e.args) { + return fmt.Errorf("expected %d, but got %d arguments", len(e.args), len(args)) + } + // @TODO should we assert either all args are named or ordinal? + for k, v := range args { + // custom argument matcher + matcher, ok := e.args[k].(Argument) + if ok { + // @TODO: does it make sense to pass value instead of named value? + if !matcher.Match(v.Value) { + return fmt.Errorf("matcher %T could not match %d argument %T - %+v", matcher, k, args[k], args[k]) + } + continue + } + + dval := e.args[k] + if named, isNamed := dval.(sql.NamedArg); isNamed { + dval = named.Value + if v.Name != named.Name { + return fmt.Errorf("named argument %d: name: \"%s\" does not match expected: \"%s\"", k, v.Name, named.Name) + } + } + + // convert to driver converter + darg, err := driver.DefaultParameterConverter.ConvertValue(dval) + if err != nil { + return fmt.Errorf("could not convert %d argument %T - %+v to driver value: %s", k, e.args[k], e.args[k], err) + } + + if !driver.IsValue(darg) { + return fmt.Errorf("argument %d: non-subset type %T returned from Value", k, darg) + } + + if !reflect.DeepEqual(darg, v.Value) { + return fmt.Errorf("argument %d expected [%T - %+v] does not match actual [%T - %+v]", k, darg, darg, v.Value, v.Value) + } + } + return nil +} diff --git a/expectations.go b/expectations.go index 3902ff3..b19fbc9 100644 --- a/expectations.go +++ b/expectations.go @@ -3,7 +3,6 @@ package sqlmock import ( "database/sql/driver" "fmt" - "reflect" "regexp" "strings" "sync" @@ -355,49 +354,3 @@ func (e *queryBasedExpectation) attemptMatch(sql string, args []namedValue) (err func (e *queryBasedExpectation) queryMatches(sql string) bool { return e.sqlRegex.MatchString(sql) } - -func (e *queryBasedExpectation) argsMatches(args []namedValue) error { - if nil == e.args { - return nil - } - if len(args) != len(e.args) { - return fmt.Errorf("expected %d, but got %d arguments", len(e.args), len(args)) - } - for k, v := range args { - // custom argument matcher - matcher, ok := e.args[k].(Argument) - if ok { - // @TODO: does it make sense to pass value instead of named value? - if !matcher.Match(v.Value) { - return fmt.Errorf("matcher %T could not match %d argument %T - %+v", matcher, k, args[k], args[k]) - } - continue - } - - dval := e.args[k] - if named, isNamed := dval.(namedValue); isNamed { - dval = named.Value - if v.Name != named.Name { - return fmt.Errorf("named argument %d: name: \"%s\" does not match expected: \"%s\"", k, v.Name, named.Name) - } - if v.Ordinal != named.Ordinal { - return fmt.Errorf("named argument %d: ordinal position: \"%d\" does not match expected: \"%d\"", k, v.Ordinal, named.Ordinal) - } - } - - // convert to driver converter - darg, err := driver.DefaultParameterConverter.ConvertValue(dval) - if err != nil { - return fmt.Errorf("could not convert %d argument %T - %+v to driver value: %s", k, e.args[k], e.args[k], err) - } - - if !driver.IsValue(darg) { - return fmt.Errorf("argument %d: non-subset type %T returned from Value", k, darg) - } - - if !reflect.DeepEqual(darg, v.Value) { - return fmt.Errorf("argument %d expected [%T - %+v] does not match actual [%T - %+v]", k, darg, darg, v.Value, v.Value) - } - } - return nil -} diff --git a/expectations_test.go b/expectations_test.go index 6238532..2e3c097 100644 --- a/expectations_test.go +++ b/expectations_test.go @@ -64,64 +64,6 @@ func TestQueryExpectationArgComparison(t *testing.T) { } } -func TestQueryExpectationNamedArgComparison(t *testing.T) { - e := &queryBasedExpectation{} - against := []namedValue{{Value: int64(5), Name: "id"}} - if err := e.argsMatches(against); err != nil { - t.Errorf("arguments should match, since the no expectation was set, but got err: %s", err) - } - - e.args = []driver.Value{ - namedValue{Name: "id", Value: int64(5)}, - namedValue{Name: "s", Value: "str"}, - } - - if err := e.argsMatches(against); err == nil { - t.Error("arguments should not match, since the size is not the same") - } - - against = []namedValue{ - {Value: int64(5), Name: "id"}, - {Value: "str", Name: "s"}, - } - - if err := e.argsMatches(against); err != nil { - t.Errorf("arguments should have matched, but it did not: %v", err) - } - - against = []namedValue{ - {Value: int64(5), Name: "id"}, - {Value: "str", Name: "username"}, - } - - if err := e.argsMatches(against); err == nil { - t.Error("arguments matched, but it should have not due to Name") - } - - e.args = []driver.Value{ - namedValue{Ordinal: 1, Value: int64(5)}, - namedValue{Ordinal: 2, Value: "str"}, - } - - against = []namedValue{ - {Value: int64(5), Ordinal: 0}, - {Value: "str", Ordinal: 1}, - } - - if err := e.argsMatches(against); err == nil { - t.Error("arguments matched, but it should have not due to wrong Ordinal position") - } - - against = []namedValue{ - {Value: int64(5), Ordinal: 1}, - {Value: "str", Ordinal: 2}, - } - - if err := e.argsMatches(against); err != nil { - t.Errorf("arguments should have matched, but it did not: %v", err) - } -} - func TestQueryExpectationArgComparisonBool(t *testing.T) { var e *queryBasedExpectation diff --git a/expectations_test_go18.go b/expectations_test_go18.go new file mode 100644 index 0000000..5f30d2f --- /dev/null +++ b/expectations_test_go18.go @@ -0,0 +1,64 @@ +// +build go1.8 + +package sqlmock + +import ( + "database/sql" + "database/sql/driver" + "testing" +) + +func TestQueryExpectationNamedArgComparison(t *testing.T) { + e := &queryBasedExpectation{} + against := []namedValue{{Value: int64(5), Name: "id"}} + if err := e.argsMatches(against); err != nil { + t.Errorf("arguments should match, since the no expectation was set, but got err: %s", err) + } + + e.args = []driver.Value{ + sql.Named("id", 5), + sql.Named("s", "str"), + } + + if err := e.argsMatches(against); err == nil { + t.Error("arguments should not match, since the size is not the same") + } + + against = []namedValue{ + {Value: int64(5), Name: "id"}, + {Value: "str", Name: "s"}, + } + + if err := e.argsMatches(against); err != nil { + t.Errorf("arguments should have matched, but it did not: %v", err) + } + + against = []namedValue{ + {Value: int64(5), Name: "id"}, + {Value: "str", Name: "username"}, + } + + if err := e.argsMatches(against); err == nil { + t.Error("arguments matched, but it should have not due to Name") + } + + e.args = []driver.Value{int64(5), "str"} + + against = []namedValue{ + {Value: int64(5), Ordinal: 0}, + {Value: "str", Ordinal: 1}, + } + + if err := e.argsMatches(against); err == nil { + t.Error("arguments matched, but it should have not due to wrong Ordinal position") + } + + against = []namedValue{ + {Value: int64(5), Ordinal: 1}, + {Value: "str", Ordinal: 2}, + } + + if err := e.argsMatches(against); err != nil { + t.Errorf("arguments should have matched, but it did not: %v", err) + } +} diff --git a/sqlmock.go b/sqlmock.go index 536fa13..2052174 100644 --- a/sqlmock.go +++ b/sqlmock.go @@ -155,20 +155,16 @@ func (c *sqlmock) ExpectationsWereMet() error { // Begin meets http://golang.org/pkg/database/sql/driver/#Conn interface func (c *sqlmock) Begin() (driver.Tx, error) { - ex, err := c.beginExpectation() + ex, err := c.begin() if err != nil { return nil, err } - return c.begin(ex) -} - -func (c *sqlmock) begin(expected *ExpectedBegin) (driver.Tx, error) { - defer time.Sleep(expected.delay) + time.Sleep(ex.delay) return c, nil } -func (c *sqlmock) beginExpectation() (*ExpectedBegin, error) { +func (c *sqlmock) begin() (*ExpectedBegin, error) { var expected *ExpectedBegin var ok bool var fulfilled int @@ -219,15 +215,16 @@ func (c *sqlmock) Exec(query string, args []driver.Value) (driver.Result, error) } } - ex, err := c.execExpectation(query, namedArgs) + ex, err := c.exec(query, namedArgs) if err != nil { return nil, err } - return c.exec(ex) + time.Sleep(ex.delay) + return ex.result, nil } -func (c *sqlmock) execExpectation(query string, args []namedValue) (*ExpectedExec, error) { +func (c *sqlmock) exec(query string, args []namedValue) (*ExpectedExec, error) { query = stripQuery(query) var expected *ExpectedExec var fulfilled int @@ -284,11 +281,6 @@ func (c *sqlmock) execExpectation(query string, args []namedValue) (*ExpectedExe return expected, nil } -func (c *sqlmock) exec(expected *ExpectedExec) (driver.Result, error) { - defer time.Sleep(expected.delay) - return expected.result, nil -} - func (c *sqlmock) ExpectExec(sqlRegexStr string) *ExpectedExec { e := &ExpectedExec{} sqlRegexStr = stripQuery(sqlRegexStr) @@ -299,15 +291,16 @@ func (c *sqlmock) ExpectExec(sqlRegexStr string) *ExpectedExec { // Prepare meets http://golang.org/pkg/database/sql/driver/#Conn interface func (c *sqlmock) Prepare(query string) (driver.Stmt, error) { - ex, err := c.prepareExpectation(query) + ex, err := c.prepare(query) if err != nil { return nil, err } - return c.prepare(ex, query) + time.Sleep(ex.delay) + return &statement{c, query, ex.closeErr}, nil } -func (c *sqlmock) prepareExpectation(query string) (*ExpectedPrepare, error) { +func (c *sqlmock) prepare(query string) (*ExpectedPrepare, error) { var expected *ExpectedPrepare var fulfilled int var ok bool @@ -346,11 +339,6 @@ func (c *sqlmock) prepareExpectation(query string) (*ExpectedPrepare, error) { return expected, expected.err } -func (c *sqlmock) prepare(expected *ExpectedPrepare, query string) (driver.Stmt, error) { - defer time.Sleep(expected.delay) - return &statement{c, query, expected.closeErr}, nil -} - func (c *sqlmock) ExpectPrepare(sqlRegexStr string) *ExpectedPrepare { sqlRegexStr = stripQuery(sqlRegexStr) e := &ExpectedPrepare{sqlRegex: regexp.MustCompile(sqlRegexStr), mock: c} @@ -374,15 +362,16 @@ func (c *sqlmock) Query(query string, args []driver.Value) (driver.Rows, error) } } - ex, err := c.queryExpectation(query, namedArgs) + ex, err := c.query(query, namedArgs) if err != nil { return nil, err } - return c.query(ex) + time.Sleep(ex.delay) + return ex.rows, nil } -func (c *sqlmock) queryExpectation(query string, args []namedValue) (*ExpectedQuery, error) { +func (c *sqlmock) query(query string, args []namedValue) (*ExpectedQuery, error) { query = stripQuery(query) var expected *ExpectedQuery var fulfilled int @@ -440,12 +429,6 @@ func (c *sqlmock) queryExpectation(query string, args []namedValue) (*ExpectedQu return expected, nil } -func (c *sqlmock) query(expected *ExpectedQuery) (driver.Rows, error) { - defer time.Sleep(expected.delay) - - return expected.rows, nil -} - func (c *sqlmock) ExpectQuery(sqlRegexStr string) *ExpectedQuery { e := &ExpectedQuery{} sqlRegexStr = stripQuery(sqlRegexStr) diff --git a/sqlmock_go18.go b/sqlmock_go18.go index 0c82a3b..7b3c949 100644 --- a/sqlmock_go18.go +++ b/sqlmock_go18.go @@ -5,10 +5,11 @@ package sqlmock import ( "context" "database/sql/driver" - "fmt" + "errors" + "time" ) -var CancelledStatementErr = fmt.Errorf("canceling query due to user request") +var ErrCancelled = errors.New("canceling query due to user request") // Implement the "QueryerContext" interface func (c *sqlmock) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { @@ -17,31 +18,16 @@ func (c *sqlmock) QueryContext(ctx context.Context, query string, args []driver. namedArgs[i] = namedValue(nv) } - ex, err := c.queryExpectation(query, namedArgs) + ex, err := c.query(query, namedArgs) if err != nil { return nil, err } - type result struct { - rows driver.Rows - err error - } - - exec := make(chan result) - defer func() { - close(exec) - }() - - go func() { - rows, err := c.query(ex) - exec <- result{rows, err} - }() - select { - case res := <-exec: - return res.rows, res.err + case <-time.After(ex.delay): + return ex.rows, nil case <-ctx.Done(): - return nil, CancelledStatementErr + return nil, ErrCancelled } } @@ -52,91 +38,46 @@ func (c *sqlmock) ExecContext(ctx context.Context, query string, args []driver.N namedArgs[i] = namedValue(nv) } - ex, err := c.execExpectation(query, namedArgs) + ex, err := c.exec(query, namedArgs) if err != nil { return nil, err } - type result struct { - rs driver.Result - err error - } - - exec := make(chan result) - defer func() { - close(exec) - }() - - go func() { - rs, err := c.exec(ex) - exec <- result{rs, err} - }() - select { - case res := <-exec: - return res.rs, res.err + case <-time.After(ex.delay): + return ex.result, nil case <-ctx.Done(): - return nil, CancelledStatementErr + return nil, ErrCancelled } } // Implement the "ConnBeginTx" interface func (c *sqlmock) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { - ex, err := c.beginExpectation() + ex, err := c.begin() if err != nil { return nil, err } - type result struct { - tx driver.Tx - err error - } - - exec := make(chan result) - defer func() { - close(exec) - }() - - go func() { - tx, err := c.begin(ex) - exec <- result{tx, err} - }() - select { - case res := <-exec: - return res.tx, res.err + case <-time.After(ex.delay): + return c, nil case <-ctx.Done(): - return nil, CancelledStatementErr + return nil, ErrCancelled } } // Implement the "ConnPrepareContext" interface func (c *sqlmock) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { - ex, err := c.prepareExpectation(query) + ex, err := c.prepare(query) if err != nil { return nil, err } - type result struct { - stmt driver.Stmt - err error - } - - exec := make(chan result) - defer func() { - close(exec) - }() - - go func() { - stmt, err := c.prepare(ex, query) - exec <- result{stmt, err} - }() - select { - case res := <-exec: - return res.stmt, res.err + case <-time.After(ex.delay): + return &statement{c, query, ex.closeErr}, nil case <-ctx.Done(): - return nil, CancelledStatementErr + return nil, ErrCancelled } } diff --git a/sqlmock_go18_test.go b/sqlmock_go18_test.go index 713ed0f..e491fbd 100644 --- a/sqlmock_go18_test.go +++ b/sqlmock_go18_test.go @@ -1,3 +1,335 @@ // +build go1.8 package sqlmock + +import ( + "context" + "database/sql" + "testing" + "time" +) + +func TestContextExecCancel(t *testing.T) { + t.Parallel() + db, mock, err := New() + if err != nil { + t.Errorf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + + mock.ExpectExec("DELETE FROM users"). + WillDelayFor(time.Second). + WillReturnResult(NewResult(1, 1)) + + ctx, cancel := context.WithCancel(context.Background()) + + go func() { + time.Sleep(time.Millisecond * 10) + cancel() + }() + + _, err = db.ExecContext(ctx, "DELETE FROM users") + if err == nil { + t.Error("error was expected, but there was none") + } + + if err != ErrCancelled { + t.Errorf("was expecting cancel error, but got: %v", err) + } + + _, err = db.ExecContext(ctx, "DELETE FROM users") + if err != context.Canceled { + t.Error("error was expected since context was already done, but there was none") + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expections: %s", err) + } +} + +func TestContextExecWithNamedArg(t *testing.T) { + t.Parallel() + db, mock, err := New() + if err != nil { + t.Errorf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + + mock.ExpectExec("DELETE FROM users"). + WithArgs(sql.Named("id", 5)). + WillDelayFor(time.Second). + WillReturnResult(NewResult(1, 1)) + + ctx, cancel := context.WithCancel(context.Background()) + + go func() { + time.Sleep(time.Millisecond * 10) + cancel() + }() + + _, err = db.ExecContext(ctx, "DELETE FROM users WHERE id = :id", sql.Named("id", 5)) + if err == nil { + t.Error("error was expected, but there was none") + } + + if err != ErrCancelled { + t.Errorf("was expecting cancel error, but got: %v", err) + } + + _, err = db.ExecContext(ctx, "DELETE FROM users WHERE id = :id", sql.Named("id", 5)) + if err != context.Canceled { + t.Error("error was expected since context was already done, but there was none") + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expections: %s", err) + } +} + +func TestContextExec(t *testing.T) { + t.Parallel() + db, mock, err := New() + if err != nil { + t.Errorf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + + mock.ExpectExec("DELETE FROM users"). + WillReturnResult(NewResult(1, 1)) + + ctx, cancel := context.WithCancel(context.Background()) + + go func() { + time.Sleep(time.Millisecond * 10) + cancel() + }() + + res, err := db.ExecContext(ctx, "DELETE FROM users") + if err != nil { + t.Errorf("error was not expected, but got: %v", err) + } + + affected, err := res.RowsAffected() + if affected != 1 { + t.Errorf("expected affected rows 1, but got %v", affected) + } + + if err != nil { + t.Errorf("error was not expected, but got: %v", err) + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expections: %s", err) + } +} + +func TestContextQueryCancel(t *testing.T) { + t.Parallel() + db, mock, err := New() + if err != nil { + t.Errorf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + + rs := NewRows([]string{"id", "title"}).AddRow(5, "hello world") + + mock.ExpectQuery("SELECT (.+) FROM articles WHERE id = ?"). + WithArgs(5). + WillDelayFor(time.Second). + WillReturnRows(rs) + + ctx, cancel := context.WithCancel(context.Background()) + + go func() { + time.Sleep(time.Millisecond * 10) + cancel() + }() + + _, err = db.QueryContext(ctx, "SELECT id, title FROM articles WHERE id = ?", 5) + if err == nil { + t.Error("error was expected, but there was none") + } + + if err != ErrCancelled { + t.Errorf("was expecting cancel error, but got: %v", err) + } + + _, err = db.QueryContext(ctx, "SELECT id, title FROM articles WHERE id = ?", 5) + if err != context.Canceled { + t.Error("error was expected since context was already done, but there was none") + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expections: %s", err) + } +} + +func TestContextQuery(t *testing.T) { + t.Parallel() + db, mock, err := New() + if err != nil { + t.Errorf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + + rs := NewRows([]string{"id", "title"}).AddRow(5, "hello world") + + mock.ExpectQuery("SELECT (.+) FROM articles WHERE id ="). + WithArgs(sql.Named("id", 5)). + WillDelayFor(time.Millisecond * 3). + WillReturnRows(rs) + + ctx, cancel := context.WithCancel(context.Background()) + + go func() { + time.Sleep(time.Millisecond * 10) + cancel() + }() + + rows, err := db.QueryContext(ctx, "SELECT id, title FROM articles WHERE id = :id", sql.Named("id", 5)) + if err != nil { + t.Errorf("error was not expected, but got: %v", err) + } + + if !rows.Next() { + t.Error("expected one row, but there was none") + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expections: %s", err) + } +} + +func TestContextBeginCancel(t *testing.T) { + t.Parallel() + db, mock, err := New() + if err != nil { + t.Errorf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + + mock.ExpectBegin().WillDelayFor(time.Second) + + ctx, cancel := context.WithCancel(context.Background()) + + go func() { + time.Sleep(time.Millisecond * 10) + cancel() + }() + + _, err = db.BeginTx(ctx, nil) + if err == nil { + t.Error("error was expected, but there was none") + } + + if err != ErrCancelled { + t.Errorf("was expecting cancel error, but got: %v", err) + } + + _, err = db.BeginTx(ctx, nil) + if err != context.Canceled { + t.Error("error was expected since context was already done, but there was none") + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expections: %s", err) + } +} + +func TestContextBegin(t *testing.T) { + t.Parallel() + db, mock, err := New() + if err != nil { + t.Errorf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + + mock.ExpectBegin().WillDelayFor(time.Millisecond * 3) + + ctx, cancel := context.WithCancel(context.Background()) + + go func() { + time.Sleep(time.Millisecond * 10) + cancel() + }() + + tx, err := db.BeginTx(ctx, nil) + if err != nil { + t.Errorf("error was not expected, but got: %v", err) + } + + if tx == nil { + t.Error("expected tx, but there was nil") + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expections: %s", err) + } +} + +func TestContextPrepareCancel(t *testing.T) { + t.Parallel() + db, mock, err := New() + if err != nil { + t.Errorf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + + mock.ExpectPrepare("SELECT").WillDelayFor(time.Second) + + ctx, cancel := context.WithCancel(context.Background()) + + go func() { + time.Sleep(time.Millisecond * 10) + cancel() + }() + + _, err = db.PrepareContext(ctx, "SELECT") + if err == nil { + t.Error("error was expected, but there was none") + } + + if err != ErrCancelled { + t.Errorf("was expecting cancel error, but got: %v", err) + } + + _, err = db.PrepareContext(ctx, "SELECT") + if err != context.Canceled { + t.Error("error was expected since context was already done, but there was none") + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expections: %s", err) + } +} + +func TestContextPrepare(t *testing.T) { + t.Parallel() + db, mock, err := New() + if err != nil { + t.Errorf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + + mock.ExpectPrepare("SELECT").WillDelayFor(time.Millisecond * 3) + + ctx, cancel := context.WithCancel(context.Background()) + + go func() { + time.Sleep(time.Millisecond * 10) + cancel() + }() + + stmt, err := db.PrepareContext(ctx, "SELECT") + if err != nil { + t.Errorf("error was not expected, but got: %v", err) + } + + if stmt == nil { + t.Error("expected stmt, but there was nil") + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expections: %s", err) + } +}