From dd0fe2afd6d6564f90badb06635fe9aa125b6d7e Mon Sep 17 00:00:00 2001 From: Matthew Huxtable Date: Mon, 4 Nov 2019 18:02:04 +0000 Subject: [PATCH] Add ExpectPing to watch for db Ping calls --- expectations.go | 29 +++++++ options.go | 16 ++++ sqlmock.go | 22 ++++- sqlmock_before_go18.go | 10 +++ sqlmock_before_go18_test.go | 26 ++++++ sqlmock_go18.go | 66 ++++++++++++++- sqlmock_go18_test.go | 165 ++++++++++++++++++++++++++++++++++++ 7 files changed, 329 insertions(+), 5 deletions(-) create mode 100644 sqlmock_before_go18.go create mode 100644 sqlmock_before_go18_test.go diff --git a/expectations.go b/expectations.go index 38c0d17..ae2a47f 100644 --- a/expectations.go +++ b/expectations.go @@ -353,3 +353,32 @@ func (e *queryBasedExpectation) attemptArgMatch(args []namedValue) (err error) { err = e.argsMatches(args) return } + +// ExpectedPing is used to manage *sql.DB.Ping expectations. +// Returned by *Sqlmock.ExpectPing. +type ExpectedPing struct { + commonExpectation + delay time.Duration +} + +// WillDelayFor allows to specify duration for which it will delay result. May +// be used together with Context. +func (e *ExpectedPing) WillDelayFor(duration time.Duration) *ExpectedPing { + e.delay = duration + return e +} + +// WillReturnError allows to set an error for expected database ping +func (e *ExpectedPing) WillReturnError(err error) *ExpectedPing { + e.err = err + return e +} + +// String returns string representation +func (e *ExpectedPing) String() string { + msg := "ExpectedPing => expecting database Ping" + if e.err != nil { + msg += fmt.Sprintf(", which should return error: %s", e.err) + } + return msg +} diff --git a/options.go b/options.go index 29053ee..00c9837 100644 --- a/options.go +++ b/options.go @@ -20,3 +20,19 @@ func QueryMatcherOption(queryMatcher QueryMatcher) func(*sqlmock) error { return nil } } + +// MonitorPingsOption determines whether calls to Ping on the driver should be +// observed and mocked. +// +// If true is passed, we will check these calls were expected. Expectations can +// be registered using the ExpectPing() method on the mock. +// +// If false is passed or this option is omitted, calls to Ping will not be +// considered when determining expectations and calls to ExpectPing will have +// no effect. +func MonitorPingsOption(monitorPings bool) func(*sqlmock) error { + return func(s *sqlmock) error { + s.monitorPings = monitorPings + return nil + } +} diff --git a/sqlmock.go b/sqlmock.go index 4896307..9431d0e 100644 --- a/sqlmock.go +++ b/sqlmock.go @@ -21,7 +21,6 @@ import ( // for any kind of database action in order to mock // and test real database behavior. type Sqlmock interface { - // ExpectClose queues an expectation for this database // action to be triggered. the *ExpectedClose allows // to mock database response @@ -57,6 +56,17 @@ type Sqlmock interface { // the *ExpectedRollback allows to mock database response ExpectRollback() *ExpectedRollback + // ExpectPing expected *sql.DB.Ping to be called. + // the *ExpectedPing allows to mock database response + // + // Ping support only exists in the SQL library in Go 1.8 and above. + // ExpectPing in Go <=1.7 will return an ExpectedPing but not register + // any expectations. + // + // You must enable pings using MonitorPingsOption for this to register + // any expectations. + ExpectPing() *ExpectedPing + // MatchExpectationsInOrder gives an option whether to match all // expectations in the order they were set or not. // @@ -83,6 +93,7 @@ type sqlmock struct { drv *mockDriver converter driver.ValueConverter queryMatcher QueryMatcher + monitorPings bool expected []expectation } @@ -104,6 +115,15 @@ func (c *sqlmock) open(options []func(*sqlmock) error) (*sql.DB, Sqlmock, error) if c.queryMatcher == nil { c.queryMatcher = QueryMatcherRegexp } + + if c.monitorPings { + // We call Ping on the driver shortly to verify startup assertions by + // driving internal behaviour of the sql standard library. We don't + // want this call to ping to be monitored for expectation purposes so + // temporarily disable. + c.monitorPings = false + defer func() { c.monitorPings = true }() + } return db, c, db.Ping() } diff --git a/sqlmock_before_go18.go b/sqlmock_before_go18.go new file mode 100644 index 0000000..88b7aa0 --- /dev/null +++ b/sqlmock_before_go18.go @@ -0,0 +1,10 @@ +// +build !go1.8 + +package sqlmock + +import "log" + +func (c *sqlmock) ExpectPing() *ExpectedPing { + log.Println("ExpectPing has no effect on Go 1.7 or below") + return &ExpectedPing{} +} diff --git a/sqlmock_before_go18_test.go b/sqlmock_before_go18_test.go new file mode 100644 index 0000000..3c510a3 --- /dev/null +++ b/sqlmock_before_go18_test.go @@ -0,0 +1,26 @@ +// +build !go1.8 + +package sqlmock + +import ( + "fmt" + "testing" + "time" +) + +func TestSqlmockExpectPingHasNoEffect(t *testing.T) { + db, mock, err := New() + if err != nil { + t.Errorf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + + e := mock.ExpectPing() + + // Methods on the expectation can be called + e.WillDelayFor(time.Hour).WillReturnError(fmt.Errorf("an error")) + + if err = mock.ExpectationsWereMet(); err != nil { + t.Errorf("expected no error to be returned, but got '%s'", err) + } +} diff --git a/sqlmock_go18.go b/sqlmock_go18.go index 8fe9c1d..43fbb5d 100644 --- a/sqlmock_go18.go +++ b/sqlmock_go18.go @@ -6,6 +6,8 @@ import ( "context" "database/sql/driver" "errors" + "fmt" + "log" "time" ) @@ -95,11 +97,57 @@ func (c *sqlmock) PrepareContext(ctx context.Context, query string) (driver.Stmt return nil, err } -// Implement the "Pinger" interface -// for now we do not have a Ping expectation -// may be something for the future +// Implement the "Pinger" interface - the explicit DB driver ping was only added to database/sql in Go 1.8 func (c *sqlmock) Ping(ctx context.Context) error { - return nil + if !c.monitorPings { + return nil + } + + ex, err := c.ping() + if ex != nil { + select { + case <-ctx.Done(): + return ErrCancelled + case <-time.After(ex.delay): + } + } + + return err +} + +func (c *sqlmock) ping() (*ExpectedPing, error) { + var expected *ExpectedPing + var fulfilled int + var ok bool + for _, next := range c.expected { + next.Lock() + if next.fulfilled() { + next.Unlock() + fulfilled++ + continue + } + + if expected, ok = next.(*ExpectedPing); ok { + break + } + + next.Unlock() + if c.ordered { + return nil, fmt.Errorf("call to database Ping, was not expected, next expectation is: %s", next) + } + } + + if expected == nil { + msg := "call to database Ping was not expected" + if fulfilled == len(c.expected) { + msg = "all expectations were already fulfilled, " + msg + } + return nil, fmt.Errorf(msg) + } + + expected.triggered = true + expected.Unlock() + return expected, expected.err } // Implement the "StmtExecContext" interface @@ -112,4 +160,14 @@ func (stmt *statement) QueryContext(ctx context.Context, args []driver.NamedValu return stmt.conn.QueryContext(ctx, stmt.query, args) } +func (c *sqlmock) ExpectPing() *ExpectedPing { + if !c.monitorPings { + log.Println("ExpectPing will have no effect as monitoring pings is disabled. Use MonitorPingsOption to enable.") + return nil + } + e := &ExpectedPing{} + c.expected = append(c.expected, e) + return e +} + // @TODO maybe add ExpectedBegin.WithOptions(driver.TxOptions) diff --git a/sqlmock_go18_test.go b/sqlmock_go18_test.go index e53d9c7..223e076 100644 --- a/sqlmock_go18_test.go +++ b/sqlmock_go18_test.go @@ -474,3 +474,168 @@ func TestContextExecErrorDelay(t *testing.T) { t.Errorf("expecting a delay of less than %v before error, actual delay was %v", delay, elapsed) } } + +// TestMonitorPingsDisabled verifies backwards-compatibility with behaviour of the library in which +// calls to Ping are not mocked out. It verifies this persists when the user does not enable the new +// behaviour. +func TestMonitorPingsDisabled(t *testing.T) { + t.Parallel() + db, mock, err := New() + if err != nil { + t.Errorf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + + // When monitoring of pings is not enabled in the mock, calling Ping should have no effect. + err = db.Ping() + if err != nil { + t.Errorf("monitoring of pings is not enabled so did not expect error from Ping, got '%s'", err) + } + + // Calling ExpectPing should also not register any expectations in the mock. The return from + // ExpectPing should be nil. + expectation := mock.ExpectPing() + if expectation != nil { + t.Errorf("expected ExpectPing to return a nil pointer when monitoring of pings is not enabled") + } + + err = mock.ExpectationsWereMet() + if err != nil { + t.Errorf("monitoring of pings is not enabled so ExpectPing should not register an expectation, got '%s'", err) + } +} + +func TestPingExpectations(t *testing.T) { + t.Parallel() + db, mock, err := New(MonitorPingsOption(true)) + if err != nil { + t.Errorf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + + mock.ExpectPing() + if err := db.Ping(); err != nil { + t.Fatal(err) + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expectations: %s", err) + } +} + +func TestPingExpectationsErrorDelay(t *testing.T) { + t.Parallel() + db, mock, err := New(MonitorPingsOption(true)) + if err != nil { + t.Errorf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + + var delay time.Duration + delay = 100 * time.Millisecond + mock.ExpectPing(). + WillReturnError(errors.New("slow fail")). + WillDelayFor(delay) + + start := time.Now() + err = db.Ping() + stop := time.Now() + + if err == nil { + t.Errorf("result was not expected, was not expecting nil error") + } + + if err.Error() != "slow fail" { + t.Errorf("error '%s' was not expected, was expecting '%s'", err.Error(), "slow fail") + } + + elapsed := stop.Sub(start) + if elapsed < delay { + t.Errorf("expecting a delay of %v before error, actual delay was %v", delay, elapsed) + } + + mock.ExpectPing().WillReturnError(errors.New("fast fail")) + + start = time.Now() + db.Ping() + stop = time.Now() + + elapsed = stop.Sub(start) + if elapsed > delay { + t.Errorf("expecting a delay of less than %v before error, actual delay was %v", delay, elapsed) + } +} + +func TestPingExpectationsMissingPing(t *testing.T) { + t.Parallel() + db, mock, err := New(MonitorPingsOption(true)) + if err != nil { + t.Errorf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + + mock.ExpectPing() + + if err = mock.ExpectationsWereMet(); err == nil { + t.Fatalf("was expecting an error, but there wasn't one") + } +} + +func TestPingExpectationsUnexpectedPing(t *testing.T) { + t.Parallel() + db, _, err := New(MonitorPingsOption(true)) + if err != nil { + t.Errorf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + + if err = db.Ping(); err == nil { + t.Fatalf("was expecting an error, but there wasn't any") + } +} + +func TestPingOrderedWrongOrder(t *testing.T) { + t.Parallel() + db, mock, err := New(MonitorPingsOption(true)) + if err != nil { + t.Errorf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + + mock.ExpectBegin() + mock.ExpectPing() + mock.MatchExpectationsInOrder(true) + + if err = db.Ping(); err == nil { + t.Fatalf("was expecting an error, but there wasn't any") + } +} + +func TestPingExpectationsContextTimeout(t *testing.T) { + t.Parallel() + db, mock, err := New(MonitorPingsOption(true)) + if err != nil { + t.Errorf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + + mock.ExpectPing().WillDelayFor(time.Hour) + + ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) + defer cancel() + + doneCh := make(chan struct{}) + go func() { + err = db.PingContext(ctx) + close(doneCh) + }() + + select { + case <-doneCh: + if err != ErrCancelled { + t.Errorf("expected error '%s' to be returned from Ping, but got '%s'", ErrCancelled, err) + } + case <-time.After(time.Second): + t.Errorf("expected Ping to return after context timeout, but it did not in a timely fashion") + } +}