1
0
mirror of https://github.com/DATA-DOG/go-sqlmock.git synced 2025-03-25 21:18:08 +02:00

Merge pull request from nineinchnick/custom-converter

Allow to use a custom converter
This commit is contained in:
Gediminas Morkevicius 2018-09-14 13:42:08 +03:00 committed by GitHub
commit d4b2bccf3e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 161 additions and 31 deletions

@ -35,11 +35,12 @@ func (d *mockDriver) Open(dsn string) (driver.Conn, error) {
return c, nil
}
// New creates sqlmock database connection
// and a mock to manage expectations.
// New creates sqlmock database connection and a mock to manage expectations.
// Accepts options, like ValueConverterOption, to use a ValueConverter from
// a specific driver.
// Pings db so that all expectations could be
// asserted.
func New() (*sql.DB, Sqlmock, error) {
func New(options ...func(*sqlmock) error) (*sql.DB, Sqlmock, error) {
pool.Lock()
dsn := fmt.Sprintf("sqlmock_db_%d", pool.counter)
pool.counter++
@ -48,11 +49,13 @@ func New() (*sql.DB, Sqlmock, error) {
pool.conns[dsn] = smock
pool.Unlock()
return smock.open()
return smock.open(options)
}
// NewWithDSN creates sqlmock database connection
// with a specific DSN and a mock to manage expectations.
// NewWithDSN creates sqlmock database connection with a specific DSN
// and a mock to manage expectations.
// Accepts options, like ValueConverterOption, to use a ValueConverter from
// a specific driver.
// Pings db so that all expectations could be asserted.
//
// This method is introduced because of sql abstraction
@ -64,7 +67,7 @@ func New() (*sql.DB, Sqlmock, error) {
//
// It is not recommended to use this method, unless you
// really need it and there is no other way around.
func NewWithDSN(dsn string) (*sql.DB, Sqlmock, error) {
func NewWithDSN(dsn string, options ...func(*sqlmock) error) (*sql.DB, Sqlmock, error) {
pool.Lock()
if _, ok := pool.conns[dsn]; ok {
pool.Unlock()
@ -74,5 +77,5 @@ func NewWithDSN(dsn string) (*sql.DB, Sqlmock, error) {
pool.conns[dsn] = smock
pool.Unlock()
return smock.open()
return smock.open(options)
}

@ -1,6 +1,8 @@
package sqlmock
import (
"database/sql/driver"
"errors"
"fmt"
"testing"
)
@ -9,6 +11,12 @@ type void struct{}
func (void) Print(...interface{}) {}
type converter struct{}
func (c *converter) ConvertValue(v interface{}) (driver.Value, error) {
return nil, errors.New("converter disabled")
}
func ExampleNew() {
db, mock, err := New()
if err != nil {
@ -90,6 +98,18 @@ func TestTwoOpenConnectionsOnTheSameDSN(t *testing.T) {
}
}
func TestWithOptions(t *testing.T) {
c := &converter{}
_, mock, err := New(ValueConverterOption(c))
if err != nil {
t.Errorf("expected no error, but got: %s", err)
}
smock, _ := mock.(*sqlmock)
if smock.converter.(*converter) != c {
t.Errorf("expected a custom converter to be set")
}
}
func TestWrongDSN(t *testing.T) {
t.Parallel()
db, _, _ := New()

@ -292,6 +292,7 @@ func (e *ExpectedPrepare) WillBeClosed() *ExpectedPrepare {
func (e *ExpectedPrepare) ExpectQuery() *ExpectedQuery {
eq := &ExpectedQuery{}
eq.sqlRegex = e.sqlRegex
eq.converter = e.mock.converter
e.mock.expected = append(e.mock.expected, eq)
return eq
}
@ -301,6 +302,7 @@ func (e *ExpectedPrepare) ExpectQuery() *ExpectedQuery {
func (e *ExpectedPrepare) ExpectExec() *ExpectedExec {
eq := &ExpectedExec{}
eq.sqlRegex = e.sqlRegex
eq.converter = e.mock.converter
e.mock.expected = append(e.mock.expected, eq)
return eq
}
@ -325,8 +327,9 @@ func (e *ExpectedPrepare) String() string {
// adds a query matching logic
type queryBasedExpectation struct {
commonExpectation
sqlRegex *regexp.Regexp
args []driver.Value
sqlRegex *regexp.Regexp
converter driver.ValueConverter
args []driver.Value
}
func (e *queryBasedExpectation) attemptMatch(sql string, args []namedValue) (err error) {

@ -35,7 +35,7 @@ func (e *queryBasedExpectation) argsMatches(args []namedValue) error {
dval := e.args[k]
// convert to driver converter
darg, err := driver.DefaultParameterConverter.ConvertValue(dval)
darg, err := e.converter.ConvertValue(dval)
if err != nil {
return fmt.Errorf("could not convert %d argument %T - %+v to driver value: %s", k, e.args[k], e.args[k], err)
}

@ -49,7 +49,7 @@ func (e *queryBasedExpectation) argsMatches(args []namedValue) error {
}
// convert to driver converter
darg, err := driver.DefaultParameterConverter.ConvertValue(dval)
darg, err := e.converter.ConvertValue(dval)
if err != nil {
return fmt.Errorf("could not convert %d argument %T - %+v to driver value: %s", k, e.args[k], e.args[k], err)
}

@ -9,7 +9,7 @@ import (
)
func TestQueryExpectationNamedArgComparison(t *testing.T) {
e := &queryBasedExpectation{}
e := &queryBasedExpectation{converter: driver.DefaultParameterConverter}
against := []namedValue{{Value: int64(5), Name: "id"}}
if err := e.argsMatches(against); err != nil {
t.Errorf("arguments should match, since the no expectation was set, but got err: %s", err)

@ -9,7 +9,7 @@ import (
)
func TestQueryExpectationArgComparison(t *testing.T) {
e := &queryBasedExpectation{}
e := &queryBasedExpectation{converter: driver.DefaultParameterConverter}
against := []namedValue{{Value: int64(5), Ordinal: 1}}
if err := e.argsMatches(against); err != nil {
t.Errorf("arguments should match, since the no expectation was set, but got err: %s", err)
@ -67,7 +67,7 @@ func TestQueryExpectationArgComparison(t *testing.T) {
func TestQueryExpectationArgComparisonBool(t *testing.T) {
var e *queryBasedExpectation
e = &queryBasedExpectation{args: []driver.Value{true}}
e = &queryBasedExpectation{args: []driver.Value{true}, converter: driver.DefaultParameterConverter}
against := []namedValue{
{Value: true, Ordinal: 1},
}
@ -75,7 +75,7 @@ func TestQueryExpectationArgComparisonBool(t *testing.T) {
t.Error("arguments should match, since arguments are the same")
}
e = &queryBasedExpectation{args: []driver.Value{false}}
e = &queryBasedExpectation{args: []driver.Value{false}, converter: driver.DefaultParameterConverter}
against = []namedValue{
{Value: false, Ordinal: 1},
}
@ -83,7 +83,7 @@ func TestQueryExpectationArgComparisonBool(t *testing.T) {
t.Error("arguments should match, since argument are the same")
}
e = &queryBasedExpectation{args: []driver.Value{true}}
e = &queryBasedExpectation{args: []driver.Value{true}, converter: driver.DefaultParameterConverter}
against = []namedValue{
{Value: false, Ordinal: 1},
}
@ -91,7 +91,7 @@ func TestQueryExpectationArgComparisonBool(t *testing.T) {
t.Error("arguments should not match, since argument is different")
}
e = &queryBasedExpectation{args: []driver.Value{false}}
e = &queryBasedExpectation{args: []driver.Value{false}, converter: driver.DefaultParameterConverter}
against = []namedValue{
{Value: true, Ordinal: 1},
}

12
options.go Normal file

@ -0,0 +1,12 @@
package sqlmock
import "database/sql/driver"
// ValueConverterOption allows to create a sqlmock connection
// with a custom ValueConverter to support drivers with special data types.
func ValueConverterOption(converter driver.ValueConverter) func(*sqlmock) error {
return func(s *sqlmock) error {
s.converter = converter
return nil
}
}

22
rows.go

@ -81,18 +81,24 @@ func (rs *rowSets) empty() bool {
// Rows is a mocked collection of rows to
// return for Query result
type Rows struct {
cols []string
rows [][]driver.Value
pos int
nextErr map[int]error
closeErr error
converter driver.ValueConverter
cols []string
rows [][]driver.Value
pos int
nextErr map[int]error
closeErr error
}
// NewRows allows Rows to be created from a
// sql driver.Value slice or from the CSV string and
// to be used as sql driver.Rows
// to be used as sql driver.Rows.
// Use Sqlmock.NewRows instead if using a custom converter
func NewRows(columns []string) *Rows {
return &Rows{cols: columns, nextErr: make(map[int]error)}
return &Rows{
cols: columns,
nextErr: make(map[int]error),
converter: driver.DefaultParameterConverter,
}
}
// CloseError allows to set an error
@ -129,7 +135,7 @@ func (r *Rows) AddRow(values ...driver.Value) *Rows {
// Convert user-friendly values (such as int or driver.Valuer)
// to database/sql native value (driver.Value such as int64)
var err error
v, err = driver.DefaultParameterConverter.ConvertValue(v)
v, err = r.converter.ConvertValue(v)
if err != nil {
panic(fmt.Errorf(
"row #%d, column #%d (%q) type %T: %s",

@ -73,22 +73,37 @@ type Sqlmock interface {
// in any order. Or otherwise if switched to true, any unmatched
// expectations will be expected in order
MatchExpectationsInOrder(bool)
// NewRows allows Rows to be created from a
// sql driver.Value slice or from the CSV string and
// to be used as sql driver.Rows.
NewRows(columns []string) *Rows
}
type sqlmock struct {
ordered bool
dsn string
opened int
drv *mockDriver
ordered bool
dsn string
opened int
drv *mockDriver
converter driver.ValueConverter
expected []expectation
}
func (c *sqlmock) open() (*sql.DB, Sqlmock, error) {
func (c *sqlmock) open(options []func(*sqlmock) error) (*sql.DB, Sqlmock, error) {
db, err := sql.Open("sqlmock", c.dsn)
if err != nil {
return db, c, err
}
for _, option := range options {
err := option(c)
if err != nil {
return db, c, err
}
}
if c.converter == nil {
c.converter = driver.DefaultParameterConverter
}
return db, c, db.Ping()
}
@ -301,6 +316,7 @@ func (c *sqlmock) ExpectExec(sqlRegexStr string) *ExpectedExec {
e := &ExpectedExec{}
sqlRegexStr = stripQuery(sqlRegexStr)
e.sqlRegex = regexp.MustCompile(sqlRegexStr)
e.converter = c.converter
c.expected = append(c.expected, e)
return e
}
@ -463,6 +479,7 @@ func (c *sqlmock) ExpectQuery(sqlRegexStr string) *ExpectedQuery {
e := &ExpectedQuery{}
sqlRegexStr = stripQuery(sqlRegexStr)
e.sqlRegex = regexp.MustCompile(sqlRegexStr)
e.converter = c.converter
c.expected = append(c.expected, e)
return e
}
@ -548,3 +565,12 @@ func (c *sqlmock) Rollback() error {
expected.Unlock()
return expected.err
}
// NewRows allows Rows to be created from a
// sql driver.Value slice or from the CSV string and
// to be used as sql driver.Rows.
func (c *sqlmock) NewRows(columns []string) *Rows {
r := NewRows(columns)
r.converter = c.converter
return r
}

@ -111,3 +111,9 @@ func (stmt *statement) QueryContext(ctx context.Context, args []driver.NamedValu
}
// @TODO maybe add ExpectedBegin.WithOptions(driver.TxOptions)
// CheckNamedValue meets https://golang.org/pkg/database/sql/driver/#NamedValueChecker
func (c *sqlmock) CheckNamedValue(nv *driver.NamedValue) (err error) {
nv.Value, err = c.converter.ConvertValue(nv.Value)
return err
}

@ -849,6 +849,32 @@ func TestRollbackThrow(t *testing.T) {
// Output:
}
func TestUnexpectedBegin(t *testing.T) {
// Open new mock database
db, _, err := New()
if err != nil {
fmt.Println("error creating mock database")
return
}
if _, err := db.Begin(); err == nil {
t.Error("an error was expected when calling begin, but got none")
}
}
func TestUnexpectedExec(t *testing.T) {
// Open new mock database
db, mock, err := New()
if err != nil {
fmt.Println("error creating mock database")
return
}
mock.ExpectBegin()
db.Begin()
if _, err := db.Exec("SELECT 1"); err == nil {
t.Error("an error was expected when calling exec, but got none")
}
}
func TestUnexpectedCommit(t *testing.T) {
// Open new mock database
db, mock, err := New()
@ -1113,3 +1139,31 @@ func TestExecExpectationErrorDelay(t *testing.T) {
t.Errorf("expecting a delay of less than %v before error, actual delay was %v", delay, elapsed)
}
}
func TestOptionsFail(t *testing.T) {
t.Parallel()
expected := errors.New("failing option")
option := func(*sqlmock) error {
return expected
}
db, _, err := New(option)
defer db.Close()
if err == nil {
t.Errorf("missing expecting error '%s' when opening a stub database connection", expected)
}
}
func TestNewRows(t *testing.T) {
t.Parallel()
db, mock, err := New()
if err != nil {
t.Errorf("an error '%s' was not expected when opening a stub database connection", err)
}
defer db.Close()
columns := []string{"col1", "col2"}
r := mock.NewRows(columns)
if len(r.cols) != len(columns) || r.cols[0] != columns[0] || r.cols[1] != columns[1] {
t.Errorf("expecting to create a row with columns %v, actual colmns are %v", r.cols, columns)
}
}