1
0
mirror of https://github.com/DATA-DOG/go-sqlmock.git synced 2025-04-21 11:56:50 +02:00

add custom argument matching

This commit is contained in:
Algirdas Matas 2014-05-29 16:43:37 +03:00
parent fa31f407df
commit 27fabfa23a
3 changed files with 196 additions and 171 deletions

View File

@ -1,118 +1,118 @@
package sqlmock package sqlmock
import ( import (
"database/sql/driver" "database/sql/driver"
"fmt" "fmt"
) )
type conn struct { type conn struct {
expectations []expectation expectations []expectation
active expectation active expectation
} }
// Close a mock database driver connection. It should // Close a mock database driver connection. It should
// be always called to ensure that all expectations // be always called to ensure that all expectations
// were met successfully. Returns error if there is any // were met successfully. Returns error if there is any
func (c *conn) Close() (err error) { func (c *conn) Close() (err error) {
for _, e := range mock.conn.expectations { for _, e := range mock.conn.expectations {
if !e.fulfilled() { if !e.fulfilled() {
err = fmt.Errorf("there is a remaining expectation %T which was not matched yet", e) err = fmt.Errorf("there is a remaining expectation %T which was not matched yet", e)
break break
} }
} }
mock.conn.expectations = []expectation{} mock.conn.expectations = []expectation{}
mock.conn.active = nil mock.conn.active = nil
return err return err
} }
func (c *conn) Begin() (driver.Tx, error) { func (c *conn) Begin() (driver.Tx, error) {
e := c.next() e := c.next()
if e == nil { if e == nil {
return nil, fmt.Errorf("all expectations were already fulfilled, call to begin transaction was not expected") return nil, fmt.Errorf("all expectations were already fulfilled, call to begin transaction was not expected")
} }
etb, ok := e.(*expectedBegin) etb, ok := e.(*expectedBegin)
if !ok { if !ok {
return nil, fmt.Errorf("call to begin transaction, was not expected, next expectation is %T as %+v", e, e) return nil, fmt.Errorf("call to begin transaction, was not expected, next expectation is %T as %+v", e, e)
} }
etb.triggered = true etb.triggered = true
return &transaction{c}, etb.err return &transaction{c}, etb.err
} }
// get next unfulfilled expectation // get next unfulfilled expectation
func (c *conn) next() (e expectation) { func (c *conn) next() (e expectation) {
for _, e = range c.expectations { for _, e = range c.expectations {
if !e.fulfilled() { if !e.fulfilled() {
return return
} }
} }
return nil // all expectations were fulfilled return nil // all expectations were fulfilled
} }
func (c *conn) Exec(query string, args []driver.Value) (driver.Result, error) { func (c *conn) Exec(query string, args []driver.Value) (driver.Result, error) {
e := c.next() e := c.next()
query = stripQuery(query) query = stripQuery(query)
if e == nil { if e == nil {
return nil, fmt.Errorf("all expectations were already fulfilled, call to exec '%s' query with args %+v was not expected", query, args) return nil, fmt.Errorf("all expectations were already fulfilled, call to exec '%s' query with args %+v was not expected", query, args)
} }
eq, ok := e.(*expectedExec) eq, ok := e.(*expectedExec)
if !ok { if !ok {
return nil, fmt.Errorf("call to exec query '%s' with args %+v, was not expected, next expectation is %T as %+v", query, args, e, e) return nil, fmt.Errorf("call to exec query '%s' with args %+v, was not expected, next expectation is %T as %+v", query, args, e, e)
} }
eq.triggered = true eq.triggered = true
if eq.err != nil { if eq.err != nil {
return nil, eq.err // mocked to return error return nil, eq.err // mocked to return error
} }
if eq.result == nil { if eq.result == nil {
return nil, fmt.Errorf("exec query '%s' with args %+v, must return a database/sql/driver.result, but it was not set for expectation %T as %+v", query, args, eq, eq) return nil, fmt.Errorf("exec query '%s' with args %+v, must return a database/sql/driver.result, but it was not set for expectation %T as %+v", query, args, eq, eq)
} }
if !eq.queryMatches(query) { if !eq.queryMatches(query) {
return nil, fmt.Errorf("exec query '%s', does not match regex '%s'", query, eq.sqlRegex.String()) return nil, fmt.Errorf("exec query '%s', does not match regex '%s'", query, eq.sqlRegex.String())
} }
if !eq.argsMatches(args) { if !eq.argsMatches(args) {
return nil, fmt.Errorf("exec query '%s', args %+v does not match expected %+v", query, args, eq.args) return nil, fmt.Errorf("exec query '%s', args %+v does not match expected %+v", query, args, eq.args)
} }
return eq.result, nil return eq.result, nil
} }
func (c *conn) Prepare(query string) (driver.Stmt, error) { func (c *conn) Prepare(query string) (driver.Stmt, error) {
return &statement{mock.conn, stripQuery(query)}, nil return &statement{mock.conn, stripQuery(query)}, nil
} }
func (c *conn) Query(query string, args []driver.Value) (driver.Rows, error) { func (c *conn) Query(query string, args []driver.Value) (driver.Rows, error) {
e := c.next() e := c.next()
query = stripQuery(query) query = stripQuery(query)
if e == nil { if e == nil {
return nil, fmt.Errorf("all expectations were already fulfilled, call to query '%s' with args %+v was not expected", query, args) return nil, fmt.Errorf("all expectations were already fulfilled, call to query '%s' with args %+v was not expected", query, args)
} }
eq, ok := e.(*expectedQuery) eq, ok := e.(*expectedQuery)
if !ok { if !ok {
return nil, fmt.Errorf("call to query '%s' with args %+v, was not expected, next expectation is %T as %+v", query, args, e, e) return nil, fmt.Errorf("call to query '%s' with args %+v, was not expected, next expectation is %T as %+v", query, args, e, e)
} }
eq.triggered = true eq.triggered = true
if eq.err != nil { if eq.err != nil {
return nil, eq.err // mocked to return error return nil, eq.err // mocked to return error
} }
if eq.rows == nil { if eq.rows == nil {
return nil, fmt.Errorf("query '%s' with args %+v, must return a database/sql/driver.rows, but it was not set for expectation %T as %+v", query, args, eq, eq) return nil, fmt.Errorf("query '%s' with args %+v, must return a database/sql/driver.rows, but it was not set for expectation %T as %+v", query, args, eq, eq)
} }
if !eq.queryMatches(query) { if !eq.queryMatches(query) {
return nil, fmt.Errorf("query '%s', does not match regex [%s]", query, eq.sqlRegex.String()) return nil, fmt.Errorf("query '%s', does not match regex [%s]", query, eq.sqlRegex.String())
} }
if !eq.argsMatches(args) { if !eq.argsMatches(args) {
return nil, fmt.Errorf("query '%s', args %+v does not match expected %+v", query, args, eq.args) return nil, fmt.Errorf("query '%s', args %+v does not match expected %+v", query, args, eq.args)
} }
return eq.rows, nil return eq.rows, nil
} }

View File

@ -1,106 +1,119 @@
package sqlmock package sqlmock
import ( import (
"database/sql/driver" "database/sql/driver"
"reflect" "reflect"
"regexp" "regexp"
) )
// Argument interface allows to match
// any argument in specific way
type Argument interface {
Match(driver.Value) bool
}
// an expectation interface // an expectation interface
type expectation interface { type expectation interface {
fulfilled() bool fulfilled() bool
setError(err error) setError(err error)
} }
// common expectation struct // common expectation struct
// satisfies the expectation interface // satisfies the expectation interface
type commonExpectation struct { type commonExpectation struct {
triggered bool triggered bool
err error err error
} }
func (e *commonExpectation) fulfilled() bool { func (e *commonExpectation) fulfilled() bool {
return e.triggered return e.triggered
} }
func (e *commonExpectation) setError(err error) { func (e *commonExpectation) setError(err error) {
e.err = err e.err = err
} }
// query based expectation // query based expectation
// adds a query matching logic // adds a query matching logic
type queryBasedExpectation struct { type queryBasedExpectation struct {
commonExpectation commonExpectation
sqlRegex *regexp.Regexp sqlRegex *regexp.Regexp
args []driver.Value args []driver.Value
} }
func (e *queryBasedExpectation) queryMatches(sql string) bool { func (e *queryBasedExpectation) queryMatches(sql string) bool {
return e.sqlRegex.MatchString(sql) return e.sqlRegex.MatchString(sql)
} }
func (e *queryBasedExpectation) argsMatches(args []driver.Value) bool { func (e *queryBasedExpectation) argsMatches(args []driver.Value) bool {
if nil == e.args { if nil == e.args {
return true return true
} }
if len(args) != len(e.args) { if len(args) != len(e.args) {
return false return false
} }
for k, v := range args { for k, v := range args {
vi := reflect.ValueOf(v) matcher, ok := e.args[k].(Argument)
ai := reflect.ValueOf(e.args[k]) if ok {
switch vi.Kind() { if !matcher.Match(v) {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: return false
if vi.Int() != ai.Int() { }
return false continue
} }
case reflect.Float32, reflect.Float64: vi := reflect.ValueOf(v)
if vi.Float() != ai.Float() { ai := reflect.ValueOf(e.args[k])
return false switch vi.Kind() {
} case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: if vi.Int() != ai.Int() {
if vi.Uint() != ai.Uint() { return false
return false }
} case reflect.Float32, reflect.Float64:
case reflect.String: if vi.Float() != ai.Float() {
if vi.String() != ai.String() { return false
return false }
} case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
default: if vi.Uint() != ai.Uint() {
// compare types like time.Time based on type only return false
if vi.Kind() != ai.Kind() { }
return false case reflect.String:
} if vi.String() != ai.String() {
} return false
} }
return true default:
// compare types like time.Time based on type only
if vi.Kind() != ai.Kind() {
return false
}
}
}
return true
} }
// begin transaction // begin transaction
type expectedBegin struct { type expectedBegin struct {
commonExpectation commonExpectation
} }
// tx commit // tx commit
type expectedCommit struct { type expectedCommit struct {
commonExpectation commonExpectation
} }
// tx rollback // tx rollback
type expectedRollback struct { type expectedRollback struct {
commonExpectation commonExpectation
} }
// query expectation // query expectation
type expectedQuery struct { type expectedQuery struct {
queryBasedExpectation queryBasedExpectation
rows driver.Rows rows driver.Rows
} }
// exec query expectation // exec query expectation
type expectedExec struct { type expectedExec struct {
queryBasedExpectation queryBasedExpectation
result driver.Result result driver.Result
} }

View File

@ -1,47 +1,59 @@
package sqlmock package sqlmock
import ( import (
"database/sql/driver" "database/sql/driver"
"testing" "testing"
"time" "time"
) )
func TestQueryExpectationArgComparison(t *testing.T) { type matcher struct {
e := &queryBasedExpectation{} }
against := []driver.Value{5}
if !e.argsMatches(against) { func (m matcher) Match(driver.Value) bool {
t.Error("arguments should match, since the no expectation was set") return true
} }
e.args = []driver.Value{5, "str"} func TestQueryExpectationArgComparison(t *testing.T) {
e := &queryBasedExpectation{}
against = []driver.Value{5} against := []driver.Value{5}
if e.argsMatches(against) { if !e.argsMatches(against) {
t.Error("arguments should not match, since the size is not the same") t.Error("arguments should match, since the no expectation was set")
} }
against = []driver.Value{3, "str"} e.args = []driver.Value{5, "str"}
if e.argsMatches(against) {
t.Error("arguments should not match, since the first argument (int value) is different") against = []driver.Value{5}
} if e.argsMatches(against) {
t.Error("arguments should not match, since the size is not the same")
against = []driver.Value{5, "st"} }
if e.argsMatches(against) {
t.Error("arguments should not match, since the second argument (string value) is different") against = []driver.Value{3, "str"}
} if e.argsMatches(against) {
t.Error("arguments should not match, since the first argument (int value) is different")
against = []driver.Value{5, "str"} }
if !e.argsMatches(against) {
t.Error("arguments should match, but it did not") against = []driver.Value{5, "st"}
} if e.argsMatches(against) {
t.Error("arguments should not match, since the second argument (string value) is different")
e.args = []driver.Value{5, time.Now()} }
const longForm = "Jan 2, 2006 at 3:04pm (MST)" against = []driver.Value{5, "str"}
tm, _ := time.Parse(longForm, "Feb 3, 2013 at 7:54pm (PST)") if !e.argsMatches(against) {
t.Error("arguments should match, but it did not")
against = []driver.Value{5, tm} }
if !e.argsMatches(against) {
t.Error("arguments should match (time will be compared only by type), but it did not") e.args = []driver.Value{5, time.Now()}
}
const longForm = "Jan 2, 2006 at 3:04pm (MST)"
tm, _ := time.Parse(longForm, "Feb 3, 2013 at 7:54pm (PST)")
against = []driver.Value{5, tm}
if !e.argsMatches(against) {
t.Error("arguments should match (time will be compared only by type), but it did not")
}
against = []driver.Value{5, matcher{}}
if !e.argsMatches(against) {
t.Error("arguments should match, but it did not")
}
} }