1
0
mirror of https://github.com/DATA-DOG/go-sqlmock.git synced 2025-05-19 22:23:05 +02:00

use configured QueryMatcher in order to match expected SQL to actual, closes #70

This commit is contained in:
gedi 2018-12-11 14:56:33 +02:00
parent 2a15d9c09b
commit a6e6646ad9
No known key found for this signature in database
GPG Key ID: 56604CDCCC201556
6 changed files with 100 additions and 65 deletions

View File

@ -145,6 +145,28 @@ func TestShouldRollbackStatUpdatesOnFailure(t *testing.T) {
} }
``` ```
## Customize SQL query matching
There were plenty of requests from users regarding SQL query string validation or different matching option.
We have now implemented the `QueryMatcher` interface, which can be passed through an option when calling
`sqlmock.New` or `sqlmock.NewWithDSN`.
This now allows to include some library, which would allow for example to parse and validate `mysql` SQL AST.
And create a custom QueryMatcher in order to validate SQL in sophisticated ways.
By default, **sqlmock** is preserving backward compatibility and default query matcher is `sqlmock.QueryMatcherRegexp`
which uses expected SQL string as a regular expression to match incoming query string. There is an equality matcher:
`QueryMatcherEqual` which will do a full case sensitive match.
In order to customize the QueryMatcher, use the following:
``` go
db, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual))
```
The query matcher can be fully customized based on user needs. **sqlmock** will not
provide a standard sql parsing matchers, since various drivers may not follow the same SQL standard.
## Matching arguments like time.Time ## Matching arguments like time.Time
There may be arguments which are of `struct` type and cannot be compared easily by value like `time.Time`. In this case There may be arguments which are of `struct` type and cannot be compared easily by value like `time.Time`. In this case
@ -191,6 +213,7 @@ It only asserts that argument is of `time.Time` type.
## Change Log ## Change Log
- **2018-12-11** - introduced an option to provide **QueryMatcher** in order to customize SQL query matching.
- **2017-09-01** - it is now possible to expect that prepared statement will be closed, - **2017-09-01** - it is now possible to expect that prepared statement will be closed,
using **ExpectedPrepare.WillBeClosed**. using **ExpectedPrepare.WillBeClosed**.
- **2017-02-09** - implemented support for **go1.8** features. **Rows** interface was changed to struct - **2017-02-09** - implemented support for **go1.8** features. **Rows** interface was changed to struct

View File

@ -3,7 +3,6 @@ package sqlmock
import ( import (
"database/sql/driver" "database/sql/driver"
"fmt" "fmt"
"regexp"
"strings" "strings"
"sync" "sync"
"time" "time"
@ -154,7 +153,7 @@ func (e *ExpectedQuery) WillDelayFor(duration time.Duration) *ExpectedQuery {
// String returns string representation // String returns string representation
func (e *ExpectedQuery) String() string { func (e *ExpectedQuery) String() string {
msg := "ExpectedQuery => expecting Query, QueryContext or QueryRow which:" msg := "ExpectedQuery => expecting Query, QueryContext or QueryRow which:"
msg += "\n - matches sql: '" + e.sqlRegex.String() + "'" msg += "\n - matches sql: '" + e.expectSQL + "'"
if len(e.args) == 0 { if len(e.args) == 0 {
msg += "\n - is without arguments" msg += "\n - is without arguments"
@ -209,7 +208,7 @@ func (e *ExpectedExec) WillDelayFor(duration time.Duration) *ExpectedExec {
// String returns string representation // String returns string representation
func (e *ExpectedExec) String() string { func (e *ExpectedExec) String() string {
msg := "ExpectedExec => expecting Exec or ExecContext which:" msg := "ExpectedExec => expecting Exec or ExecContext which:"
msg += "\n - matches sql: '" + e.sqlRegex.String() + "'" msg += "\n - matches sql: '" + e.expectSQL + "'"
if len(e.args) == 0 { if len(e.args) == 0 {
msg += "\n - is without arguments" msg += "\n - is without arguments"
@ -253,7 +252,7 @@ func (e *ExpectedExec) WillReturnResult(result driver.Result) *ExpectedExec {
type ExpectedPrepare struct { type ExpectedPrepare struct {
commonExpectation commonExpectation
mock *sqlmock mock *sqlmock
sqlRegex *regexp.Regexp expectSQL string
statement driver.Stmt statement driver.Stmt
closeErr error closeErr error
mustBeClosed bool mustBeClosed bool
@ -291,7 +290,7 @@ func (e *ExpectedPrepare) WillBeClosed() *ExpectedPrepare {
// this method is convenient in order to prevent duplicating sql query string matching. // this method is convenient in order to prevent duplicating sql query string matching.
func (e *ExpectedPrepare) ExpectQuery() *ExpectedQuery { func (e *ExpectedPrepare) ExpectQuery() *ExpectedQuery {
eq := &ExpectedQuery{} eq := &ExpectedQuery{}
eq.sqlRegex = e.sqlRegex eq.expectSQL = e.expectSQL
eq.converter = e.mock.converter eq.converter = e.mock.converter
e.mock.expected = append(e.mock.expected, eq) e.mock.expected = append(e.mock.expected, eq)
return eq return eq
@ -301,7 +300,7 @@ func (e *ExpectedPrepare) ExpectQuery() *ExpectedQuery {
// this method is convenient in order to prevent duplicating sql query string matching. // this method is convenient in order to prevent duplicating sql query string matching.
func (e *ExpectedPrepare) ExpectExec() *ExpectedExec { func (e *ExpectedPrepare) ExpectExec() *ExpectedExec {
eq := &ExpectedExec{} eq := &ExpectedExec{}
eq.sqlRegex = e.sqlRegex eq.expectSQL = e.expectSQL
eq.converter = e.mock.converter eq.converter = e.mock.converter
e.mock.expected = append(e.mock.expected, eq) e.mock.expected = append(e.mock.expected, eq)
return eq return eq
@ -310,7 +309,7 @@ func (e *ExpectedPrepare) ExpectExec() *ExpectedExec {
// String returns string representation // String returns string representation
func (e *ExpectedPrepare) String() string { func (e *ExpectedPrepare) String() string {
msg := "ExpectedPrepare => expecting Prepare statement which:" msg := "ExpectedPrepare => expecting Prepare statement which:"
msg += "\n - matches sql: '" + e.sqlRegex.String() + "'" msg += "\n - matches sql: '" + e.expectSQL + "'"
if e.err != nil { if e.err != nil {
msg += fmt.Sprintf("\n - should return error: %s", e.err) msg += fmt.Sprintf("\n - should return error: %s", e.err)
@ -327,16 +326,12 @@ func (e *ExpectedPrepare) String() string {
// adds a query matching logic // adds a query matching logic
type queryBasedExpectation struct { type queryBasedExpectation struct {
commonExpectation commonExpectation
sqlRegex *regexp.Regexp expectSQL string
converter driver.ValueConverter converter driver.ValueConverter
args []driver.Value args []driver.Value
} }
func (e *queryBasedExpectation) attemptMatch(sql string, args []namedValue) (err error) { func (e *queryBasedExpectation) attemptArgMatch(args []namedValue) (err error) {
if !e.queryMatches(sql) {
return fmt.Errorf(`could not match sql: "%s" with expected regexp "%s"`, sql, e.sqlRegex.String())
}
// catch panic // catch panic
defer func() { defer func() {
if e := recover(); e != nil { if e := recover(); e != nil {
@ -350,7 +345,3 @@ func (e *queryBasedExpectation) attemptMatch(sql string, args []namedValue) (err
err = e.argsMatches(args) err = e.argsMatches(args)
return return
} }
func (e *queryBasedExpectation) queryMatches(sql string) bool {
return e.sqlRegex.MatchString(sql)
}

View File

@ -3,7 +3,6 @@ package sqlmock
import ( import (
"database/sql/driver" "database/sql/driver"
"fmt" "fmt"
"regexp"
"testing" "testing"
"time" "time"
) )
@ -100,20 +99,6 @@ func TestQueryExpectationArgComparisonBool(t *testing.T) {
} }
} }
func TestQueryExpectationSqlMatch(t *testing.T) {
e := &ExpectedExec{}
e.sqlRegex = regexp.MustCompile("SELECT x FROM")
if !e.queryMatches("SELECT x FROM someting") {
t.Errorf("Sql must have matched the query")
}
e.sqlRegex = regexp.MustCompile("SELECT COUNT\\(x\\) FROM")
if !e.queryMatches("SELECT COUNT(x) FROM someting") {
t.Errorf("Sql must have matched the query")
}
}
func ExampleExpectedExec() { func ExampleExpectedExec() {
db, mock, _ := New() db, mock, _ := New()
result := NewErrorResult(fmt.Errorf("some error")) result := NewErrorResult(fmt.Errorf("some error"))

View File

@ -5,6 +5,42 @@ import (
"testing" "testing"
) )
func ExampleQueryMatcher() {
// configure to use case sensitive SQL query matcher
// instead of default regular expression matcher
db, mock, err := New(QueryMatcherOption(QueryMatcherEqual))
if err != nil {
fmt.Println("failed to open sqlmock database:", err)
}
defer db.Close()
rows := NewRows([]string{"id", "title"}).
AddRow(1, "one").
AddRow(2, "two")
mock.ExpectQuery("SELECT * FROM users").WillReturnRows(rows)
rs, err := db.Query("SELECT * FROM users")
if err != nil {
fmt.Println("failed to match expected query")
return
}
defer rs.Close()
for rs.Next() {
var id int
var title string
rs.Scan(&id, &title)
fmt.Println("scanned id:", id, "and title:", title)
}
if rs.Err() != nil {
fmt.Println("got rows error:", rs.Err())
}
// Output: scanned id: 1 and title: one
// scanned id: 2 and title: two
}
func TestQueryStringStripping(t *testing.T) { func TestQueryStringStripping(t *testing.T) {
assert := func(actual, expected string) { assert := func(actual, expected string) {
if res := stripQuery(actual); res != expected { if res := stripQuery(actual); res != expected {

View File

@ -14,7 +14,6 @@ import (
"database/sql" "database/sql"
"database/sql/driver" "database/sql/driver"
"fmt" "fmt"
"regexp"
"time" "time"
) )
@ -32,22 +31,19 @@ type Sqlmock interface {
// were met in order. If any of them was not met - an error is returned. // were met in order. If any of them was not met - an error is returned.
ExpectationsWereMet() error ExpectationsWereMet() error
// ExpectPrepare expects Prepare() to be called with sql query // ExpectPrepare expects Prepare() to be called with expectedSQL query.
// which match sqlRegexStr given regexp.
// the *ExpectedPrepare allows to mock database response. // the *ExpectedPrepare allows to mock database response.
// Note that you may expect Query() or Exec() on the *ExpectedPrepare // Note that you may expect Query() or Exec() on the *ExpectedPrepare
// statement to prevent repeating sqlRegexStr // statement to prevent repeating expectedSQL
ExpectPrepare(sqlRegexStr string) *ExpectedPrepare ExpectPrepare(expectedSQL string) *ExpectedPrepare
// ExpectQuery expects Query() or QueryRow() to be called with sql query // ExpectQuery expects Query() or QueryRow() to be called with expectedSQL query.
// which match sqlRegexStr given regexp.
// the *ExpectedQuery allows to mock database response. // the *ExpectedQuery allows to mock database response.
ExpectQuery(sqlRegexStr string) *ExpectedQuery ExpectQuery(expectedSQL string) *ExpectedQuery
// ExpectExec expects Exec() to be called with sql query // ExpectExec expects Exec() to be called with expectedSQL query.
// which match sqlRegexStr given regexp.
// the *ExpectedExec allows to mock database response // the *ExpectedExec allows to mock database response
ExpectExec(sqlRegexStr string) *ExpectedExec ExpectExec(expectedSQL string) *ExpectedExec
// ExpectBegin expects *sql.DB.Begin to be called. // ExpectBegin expects *sql.DB.Begin to be called.
// the *ExpectedBegin allows to mock database response // the *ExpectedBegin allows to mock database response
@ -260,7 +256,6 @@ func (c *sqlmock) Exec(query string, args []driver.Value) (driver.Result, error)
} }
func (c *sqlmock) exec(query string, args []namedValue) (*ExpectedExec, error) { func (c *sqlmock) exec(query string, args []namedValue) (*ExpectedExec, error) {
query = stripQuery(query)
var expected *ExpectedExec var expected *ExpectedExec
var fulfilled int var fulfilled int
var ok bool var ok bool
@ -280,7 +275,12 @@ func (c *sqlmock) exec(query string, args []namedValue) (*ExpectedExec, error) {
return nil, fmt.Errorf("call to ExecQuery '%s' with args %+v, was not expected, next expectation is: %s", query, args, next) return nil, fmt.Errorf("call to ExecQuery '%s' with args %+v, was not expected, next expectation is: %s", query, args, next)
} }
if exec, ok := next.(*ExpectedExec); ok { if exec, ok := next.(*ExpectedExec); ok {
if err := exec.attemptMatch(query, args); err == nil { if err := c.queryMatcher.Match(exec.expectSQL, query); err != nil {
next.Unlock()
continue
}
if err := exec.attemptArgMatch(args); err == nil {
expected = exec expected = exec
break break
} }
@ -296,8 +296,8 @@ func (c *sqlmock) exec(query string, args []namedValue) (*ExpectedExec, error) {
} }
defer expected.Unlock() defer expected.Unlock()
if !expected.queryMatches(query) { if err := c.queryMatcher.Match(expected.expectSQL, query); err != nil {
return nil, fmt.Errorf("ExecQuery '%s', does not match regex '%s'", query, expected.sqlRegex.String()) return nil, fmt.Errorf("ExecQuery: %v", err)
} }
if err := expected.argsMatches(args); err != nil { if err := expected.argsMatches(args); err != nil {
@ -316,10 +316,9 @@ func (c *sqlmock) exec(query string, args []namedValue) (*ExpectedExec, error) {
return expected, nil return expected, nil
} }
func (c *sqlmock) ExpectExec(sqlRegexStr string) *ExpectedExec { func (c *sqlmock) ExpectExec(expectedSQL string) *ExpectedExec {
e := &ExpectedExec{} e := &ExpectedExec{}
sqlRegexStr = stripQuery(sqlRegexStr) e.expectSQL = expectedSQL
e.sqlRegex = regexp.MustCompile(sqlRegexStr)
e.converter = c.converter e.converter = c.converter
c.expected = append(c.expected, e) c.expected = append(c.expected, e)
return e return e
@ -343,8 +342,6 @@ func (c *sqlmock) prepare(query string) (*ExpectedPrepare, error) {
var fulfilled int var fulfilled int
var ok bool var ok bool
query = stripQuery(query)
for _, next := range c.expected { for _, next := range c.expected {
next.Lock() next.Lock()
if next.fulfilled() { if next.fulfilled() {
@ -363,7 +360,7 @@ func (c *sqlmock) prepare(query string) (*ExpectedPrepare, error) {
} }
if pr, ok := next.(*ExpectedPrepare); ok { if pr, ok := next.(*ExpectedPrepare); ok {
if pr.sqlRegex.MatchString(query) { if err := c.queryMatcher.Match(pr.expectSQL, query); err == nil {
expected = pr expected = pr
break break
} }
@ -379,17 +376,16 @@ func (c *sqlmock) prepare(query string) (*ExpectedPrepare, error) {
return nil, fmt.Errorf(msg, query) return nil, fmt.Errorf(msg, query)
} }
defer expected.Unlock() defer expected.Unlock()
if !expected.sqlRegex.MatchString(query) { if err := c.queryMatcher.Match(expected.expectSQL, query); err != nil {
return nil, fmt.Errorf("Prepare query string '%s', does not match regex [%s]", query, expected.sqlRegex.String()) return nil, fmt.Errorf("Prepare: %v", err)
} }
expected.triggered = true expected.triggered = true
return expected, expected.err return expected, expected.err
} }
func (c *sqlmock) ExpectPrepare(sqlRegexStr string) *ExpectedPrepare { func (c *sqlmock) ExpectPrepare(expectedSQL string) *ExpectedPrepare {
sqlRegexStr = stripQuery(sqlRegexStr) e := &ExpectedPrepare{expectSQL: expectedSQL, mock: c}
e := &ExpectedPrepare{sqlRegex: regexp.MustCompile(sqlRegexStr), mock: c}
c.expected = append(c.expected, e) c.expected = append(c.expected, e)
return e return e
} }
@ -422,7 +418,6 @@ func (c *sqlmock) Query(query string, args []driver.Value) (driver.Rows, error)
} }
func (c *sqlmock) query(query string, args []namedValue) (*ExpectedQuery, error) { func (c *sqlmock) query(query string, args []namedValue) (*ExpectedQuery, error) {
query = stripQuery(query)
var expected *ExpectedQuery var expected *ExpectedQuery
var fulfilled int var fulfilled int
var ok bool var ok bool
@ -442,7 +437,11 @@ func (c *sqlmock) query(query string, args []namedValue) (*ExpectedQuery, error)
return nil, fmt.Errorf("call to Query '%s' with args %+v, was not expected, next expectation is: %s", query, args, next) return nil, fmt.Errorf("call to Query '%s' with args %+v, was not expected, next expectation is: %s", query, args, next)
} }
if qr, ok := next.(*ExpectedQuery); ok { if qr, ok := next.(*ExpectedQuery); ok {
if err := qr.attemptMatch(query, args); err == nil { if err := c.queryMatcher.Match(qr.expectSQL, query); err != nil {
next.Unlock()
continue
}
if err := qr.attemptArgMatch(args); err == nil {
expected = qr expected = qr
break break
} }
@ -460,8 +459,8 @@ func (c *sqlmock) query(query string, args []namedValue) (*ExpectedQuery, error)
defer expected.Unlock() defer expected.Unlock()
if !expected.queryMatches(query) { if err := c.queryMatcher.Match(expected.expectSQL, query); err != nil {
return nil, fmt.Errorf("Query '%s', does not match regex [%s]", query, expected.sqlRegex.String()) return nil, fmt.Errorf("Query: %v", err)
} }
if err := expected.argsMatches(args); err != nil { if err := expected.argsMatches(args); err != nil {
@ -479,10 +478,9 @@ func (c *sqlmock) query(query string, args []namedValue) (*ExpectedQuery, error)
return expected, nil return expected, nil
} }
func (c *sqlmock) ExpectQuery(sqlRegexStr string) *ExpectedQuery { func (c *sqlmock) ExpectQuery(expectedSQL string) *ExpectedQuery {
e := &ExpectedQuery{} e := &ExpectedQuery{}
sqlRegexStr = stripQuery(sqlRegexStr) e.expectSQL = expectedSQL
e.sqlRegex = regexp.MustCompile(sqlRegexStr)
e.converter = c.converter e.converter = c.converter
c.expected = append(c.expected, e) c.expected = append(c.expected, e)
return e return e

View File

@ -9,6 +9,8 @@ import (
"time" "time"
) )
// ErrCancelled defines an error value, which can be expected in case of
// such cancellation error.
var ErrCancelled = errors.New("canceling query due to user request") var ErrCancelled = errors.New("canceling query due to user request")
// Implement the "QueryerContext" interface // Implement the "QueryerContext" interface