1
0
mirror of https://github.com/DATA-DOG/go-sqlmock.git synced 2025-02-19 19:00:11 +02:00

add QueryMatcher interface for customizing SQL matching

This commit is contained in:
gedi 2018-12-11 14:22:16 +02:00
parent e4e10ddf73
commit 2a15d9c09b
No known key found for this signature in database
GPG Key ID: 56604CDCCC201556
6 changed files with 174 additions and 39 deletions

View File

@ -10,3 +10,13 @@ func ValueConverterOption(converter driver.ValueConverter) func(*sqlmock) error
return nil
}
}
// QueryMatcherOption allows to customize SQL query matcher
// and match SQL query strings in more sophisticated ways.
// The default QueryMatcher is QueryMatcherRegexp.
func QueryMatcherOption(queryMatcher QueryMatcher) func(*sqlmock) error {
return func(s *sqlmock) error {
s.queryMatcher = queryMatcher
return nil
}
}

68
query.go Normal file
View File

@ -0,0 +1,68 @@
package sqlmock
import (
"fmt"
"regexp"
"strings"
)
var re = regexp.MustCompile("\\s+")
// strip out new lines and trim spaces
func stripQuery(q string) (s string) {
return strings.TrimSpace(re.ReplaceAllString(q, " "))
}
// QueryMatcher is an SQL query string matcher interface,
// which can be used to customize validation of SQL query strings.
// As an exaple, external library could be used to build
// and validate SQL ast, columns selected.
//
// sqlmock can be customized to implement a different QueryMatcher
// configured through an option when sqlmock.New or sqlmock.NewWithDSN
// is called, default QueryMatcher is QueryMatcherRegexp.
type QueryMatcher interface {
// Match expected SQL query string without whitespace to
// actual SQL.
Match(expectedSQL, actualSQL string) error
}
// QueryMatcherFunc type is an adapter to allow the use of
// ordinary functions as QueryMatcher. If f is a function
// with the appropriate signature, QueryMatcherFunc(f) is a
// QueryMatcher that calls f.
type QueryMatcherFunc func(expectedSQL, actualSQL string) error
// Match implements the QueryMatcher
func (f QueryMatcherFunc) Match(expectedSQL, actualSQL string) error {
return f(expectedSQL, actualSQL)
}
// QueryMatcherRegexp is the default SQL query matcher
// used by sqlmock. It parses expectedSQL to a regular
// expression and attempts to match actualSQL.
var QueryMatcherRegexp QueryMatcher = QueryMatcherFunc(func(expectedSQL, actualSQL string) error {
expect := stripQuery(expectedSQL)
actual := stripQuery(actualSQL)
re, err := regexp.Compile(expect)
if err != nil {
return err
}
if !re.MatchString(actual) {
return fmt.Errorf(`could not match actual sql: "%s" with expected regexp "%s"`, actual, re.String())
}
return nil
})
// QueryMatcherEqual is the SQL query matcher
// which simply tries a case sensitive match of
// expected and actual SQL strings without whitespace.
var QueryMatcherEqual QueryMatcher = QueryMatcherFunc(func(expectedSQL, actualSQL string) error {
expect := stripQuery(expectedSQL)
actual := stripQuery(actualSQL)
if actual != expect {
return fmt.Errorf(`actual sql: "%s" does not equal to expected "%s"`, actual, expect)
}
return nil
})

87
query_test.go Normal file
View File

@ -0,0 +1,87 @@
package sqlmock
import (
"fmt"
"testing"
)
func TestQueryStringStripping(t *testing.T) {
assert := func(actual, expected string) {
if res := stripQuery(actual); res != expected {
t.Errorf("Expected '%s' to be '%s', but got '%s'", actual, expected, res)
}
}
assert(" SELECT 1", "SELECT 1")
assert("SELECT 1 FROM d", "SELECT 1 FROM d")
assert(`
SELECT c
FROM D
`, "SELECT c FROM D")
assert("UPDATE (.+) SET ", "UPDATE (.+) SET")
}
func TestQueryMatcherRegexp(t *testing.T) {
type testCase struct {
expected string
actual string
err error
}
cases := []testCase{
{"?\\l", "SEL", fmt.Errorf("error parsing regexp: missing argument to repetition operator: `?`")},
{"SELECT (.+) FROM users", "SELECT name, email FROM users WHERE id = ?", nil},
{"Select (.+) FROM users", "SELECT name, email FROM users WHERE id = ?", fmt.Errorf(`could not match actual sql: "SELECT name, email FROM users WHERE id = ?" with expected regexp "Select (.+) FROM users"`)},
{"SELECT (.+) FROM\nusers", "SELECT name, email\n FROM users\n WHERE id = ?", nil},
}
for i, c := range cases {
err := QueryMatcherRegexp.Match(c.expected, c.actual)
if err == nil && c.err != nil {
t.Errorf(`got no error, but expected "%v" at %d case`, c.err, i)
continue
}
if err != nil && c.err == nil {
t.Errorf(`got unexpected error "%v" at %d case`, err, i)
continue
}
if err == nil {
continue
}
if err.Error() != c.err.Error() {
t.Errorf(`expected error "%v", but got "%v" at %d case`, c.err, err, i)
}
}
}
func TestQueryMatcherEqual(t *testing.T) {
type testCase struct {
expected string
actual string
err error
}
cases := []testCase{
{"SELECT name, email FROM users WHERE id = ?", "SELECT name, email\n FROM users\n WHERE id = ?", nil},
{"SELECT", "Select", fmt.Errorf(`actual sql: "Select" does not equal to expected "SELECT"`)},
{"SELECT from users", "SELECT from table", fmt.Errorf(`actual sql: "SELECT from table" does not equal to expected "SELECT from users"`)},
}
for i, c := range cases {
err := QueryMatcherEqual.Match(c.expected, c.actual)
if err == nil && c.err != nil {
t.Errorf(`got no error, but expected "%v" at %d case`, c.err, i)
continue
}
if err != nil && c.err == nil {
t.Errorf(`got unexpected error "%v" at %d case`, err, i)
continue
}
if err == nil {
continue
}
if err.Error() != c.err.Error() {
t.Errorf(`expected error "%v", but got "%v" at %d case`, c.err, err, i)
}
}
}

View File

@ -81,11 +81,12 @@ type Sqlmock interface {
}
type sqlmock struct {
ordered bool
dsn string
opened int
drv *mockDriver
converter driver.ValueConverter
ordered bool
dsn string
opened int
drv *mockDriver
converter driver.ValueConverter
queryMatcher QueryMatcher
expected []expectation
}
@ -104,6 +105,9 @@ func (c *sqlmock) open(options []func(*sqlmock) error) (*sql.DB, Sqlmock, error)
if c.converter == nil {
c.converter = driver.DefaultParameterConverter
}
if c.queryMatcher == nil {
c.queryMatcher = QueryMatcherRegexp
}
return db, c, db.Ping()
}

13
util.go
View File

@ -1,13 +0,0 @@
package sqlmock
import (
"regexp"
"strings"
)
var re = regexp.MustCompile("\\s+")
// strip out new lines and trim spaces
func stripQuery(q string) (s string) {
return strings.TrimSpace(re.ReplaceAllString(q, " "))
}

View File

@ -1,21 +0,0 @@
package sqlmock
import (
"testing"
)
func TestQueryStringStripping(t *testing.T) {
assert := func(actual, expected string) {
if res := stripQuery(actual); res != expected {
t.Errorf("Expected '%s' to be '%s', but got '%s'", actual, expected, res)
}
}
assert(" SELECT 1", "SELECT 1")
assert("SELECT 1 FROM d", "SELECT 1 FROM d")
assert(`
SELECT c
FROM D
`, "SELECT c FROM D")
assert("UPDATE (.+) SET ", "UPDATE (.+) SET")
}