mirror of
https://github.com/DATA-DOG/go-sqlmock.git
synced 2024-11-21 17:17:08 +02:00
Add WithTXOption expectation to ExpectBegin
This commit is contained in:
parent
6bed17cdbe
commit
4a27a756c1
@ -1,6 +1,7 @@
|
||||
package sqlmock
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"fmt"
|
||||
"strings"
|
||||
@ -53,7 +54,8 @@ func (e *ExpectedClose) String() string {
|
||||
// returned by *Sqlmock.ExpectBegin.
|
||||
type ExpectedBegin struct {
|
||||
commonExpectation
|
||||
delay time.Duration
|
||||
delay time.Duration
|
||||
txOpts *driver.TxOptions
|
||||
}
|
||||
|
||||
// WillReturnError allows to set an error for *sql.DB.Begin action
|
||||
@ -65,6 +67,9 @@ func (e *ExpectedBegin) WillReturnError(err error) *ExpectedBegin {
|
||||
// String returns string representation
|
||||
func (e *ExpectedBegin) String() string {
|
||||
msg := "ExpectedBegin => expecting database transaction Begin"
|
||||
if e.txOpts != nil {
|
||||
msg += fmt.Sprintf(", with tx options: %+v", e.txOpts)
|
||||
}
|
||||
if e.err != nil {
|
||||
msg += fmt.Sprintf(", which should return error: %s", e.err)
|
||||
}
|
||||
@ -78,6 +83,15 @@ func (e *ExpectedBegin) WillDelayFor(duration time.Duration) *ExpectedBegin {
|
||||
return e
|
||||
}
|
||||
|
||||
// WithTxOptions allows to set transaction options for *sql.DB.Begin action
|
||||
func (e *ExpectedBegin) WithTxOptions(opts sql.TxOptions) *ExpectedBegin {
|
||||
e.txOpts = &driver.TxOptions{
|
||||
Isolation: driver.IsolationLevel(opts.Isolation),
|
||||
ReadOnly: opts.ReadOnly,
|
||||
}
|
||||
return e
|
||||
}
|
||||
|
||||
// ExpectedCommit is used to manage *sql.Tx.Commit expectation
|
||||
// returned by *Sqlmock.ExpectCommit.
|
||||
type ExpectedCommit struct {
|
||||
|
11
sqlmock.go
11
sqlmock.go
@ -213,7 +213,7 @@ func (c *sqlmock) ExpectationsWereMet() error {
|
||||
|
||||
// Begin meets http://golang.org/pkg/database/sql/driver/#Conn interface
|
||||
func (c *sqlmock) Begin() (driver.Tx, error) {
|
||||
ex, err := c.begin()
|
||||
ex, err := c.begin(driver.TxOptions{})
|
||||
if ex != nil {
|
||||
time.Sleep(ex.delay)
|
||||
}
|
||||
@ -224,7 +224,7 @@ func (c *sqlmock) Begin() (driver.Tx, error) {
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (c *sqlmock) begin() (*ExpectedBegin, error) {
|
||||
func (c *sqlmock) begin(opts driver.TxOptions) (*ExpectedBegin, error) {
|
||||
var expected *ExpectedBegin
|
||||
var ok bool
|
||||
var fulfilled int
|
||||
@ -252,9 +252,14 @@ func (c *sqlmock) begin() (*ExpectedBegin, error) {
|
||||
}
|
||||
return nil, fmt.Errorf(msg)
|
||||
}
|
||||
defer expected.Unlock()
|
||||
if expected.txOpts != nil &&
|
||||
expected.txOpts.Isolation != opts.Isolation &&
|
||||
expected.txOpts.ReadOnly != opts.ReadOnly {
|
||||
return nil, fmt.Errorf("expected transaction options do not match: %+v, got: %+v", expected.txOpts, opts)
|
||||
}
|
||||
|
||||
expected.triggered = true
|
||||
expected.Unlock()
|
||||
|
||||
return expected, expected.err
|
||||
}
|
||||
|
@ -1,3 +1,4 @@
|
||||
//go:build go1.8
|
||||
// +build go1.8
|
||||
|
||||
package sqlmock
|
||||
@ -66,7 +67,7 @@ func (c *sqlmock) ExecContext(ctx context.Context, query string, args []driver.N
|
||||
|
||||
// Implement the "ConnBeginTx" interface
|
||||
func (c *sqlmock) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
|
||||
ex, err := c.begin()
|
||||
ex, err := c.begin(opts)
|
||||
if ex != nil {
|
||||
select {
|
||||
case <-time.After(ex.delay):
|
||||
|
@ -360,6 +360,66 @@ func TestContextBegin(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestContextBeginWithTxOptions(t *testing.T) {
|
||||
t.Parallel()
|
||||
db, mock, err := New()
|
||||
if err != nil {
|
||||
t.Errorf("an error '%s' was not expected when opening a stub database connection", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
mock.ExpectBegin().WithTxOptions(sql.TxOptions{
|
||||
Isolation: sql.LevelReadCommitted,
|
||||
ReadOnly: true,
|
||||
})
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
go func() {
|
||||
time.Sleep(time.Millisecond * 10)
|
||||
cancel()
|
||||
}()
|
||||
|
||||
_, err = db.BeginTx(ctx, &sql.TxOptions{Isolation: sql.LevelReadCommitted, ReadOnly: false})
|
||||
if err != nil {
|
||||
t.Errorf("error was not expected, but got: %v", err)
|
||||
}
|
||||
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("there were unfulfilled expectations: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestContextBeginWithTxOptionsMismatch(t *testing.T) {
|
||||
t.Parallel()
|
||||
db, mock, err := New()
|
||||
if err != nil {
|
||||
t.Errorf("an error '%s' was not expected when opening a stub database connection", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
mock.ExpectBegin().WithTxOptions(sql.TxOptions{
|
||||
Isolation: sql.LevelReadCommitted,
|
||||
ReadOnly: true,
|
||||
})
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
go func() {
|
||||
time.Sleep(time.Millisecond * 10)
|
||||
cancel()
|
||||
}()
|
||||
|
||||
_, err = db.BeginTx(ctx, &sql.TxOptions{Isolation: sql.LevelDefault, ReadOnly: false})
|
||||
if err == nil {
|
||||
t.Error("error was expected, but there was none")
|
||||
}
|
||||
|
||||
if err := mock.ExpectationsWereMet(); err == nil {
|
||||
t.Errorf("was expecting an error, as the tx options did not match, but there wasn't one")
|
||||
}
|
||||
}
|
||||
|
||||
func TestContextPrepareCancel(t *testing.T) {
|
||||
t.Parallel()
|
||||
db, mock, err := New()
|
||||
|
Loading…
Reference in New Issue
Block a user