From f2bc8f904e0cf99f8c395246496c56353b0bc4e2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jan=20Wa=C5=9B?= Date: Mon, 6 Aug 2018 22:29:24 +0200 Subject: [PATCH 1/6] allow to use a custom converter --- driver.go | 17 +++++++++++---- driver_test.go | 20 ++++++++++++++++++ expectations.go | 7 +++++-- expectations_before_go18.go | 2 +- expectations_go18.go | 2 +- expectations_go18_test.go | 2 +- expectations_test.go | 10 ++++----- rows.go | 22 ++++++++++++-------- sqlmock.go | 41 ++++++++++++++++++++++++++++++++----- 9 files changed, 96 insertions(+), 27 deletions(-) diff --git a/driver.go b/driver.go index 2a480fe..876f7a2 100644 --- a/driver.go +++ b/driver.go @@ -39,7 +39,7 @@ func (d *mockDriver) Open(dsn string) (driver.Conn, error) { // and a mock to manage expectations. // Pings db so that all expectations could be // asserted. -func New() (*sql.DB, Sqlmock, error) { +func New(options ...func(*sqlmock) error) (*sql.DB, Sqlmock, error) { pool.Lock() dsn := fmt.Sprintf("sqlmock_db_%d", pool.counter) pool.counter++ @@ -48,7 +48,7 @@ func New() (*sql.DB, Sqlmock, error) { pool.conns[dsn] = smock pool.Unlock() - return smock.open() + return smock.open(options) } // NewWithDSN creates sqlmock database connection @@ -64,7 +64,7 @@ func New() (*sql.DB, Sqlmock, error) { // // It is not recommended to use this method, unless you // really need it and there is no other way around. -func NewWithDSN(dsn string) (*sql.DB, Sqlmock, error) { +func NewWithDSN(dsn string, options ...func(*sqlmock) error) (*sql.DB, Sqlmock, error) { pool.Lock() if _, ok := pool.conns[dsn]; ok { pool.Unlock() @@ -74,5 +74,14 @@ func NewWithDSN(dsn string) (*sql.DB, Sqlmock, error) { pool.conns[dsn] = smock pool.Unlock() - return smock.open() + return smock.open(options) +} + +// WithValueConverter allows to create a sqlmock connection +// with a custom ValueConverter to support drivers with special data types. +func WithValueConverter(converter driver.ValueConverter) func(*sqlmock) error { + return func(s *sqlmock) error { + s.converter = converter + return nil + } } diff --git a/driver_test.go b/driver_test.go index 4f805ba..a8acc69 100644 --- a/driver_test.go +++ b/driver_test.go @@ -1,6 +1,8 @@ package sqlmock import ( + "database/sql/driver" + "errors" "fmt" "testing" ) @@ -9,6 +11,12 @@ type void struct{} func (void) Print(...interface{}) {} +type converter struct{} + +func (c *converter) ConvertValue(v interface{}) (driver.Value, error) { + return nil, errors.New("converter disabled") +} + func ExampleNew() { db, mock, err := New() if err != nil { @@ -90,6 +98,18 @@ func TestTwoOpenConnectionsOnTheSameDSN(t *testing.T) { } } +func TestWithOptions(t *testing.T) { + c := &converter{} + _, mock, err := New(WithValueConverter(c)) + if err != nil { + t.Errorf("expected no error, but got: %s", err) + } + smock, _ := mock.(*sqlmock) + if smock.converter.(*converter) != c { + t.Errorf("expected a custom converter to be set") + } +} + func TestWrongDSN(t *testing.T) { t.Parallel() db, _, _ := New() diff --git a/expectations.go b/expectations.go index 6ff9a65..9f54967 100644 --- a/expectations.go +++ b/expectations.go @@ -292,6 +292,7 @@ func (e *ExpectedPrepare) WillBeClosed() *ExpectedPrepare { func (e *ExpectedPrepare) ExpectQuery() *ExpectedQuery { eq := &ExpectedQuery{} eq.sqlRegex = e.sqlRegex + eq.converter = e.mock.converter e.mock.expected = append(e.mock.expected, eq) return eq } @@ -301,6 +302,7 @@ func (e *ExpectedPrepare) ExpectQuery() *ExpectedQuery { func (e *ExpectedPrepare) ExpectExec() *ExpectedExec { eq := &ExpectedExec{} eq.sqlRegex = e.sqlRegex + eq.converter = e.mock.converter e.mock.expected = append(e.mock.expected, eq) return eq } @@ -325,8 +327,9 @@ func (e *ExpectedPrepare) String() string { // adds a query matching logic type queryBasedExpectation struct { commonExpectation - sqlRegex *regexp.Regexp - args []driver.Value + sqlRegex *regexp.Regexp + converter driver.ValueConverter + args []driver.Value } func (e *queryBasedExpectation) attemptMatch(sql string, args []namedValue) (err error) { diff --git a/expectations_before_go18.go b/expectations_before_go18.go index 146f240..888df88 100644 --- a/expectations_before_go18.go +++ b/expectations_before_go18.go @@ -35,7 +35,7 @@ func (e *queryBasedExpectation) argsMatches(args []namedValue) error { dval := e.args[k] // convert to driver converter - darg, err := driver.DefaultParameterConverter.ConvertValue(dval) + darg, err := e.converter.ConvertValue(dval) if err != nil { return fmt.Errorf("could not convert %d argument %T - %+v to driver value: %s", k, e.args[k], e.args[k], err) } diff --git a/expectations_go18.go b/expectations_go18.go index 2b4b44e..2f0b64d 100644 --- a/expectations_go18.go +++ b/expectations_go18.go @@ -49,7 +49,7 @@ func (e *queryBasedExpectation) argsMatches(args []namedValue) error { } // convert to driver converter - darg, err := driver.DefaultParameterConverter.ConvertValue(dval) + darg, err := e.converter.ConvertValue(dval) if err != nil { return fmt.Errorf("could not convert %d argument %T - %+v to driver value: %s", k, e.args[k], e.args[k], err) } diff --git a/expectations_go18_test.go b/expectations_go18_test.go index 5f30d2f..2b85db3 100644 --- a/expectations_go18_test.go +++ b/expectations_go18_test.go @@ -9,7 +9,7 @@ import ( ) func TestQueryExpectationNamedArgComparison(t *testing.T) { - e := &queryBasedExpectation{} + e := &queryBasedExpectation{converter: driver.DefaultParameterConverter} against := []namedValue{{Value: int64(5), Name: "id"}} if err := e.argsMatches(against); err != nil { t.Errorf("arguments should match, since the no expectation was set, but got err: %s", err) diff --git a/expectations_test.go b/expectations_test.go index 2e3c097..90e3f1f 100644 --- a/expectations_test.go +++ b/expectations_test.go @@ -9,7 +9,7 @@ import ( ) func TestQueryExpectationArgComparison(t *testing.T) { - e := &queryBasedExpectation{} + e := &queryBasedExpectation{converter: driver.DefaultParameterConverter} against := []namedValue{{Value: int64(5), Ordinal: 1}} if err := e.argsMatches(against); err != nil { t.Errorf("arguments should match, since the no expectation was set, but got err: %s", err) @@ -67,7 +67,7 @@ func TestQueryExpectationArgComparison(t *testing.T) { func TestQueryExpectationArgComparisonBool(t *testing.T) { var e *queryBasedExpectation - e = &queryBasedExpectation{args: []driver.Value{true}} + e = &queryBasedExpectation{args: []driver.Value{true}, converter: driver.DefaultParameterConverter} against := []namedValue{ {Value: true, Ordinal: 1}, } @@ -75,7 +75,7 @@ func TestQueryExpectationArgComparisonBool(t *testing.T) { t.Error("arguments should match, since arguments are the same") } - e = &queryBasedExpectation{args: []driver.Value{false}} + e = &queryBasedExpectation{args: []driver.Value{false}, converter: driver.DefaultParameterConverter} against = []namedValue{ {Value: false, Ordinal: 1}, } @@ -83,7 +83,7 @@ func TestQueryExpectationArgComparisonBool(t *testing.T) { t.Error("arguments should match, since argument are the same") } - e = &queryBasedExpectation{args: []driver.Value{true}} + e = &queryBasedExpectation{args: []driver.Value{true}, converter: driver.DefaultParameterConverter} against = []namedValue{ {Value: false, Ordinal: 1}, } @@ -91,7 +91,7 @@ func TestQueryExpectationArgComparisonBool(t *testing.T) { t.Error("arguments should not match, since argument is different") } - e = &queryBasedExpectation{args: []driver.Value{false}} + e = &queryBasedExpectation{args: []driver.Value{false}, converter: driver.DefaultParameterConverter} against = []namedValue{ {Value: true, Ordinal: 1}, } diff --git a/rows.go b/rows.go index 6477ed6..ea35644 100644 --- a/rows.go +++ b/rows.go @@ -81,18 +81,24 @@ func (rs *rowSets) empty() bool { // Rows is a mocked collection of rows to // return for Query result type Rows struct { - cols []string - rows [][]driver.Value - pos int - nextErr map[int]error - closeErr error + converter driver.ValueConverter + cols []string + rows [][]driver.Value + pos int + nextErr map[int]error + closeErr error } // NewRows allows Rows to be created from a // sql driver.Value slice or from the CSV string and -// to be used as sql driver.Rows +// to be used as sql driver.Rows. +// Use Sqlmock.NewRows instead if using a custom converter func NewRows(columns []string) *Rows { - return &Rows{cols: columns, nextErr: make(map[int]error)} + return &Rows{ + cols: columns, + nextErr: make(map[int]error), + converter: driver.DefaultParameterConverter, + } } // CloseError allows to set an error @@ -129,7 +135,7 @@ func (r *Rows) AddRow(values ...driver.Value) *Rows { // Convert user-friendly values (such as int or driver.Valuer) // to database/sql native value (driver.Value such as int64) var err error - v, err = driver.DefaultParameterConverter.ConvertValue(v) + v, err = r.converter.ConvertValue(v) if err != nil { panic(fmt.Errorf( "row #%d, column #%d (%q) type %T: %s", diff --git a/sqlmock.go b/sqlmock.go index 8fe5cc6..17cfd6c 100644 --- a/sqlmock.go +++ b/sqlmock.go @@ -73,22 +73,37 @@ type Sqlmock interface { // in any order. Or otherwise if switched to true, any unmatched // expectations will be expected in order MatchExpectationsInOrder(bool) + + // NewRows allows Rows to be created from a + // sql driver.Value slice or from the CSV string and + // to be used as sql driver.Rows. + NewRows(columns []string) *Rows } type sqlmock struct { - ordered bool - dsn string - opened int - drv *mockDriver + ordered bool + dsn string + opened int + drv *mockDriver + converter driver.ValueConverter expected []expectation } -func (c *sqlmock) open() (*sql.DB, Sqlmock, error) { +func (c *sqlmock) open(options []func(*sqlmock) error) (*sql.DB, Sqlmock, error) { db, err := sql.Open("sqlmock", c.dsn) if err != nil { return db, c, err } + for _, option := range options { + err := option(c) + if err != nil { + return db, c, err + } + } + if c.converter == nil { + c.converter = driver.DefaultParameterConverter + } return db, c, db.Ping() } @@ -165,6 +180,11 @@ func (c *sqlmock) ExpectationsWereMet() error { return nil } +func (c *sqlmock) CheckNamedValue(nv *driver.NamedValue) (err error) { + nv.Value, err = c.converter.ConvertValue(nv.Value) + return err +} + // Begin meets http://golang.org/pkg/database/sql/driver/#Conn interface func (c *sqlmock) Begin() (driver.Tx, error) { ex, err := c.begin() @@ -301,6 +321,7 @@ func (c *sqlmock) ExpectExec(sqlRegexStr string) *ExpectedExec { e := &ExpectedExec{} sqlRegexStr = stripQuery(sqlRegexStr) e.sqlRegex = regexp.MustCompile(sqlRegexStr) + e.converter = c.converter c.expected = append(c.expected, e) return e } @@ -463,6 +484,7 @@ func (c *sqlmock) ExpectQuery(sqlRegexStr string) *ExpectedQuery { e := &ExpectedQuery{} sqlRegexStr = stripQuery(sqlRegexStr) e.sqlRegex = regexp.MustCompile(sqlRegexStr) + e.converter = c.converter c.expected = append(c.expected, e) return e } @@ -548,3 +570,12 @@ func (c *sqlmock) Rollback() error { expected.Unlock() return expected.err } + +// NewRows allows Rows to be created from a +// sql driver.Value slice or from the CSV string and +// to be used as sql driver.Rows. +func (c *sqlmock) NewRows(columns []string) *Rows { + r := NewRows(columns) + r.converter = c.converter + return r +} From 3d314830140d16ea2f84b8339aa05361498f8859 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jan=20Wa=C5=9B?= Date: Mon, 6 Aug 2018 22:41:15 +0200 Subject: [PATCH 2/6] add comment to CheckNamedValue --- sqlmock.go | 1 + 1 file changed, 1 insertion(+) diff --git a/sqlmock.go b/sqlmock.go index 17cfd6c..01745f1 100644 --- a/sqlmock.go +++ b/sqlmock.go @@ -180,6 +180,7 @@ func (c *sqlmock) ExpectationsWereMet() error { return nil } +// CheckNamedValue meets https://golang.org/pkg/database/sql/driver/#NamedValueChecker func (c *sqlmock) CheckNamedValue(nv *driver.NamedValue) (err error) { nv.Value, err = c.converter.ConvertValue(nv.Value) return err From aaceb21fbdfc0bc260dbfd622ffa4ab591e0a0a2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jan=20Wa=C5=9B?= Date: Mon, 6 Aug 2018 22:49:41 +0200 Subject: [PATCH 3/6] use custom converter only in newer go --- sqlmock.go | 6 ------ sqlmock_go18.go | 6 ++++++ 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/sqlmock.go b/sqlmock.go index 01745f1..d0e79c1 100644 --- a/sqlmock.go +++ b/sqlmock.go @@ -180,12 +180,6 @@ func (c *sqlmock) ExpectationsWereMet() error { return nil } -// CheckNamedValue meets https://golang.org/pkg/database/sql/driver/#NamedValueChecker -func (c *sqlmock) CheckNamedValue(nv *driver.NamedValue) (err error) { - nv.Value, err = c.converter.ConvertValue(nv.Value) - return err -} - // Begin meets http://golang.org/pkg/database/sql/driver/#Conn interface func (c *sqlmock) Begin() (driver.Tx, error) { ex, err := c.begin() diff --git a/sqlmock_go18.go b/sqlmock_go18.go index 49bd4f5..b8c76f8 100644 --- a/sqlmock_go18.go +++ b/sqlmock_go18.go @@ -111,3 +111,9 @@ func (stmt *statement) QueryContext(ctx context.Context, args []driver.NamedValu } // @TODO maybe add ExpectedBegin.WithOptions(driver.TxOptions) + +// CheckNamedValue meets https://golang.org/pkg/database/sql/driver/#NamedValueChecker +func (c *sqlmock) CheckNamedValue(nv *driver.NamedValue) (err error) { + nv.Value, err = c.converter.ConvertValue(nv.Value) + return err +} From 298bfde310e2d7a083dd057bd44c5b69a2005a64 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jan=20Wa=C5=9B?= Date: Mon, 6 Aug 2018 22:59:44 +0200 Subject: [PATCH 4/6] add tests --- sqlmock_test.go | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/sqlmock_test.go b/sqlmock_test.go index f43a3e7..6cc56ae 100644 --- a/sqlmock_test.go +++ b/sqlmock_test.go @@ -1113,3 +1113,31 @@ func TestExecExpectationErrorDelay(t *testing.T) { t.Errorf("expecting a delay of less than %v before error, actual delay was %v", delay, elapsed) } } + +func TestOptionsFail(t *testing.T) { + t.Parallel() + expected := errors.New("failing option") + option := func(*sqlmock) error { + return expected + } + db, _, err := New(option) + defer db.Close() + if err == nil { + t.Errorf("missing expecting error '%s' when opening a stub database connection", expected) + } +} + +func TestNewRows(t *testing.T) { + t.Parallel() + db, mock, err := New() + if err != nil { + t.Errorf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + columns := []string{"col1", "col2"} + + r := mock.NewRows(columns) + if len(r.cols) != len(columns) || r.cols[0] != columns[0] || r.cols[1] != columns[1] { + t.Errorf("expecting to create a row with columns %v, actual colmns are %v", r.cols, columns) + } +} From 3cbf32d5e7413a59c37f2851e3c6e26f9005832f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jan=20Wa=C5=9B?= Date: Mon, 6 Aug 2018 23:07:17 +0200 Subject: [PATCH 5/6] more tests --- sqlmock_test.go | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/sqlmock_test.go b/sqlmock_test.go index 6cc56ae..e56ec03 100644 --- a/sqlmock_test.go +++ b/sqlmock_test.go @@ -849,6 +849,32 @@ func TestRollbackThrow(t *testing.T) { // Output: } +func TestUnexpectedBegin(t *testing.T) { + // Open new mock database + db, _, err := New() + if err != nil { + fmt.Println("error creating mock database") + return + } + if _, err := db.Begin(); err == nil { + t.Error("an error was expected when calling begin, but got none") + } +} + +func TestUnexpectedExec(t *testing.T) { + // Open new mock database + db, mock, err := New() + if err != nil { + fmt.Println("error creating mock database") + return + } + mock.ExpectBegin() + db.Begin() + if _, err := db.Exec("SELECT 1"); err == nil { + t.Error("an error was expected when calling exec, but got none") + } +} + func TestUnexpectedCommit(t *testing.T) { // Open new mock database db, mock, err := New() From 168056e96a1eb8977e289712fdf73f2852a09665 Mon Sep 17 00:00:00 2001 From: Jan Was Date: Fri, 14 Sep 2018 11:03:24 +0200 Subject: [PATCH 6/6] move options to a separate file --- driver.go | 20 +++++++------------- driver_test.go | 2 +- options.go | 12 ++++++++++++ 3 files changed, 20 insertions(+), 14 deletions(-) create mode 100644 options.go diff --git a/driver.go b/driver.go index 876f7a2..802f8fb 100644 --- a/driver.go +++ b/driver.go @@ -35,8 +35,9 @@ func (d *mockDriver) Open(dsn string) (driver.Conn, error) { return c, nil } -// New creates sqlmock database connection -// and a mock to manage expectations. +// New creates sqlmock database connection and a mock to manage expectations. +// Accepts options, like ValueConverterOption, to use a ValueConverter from +// a specific driver. // Pings db so that all expectations could be // asserted. func New(options ...func(*sqlmock) error) (*sql.DB, Sqlmock, error) { @@ -51,8 +52,10 @@ func New(options ...func(*sqlmock) error) (*sql.DB, Sqlmock, error) { return smock.open(options) } -// NewWithDSN creates sqlmock database connection -// with a specific DSN and a mock to manage expectations. +// NewWithDSN creates sqlmock database connection with a specific DSN +// and a mock to manage expectations. +// Accepts options, like ValueConverterOption, to use a ValueConverter from +// a specific driver. // Pings db so that all expectations could be asserted. // // This method is introduced because of sql abstraction @@ -76,12 +79,3 @@ func NewWithDSN(dsn string, options ...func(*sqlmock) error) (*sql.DB, Sqlmock, return smock.open(options) } - -// WithValueConverter allows to create a sqlmock connection -// with a custom ValueConverter to support drivers with special data types. -func WithValueConverter(converter driver.ValueConverter) func(*sqlmock) error { - return func(s *sqlmock) error { - s.converter = converter - return nil - } -} diff --git a/driver_test.go b/driver_test.go index a8acc69..bbd7293 100644 --- a/driver_test.go +++ b/driver_test.go @@ -100,7 +100,7 @@ func TestTwoOpenConnectionsOnTheSameDSN(t *testing.T) { func TestWithOptions(t *testing.T) { c := &converter{} - _, mock, err := New(WithValueConverter(c)) + _, mock, err := New(ValueConverterOption(c)) if err != nil { t.Errorf("expected no error, but got: %s", err) } diff --git a/options.go b/options.go new file mode 100644 index 0000000..05c09d6 --- /dev/null +++ b/options.go @@ -0,0 +1,12 @@ +package sqlmock + +import "database/sql/driver" + +// ValueConverterOption allows to create a sqlmock connection +// with a custom ValueConverter to support drivers with special data types. +func ValueConverterOption(converter driver.ValueConverter) func(*sqlmock) error { + return func(s *sqlmock) error { + s.converter = converter + return nil + } +}