From 2a15d9c09b0fef7e2722e30109a2efe0b398a2df Mon Sep 17 00:00:00 2001 From: gedi Date: Tue, 11 Dec 2018 14:22:16 +0200 Subject: [PATCH] add QueryMatcher interface for customizing SQL matching --- options.go | 10 ++++++ query.go | 68 ++++++++++++++++++++++++++++++++++++++++ query_test.go | 87 +++++++++++++++++++++++++++++++++++++++++++++++++++ sqlmock.go | 14 ++++++--- util.go | 13 -------- util_test.go | 21 ------------- 6 files changed, 174 insertions(+), 39 deletions(-) create mode 100644 query.go create mode 100644 query_test.go delete mode 100644 util.go delete mode 100644 util_test.go diff --git a/options.go b/options.go index 05c09d6..29053ee 100644 --- a/options.go +++ b/options.go @@ -10,3 +10,13 @@ func ValueConverterOption(converter driver.ValueConverter) func(*sqlmock) error return nil } } + +// QueryMatcherOption allows to customize SQL query matcher +// and match SQL query strings in more sophisticated ways. +// The default QueryMatcher is QueryMatcherRegexp. +func QueryMatcherOption(queryMatcher QueryMatcher) func(*sqlmock) error { + return func(s *sqlmock) error { + s.queryMatcher = queryMatcher + return nil + } +} diff --git a/query.go b/query.go new file mode 100644 index 0000000..8e05848 --- /dev/null +++ b/query.go @@ -0,0 +1,68 @@ +package sqlmock + +import ( + "fmt" + "regexp" + "strings" +) + +var re = regexp.MustCompile("\\s+") + +// strip out new lines and trim spaces +func stripQuery(q string) (s string) { + return strings.TrimSpace(re.ReplaceAllString(q, " ")) +} + +// QueryMatcher is an SQL query string matcher interface, +// which can be used to customize validation of SQL query strings. +// As an exaple, external library could be used to build +// and validate SQL ast, columns selected. +// +// sqlmock can be customized to implement a different QueryMatcher +// configured through an option when sqlmock.New or sqlmock.NewWithDSN +// is called, default QueryMatcher is QueryMatcherRegexp. +type QueryMatcher interface { + + // Match expected SQL query string without whitespace to + // actual SQL. + Match(expectedSQL, actualSQL string) error +} + +// QueryMatcherFunc type is an adapter to allow the use of +// ordinary functions as QueryMatcher. If f is a function +// with the appropriate signature, QueryMatcherFunc(f) is a +// QueryMatcher that calls f. +type QueryMatcherFunc func(expectedSQL, actualSQL string) error + +// Match implements the QueryMatcher +func (f QueryMatcherFunc) Match(expectedSQL, actualSQL string) error { + return f(expectedSQL, actualSQL) +} + +// QueryMatcherRegexp is the default SQL query matcher +// used by sqlmock. It parses expectedSQL to a regular +// expression and attempts to match actualSQL. +var QueryMatcherRegexp QueryMatcher = QueryMatcherFunc(func(expectedSQL, actualSQL string) error { + expect := stripQuery(expectedSQL) + actual := stripQuery(actualSQL) + re, err := regexp.Compile(expect) + if err != nil { + return err + } + if !re.MatchString(actual) { + return fmt.Errorf(`could not match actual sql: "%s" with expected regexp "%s"`, actual, re.String()) + } + return nil +}) + +// QueryMatcherEqual is the SQL query matcher +// which simply tries a case sensitive match of +// expected and actual SQL strings without whitespace. +var QueryMatcherEqual QueryMatcher = QueryMatcherFunc(func(expectedSQL, actualSQL string) error { + expect := stripQuery(expectedSQL) + actual := stripQuery(actualSQL) + if actual != expect { + return fmt.Errorf(`actual sql: "%s" does not equal to expected "%s"`, actual, expect) + } + return nil +}) diff --git a/query_test.go b/query_test.go new file mode 100644 index 0000000..f9546ca --- /dev/null +++ b/query_test.go @@ -0,0 +1,87 @@ +package sqlmock + +import ( + "fmt" + "testing" +) + +func TestQueryStringStripping(t *testing.T) { + assert := func(actual, expected string) { + if res := stripQuery(actual); res != expected { + t.Errorf("Expected '%s' to be '%s', but got '%s'", actual, expected, res) + } + } + + assert(" SELECT 1", "SELECT 1") + assert("SELECT 1 FROM d", "SELECT 1 FROM d") + assert(` + SELECT c + FROM D +`, "SELECT c FROM D") + assert("UPDATE (.+) SET ", "UPDATE (.+) SET") +} + +func TestQueryMatcherRegexp(t *testing.T) { + type testCase struct { + expected string + actual string + err error + } + + cases := []testCase{ + {"?\\l", "SEL", fmt.Errorf("error parsing regexp: missing argument to repetition operator: `?`")}, + {"SELECT (.+) FROM users", "SELECT name, email FROM users WHERE id = ?", nil}, + {"Select (.+) FROM users", "SELECT name, email FROM users WHERE id = ?", fmt.Errorf(`could not match actual sql: "SELECT name, email FROM users WHERE id = ?" with expected regexp "Select (.+) FROM users"`)}, + {"SELECT (.+) FROM\nusers", "SELECT name, email\n FROM users\n WHERE id = ?", nil}, + } + + for i, c := range cases { + err := QueryMatcherRegexp.Match(c.expected, c.actual) + if err == nil && c.err != nil { + t.Errorf(`got no error, but expected "%v" at %d case`, c.err, i) + continue + } + if err != nil && c.err == nil { + t.Errorf(`got unexpected error "%v" at %d case`, err, i) + continue + } + if err == nil { + continue + } + if err.Error() != c.err.Error() { + t.Errorf(`expected error "%v", but got "%v" at %d case`, c.err, err, i) + } + } +} + +func TestQueryMatcherEqual(t *testing.T) { + type testCase struct { + expected string + actual string + err error + } + + cases := []testCase{ + {"SELECT name, email FROM users WHERE id = ?", "SELECT name, email\n FROM users\n WHERE id = ?", nil}, + {"SELECT", "Select", fmt.Errorf(`actual sql: "Select" does not equal to expected "SELECT"`)}, + {"SELECT from users", "SELECT from table", fmt.Errorf(`actual sql: "SELECT from table" does not equal to expected "SELECT from users"`)}, + } + + for i, c := range cases { + err := QueryMatcherEqual.Match(c.expected, c.actual) + if err == nil && c.err != nil { + t.Errorf(`got no error, but expected "%v" at %d case`, c.err, i) + continue + } + if err != nil && c.err == nil { + t.Errorf(`got unexpected error "%v" at %d case`, err, i) + continue + } + if err == nil { + continue + } + if err.Error() != c.err.Error() { + t.Errorf(`expected error "%v", but got "%v" at %d case`, c.err, err, i) + } + } +} diff --git a/sqlmock.go b/sqlmock.go index d0e79c1..113ecd8 100644 --- a/sqlmock.go +++ b/sqlmock.go @@ -81,11 +81,12 @@ type Sqlmock interface { } type sqlmock struct { - ordered bool - dsn string - opened int - drv *mockDriver - converter driver.ValueConverter + ordered bool + dsn string + opened int + drv *mockDriver + converter driver.ValueConverter + queryMatcher QueryMatcher expected []expectation } @@ -104,6 +105,9 @@ func (c *sqlmock) open(options []func(*sqlmock) error) (*sql.DB, Sqlmock, error) if c.converter == nil { c.converter = driver.DefaultParameterConverter } + if c.queryMatcher == nil { + c.queryMatcher = QueryMatcherRegexp + } return db, c, db.Ping() } diff --git a/util.go b/util.go deleted file mode 100644 index 072e380..0000000 --- a/util.go +++ /dev/null @@ -1,13 +0,0 @@ -package sqlmock - -import ( - "regexp" - "strings" -) - -var re = regexp.MustCompile("\\s+") - -// strip out new lines and trim spaces -func stripQuery(q string) (s string) { - return strings.TrimSpace(re.ReplaceAllString(q, " ")) -} diff --git a/util_test.go b/util_test.go deleted file mode 100644 index c4b3974..0000000 --- a/util_test.go +++ /dev/null @@ -1,21 +0,0 @@ -package sqlmock - -import ( - "testing" -) - -func TestQueryStringStripping(t *testing.T) { - assert := func(actual, expected string) { - if res := stripQuery(actual); res != expected { - t.Errorf("Expected '%s' to be '%s', but got '%s'", actual, expected, res) - } - } - - assert(" SELECT 1", "SELECT 1") - assert("SELECT 1 FROM d", "SELECT 1 FROM d") - assert(` - SELECT c - FROM D -`, "SELECT c FROM D") - assert("UPDATE (.+) SET ", "UPDATE (.+) SET") -}