diff --git a/README.md b/README.md index ddd98d2..27cd72b 100644 --- a/README.md +++ b/README.md @@ -145,6 +145,28 @@ func TestShouldRollbackStatUpdatesOnFailure(t *testing.T) { } ``` +## Customize SQL query matching + +There were plenty of requests from users regarding SQL query string validation or different matching option. +We have now implemented the `QueryMatcher` interface, which can be passed through an option when calling +`sqlmock.New` or `sqlmock.NewWithDSN`. + +This now allows to include some library, which would allow for example to parse and validate `mysql` SQL AST. +And create a custom QueryMatcher in order to validate SQL in sophisticated ways. + +By default, **sqlmock** is preserving backward compatibility and default query matcher is `sqlmock.QueryMatcherRegexp` +which uses expected SQL string as a regular expression to match incoming query string. There is an equality matcher: +`QueryMatcherEqual` which will do a full case sensitive match. + +In order to customize the QueryMatcher, use the following: + +``` go + db, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual)) +``` + +The query matcher can be fully customized based on user needs. **sqlmock** will not +provide a standard sql parsing matchers, since various drivers may not follow the same SQL standard. + ## Matching arguments like time.Time There may be arguments which are of `struct` type and cannot be compared easily by value like `time.Time`. In this case @@ -191,6 +213,7 @@ It only asserts that argument is of `time.Time` type. ## Change Log +- **2018-12-11** - introduced an option to provide **QueryMatcher** in order to customize SQL query matching. - **2017-09-01** - it is now possible to expect that prepared statement will be closed, using **ExpectedPrepare.WillBeClosed**. - **2017-02-09** - implemented support for **go1.8** features. **Rows** interface was changed to struct diff --git a/expectations.go b/expectations.go index 9f54967..b1ec004 100644 --- a/expectations.go +++ b/expectations.go @@ -3,7 +3,6 @@ package sqlmock import ( "database/sql/driver" "fmt" - "regexp" "strings" "sync" "time" @@ -154,7 +153,7 @@ func (e *ExpectedQuery) WillDelayFor(duration time.Duration) *ExpectedQuery { // String returns string representation func (e *ExpectedQuery) String() string { msg := "ExpectedQuery => expecting Query, QueryContext or QueryRow which:" - msg += "\n - matches sql: '" + e.sqlRegex.String() + "'" + msg += "\n - matches sql: '" + e.expectSQL + "'" if len(e.args) == 0 { msg += "\n - is without arguments" @@ -209,7 +208,7 @@ func (e *ExpectedExec) WillDelayFor(duration time.Duration) *ExpectedExec { // String returns string representation func (e *ExpectedExec) String() string { msg := "ExpectedExec => expecting Exec or ExecContext which:" - msg += "\n - matches sql: '" + e.sqlRegex.String() + "'" + msg += "\n - matches sql: '" + e.expectSQL + "'" if len(e.args) == 0 { msg += "\n - is without arguments" @@ -253,7 +252,7 @@ func (e *ExpectedExec) WillReturnResult(result driver.Result) *ExpectedExec { type ExpectedPrepare struct { commonExpectation mock *sqlmock - sqlRegex *regexp.Regexp + expectSQL string statement driver.Stmt closeErr error mustBeClosed bool @@ -291,7 +290,7 @@ func (e *ExpectedPrepare) WillBeClosed() *ExpectedPrepare { // this method is convenient in order to prevent duplicating sql query string matching. func (e *ExpectedPrepare) ExpectQuery() *ExpectedQuery { eq := &ExpectedQuery{} - eq.sqlRegex = e.sqlRegex + eq.expectSQL = e.expectSQL eq.converter = e.mock.converter e.mock.expected = append(e.mock.expected, eq) return eq @@ -301,7 +300,7 @@ func (e *ExpectedPrepare) ExpectQuery() *ExpectedQuery { // this method is convenient in order to prevent duplicating sql query string matching. func (e *ExpectedPrepare) ExpectExec() *ExpectedExec { eq := &ExpectedExec{} - eq.sqlRegex = e.sqlRegex + eq.expectSQL = e.expectSQL eq.converter = e.mock.converter e.mock.expected = append(e.mock.expected, eq) return eq @@ -310,7 +309,7 @@ func (e *ExpectedPrepare) ExpectExec() *ExpectedExec { // String returns string representation func (e *ExpectedPrepare) String() string { msg := "ExpectedPrepare => expecting Prepare statement which:" - msg += "\n - matches sql: '" + e.sqlRegex.String() + "'" + msg += "\n - matches sql: '" + e.expectSQL + "'" if e.err != nil { msg += fmt.Sprintf("\n - should return error: %s", e.err) @@ -327,16 +326,12 @@ func (e *ExpectedPrepare) String() string { // adds a query matching logic type queryBasedExpectation struct { commonExpectation - sqlRegex *regexp.Regexp + expectSQL string converter driver.ValueConverter args []driver.Value } -func (e *queryBasedExpectation) attemptMatch(sql string, args []namedValue) (err error) { - if !e.queryMatches(sql) { - return fmt.Errorf(`could not match sql: "%s" with expected regexp "%s"`, sql, e.sqlRegex.String()) - } - +func (e *queryBasedExpectation) attemptArgMatch(args []namedValue) (err error) { // catch panic defer func() { if e := recover(); e != nil { @@ -350,7 +345,3 @@ func (e *queryBasedExpectation) attemptMatch(sql string, args []namedValue) (err err = e.argsMatches(args) return } - -func (e *queryBasedExpectation) queryMatches(sql string) bool { - return e.sqlRegex.MatchString(sql) -} diff --git a/expectations_test.go b/expectations_test.go index 90e3f1f..8d2f6d7 100644 --- a/expectations_test.go +++ b/expectations_test.go @@ -3,7 +3,6 @@ package sqlmock import ( "database/sql/driver" "fmt" - "regexp" "testing" "time" ) @@ -100,20 +99,6 @@ func TestQueryExpectationArgComparisonBool(t *testing.T) { } } -func TestQueryExpectationSqlMatch(t *testing.T) { - e := &ExpectedExec{} - - e.sqlRegex = regexp.MustCompile("SELECT x FROM") - if !e.queryMatches("SELECT x FROM someting") { - t.Errorf("Sql must have matched the query") - } - - e.sqlRegex = regexp.MustCompile("SELECT COUNT\\(x\\) FROM") - if !e.queryMatches("SELECT COUNT(x) FROM someting") { - t.Errorf("Sql must have matched the query") - } -} - func ExampleExpectedExec() { db, mock, _ := New() result := NewErrorResult(fmt.Errorf("some error")) diff --git a/query_test.go b/query_test.go index f9546ca..0ba7bdc 100644 --- a/query_test.go +++ b/query_test.go @@ -5,6 +5,42 @@ import ( "testing" ) +func ExampleQueryMatcher() { + // configure to use case sensitive SQL query matcher + // instead of default regular expression matcher + db, mock, err := New(QueryMatcherOption(QueryMatcherEqual)) + if err != nil { + fmt.Println("failed to open sqlmock database:", err) + } + defer db.Close() + + rows := NewRows([]string{"id", "title"}). + AddRow(1, "one"). + AddRow(2, "two") + + mock.ExpectQuery("SELECT * FROM users").WillReturnRows(rows) + + rs, err := db.Query("SELECT * FROM users") + if err != nil { + fmt.Println("failed to match expected query") + return + } + defer rs.Close() + + for rs.Next() { + var id int + var title string + rs.Scan(&id, &title) + fmt.Println("scanned id:", id, "and title:", title) + } + + if rs.Err() != nil { + fmt.Println("got rows error:", rs.Err()) + } + // Output: scanned id: 1 and title: one + // scanned id: 2 and title: two +} + func TestQueryStringStripping(t *testing.T) { assert := func(actual, expected string) { if res := stripQuery(actual); res != expected { diff --git a/sqlmock.go b/sqlmock.go index 113ecd8..bf274d5 100644 --- a/sqlmock.go +++ b/sqlmock.go @@ -14,7 +14,6 @@ import ( "database/sql" "database/sql/driver" "fmt" - "regexp" "time" ) @@ -32,22 +31,19 @@ type Sqlmock interface { // were met in order. If any of them was not met - an error is returned. ExpectationsWereMet() error - // ExpectPrepare expects Prepare() to be called with sql query - // which match sqlRegexStr given regexp. + // ExpectPrepare expects Prepare() to be called with expectedSQL query. // the *ExpectedPrepare allows to mock database response. // Note that you may expect Query() or Exec() on the *ExpectedPrepare - // statement to prevent repeating sqlRegexStr - ExpectPrepare(sqlRegexStr string) *ExpectedPrepare + // statement to prevent repeating expectedSQL + ExpectPrepare(expectedSQL string) *ExpectedPrepare - // ExpectQuery expects Query() or QueryRow() to be called with sql query - // which match sqlRegexStr given regexp. + // ExpectQuery expects Query() or QueryRow() to be called with expectedSQL query. // the *ExpectedQuery allows to mock database response. - ExpectQuery(sqlRegexStr string) *ExpectedQuery + ExpectQuery(expectedSQL string) *ExpectedQuery - // ExpectExec expects Exec() to be called with sql query - // which match sqlRegexStr given regexp. + // ExpectExec expects Exec() to be called with expectedSQL query. // the *ExpectedExec allows to mock database response - ExpectExec(sqlRegexStr string) *ExpectedExec + ExpectExec(expectedSQL string) *ExpectedExec // ExpectBegin expects *sql.DB.Begin to be called. // the *ExpectedBegin allows to mock database response @@ -260,7 +256,6 @@ func (c *sqlmock) Exec(query string, args []driver.Value) (driver.Result, error) } func (c *sqlmock) exec(query string, args []namedValue) (*ExpectedExec, error) { - query = stripQuery(query) var expected *ExpectedExec var fulfilled int var ok bool @@ -280,7 +275,12 @@ func (c *sqlmock) exec(query string, args []namedValue) (*ExpectedExec, error) { return nil, fmt.Errorf("call to ExecQuery '%s' with args %+v, was not expected, next expectation is: %s", query, args, next) } if exec, ok := next.(*ExpectedExec); ok { - if err := exec.attemptMatch(query, args); err == nil { + if err := c.queryMatcher.Match(exec.expectSQL, query); err != nil { + next.Unlock() + continue + } + + if err := exec.attemptArgMatch(args); err == nil { expected = exec break } @@ -296,8 +296,8 @@ func (c *sqlmock) exec(query string, args []namedValue) (*ExpectedExec, error) { } defer expected.Unlock() - if !expected.queryMatches(query) { - return nil, fmt.Errorf("ExecQuery '%s', does not match regex '%s'", query, expected.sqlRegex.String()) + if err := c.queryMatcher.Match(expected.expectSQL, query); err != nil { + return nil, fmt.Errorf("ExecQuery: %v", err) } if err := expected.argsMatches(args); err != nil { @@ -316,10 +316,9 @@ func (c *sqlmock) exec(query string, args []namedValue) (*ExpectedExec, error) { return expected, nil } -func (c *sqlmock) ExpectExec(sqlRegexStr string) *ExpectedExec { +func (c *sqlmock) ExpectExec(expectedSQL string) *ExpectedExec { e := &ExpectedExec{} - sqlRegexStr = stripQuery(sqlRegexStr) - e.sqlRegex = regexp.MustCompile(sqlRegexStr) + e.expectSQL = expectedSQL e.converter = c.converter c.expected = append(c.expected, e) return e @@ -343,8 +342,6 @@ func (c *sqlmock) prepare(query string) (*ExpectedPrepare, error) { var fulfilled int var ok bool - query = stripQuery(query) - for _, next := range c.expected { next.Lock() if next.fulfilled() { @@ -363,7 +360,7 @@ func (c *sqlmock) prepare(query string) (*ExpectedPrepare, error) { } if pr, ok := next.(*ExpectedPrepare); ok { - if pr.sqlRegex.MatchString(query) { + if err := c.queryMatcher.Match(pr.expectSQL, query); err == nil { expected = pr break } @@ -379,17 +376,16 @@ func (c *sqlmock) prepare(query string) (*ExpectedPrepare, error) { return nil, fmt.Errorf(msg, query) } defer expected.Unlock() - if !expected.sqlRegex.MatchString(query) { - return nil, fmt.Errorf("Prepare query string '%s', does not match regex [%s]", query, expected.sqlRegex.String()) + if err := c.queryMatcher.Match(expected.expectSQL, query); err != nil { + return nil, fmt.Errorf("Prepare: %v", err) } expected.triggered = true return expected, expected.err } -func (c *sqlmock) ExpectPrepare(sqlRegexStr string) *ExpectedPrepare { - sqlRegexStr = stripQuery(sqlRegexStr) - e := &ExpectedPrepare{sqlRegex: regexp.MustCompile(sqlRegexStr), mock: c} +func (c *sqlmock) ExpectPrepare(expectedSQL string) *ExpectedPrepare { + e := &ExpectedPrepare{expectSQL: expectedSQL, mock: c} c.expected = append(c.expected, e) return e } @@ -422,7 +418,6 @@ func (c *sqlmock) Query(query string, args []driver.Value) (driver.Rows, error) } func (c *sqlmock) query(query string, args []namedValue) (*ExpectedQuery, error) { - query = stripQuery(query) var expected *ExpectedQuery var fulfilled int var ok bool @@ -442,7 +437,11 @@ func (c *sqlmock) query(query string, args []namedValue) (*ExpectedQuery, error) return nil, fmt.Errorf("call to Query '%s' with args %+v, was not expected, next expectation is: %s", query, args, next) } if qr, ok := next.(*ExpectedQuery); ok { - if err := qr.attemptMatch(query, args); err == nil { + if err := c.queryMatcher.Match(qr.expectSQL, query); err != nil { + next.Unlock() + continue + } + if err := qr.attemptArgMatch(args); err == nil { expected = qr break } @@ -460,8 +459,8 @@ func (c *sqlmock) query(query string, args []namedValue) (*ExpectedQuery, error) defer expected.Unlock() - if !expected.queryMatches(query) { - return nil, fmt.Errorf("Query '%s', does not match regex [%s]", query, expected.sqlRegex.String()) + if err := c.queryMatcher.Match(expected.expectSQL, query); err != nil { + return nil, fmt.Errorf("Query: %v", err) } if err := expected.argsMatches(args); err != nil { @@ -479,10 +478,9 @@ func (c *sqlmock) query(query string, args []namedValue) (*ExpectedQuery, error) return expected, nil } -func (c *sqlmock) ExpectQuery(sqlRegexStr string) *ExpectedQuery { +func (c *sqlmock) ExpectQuery(expectedSQL string) *ExpectedQuery { e := &ExpectedQuery{} - sqlRegexStr = stripQuery(sqlRegexStr) - e.sqlRegex = regexp.MustCompile(sqlRegexStr) + e.expectSQL = expectedSQL e.converter = c.converter c.expected = append(c.expected, e) return e diff --git a/sqlmock_go18.go b/sqlmock_go18.go index b8c76f8..0afb296 100644 --- a/sqlmock_go18.go +++ b/sqlmock_go18.go @@ -9,6 +9,8 @@ import ( "time" ) +// ErrCancelled defines an error value, which can be expected in case of +// such cancellation error. var ErrCancelled = errors.New("canceling query due to user request") // Implement the "QueryerContext" interface