1
0
mirror of https://github.com/DATA-DOG/go-sqlmock.git synced 2025-06-23 00:17:47 +02:00
This commit is contained in:
gedi
2016-02-23 11:14:34 +02:00
parent 808cdc9973
commit de514b7bf0
4 changed files with 64 additions and 93 deletions

View File

@ -3,7 +3,6 @@ package sqlmock
import (
"database/sql/driver"
"fmt"
"reflect"
"regexp"
"strings"
"sync"
@ -307,16 +306,22 @@ type queryBasedExpectation struct {
args []driver.Value
}
func (e *queryBasedExpectation) attemptMatch(sql string, args []driver.Value) (ret bool) {
func (e *queryBasedExpectation) attemptMatch(sql string, args []driver.Value) (err error) {
if !e.queryMatches(sql) {
return
return fmt.Errorf(`could not match sql: "%s" with expected regexp "%s"`, sql, e.sqlRegex.String())
}
defer recover() // ignore panic since we attempt a match
// catch panic
defer func() {
if e := recover(); e != nil {
_, ok := e.(error)
if !ok {
err = fmt.Errorf(e.(string))
}
}
}()
if e.argsMatches(args) {
return true
}
err = e.argsMatches(args)
return
}
@ -324,50 +329,36 @@ func (e *queryBasedExpectation) queryMatches(sql string) bool {
return e.sqlRegex.MatchString(sql)
}
func (e *queryBasedExpectation) argsMatches(args []driver.Value) bool {
func (e *queryBasedExpectation) argsMatches(args []driver.Value) error {
if nil == e.args {
return true
return nil
}
if len(args) != len(e.args) {
return false
return fmt.Errorf("expected %d, but got %d arguments", len(e.args), len(args))
}
for k, v := range args {
// custom argument matcher
matcher, ok := e.args[k].(Argument)
if ok {
if !matcher.Match(v) {
return false
return fmt.Errorf("matcher %T could not match %d argument %T - %+v", matcher, k, args[k], args[k])
}
continue
}
vi := reflect.ValueOf(v)
ai := reflect.ValueOf(e.args[k])
switch vi.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
if vi.Int() != ai.Int() {
return false
}
case reflect.Float32, reflect.Float64:
if vi.Float() != ai.Float() {
return false
}
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
if vi.Uint() != ai.Uint() {
return false
}
case reflect.String:
if vi.String() != ai.String() {
return false
}
case reflect.Bool:
if vi.Bool() != ai.Bool() {
return false
}
default:
// compare types like time.Time based on type only
if vi.Kind() != ai.Kind() {
return false
}
// convert to driver converter
darg, err := driver.DefaultParameterConverter.ConvertValue(e.args[k])
if err != nil {
return fmt.Errorf("could not convert %d argument %T - %+v to driver value: %s", k, e.args[k], e.args[k], err)
}
if !driver.IsValue(darg) {
return fmt.Errorf("argument %d: non-subset type %T returned from Value", k, darg)
}
if darg != args[k] {
return fmt.Errorf("argument %d expected [%T - %+v] does not match actual [%T - %+v]", k, darg, darg, args[k], args[k])
}
}
return true
return nil
}