1
0
mirror of https://github.com/DATA-DOG/go-sqlmock.git synced 2025-04-15 11:36:45 +02:00

allow to use a custom converter

This commit is contained in:
Jan Waś 2018-08-06 22:29:24 +02:00
parent c8e01dc244
commit f2bc8f904e
9 changed files with 96 additions and 27 deletions

View File

@ -39,7 +39,7 @@ func (d *mockDriver) Open(dsn string) (driver.Conn, error) {
// and a mock to manage expectations. // and a mock to manage expectations.
// Pings db so that all expectations could be // Pings db so that all expectations could be
// asserted. // asserted.
func New() (*sql.DB, Sqlmock, error) { func New(options ...func(*sqlmock) error) (*sql.DB, Sqlmock, error) {
pool.Lock() pool.Lock()
dsn := fmt.Sprintf("sqlmock_db_%d", pool.counter) dsn := fmt.Sprintf("sqlmock_db_%d", pool.counter)
pool.counter++ pool.counter++
@ -48,7 +48,7 @@ func New() (*sql.DB, Sqlmock, error) {
pool.conns[dsn] = smock pool.conns[dsn] = smock
pool.Unlock() pool.Unlock()
return smock.open() return smock.open(options)
} }
// NewWithDSN creates sqlmock database connection // NewWithDSN creates sqlmock database connection
@ -64,7 +64,7 @@ func New() (*sql.DB, Sqlmock, error) {
// //
// It is not recommended to use this method, unless you // It is not recommended to use this method, unless you
// really need it and there is no other way around. // really need it and there is no other way around.
func NewWithDSN(dsn string) (*sql.DB, Sqlmock, error) { func NewWithDSN(dsn string, options ...func(*sqlmock) error) (*sql.DB, Sqlmock, error) {
pool.Lock() pool.Lock()
if _, ok := pool.conns[dsn]; ok { if _, ok := pool.conns[dsn]; ok {
pool.Unlock() pool.Unlock()
@ -74,5 +74,14 @@ func NewWithDSN(dsn string) (*sql.DB, Sqlmock, error) {
pool.conns[dsn] = smock pool.conns[dsn] = smock
pool.Unlock() pool.Unlock()
return smock.open() return smock.open(options)
}
// WithValueConverter allows to create a sqlmock connection
// with a custom ValueConverter to support drivers with special data types.
func WithValueConverter(converter driver.ValueConverter) func(*sqlmock) error {
return func(s *sqlmock) error {
s.converter = converter
return nil
}
} }

View File

@ -1,6 +1,8 @@
package sqlmock package sqlmock
import ( import (
"database/sql/driver"
"errors"
"fmt" "fmt"
"testing" "testing"
) )
@ -9,6 +11,12 @@ type void struct{}
func (void) Print(...interface{}) {} func (void) Print(...interface{}) {}
type converter struct{}
func (c *converter) ConvertValue(v interface{}) (driver.Value, error) {
return nil, errors.New("converter disabled")
}
func ExampleNew() { func ExampleNew() {
db, mock, err := New() db, mock, err := New()
if err != nil { if err != nil {
@ -90,6 +98,18 @@ func TestTwoOpenConnectionsOnTheSameDSN(t *testing.T) {
} }
} }
func TestWithOptions(t *testing.T) {
c := &converter{}
_, mock, err := New(WithValueConverter(c))
if err != nil {
t.Errorf("expected no error, but got: %s", err)
}
smock, _ := mock.(*sqlmock)
if smock.converter.(*converter) != c {
t.Errorf("expected a custom converter to be set")
}
}
func TestWrongDSN(t *testing.T) { func TestWrongDSN(t *testing.T) {
t.Parallel() t.Parallel()
db, _, _ := New() db, _, _ := New()

View File

@ -292,6 +292,7 @@ func (e *ExpectedPrepare) WillBeClosed() *ExpectedPrepare {
func (e *ExpectedPrepare) ExpectQuery() *ExpectedQuery { func (e *ExpectedPrepare) ExpectQuery() *ExpectedQuery {
eq := &ExpectedQuery{} eq := &ExpectedQuery{}
eq.sqlRegex = e.sqlRegex eq.sqlRegex = e.sqlRegex
eq.converter = e.mock.converter
e.mock.expected = append(e.mock.expected, eq) e.mock.expected = append(e.mock.expected, eq)
return eq return eq
} }
@ -301,6 +302,7 @@ func (e *ExpectedPrepare) ExpectQuery() *ExpectedQuery {
func (e *ExpectedPrepare) ExpectExec() *ExpectedExec { func (e *ExpectedPrepare) ExpectExec() *ExpectedExec {
eq := &ExpectedExec{} eq := &ExpectedExec{}
eq.sqlRegex = e.sqlRegex eq.sqlRegex = e.sqlRegex
eq.converter = e.mock.converter
e.mock.expected = append(e.mock.expected, eq) e.mock.expected = append(e.mock.expected, eq)
return eq return eq
} }
@ -325,8 +327,9 @@ func (e *ExpectedPrepare) String() string {
// adds a query matching logic // adds a query matching logic
type queryBasedExpectation struct { type queryBasedExpectation struct {
commonExpectation commonExpectation
sqlRegex *regexp.Regexp sqlRegex *regexp.Regexp
args []driver.Value converter driver.ValueConverter
args []driver.Value
} }
func (e *queryBasedExpectation) attemptMatch(sql string, args []namedValue) (err error) { func (e *queryBasedExpectation) attemptMatch(sql string, args []namedValue) (err error) {

View File

@ -35,7 +35,7 @@ func (e *queryBasedExpectation) argsMatches(args []namedValue) error {
dval := e.args[k] dval := e.args[k]
// convert to driver converter // convert to driver converter
darg, err := driver.DefaultParameterConverter.ConvertValue(dval) darg, err := e.converter.ConvertValue(dval)
if err != nil { if err != nil {
return fmt.Errorf("could not convert %d argument %T - %+v to driver value: %s", k, e.args[k], e.args[k], err) return fmt.Errorf("could not convert %d argument %T - %+v to driver value: %s", k, e.args[k], e.args[k], err)
} }

View File

@ -49,7 +49,7 @@ func (e *queryBasedExpectation) argsMatches(args []namedValue) error {
} }
// convert to driver converter // convert to driver converter
darg, err := driver.DefaultParameterConverter.ConvertValue(dval) darg, err := e.converter.ConvertValue(dval)
if err != nil { if err != nil {
return fmt.Errorf("could not convert %d argument %T - %+v to driver value: %s", k, e.args[k], e.args[k], err) return fmt.Errorf("could not convert %d argument %T - %+v to driver value: %s", k, e.args[k], e.args[k], err)
} }

View File

@ -9,7 +9,7 @@ import (
) )
func TestQueryExpectationNamedArgComparison(t *testing.T) { func TestQueryExpectationNamedArgComparison(t *testing.T) {
e := &queryBasedExpectation{} e := &queryBasedExpectation{converter: driver.DefaultParameterConverter}
against := []namedValue{{Value: int64(5), Name: "id"}} against := []namedValue{{Value: int64(5), Name: "id"}}
if err := e.argsMatches(against); err != nil { if err := e.argsMatches(against); err != nil {
t.Errorf("arguments should match, since the no expectation was set, but got err: %s", err) t.Errorf("arguments should match, since the no expectation was set, but got err: %s", err)

View File

@ -9,7 +9,7 @@ import (
) )
func TestQueryExpectationArgComparison(t *testing.T) { func TestQueryExpectationArgComparison(t *testing.T) {
e := &queryBasedExpectation{} e := &queryBasedExpectation{converter: driver.DefaultParameterConverter}
against := []namedValue{{Value: int64(5), Ordinal: 1}} against := []namedValue{{Value: int64(5), Ordinal: 1}}
if err := e.argsMatches(against); err != nil { if err := e.argsMatches(against); err != nil {
t.Errorf("arguments should match, since the no expectation was set, but got err: %s", err) t.Errorf("arguments should match, since the no expectation was set, but got err: %s", err)
@ -67,7 +67,7 @@ func TestQueryExpectationArgComparison(t *testing.T) {
func TestQueryExpectationArgComparisonBool(t *testing.T) { func TestQueryExpectationArgComparisonBool(t *testing.T) {
var e *queryBasedExpectation var e *queryBasedExpectation
e = &queryBasedExpectation{args: []driver.Value{true}} e = &queryBasedExpectation{args: []driver.Value{true}, converter: driver.DefaultParameterConverter}
against := []namedValue{ against := []namedValue{
{Value: true, Ordinal: 1}, {Value: true, Ordinal: 1},
} }
@ -75,7 +75,7 @@ func TestQueryExpectationArgComparisonBool(t *testing.T) {
t.Error("arguments should match, since arguments are the same") t.Error("arguments should match, since arguments are the same")
} }
e = &queryBasedExpectation{args: []driver.Value{false}} e = &queryBasedExpectation{args: []driver.Value{false}, converter: driver.DefaultParameterConverter}
against = []namedValue{ against = []namedValue{
{Value: false, Ordinal: 1}, {Value: false, Ordinal: 1},
} }
@ -83,7 +83,7 @@ func TestQueryExpectationArgComparisonBool(t *testing.T) {
t.Error("arguments should match, since argument are the same") t.Error("arguments should match, since argument are the same")
} }
e = &queryBasedExpectation{args: []driver.Value{true}} e = &queryBasedExpectation{args: []driver.Value{true}, converter: driver.DefaultParameterConverter}
against = []namedValue{ against = []namedValue{
{Value: false, Ordinal: 1}, {Value: false, Ordinal: 1},
} }
@ -91,7 +91,7 @@ func TestQueryExpectationArgComparisonBool(t *testing.T) {
t.Error("arguments should not match, since argument is different") t.Error("arguments should not match, since argument is different")
} }
e = &queryBasedExpectation{args: []driver.Value{false}} e = &queryBasedExpectation{args: []driver.Value{false}, converter: driver.DefaultParameterConverter}
against = []namedValue{ against = []namedValue{
{Value: true, Ordinal: 1}, {Value: true, Ordinal: 1},
} }

22
rows.go
View File

@ -81,18 +81,24 @@ func (rs *rowSets) empty() bool {
// Rows is a mocked collection of rows to // Rows is a mocked collection of rows to
// return for Query result // return for Query result
type Rows struct { type Rows struct {
cols []string converter driver.ValueConverter
rows [][]driver.Value cols []string
pos int rows [][]driver.Value
nextErr map[int]error pos int
closeErr error nextErr map[int]error
closeErr error
} }
// NewRows allows Rows to be created from a // NewRows allows Rows to be created from a
// sql driver.Value slice or from the CSV string and // sql driver.Value slice or from the CSV string and
// to be used as sql driver.Rows // to be used as sql driver.Rows.
// Use Sqlmock.NewRows instead if using a custom converter
func NewRows(columns []string) *Rows { func NewRows(columns []string) *Rows {
return &Rows{cols: columns, nextErr: make(map[int]error)} return &Rows{
cols: columns,
nextErr: make(map[int]error),
converter: driver.DefaultParameterConverter,
}
} }
// CloseError allows to set an error // CloseError allows to set an error
@ -129,7 +135,7 @@ func (r *Rows) AddRow(values ...driver.Value) *Rows {
// Convert user-friendly values (such as int or driver.Valuer) // Convert user-friendly values (such as int or driver.Valuer)
// to database/sql native value (driver.Value such as int64) // to database/sql native value (driver.Value such as int64)
var err error var err error
v, err = driver.DefaultParameterConverter.ConvertValue(v) v, err = r.converter.ConvertValue(v)
if err != nil { if err != nil {
panic(fmt.Errorf( panic(fmt.Errorf(
"row #%d, column #%d (%q) type %T: %s", "row #%d, column #%d (%q) type %T: %s",

View File

@ -73,22 +73,37 @@ type Sqlmock interface {
// in any order. Or otherwise if switched to true, any unmatched // in any order. Or otherwise if switched to true, any unmatched
// expectations will be expected in order // expectations will be expected in order
MatchExpectationsInOrder(bool) MatchExpectationsInOrder(bool)
// NewRows allows Rows to be created from a
// sql driver.Value slice or from the CSV string and
// to be used as sql driver.Rows.
NewRows(columns []string) *Rows
} }
type sqlmock struct { type sqlmock struct {
ordered bool ordered bool
dsn string dsn string
opened int opened int
drv *mockDriver drv *mockDriver
converter driver.ValueConverter
expected []expectation expected []expectation
} }
func (c *sqlmock) open() (*sql.DB, Sqlmock, error) { func (c *sqlmock) open(options []func(*sqlmock) error) (*sql.DB, Sqlmock, error) {
db, err := sql.Open("sqlmock", c.dsn) db, err := sql.Open("sqlmock", c.dsn)
if err != nil { if err != nil {
return db, c, err return db, c, err
} }
for _, option := range options {
err := option(c)
if err != nil {
return db, c, err
}
}
if c.converter == nil {
c.converter = driver.DefaultParameterConverter
}
return db, c, db.Ping() return db, c, db.Ping()
} }
@ -165,6 +180,11 @@ func (c *sqlmock) ExpectationsWereMet() error {
return nil return nil
} }
func (c *sqlmock) CheckNamedValue(nv *driver.NamedValue) (err error) {
nv.Value, err = c.converter.ConvertValue(nv.Value)
return err
}
// Begin meets http://golang.org/pkg/database/sql/driver/#Conn interface // Begin meets http://golang.org/pkg/database/sql/driver/#Conn interface
func (c *sqlmock) Begin() (driver.Tx, error) { func (c *sqlmock) Begin() (driver.Tx, error) {
ex, err := c.begin() ex, err := c.begin()
@ -301,6 +321,7 @@ func (c *sqlmock) ExpectExec(sqlRegexStr string) *ExpectedExec {
e := &ExpectedExec{} e := &ExpectedExec{}
sqlRegexStr = stripQuery(sqlRegexStr) sqlRegexStr = stripQuery(sqlRegexStr)
e.sqlRegex = regexp.MustCompile(sqlRegexStr) e.sqlRegex = regexp.MustCompile(sqlRegexStr)
e.converter = c.converter
c.expected = append(c.expected, e) c.expected = append(c.expected, e)
return e return e
} }
@ -463,6 +484,7 @@ func (c *sqlmock) ExpectQuery(sqlRegexStr string) *ExpectedQuery {
e := &ExpectedQuery{} e := &ExpectedQuery{}
sqlRegexStr = stripQuery(sqlRegexStr) sqlRegexStr = stripQuery(sqlRegexStr)
e.sqlRegex = regexp.MustCompile(sqlRegexStr) e.sqlRegex = regexp.MustCompile(sqlRegexStr)
e.converter = c.converter
c.expected = append(c.expected, e) c.expected = append(c.expected, e)
return e return e
} }
@ -548,3 +570,12 @@ func (c *sqlmock) Rollback() error {
expected.Unlock() expected.Unlock()
return expected.err return expected.err
} }
// NewRows allows Rows to be created from a
// sql driver.Value slice or from the CSV string and
// to be used as sql driver.Rows.
func (c *sqlmock) NewRows(columns []string) *Rows {
r := NewRows(columns)
r.converter = c.converter
return r
}