From e671f177c0905fbb0d90149615bf628396afb85c Mon Sep 17 00:00:00 2001 From: Michael Darwish Date: Thu, 20 Dec 2018 13:00:51 -0500 Subject: [PATCH] adds missing lock around e.fulfilled() in ExpectationsWereMet() --- sqlmock.go | 6 +++++- sqlmock_test.go | 50 +++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 55 insertions(+), 1 deletion(-) diff --git a/sqlmock.go b/sqlmock.go index 20881f7..609dafd 100644 --- a/sqlmock.go +++ b/sqlmock.go @@ -166,7 +166,11 @@ func (c *sqlmock) Close() error { func (c *sqlmock) ExpectationsWereMet() error { for _, e := range c.expected { - if !e.fulfilled() { + e.Lock() + fulfilled := e.fulfilled() + e.Unlock() + + if !fulfilled { return fmt.Errorf("there is a remaining expectation which was not matched: %s", e) } diff --git a/sqlmock_test.go b/sqlmock_test.go index 6f28072..11fd59a 100644 --- a/sqlmock_test.go +++ b/sqlmock_test.go @@ -1167,3 +1167,53 @@ func TestNewRows(t *testing.T) { t.Errorf("expecting to create a row with columns %v, actual colmns are %v", r.cols, columns) } } + +// This is actually a test of ExpectationsWereMet. Without a lock around e.fulfilled() inside +// ExpectationWereMet, the race detector complains if e.triggered is being read while it is also +// being written by the query running in another goroutine. +func TestQueryWithTimeout(t *testing.T) { + db, mock, err := New() + if err != nil { + t.Errorf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + + rs := NewRows([]string{"id", "title"}).FromCSVString("5,hello world") + + mock.ExpectQuery("SELECT (.+) FROM articles WHERE id = ?"). + WillDelayFor(15 * time.Millisecond). // Query will take longer than timeout + WithArgs(5). + WillReturnRows(rs) + + _, err = queryWithTimeout(10*time.Millisecond, db, "SELECT (.+) FROM articles WHERE id = ?", 5) + if err == nil { + t.Errorf("expecting query to time out") + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expectations: %s", err) + } +} + +func queryWithTimeout(t time.Duration, db *sql.DB, query string, args ...interface{}) (*sql.Rows, error) { + rowsChan := make(chan *sql.Rows, 1) + errChan := make(chan error, 1) + + go func() { + rows, err := db.Query(query, args...) + if err != nil { + errChan <- err + return + } + rowsChan <- rows + }() + + select { + case rows := <-rowsChan: + return rows, nil + case err := <-errChan: + return nil, err + case <-time.After(t): + return nil, fmt.Errorf("query timed out after %v", t) + } +}