1
0
mirror of https://github.com/zhashkevych/go-sqlxmock.git synced 2024-11-24 08:12:13 +02:00

allow to use a custom converter

This commit is contained in:
Jan Waś 2018-08-06 22:29:24 +02:00
parent c8e01dc244
commit f2bc8f904e
9 changed files with 96 additions and 27 deletions

View File

@ -39,7 +39,7 @@ func (d *mockDriver) Open(dsn string) (driver.Conn, error) {
// and a mock to manage expectations.
// Pings db so that all expectations could be
// asserted.
func New() (*sql.DB, Sqlmock, error) {
func New(options ...func(*sqlmock) error) (*sql.DB, Sqlmock, error) {
pool.Lock()
dsn := fmt.Sprintf("sqlmock_db_%d", pool.counter)
pool.counter++
@ -48,7 +48,7 @@ func New() (*sql.DB, Sqlmock, error) {
pool.conns[dsn] = smock
pool.Unlock()
return smock.open()
return smock.open(options)
}
// NewWithDSN creates sqlmock database connection
@ -64,7 +64,7 @@ func New() (*sql.DB, Sqlmock, error) {
//
// It is not recommended to use this method, unless you
// really need it and there is no other way around.
func NewWithDSN(dsn string) (*sql.DB, Sqlmock, error) {
func NewWithDSN(dsn string, options ...func(*sqlmock) error) (*sql.DB, Sqlmock, error) {
pool.Lock()
if _, ok := pool.conns[dsn]; ok {
pool.Unlock()
@ -74,5 +74,14 @@ func NewWithDSN(dsn string) (*sql.DB, Sqlmock, error) {
pool.conns[dsn] = smock
pool.Unlock()
return smock.open()
return smock.open(options)
}
// WithValueConverter allows to create a sqlmock connection
// with a custom ValueConverter to support drivers with special data types.
func WithValueConverter(converter driver.ValueConverter) func(*sqlmock) error {
return func(s *sqlmock) error {
s.converter = converter
return nil
}
}

View File

@ -1,6 +1,8 @@
package sqlmock
import (
"database/sql/driver"
"errors"
"fmt"
"testing"
)
@ -9,6 +11,12 @@ type void struct{}
func (void) Print(...interface{}) {}
type converter struct{}
func (c *converter) ConvertValue(v interface{}) (driver.Value, error) {
return nil, errors.New("converter disabled")
}
func ExampleNew() {
db, mock, err := New()
if err != nil {
@ -90,6 +98,18 @@ func TestTwoOpenConnectionsOnTheSameDSN(t *testing.T) {
}
}
func TestWithOptions(t *testing.T) {
c := &converter{}
_, mock, err := New(WithValueConverter(c))
if err != nil {
t.Errorf("expected no error, but got: %s", err)
}
smock, _ := mock.(*sqlmock)
if smock.converter.(*converter) != c {
t.Errorf("expected a custom converter to be set")
}
}
func TestWrongDSN(t *testing.T) {
t.Parallel()
db, _, _ := New()

View File

@ -292,6 +292,7 @@ func (e *ExpectedPrepare) WillBeClosed() *ExpectedPrepare {
func (e *ExpectedPrepare) ExpectQuery() *ExpectedQuery {
eq := &ExpectedQuery{}
eq.sqlRegex = e.sqlRegex
eq.converter = e.mock.converter
e.mock.expected = append(e.mock.expected, eq)
return eq
}
@ -301,6 +302,7 @@ func (e *ExpectedPrepare) ExpectQuery() *ExpectedQuery {
func (e *ExpectedPrepare) ExpectExec() *ExpectedExec {
eq := &ExpectedExec{}
eq.sqlRegex = e.sqlRegex
eq.converter = e.mock.converter
e.mock.expected = append(e.mock.expected, eq)
return eq
}
@ -326,6 +328,7 @@ func (e *ExpectedPrepare) String() string {
type queryBasedExpectation struct {
commonExpectation
sqlRegex *regexp.Regexp
converter driver.ValueConverter
args []driver.Value
}

View File

@ -35,7 +35,7 @@ func (e *queryBasedExpectation) argsMatches(args []namedValue) error {
dval := e.args[k]
// convert to driver converter
darg, err := driver.DefaultParameterConverter.ConvertValue(dval)
darg, err := e.converter.ConvertValue(dval)
if err != nil {
return fmt.Errorf("could not convert %d argument %T - %+v to driver value: %s", k, e.args[k], e.args[k], err)
}

View File

@ -49,7 +49,7 @@ func (e *queryBasedExpectation) argsMatches(args []namedValue) error {
}
// convert to driver converter
darg, err := driver.DefaultParameterConverter.ConvertValue(dval)
darg, err := e.converter.ConvertValue(dval)
if err != nil {
return fmt.Errorf("could not convert %d argument %T - %+v to driver value: %s", k, e.args[k], e.args[k], err)
}

View File

@ -9,7 +9,7 @@ import (
)
func TestQueryExpectationNamedArgComparison(t *testing.T) {
e := &queryBasedExpectation{}
e := &queryBasedExpectation{converter: driver.DefaultParameterConverter}
against := []namedValue{{Value: int64(5), Name: "id"}}
if err := e.argsMatches(against); err != nil {
t.Errorf("arguments should match, since the no expectation was set, but got err: %s", err)

View File

@ -9,7 +9,7 @@ import (
)
func TestQueryExpectationArgComparison(t *testing.T) {
e := &queryBasedExpectation{}
e := &queryBasedExpectation{converter: driver.DefaultParameterConverter}
against := []namedValue{{Value: int64(5), Ordinal: 1}}
if err := e.argsMatches(against); err != nil {
t.Errorf("arguments should match, since the no expectation was set, but got err: %s", err)
@ -67,7 +67,7 @@ func TestQueryExpectationArgComparison(t *testing.T) {
func TestQueryExpectationArgComparisonBool(t *testing.T) {
var e *queryBasedExpectation
e = &queryBasedExpectation{args: []driver.Value{true}}
e = &queryBasedExpectation{args: []driver.Value{true}, converter: driver.DefaultParameterConverter}
against := []namedValue{
{Value: true, Ordinal: 1},
}
@ -75,7 +75,7 @@ func TestQueryExpectationArgComparisonBool(t *testing.T) {
t.Error("arguments should match, since arguments are the same")
}
e = &queryBasedExpectation{args: []driver.Value{false}}
e = &queryBasedExpectation{args: []driver.Value{false}, converter: driver.DefaultParameterConverter}
against = []namedValue{
{Value: false, Ordinal: 1},
}
@ -83,7 +83,7 @@ func TestQueryExpectationArgComparisonBool(t *testing.T) {
t.Error("arguments should match, since argument are the same")
}
e = &queryBasedExpectation{args: []driver.Value{true}}
e = &queryBasedExpectation{args: []driver.Value{true}, converter: driver.DefaultParameterConverter}
against = []namedValue{
{Value: false, Ordinal: 1},
}
@ -91,7 +91,7 @@ func TestQueryExpectationArgComparisonBool(t *testing.T) {
t.Error("arguments should not match, since argument is different")
}
e = &queryBasedExpectation{args: []driver.Value{false}}
e = &queryBasedExpectation{args: []driver.Value{false}, converter: driver.DefaultParameterConverter}
against = []namedValue{
{Value: true, Ordinal: 1},
}

12
rows.go
View File

@ -81,6 +81,7 @@ func (rs *rowSets) empty() bool {
// Rows is a mocked collection of rows to
// return for Query result
type Rows struct {
converter driver.ValueConverter
cols []string
rows [][]driver.Value
pos int
@ -90,9 +91,14 @@ type Rows struct {
// NewRows allows Rows to be created from a
// sql driver.Value slice or from the CSV string and
// to be used as sql driver.Rows
// to be used as sql driver.Rows.
// Use Sqlmock.NewRows instead if using a custom converter
func NewRows(columns []string) *Rows {
return &Rows{cols: columns, nextErr: make(map[int]error)}
return &Rows{
cols: columns,
nextErr: make(map[int]error),
converter: driver.DefaultParameterConverter,
}
}
// CloseError allows to set an error
@ -129,7 +135,7 @@ func (r *Rows) AddRow(values ...driver.Value) *Rows {
// Convert user-friendly values (such as int or driver.Valuer)
// to database/sql native value (driver.Value such as int64)
var err error
v, err = driver.DefaultParameterConverter.ConvertValue(v)
v, err = r.converter.ConvertValue(v)
if err != nil {
panic(fmt.Errorf(
"row #%d, column #%d (%q) type %T: %s",

View File

@ -73,6 +73,11 @@ type Sqlmock interface {
// in any order. Or otherwise if switched to true, any unmatched
// expectations will be expected in order
MatchExpectationsInOrder(bool)
// NewRows allows Rows to be created from a
// sql driver.Value slice or from the CSV string and
// to be used as sql driver.Rows.
NewRows(columns []string) *Rows
}
type sqlmock struct {
@ -80,15 +85,25 @@ type sqlmock struct {
dsn string
opened int
drv *mockDriver
converter driver.ValueConverter
expected []expectation
}
func (c *sqlmock) open() (*sql.DB, Sqlmock, error) {
func (c *sqlmock) open(options []func(*sqlmock) error) (*sql.DB, Sqlmock, error) {
db, err := sql.Open("sqlmock", c.dsn)
if err != nil {
return db, c, err
}
for _, option := range options {
err := option(c)
if err != nil {
return db, c, err
}
}
if c.converter == nil {
c.converter = driver.DefaultParameterConverter
}
return db, c, db.Ping()
}
@ -165,6 +180,11 @@ func (c *sqlmock) ExpectationsWereMet() error {
return nil
}
func (c *sqlmock) CheckNamedValue(nv *driver.NamedValue) (err error) {
nv.Value, err = c.converter.ConvertValue(nv.Value)
return err
}
// Begin meets http://golang.org/pkg/database/sql/driver/#Conn interface
func (c *sqlmock) Begin() (driver.Tx, error) {
ex, err := c.begin()
@ -301,6 +321,7 @@ func (c *sqlmock) ExpectExec(sqlRegexStr string) *ExpectedExec {
e := &ExpectedExec{}
sqlRegexStr = stripQuery(sqlRegexStr)
e.sqlRegex = regexp.MustCompile(sqlRegexStr)
e.converter = c.converter
c.expected = append(c.expected, e)
return e
}
@ -463,6 +484,7 @@ func (c *sqlmock) ExpectQuery(sqlRegexStr string) *ExpectedQuery {
e := &ExpectedQuery{}
sqlRegexStr = stripQuery(sqlRegexStr)
e.sqlRegex = regexp.MustCompile(sqlRegexStr)
e.converter = c.converter
c.expected = append(c.expected, e)
return e
}
@ -548,3 +570,12 @@ func (c *sqlmock) Rollback() error {
expected.Unlock()
return expected.err
}
// NewRows allows Rows to be created from a
// sql driver.Value slice or from the CSV string and
// to be used as sql driver.Rows.
func (c *sqlmock) NewRows(columns []string) *Rows {
r := NewRows(columns)
r.converter = c.converter
return r
}