From 32675e3ad1c5823798bb9566f042f072d0360acd Mon Sep 17 00:00:00 2001 From: Wachowski Date: Tue, 20 Feb 2018 17:30:20 -0700 Subject: [PATCH] delay for WillDelayFor() duration, before returning error set by WillReturnError() If both WillDelayFor() and WillReturnError() are called, there should be a delay before the error is returned. Applied to Begin, Exec, Query, and Prepare. Also, the context versions of same. And a couple of unit tests. --- sqlmock.go | 20 ++++++++---- sqlmock_go18.go | 76 +++++++++++++++++++++++++------------------- sqlmock_go18_test.go | 50 +++++++++++++++++++++++++++++ sqlmock_test.go | 50 +++++++++++++++++++++++++++++ 4 files changed, 158 insertions(+), 38 deletions(-) diff --git a/sqlmock.go b/sqlmock.go index fa7f624..8fe5cc6 100644 --- a/sqlmock.go +++ b/sqlmock.go @@ -168,11 +168,13 @@ func (c *sqlmock) ExpectationsWereMet() error { // Begin meets http://golang.org/pkg/database/sql/driver/#Conn interface func (c *sqlmock) Begin() (driver.Tx, error) { ex, err := c.begin() + if ex != nil { + time.Sleep(ex.delay) + } if err != nil { return nil, err } - time.Sleep(ex.delay) return c, nil } @@ -228,11 +230,13 @@ func (c *sqlmock) Exec(query string, args []driver.Value) (driver.Result, error) } ex, err := c.exec(query, namedArgs) + if ex != nil { + time.Sleep(ex.delay) + } if err != nil { return nil, err } - time.Sleep(ex.delay) return ex.result, nil } @@ -283,7 +287,7 @@ func (c *sqlmock) exec(query string, args []namedValue) (*ExpectedExec, error) { expected.triggered = true if expected.err != nil { - return nil, expected.err // mocked to return error + return expected, expected.err // mocked to return error } if expected.result == nil { @@ -304,11 +308,13 @@ func (c *sqlmock) ExpectExec(sqlRegexStr string) *ExpectedExec { // Prepare meets http://golang.org/pkg/database/sql/driver/#Conn interface func (c *sqlmock) Prepare(query string) (driver.Stmt, error) { ex, err := c.prepare(query) + if ex != nil { + time.Sleep(ex.delay) + } if err != nil { return nil, err } - time.Sleep(ex.delay) return &statement{c, ex, query}, nil } @@ -385,11 +391,13 @@ func (c *sqlmock) Query(query string, args []driver.Value) (driver.Rows, error) } ex, err := c.query(query, namedArgs) + if ex != nil { + time.Sleep(ex.delay) + } if err != nil { return nil, err } - time.Sleep(ex.delay) return ex.rows, nil } @@ -442,7 +450,7 @@ func (c *sqlmock) query(query string, args []namedValue) (*ExpectedQuery, error) expected.triggered = true if expected.err != nil { - return nil, expected.err // mocked to return error + return expected, expected.err // mocked to return error } if expected.rows == nil { diff --git a/sqlmock_go18.go b/sqlmock_go18.go index 52b0e0c..49bd4f5 100644 --- a/sqlmock_go18.go +++ b/sqlmock_go18.go @@ -19,16 +19,19 @@ func (c *sqlmock) QueryContext(ctx context.Context, query string, args []driver. } ex, err := c.query(query, namedArgs) - if err != nil { - return nil, err + if ex != nil { + select { + case <-time.After(ex.delay): + if err != nil { + return nil, err + } + return ex.rows, nil + case <-ctx.Done(): + return nil, ErrCancelled + } } - select { - case <-time.After(ex.delay): - return ex.rows, nil - case <-ctx.Done(): - return nil, ErrCancelled - } + return nil, err } // Implement the "ExecerContext" interface @@ -39,46 +42,55 @@ func (c *sqlmock) ExecContext(ctx context.Context, query string, args []driver.N } ex, err := c.exec(query, namedArgs) - if err != nil { - return nil, err + if ex != nil { + select { + case <-time.After(ex.delay): + if err != nil { + return nil, err + } + return ex.result, nil + case <-ctx.Done(): + return nil, ErrCancelled + } } - select { - case <-time.After(ex.delay): - return ex.result, nil - case <-ctx.Done(): - return nil, ErrCancelled - } + return nil, err } // Implement the "ConnBeginTx" interface func (c *sqlmock) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { ex, err := c.begin() - if err != nil { - return nil, err + if ex != nil { + select { + case <-time.After(ex.delay): + if err != nil { + return nil, err + } + return c, nil + case <-ctx.Done(): + return nil, ErrCancelled + } } - select { - case <-time.After(ex.delay): - return c, nil - case <-ctx.Done(): - return nil, ErrCancelled - } + return nil, err } // Implement the "ConnPrepareContext" interface func (c *sqlmock) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { ex, err := c.prepare(query) - if err != nil { - return nil, err + if ex != nil { + select { + case <-time.After(ex.delay): + if err != nil { + return nil, err + } + return &statement{c, ex, query}, nil + case <-ctx.Done(): + return nil, ErrCancelled + } } - select { - case <-time.After(ex.delay): - return &statement{c, ex, query}, nil - case <-ctx.Done(): - return nil, ErrCancelled - } + return nil, err } // Implement the "Pinger" interface diff --git a/sqlmock_go18_test.go b/sqlmock_go18_test.go index abc2452..e53d9c7 100644 --- a/sqlmock_go18_test.go +++ b/sqlmock_go18_test.go @@ -5,6 +5,7 @@ package sqlmock import ( "context" "database/sql" + "errors" "testing" "time" ) @@ -424,3 +425,52 @@ func TestContextPrepare(t *testing.T) { t.Errorf("there were unfulfilled expectations: %s", err) } } + +func TestContextExecErrorDelay(t *testing.T) { + t.Parallel() + db, mock, err := New() + if err != nil { + t.Errorf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + + // test that return of error is delayed + var delay time.Duration + delay = 100 * time.Millisecond + mock.ExpectExec("^INSERT INTO articles"). + WillReturnError(errors.New("slow fail")). + WillDelayFor(delay) + + start := time.Now() + res, err := db.ExecContext(context.Background(), "INSERT INTO articles (title) VALUES (?)", "hello") + stop := time.Now() + + if res != nil { + t.Errorf("result was not expected, was expecting nil") + } + + if err == nil { + t.Errorf("error was expected, was not expecting nil") + } + + if err.Error() != "slow fail" { + t.Errorf("error '%s' was not expected, was expecting '%s'", err.Error(), "slow fail") + } + + elapsed := stop.Sub(start) + if elapsed < delay { + t.Errorf("expecting a delay of %v before error, actual delay was %v", delay, elapsed) + } + + // also test that return of error is not delayed + mock.ExpectExec("^INSERT INTO articles").WillReturnError(errors.New("fast fail")) + + start = time.Now() + db.ExecContext(context.Background(), "INSERT INTO articles (title) VALUES (?)", "hello") + stop = time.Now() + + elapsed = stop.Sub(start) + if elapsed > delay { + t.Errorf("expecting a delay of less than %v before error, actual delay was %v", delay, elapsed) + } +} diff --git a/sqlmock_test.go b/sqlmock_test.go index ecbd5d1..f43a3e7 100644 --- a/sqlmock_test.go +++ b/sqlmock_test.go @@ -2,6 +2,7 @@ package sqlmock import ( "database/sql" + "errors" "fmt" "strconv" "sync" @@ -1063,3 +1064,52 @@ func TestPreparedStatementCloseExpectation(t *testing.T) { t.Errorf("there were unfulfilled expectations: %s", err) } } + +func TestExecExpectationErrorDelay(t *testing.T) { + t.Parallel() + db, mock, err := New() + if err != nil { + t.Errorf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + + // test that return of error is delayed + var delay time.Duration + delay = 100 * time.Millisecond + mock.ExpectExec("^INSERT INTO articles"). + WillReturnError(errors.New("slow fail")). + WillDelayFor(delay) + + start := time.Now() + res, err := db.Exec("INSERT INTO articles (title) VALUES (?)", "hello") + stop := time.Now() + + if res != nil { + t.Errorf("result was not expected, was expecting nil") + } + + if err == nil { + t.Errorf("error was expected, was not expecting nil") + } + + if err.Error() != "slow fail" { + t.Errorf("error '%s' was not expected, was expecting '%s'", err.Error(), "slow fail") + } + + elapsed := stop.Sub(start) + if elapsed < delay { + t.Errorf("expecting a delay of %v before error, actual delay was %v", delay, elapsed) + } + + // also test that return of error is not delayed + mock.ExpectExec("^INSERT INTO articles").WillReturnError(errors.New("fast fail")) + + start = time.Now() + db.Exec("INSERT INTO articles (title) VALUES (?)", "hello") + stop = time.Now() + + elapsed = stop.Sub(start) + if elapsed > delay { + t.Errorf("expecting a delay of less than %v before error, actual delay was %v", delay, elapsed) + } +}