From a071483cba0979a2ec3fe2c908cf48be90994aea Mon Sep 17 00:00:00 2001 From: gedi Date: Fri, 17 Jul 2015 13:14:30 +0300 Subject: [PATCH] concurrency support, closes #20 and closes #9 and closes #15 * c600769 do not require a connection name, unique dsn is generated * 1b20b9c update travis * 1097b6a add comments for godoc documentation * c142a95 fix golint reported issues --- .travis.yml | 6 +- LICENSE | 2 +- README.md | 15 +- connection.go | 151 --------------- connection_test.go | 378 -------------------------------------- driver.go | 56 ++++++ driver_test.go | 83 +++++++++ expectations.go | 185 +++++++++++++++---- expectations_test.go | 2 +- result_test.go | 14 ++ rows.go | 49 ++--- sqlmock.go | 424 ++++++++++++++++++++++++++----------------- sqlmock_test.go | 223 ++++++++++++++--------- statement.go | 5 +- transaction.go | 37 ---- util.go | 6 +- 16 files changed, 725 insertions(+), 911 deletions(-) delete mode 100644 connection.go delete mode 100644 connection_test.go create mode 100644 driver.go create mode 100644 driver_test.go delete mode 100644 transaction.go diff --git a/.travis.yml b/.travis.yml index 464bcb0..76f71f0 100644 --- a/.travis.yml +++ b/.travis.yml @@ -7,10 +7,6 @@ go: - tip script: - - go get github.com/kisielk/errcheck - - go get ./... - - go test -v ./... - go test -race ./... - - errcheck github.com/DATA-DOG/go-sqlmock - + # linter will follow diff --git a/LICENSE b/LICENSE index d0a2e8f..dd431fa 100644 --- a/LICENSE +++ b/LICENSE @@ -1,6 +1,6 @@ The three clause BSD license (http://en.wikipedia.org/wiki/BSD_licenses) -Copyright (c) 2013, DataDog.lt team +Copyright (c) 2013-2015, DataDog.lt team All rights reserved. Redistribution and use in source and binary forms, with or without diff --git a/README.md b/README.md index 6b0769a..c43b8ca 100644 --- a/README.md +++ b/README.md @@ -5,12 +5,25 @@ This is a **mock** driver as **database/sql/driver** which is very flexible and pragmatic to manage and mock expected queries. All the expectations should be met and all queries and actions -triggered should be mocked in order to pass a test. +triggered should be mocked in order to pass a test. The package has no 3rd party dependencies. + +**NOTE:** regarding major issues #20 and #9 the api has changed to support concurrency and more than +one database connection. + +If you need an old version, checkout **go-sqlmock** at gopkg.in: + + go get gopkg.in/DATA-DOG/go-sqlmock.v0 + +Otherwise use the **v1** branch from master. ## Install go get github.com/DATA-DOG/go-sqlmock +Or take an older version: + + go get gopkg.in/DATA-DOG/go-sqlmock.v0 + ## Use it with pleasure An example of some database interaction which you may want to test: diff --git a/connection.go b/connection.go deleted file mode 100644 index ed43f06..0000000 --- a/connection.go +++ /dev/null @@ -1,151 +0,0 @@ -package sqlmock - -import ( - "database/sql/driver" - "fmt" - "reflect" -) - -type conn struct { - expectations []expectation - active expectation -} - -// Close a mock database driver connection. It should -// be always called to ensure that all expectations -// were met successfully. Returns error if there is any -func (c *conn) Close() (err error) { - for _, e := range mock.conn.expectations { - if !e.fulfilled() { - err = fmt.Errorf("there is a remaining expectation %T which was not matched yet", e) - break - } - } - mock.conn.expectations = []expectation{} - mock.conn.active = nil - return err -} - -func (c *conn) Begin() (driver.Tx, error) { - e := c.next() - if e == nil { - return nil, fmt.Errorf("all expectations were already fulfilled, call to begin transaction was not expected") - } - - etb, ok := e.(*expectedBegin) - if !ok { - return nil, fmt.Errorf("call to begin transaction, was not expected, next expectation is %T as %+v", e, e) - } - etb.triggered = true - return &transaction{c}, etb.err -} - -// get next unfulfilled expectation -func (c *conn) next() (e expectation) { - for _, e = range c.expectations { - if !e.fulfilled() { - return - } - } - return nil // all expectations were fulfilled -} - -func (c *conn) Exec(query string, args []driver.Value) (res driver.Result, err error) { - e := c.next() - query = stripQuery(query) - if e == nil { - return nil, fmt.Errorf("all expectations were already fulfilled, call to exec '%s' query with args %+v was not expected", query, args) - } - - eq, ok := e.(*expectedExec) - if !ok { - return nil, fmt.Errorf("call to exec query '%s' with args %+v, was not expected, next expectation is %T as %+v", query, args, e, e) - } - - eq.triggered = true - - defer argMatcherErrorHandler(&err) // converts panic to error in case of reflect value type mismatch - - if !eq.queryMatches(query) { - return nil, fmt.Errorf("exec query '%s', does not match regex '%s'", query, eq.sqlRegex.String()) - } - - if !eq.argsMatches(args) { - return nil, fmt.Errorf("exec query '%s', args %+v does not match expected %+v", query, args, eq.args) - } - - if eq.err != nil { - return nil, eq.err // mocked to return error - } - - if eq.result == nil { - return nil, fmt.Errorf("exec query '%s' with args %+v, must return a database/sql/driver.result, but it was not set for expectation %T as %+v", query, args, eq, eq) - } - - return eq.result, err -} - -func (c *conn) Prepare(query string) (driver.Stmt, error) { - e := c.next() - - // for backwards compatibility, ignore when Prepare not expected - if e == nil { - return &statement{mock.conn, stripQuery(query)}, nil - } - eq, ok := e.(*expectedPrepare) - if !ok { - return &statement{mock.conn, stripQuery(query)}, nil - } - - eq.triggered = true - if eq.err != nil { - return nil, eq.err // mocked to return error - } - - return &statement{mock.conn, stripQuery(query)}, nil -} - -func (c *conn) Query(query string, args []driver.Value) (rw driver.Rows, err error) { - e := c.next() - query = stripQuery(query) - if e == nil { - return nil, fmt.Errorf("all expectations were already fulfilled, call to query '%s' with args %+v was not expected", query, args) - } - - eq, ok := e.(*expectedQuery) - if !ok { - return nil, fmt.Errorf("call to query '%s' with args %+v, was not expected, next expectation is %T as %+v", query, args, e, e) - } - - eq.triggered = true - - defer argMatcherErrorHandler(&err) // converts panic to error in case of reflect value type mismatch - - if !eq.queryMatches(query) { - return nil, fmt.Errorf("query '%s', does not match regex [%s]", query, eq.sqlRegex.String()) - } - - if !eq.argsMatches(args) { - return nil, fmt.Errorf("query '%s', args %+v does not match expected %+v", query, args, eq.args) - } - - if eq.err != nil { - return nil, eq.err // mocked to return error - } - - if eq.rows == nil { - return nil, fmt.Errorf("query '%s' with args %+v, must return a database/sql/driver.rows, but it was not set for expectation %T as %+v", query, args, eq, eq) - } - - return eq.rows, err -} - -func argMatcherErrorHandler(errp *error) { - if e := recover(); e != nil { - if se, ok := e.(*reflect.ValueError); ok { // catch reflect error, failed type conversion - *errp = fmt.Errorf("Failed to compare query arguments: %s", se) - } else { - panic(e) // overwise panic - } - } -} diff --git a/connection_test.go b/connection_test.go deleted file mode 100644 index b4aabe9..0000000 --- a/connection_test.go +++ /dev/null @@ -1,378 +0,0 @@ -package sqlmock - -import ( - "database/sql/driver" - "errors" - "regexp" - "testing" -) - -func TestExecNoExpectations(t *testing.T) { - c := &conn{ - expectations: []expectation{ - &expectedExec{ - queryBasedExpectation: queryBasedExpectation{ - commonExpectation: commonExpectation{ - triggered: true, - err: errors.New("WillReturnError"), - }, - sqlRegex: regexp.MustCompile(regexp.QuoteMeta("otherquery")), - args: []driver.Value{456}, - }, - }, - }, - } - res, err := c.Exec("query", []driver.Value{123}) - if res != nil { - t.Error("Result should be nil") - } - if err == nil { - t.Error("error should not be nil") - } - pattern := regexp.MustCompile(regexp.QuoteMeta("all expectations were already fulfilled, call to exec")) - if !pattern.MatchString(err.Error()) { - t.Errorf("error should match expected error message (actual: %s)", err.Error()) - } -} - -func TestExecExpectationMismatch(t *testing.T) { - c := &conn{ - expectations: []expectation{ - &expectedQuery{ - queryBasedExpectation: queryBasedExpectation{ - commonExpectation: commonExpectation{ - err: errors.New("WillReturnError"), - }, - sqlRegex: regexp.MustCompile(regexp.QuoteMeta("otherquery")), - args: []driver.Value{456}, - }, - }, - }, - } - res, err := c.Exec("query", []driver.Value{123}) - if res != nil { - t.Error("Result should be nil") - } - if err == nil { - t.Error("error should not be nil") - } - pattern := regexp.MustCompile(regexp.QuoteMeta("was not expected, next expectation is")) - if !pattern.MatchString(err.Error()) { - t.Errorf("error should match expected error message (actual: %s)", err.Error()) - } -} - -func TestExecQueryMismatch(t *testing.T) { - c := &conn{ - expectations: []expectation{ - &expectedExec{ - queryBasedExpectation: queryBasedExpectation{ - commonExpectation: commonExpectation{ - err: errors.New("WillReturnError"), - }, - sqlRegex: regexp.MustCompile(regexp.QuoteMeta("otherquery")), - args: []driver.Value{456}, - }, - }, - }, - } - res, err := c.Exec("query", []driver.Value{123}) - if res != nil { - t.Error("Result should be nil") - } - if err == nil { - t.Error("error should not be nil") - } - pattern := regexp.MustCompile(regexp.QuoteMeta("does not match regex")) - if !pattern.MatchString(err.Error()) { - t.Errorf("error should match expected error message (actual: %s)", err.Error()) - } -} - -func TestExecArgsMismatch(t *testing.T) { - c := &conn{ - expectations: []expectation{ - &expectedExec{ - queryBasedExpectation: queryBasedExpectation{ - commonExpectation: commonExpectation{ - err: errors.New("WillReturnError"), - }, - sqlRegex: regexp.MustCompile(regexp.QuoteMeta("query")), - args: []driver.Value{456}, - }, - }, - }, - } - res, err := c.Exec("query", []driver.Value{123}) - if res != nil { - t.Error("Result should be nil") - } - if err == nil { - t.Error("error should not be nil") - } - pattern := regexp.MustCompile(regexp.QuoteMeta("does not match expected")) - if !pattern.MatchString(err.Error()) { - t.Errorf("error should match expected error message (actual: %s)", err.Error()) - } -} - -func TestExecWillReturnError(t *testing.T) { - c := &conn{ - expectations: []expectation{ - &expectedExec{ - queryBasedExpectation: queryBasedExpectation{ - commonExpectation: commonExpectation{ - err: errors.New("WillReturnError"), - }, - sqlRegex: regexp.MustCompile(regexp.QuoteMeta("query")), - }, - }, - }, - } - res, err := c.Exec("query", []driver.Value{123}) - if res != nil { - t.Error("Result should be nil") - } - if err == nil { - t.Error("error should not be nil") - } - if err.Error() != "WillReturnError" { - t.Errorf("error should match expected error message (actual: %s)", err.Error()) - } -} - -func TestExecMissingResult(t *testing.T) { - c := &conn{ - expectations: []expectation{ - &expectedExec{ - queryBasedExpectation: queryBasedExpectation{ - commonExpectation: commonExpectation{}, - sqlRegex: regexp.MustCompile(regexp.QuoteMeta("query")), - args: []driver.Value{123}, - }, - }, - }, - } - res, err := c.Exec("query", []driver.Value{123}) - if res != nil { - t.Error("Result should be nil") - } - if err == nil { - t.Error("error should not be nil") - } - pattern := regexp.MustCompile(regexp.QuoteMeta("must return a database/sql/driver.result, but it was not set for expectation")) - if !pattern.MatchString(err.Error()) { - t.Errorf("error should match expected error message (actual: %s)", err.Error()) - } -} - -func TestExec(t *testing.T) { - expectedResult := driver.Result(&result{}) - c := &conn{ - expectations: []expectation{ - &expectedExec{ - queryBasedExpectation: queryBasedExpectation{ - commonExpectation: commonExpectation{}, - sqlRegex: regexp.MustCompile(regexp.QuoteMeta("query")), - args: []driver.Value{123}, - }, - result: expectedResult, - }, - }, - } - res, err := c.Exec("query", []driver.Value{123}) - if res == nil { - t.Error("Result should not be nil") - } - if res != expectedResult { - t.Errorf("Result should match expected Result (actual %+v)", res) - } - if err != nil { - t.Errorf("error should be nil (actual %s)", err.Error()) - } -} - -func TestQueryNoExpectations(t *testing.T) { - c := &conn{ - expectations: []expectation{ - &expectedQuery{ - queryBasedExpectation: queryBasedExpectation{ - commonExpectation: commonExpectation{ - triggered: true, - err: errors.New("WillReturnError"), - }, - sqlRegex: regexp.MustCompile(regexp.QuoteMeta("otherquery")), - args: []driver.Value{456}, - }, - }, - }, - } - res, err := c.Query("query", []driver.Value{123}) - if res != nil { - t.Error("Rows should be nil") - } - if err == nil { - t.Error("error should not be nil") - } - pattern := regexp.MustCompile(regexp.QuoteMeta("all expectations were already fulfilled, call to query")) - if !pattern.MatchString(err.Error()) { - t.Errorf("error should match expected error message (actual: %s)", err.Error()) - } -} - -func TestQueryExpectationMismatch(t *testing.T) { - c := &conn{ - expectations: []expectation{ - &expectedExec{ - queryBasedExpectation: queryBasedExpectation{ - commonExpectation: commonExpectation{ - err: errors.New("WillReturnError"), - }, - sqlRegex: regexp.MustCompile(regexp.QuoteMeta("otherquery")), - args: []driver.Value{456}, - }, - }, - }, - } - res, err := c.Query("query", []driver.Value{123}) - if res != nil { - t.Error("Rows should be nil") - } - if err == nil { - t.Error("error should not be nil") - } - pattern := regexp.MustCompile(regexp.QuoteMeta("was not expected, next expectation is")) - if !pattern.MatchString(err.Error()) { - t.Errorf("error should match expected error message (actual: %s)", err.Error()) - } -} - -func TestQueryQueryMismatch(t *testing.T) { - c := &conn{ - expectations: []expectation{ - &expectedQuery{ - queryBasedExpectation: queryBasedExpectation{ - commonExpectation: commonExpectation{ - err: errors.New("WillReturnError"), - }, - sqlRegex: regexp.MustCompile(regexp.QuoteMeta("otherquery")), - args: []driver.Value{456}, - }, - }, - }, - } - res, err := c.Query("query", []driver.Value{123}) - if res != nil { - t.Error("Rows should be nil") - } - if err == nil { - t.Error("error should not be nil") - } - pattern := regexp.MustCompile(regexp.QuoteMeta("does not match regex")) - if !pattern.MatchString(err.Error()) { - t.Errorf("error should match expected error message (actual: %s)", err.Error()) - } -} - -func TestQueryArgsMismatch(t *testing.T) { - c := &conn{ - expectations: []expectation{ - &expectedQuery{ - queryBasedExpectation: queryBasedExpectation{ - commonExpectation: commonExpectation{ - err: errors.New("WillReturnError"), - }, - sqlRegex: regexp.MustCompile(regexp.QuoteMeta("query")), - args: []driver.Value{456}, - }, - }, - }, - } - res, err := c.Query("query", []driver.Value{123}) - if res != nil { - t.Error("Rows should be nil") - } - if err == nil { - t.Error("error should not be nil") - } - pattern := regexp.MustCompile(regexp.QuoteMeta("does not match expected")) - if !pattern.MatchString(err.Error()) { - t.Errorf("error should match expected error message (actual: %s)", err.Error()) - } -} - -func TestQueryWillReturnError(t *testing.T) { - c := &conn{ - expectations: []expectation{ - &expectedQuery{ - queryBasedExpectation: queryBasedExpectation{ - commonExpectation: commonExpectation{ - err: errors.New("WillReturnError"), - }, - sqlRegex: regexp.MustCompile(regexp.QuoteMeta("query")), - }, - }, - }, - } - res, err := c.Query("query", []driver.Value{123}) - if res != nil { - t.Error("Rows should be nil") - } - if err == nil { - t.Error("error should not be nil") - } - if err.Error() != "WillReturnError" { - t.Errorf("error should match expected error message (actual: %s)", err.Error()) - } -} - -func TestQueryMissingRows(t *testing.T) { - c := &conn{ - expectations: []expectation{ - &expectedQuery{ - queryBasedExpectation: queryBasedExpectation{ - commonExpectation: commonExpectation{}, - sqlRegex: regexp.MustCompile(regexp.QuoteMeta("query")), - args: []driver.Value{123}, - }, - }, - }, - } - res, err := c.Query("query", []driver.Value{123}) - if res != nil { - t.Error("Rows should be nil") - } - if err == nil { - t.Error("error should not be nil") - } - pattern := regexp.MustCompile(regexp.QuoteMeta("must return a database/sql/driver.rows, but it was not set for expectation")) - if !pattern.MatchString(err.Error()) { - t.Errorf("error should match expected error message (actual: %s)", err.Error()) - } -} - -func TestQuery(t *testing.T) { - expectedRows := driver.Rows(&rows{}) - c := &conn{ - expectations: []expectation{ - &expectedQuery{ - queryBasedExpectation: queryBasedExpectation{ - commonExpectation: commonExpectation{}, - sqlRegex: regexp.MustCompile(regexp.QuoteMeta("query")), - args: []driver.Value{123}, - }, - rows: expectedRows, - }, - }, - } - rows, err := c.Query("query", []driver.Value{123}) - if rows == nil { - t.Error("Rows should not be nil") - } - if rows != expectedRows { - t.Errorf("Rows should match expected Rows (actual %+v)", rows) - } - if err != nil { - t.Errorf("error should be nil (actual %s)", err.Error()) - } -} diff --git a/driver.go b/driver.go new file mode 100644 index 0000000..105a2af --- /dev/null +++ b/driver.go @@ -0,0 +1,56 @@ +package sqlmock + +import ( + "database/sql" + "database/sql/driver" + "fmt" + "sync" +) + +var pool *mockDriver + +func init() { + pool = &mockDriver{ + conns: make(map[string]*Sqlmock), + } + sql.Register("sqlmock", pool) +} + +type mockDriver struct { + sync.Mutex + counter int + conns map[string]*Sqlmock +} + +func (d *mockDriver) Open(dsn string) (driver.Conn, error) { + d.Lock() + defer d.Unlock() + + c, ok := d.conns[dsn] + if !ok { + return c, fmt.Errorf("expected a connection to be available, but it is not") + } + + c.opened++ + return c, nil +} + +// New creates sqlmock database connection +// and a mock to manage expectations. +// Pings db so that all expectations could be +// asserted. +func New() (db *sql.DB, mock *Sqlmock, err error) { + pool.Lock() + dsn := fmt.Sprintf("sqlmock_db_%d", pool.counter) + pool.counter++ + + mock = &Sqlmock{dsn: dsn, drv: pool} + pool.conns[dsn] = mock + pool.Unlock() + + db, err = sql.Open("sqlmock", dsn) + if err != nil { + return + } + return db, mock, db.Ping() +} diff --git a/driver_test.go b/driver_test.go new file mode 100644 index 0000000..0554a22 --- /dev/null +++ b/driver_test.go @@ -0,0 +1,83 @@ +package sqlmock + +import ( + "fmt" + "testing" +) + +func ExampleNew() { + db, mock, err := New() + if err != nil { + fmt.Println("expected no error, but got:", err) + return + } + defer db.Close() + // now we can expect operations performed on db + mock.ExpectBegin().WillReturnError(fmt.Errorf("an error will occur on db.Begin() call")) +} + +func TestShouldOpenConnectionIssue15(t *testing.T) { + db, mock, err := New() + if err != nil { + t.Errorf("expected no error, but got: %s", err) + } + if len(pool.conns) != 1 { + t.Errorf("expected 1 connection in pool, but there is: %d", len(pool.conns)) + } + + if mock.opened != 1 { + t.Errorf("expected 1 connection on mock to be opened, but there is: %d", mock.opened) + } + + // defer so the rows gets closed first + defer func() { + if mock.opened != 0 { + t.Errorf("expected no connections on mock to be opened, but there is: %d", mock.opened) + } + }() + + mock.ExpectQuery("SELECT").WillReturnRows(NewRows([]string{"one", "two"}).AddRow("val1", "val2")) + rows, err := db.Query("SELECT") + if err != nil { + t.Errorf("unexpected error: %s", err) + } + defer rows.Close() + + mock.ExpectExec("UPDATE").WillReturnResult(NewResult(1, 1)) + if _, err = db.Exec("UPDATE"); err != nil { + t.Errorf("unexpected error: %s", err) + } + + // now there should be two connections open + if mock.opened != 2 { + t.Errorf("expected 2 connection on mock to be opened, but there is: %d", mock.opened) + } + + mock.ExpectClose() + if err = db.Close(); err != nil { + t.Errorf("expected no error on close, but got: %s", err) + } + + // one is still reserved for rows + if mock.opened != 1 { + t.Errorf("expected 1 connection on mock to be still reserved for rows, but there is: %d", mock.opened) + } +} + +func TestTwoOpenConnectionsOnTheSameDSN(t *testing.T) { + db, mock, err := New() + if err != nil { + t.Errorf("expected no error, but got: %s", err) + } + db2, mock2, err := New() + if len(pool.conns) != 2 { + t.Errorf("expected 2 connection in pool, but there is: %d", len(pool.conns)) + } + + if db == db2 { + t.Errorf("expected not the same database instance, but it is the same") + } + if mock == mock2 { + t.Errorf("expected not the same mock instance, but it is the same") + } +} diff --git a/expectations.go b/expectations.go index d778afd..909c499 100644 --- a/expectations.go +++ b/expectations.go @@ -7,7 +7,8 @@ import ( ) // Argument interface allows to match -// any argument in specific way +// any argument in specific way when used with +// ExpectedQuery and ExpectedExec expectations. type Argument interface { Match(driver.Value) bool } @@ -15,7 +16,6 @@ type Argument interface { // an expectation interface type expectation interface { fulfilled() bool - setError(err error) } // common expectation struct @@ -29,8 +29,151 @@ func (e *commonExpectation) fulfilled() bool { return e.triggered } -func (e *commonExpectation) setError(err error) { +// ExpectedClose is used to manage *sql.DB.Close expectation +// returned by *Sqlmock.ExpectClose. +type ExpectedClose struct { + commonExpectation +} + +// WillReturnError allows to set an error for *sql.DB.Close action +func (e *ExpectedClose) WillReturnError(err error) *ExpectedClose { e.err = err + return e +} + +// ExpectedBegin is used to manage *sql.DB.Begin expectation +// returned by *Sqlmock.ExpectBegin. +type ExpectedBegin struct { + commonExpectation +} + +// WillReturnError allows to set an error for *sql.DB.Begin action +func (e *ExpectedBegin) WillReturnError(err error) *ExpectedBegin { + e.err = err + return e +} + +// ExpectedCommit is used to manage *sql.Tx.Commit expectation +// returned by *Sqlmock.ExpectCommit. +type ExpectedCommit struct { + commonExpectation +} + +// WillReturnError allows to set an error for *sql.Tx.Close action +func (e *ExpectedCommit) WillReturnError(err error) *ExpectedCommit { + e.err = err + return e +} + +// ExpectedRollback is used to manage *sql.Tx.Rollback expectation +// returned by *Sqlmock.ExpectRollback. +type ExpectedRollback struct { + commonExpectation +} + +// WillReturnError allows to set an error for *sql.Tx.Rollback action +func (e *ExpectedRollback) WillReturnError(err error) *ExpectedRollback { + e.err = err + return e +} + +// ExpectedQuery is used to manage *sql.DB.Query, *dql.DB.QueryRow, *sql.Tx.Query, +// *sql.Tx.QueryRow, *sql.Stmt.Query or *sql.Stmt.QueryRow expectations. +// Returned by *Sqlmock.ExpectQuery. +type ExpectedQuery struct { + queryBasedExpectation + rows driver.Rows +} + +// WithArgs will match given expected args to actual database query arguments. +// if at least one argument does not match, it will return an error. For specific +// arguments an sqlmock.Argument interface can be used to match an argument. +func (e *ExpectedQuery) WithArgs(args ...driver.Value) *ExpectedQuery { + e.args = args + return e +} + +// WillReturnError allows to set an error for expected database query +func (e *ExpectedQuery) WillReturnError(err error) *ExpectedQuery { + e.err = err + return e +} + +// WillReturnRows specifies the set of resulting rows that will be returned +// by the triggered query +func (e *ExpectedQuery) WillReturnRows(rows driver.Rows) *ExpectedQuery { + e.rows = rows + return e +} + +// ExpectedExec is used to manage *sql.DB.Exec, *sql.Tx.Exec or *sql.Stmt.Exec expectations. +// Returned by *Sqlmock.ExpectExec. +type ExpectedExec struct { + queryBasedExpectation + result driver.Result +} + +// WithArgs will match given expected args to actual database exec operation arguments. +// if at least one argument does not match, it will return an error. For specific +// arguments an sqlmock.Argument interface can be used to match an argument. +func (e *ExpectedExec) WithArgs(args ...driver.Value) *ExpectedExec { + e.args = args + return e +} + +// WillReturnError allows to set an error for expected database exec action +func (e *ExpectedExec) WillReturnError(err error) *ExpectedExec { + e.err = err + return e +} + +// WillReturnResult arranges for an expected Exec() to return a particular +// result, there is sqlmock.NewResult(lastInsertID int64, affectedRows int64) method +// to build a corresponding result. Or if actions needs to be tested against errors +// sqlmock.NewErrorResult(err error) to return a given error. +func (e *ExpectedExec) WillReturnResult(result driver.Result) *ExpectedExec { + e.result = result + return e +} + +// ExpectedPrepare is used to manage *sql.DB.Prepare or *sql.Tx.Prepare expectations. +// Returned by *Sqlmock.ExpectPrepare. +type ExpectedPrepare struct { + commonExpectation + mock *Sqlmock + sqlRegex *regexp.Regexp + statement driver.Stmt + closeErr error +} + +// WillReturnError allows to set an error for the expected *sql.DB.Prepare or *sql.Tx.Prepare action. +func (e *ExpectedPrepare) WillReturnError(err error) *ExpectedPrepare { + e.err = err + return e +} + +// WillReturnCloseError allows to set an error for this prapared statement Close action +func (e *ExpectedPrepare) WillReturnCloseError(err error) *ExpectedPrepare { + e.closeErr = err + return e +} + +// ExpectQuery allows to expect Query() or QueryRow() on this prepared statement. +// this method is convenient in order to prevent duplicating sql query string matching. +func (e *ExpectedPrepare) ExpectQuery() *ExpectedQuery { + eq := &ExpectedQuery{} + eq.sqlRegex = e.sqlRegex + e.mock.expected = append(e.mock.expected, eq) + return eq +} + +// ExpectExec allows to expect Exec() on this prepared statement. +// this method is convenient in order to prevent duplicating sql query string matching. +func (e *ExpectedPrepare) ExpectExec() *ExpectedExec { + eq := &ExpectedExec{} + eq.sqlRegex = e.sqlRegex + e.mock.expected = append(e.mock.expected, eq) + return eq } // query based expectation @@ -88,39 +231,3 @@ func (e *queryBasedExpectation) argsMatches(args []driver.Value) bool { } return true } - -// begin transaction -type expectedBegin struct { - commonExpectation -} - -// tx commit -type expectedCommit struct { - commonExpectation -} - -// tx rollback -type expectedRollback struct { - commonExpectation -} - -// query expectation -type expectedQuery struct { - queryBasedExpectation - - rows driver.Rows -} - -// exec query expectation -type expectedExec struct { - queryBasedExpectation - - result driver.Result -} - -// Prepare expectation -type expectedPrepare struct { - commonExpectation - - statement driver.Stmt -} diff --git a/expectations_test.go b/expectations_test.go index 5bcccd3..cbb13a9 100644 --- a/expectations_test.go +++ b/expectations_test.go @@ -60,7 +60,7 @@ func TestQueryExpectationArgComparison(t *testing.T) { } func TestQueryExpectationSqlMatch(t *testing.T) { - e := &expectedExec{} + e := &ExpectedExec{} e.sqlRegex = regexp.MustCompile("SELECT x FROM") if !e.queryMatches("SELECT x FROM someting") { t.Errorf("Sql must have matched the query") diff --git a/result_test.go b/result_test.go index c594f6d..3f5de21 100644 --- a/result_test.go +++ b/result_test.go @@ -5,6 +5,20 @@ import ( "testing" ) +// used for examples +var mock = &Sqlmock{} + +func ExampleNewErrorResult() { + result := NewErrorResult(fmt.Errorf("some error")) + mock.ExpectExec("^INSERT (.+)").WillReturnResult(result) +} + +func ExampleNewResult() { + var lastInsertID, affected int64 + result := NewResult(lastInsertID, affected) + mock.ExpectExec("^INSERT (.+)").WillReturnResult(result) +} + func TestShouldReturnValidSqlDriverResult(t *testing.T) { result := NewResult(1, 2) id, err := result.LastInsertId() diff --git a/rows.go b/rows.go index 04522a2..3e58a89 100644 --- a/rows.go +++ b/rows.go @@ -10,12 +10,22 @@ import ( // Rows interface allows to construct rows // which also satisfies database/sql/driver.Rows interface type Rows interface { - driver.Rows // composed interface, supports sql driver.Rows + // composed interface, supports sql driver.Rows + driver.Rows + + // AddRow composed from database driver.Value slice + // return the same instance to perform subsequent actions. + // Note that the number of values must match the number + // of columns AddRow(...driver.Value) Rows + + // FromCSVString build rows from csv string. + // return the same instance to perform subsequent actions. + // Note that the number of values must match the number + // of columns FromCSVString(s string) Rows } -// a struct which implements database/sql/driver.Rows type rows struct { cols []string rows [][]driver.Value @@ -48,16 +58,13 @@ func (r *rows) Next(dest []driver.Value) error { return nil } -// NewRows allows Rows to be created from a group of -// sql driver.Value or from the CSV string and +// NewRows allows Rows to be created from a +// sql driver.Value slice or from the CSV string and // to be used as sql driver.Rows func NewRows(columns []string) Rows { return &rows{cols: columns} } -// AddRow adds a row which is built from arguments -// in the same column order, returns sql driver.Rows -// compatible interface func (r *rows) AddRow(values ...driver.Value) Rows { if len(values) != len(r.cols) { panic("Expected number of values to match number of columns") @@ -72,8 +79,6 @@ func (r *rows) AddRow(values ...driver.Value) Rows { return r } -// FromCSVString adds rows from CSV string. -// Returns sql driver.Rows compatible interface func (r *rows) FromCSVString(s string) Rows { res := strings.NewReader(strings.TrimSpace(s)) csvReader := csv.NewReader(res) @@ -92,29 +97,3 @@ func (r *rows) FromCSVString(s string) Rows { } return r } - -// RowsFromCSVString creates Rows from CSV string -// to be used for mocked queries. Returns sql driver Rows interface -// ** DEPRECATED ** will be removed in the future, use Rows.FromCSVString -func RowsFromCSVString(columns []string, s string) driver.Rows { - rs := &rows{} - rs.cols = columns - - r := strings.NewReader(strings.TrimSpace(s)) - csvReader := csv.NewReader(r) - - for { - r, err := csvReader.Read() - if err != nil || r == nil { - break - } - - row := make([]driver.Value, len(columns)) - for i, v := range r { - v := strings.TrimSpace(v) - row[i] = []byte(v) - } - rs.rows = append(rs.rows, row) - } - return rs -} diff --git a/sqlmock.go b/sqlmock.go index b73ab02..4a44673 100644 --- a/sqlmock.go +++ b/sqlmock.go @@ -6,190 +6,280 @@ It hooks into Go standard library's database/sql package. The package provides convenient methods to mock database queries, transactions and expect the right execution flow, compare query arguments or even return error instead to simulate failures. See the example bellow, which illustrates how convenient it is -to work with: - - - package main - - import ( - "database/sql" - "github.com/DATA-DOG/go-sqlmock" - "testing" - "fmt" - ) - - // will test that order with a different status, cannot be cancelled - func TestShouldNotCancelOrderWithNonPendingStatus(t *testing.T) { - // open database stub - db, err := sql.Open("mock", "") - if err != nil { - t.Errorf("An error '%s' was not expected when opening a stub database connection", err) - } - - // columns to be used for result - columns := []string{"id", "status"} - // expect transaction begin - sqlmock.ExpectBegin() - // expect query to fetch order, match it with regexp - sqlmock.ExpectQuery("SELECT (.+) FROM orders (.+) FOR UPDATE"). - WithArgs(1). - WillReturnRows(sqlmock.NewRows(columns).FromCSVString("1,1")) - // expect transaction rollback, since order status is "cancelled" - sqlmock.ExpectRollback() - - // run the cancel order function - someOrderId := 1 - // call a function which executes expected database operations - err = cancelOrder(someOrderId, db) - if err != nil { - t.Errorf("Expected no error, but got %s instead", err) - } - // db.Close() ensures that all expectations have been met - if err = db.Close(); err != nil { - t.Errorf("Error '%s' was not expected while closing the database", err) - } - } - +to work with. */ package sqlmock import ( - "database/sql" "database/sql/driver" "fmt" + "reflect" "regexp" ) -var mock *mockDriver +// Sqlmock type satisfies required sql.driver interfaces +// to simulate actual database and also serves to +// create expectations for any kind of database action +// in order to mock and test real database behavior. +type Sqlmock struct { + dsn string + opened int + drv *mockDriver -// Mock interface defines a mock which is returned -// by any expectation and can be detailed further -// with the methods this interface provides -type Mock interface { - WithArgs(...driver.Value) Mock - WillReturnError(error) Mock - WillReturnRows(driver.Rows) Mock - WillReturnResult(driver.Result) Mock + expected []expectation } -type mockDriver struct { - conn *conn -} - -func (d *mockDriver) Open(dsn string) (driver.Conn, error) { - return mock.conn, nil -} - -func init() { - mock = &mockDriver{&conn{}} - sql.Register("mock", mock) -} - -// New creates sqlmock database connection -// and pings it so that all expectations could be -// asserted on Close. -func New() (db *sql.DB, err error) { - db, err = sql.Open("mock", "") - if err != nil { - return - } - // ensure open connection, otherwise Close does not assert expectations - return db, db.Ping() -} - -// ExpectBegin expects transaction to be started -func ExpectBegin() Mock { - e := &expectedBegin{} - mock.conn.expectations = append(mock.conn.expectations, e) - mock.conn.active = e - return mock.conn -} - -// ExpectCommit expects transaction to be commited -func ExpectCommit() Mock { - e := &expectedCommit{} - mock.conn.expectations = append(mock.conn.expectations, e) - mock.conn.active = e - return mock.conn -} - -// ExpectRollback expects transaction to be rolled back -func ExpectRollback() Mock { - e := &expectedRollback{} - mock.conn.expectations = append(mock.conn.expectations, e) - mock.conn.active = e - return mock.conn -} - -// ExpectPrepare expects Query to be prepared -func ExpectPrepare() Mock { - e := &expectedPrepare{} - mock.conn.expectations = append(mock.conn.expectations, e) - mock.conn.active = e - return mock.conn -} - -// WillReturnError the expectation will return an error -func (c *conn) WillReturnError(err error) Mock { - c.active.setError(err) - return c -} - -// ExpectExec expects database Exec to be triggered, which will match -// the given query string as a regular expression -func ExpectExec(sqlRegexStr string) Mock { - e := &expectedExec{} - e.sqlRegex = regexp.MustCompile(sqlRegexStr) - mock.conn.expectations = append(mock.conn.expectations, e) - mock.conn.active = e - return mock.conn -} - -// ExpectQuery database Query to be triggered, which will match -// the given query string as a regular expression -func ExpectQuery(sqlRegexStr string) Mock { - e := &expectedQuery{} - e.sqlRegex = regexp.MustCompile(sqlRegexStr) - - mock.conn.expectations = append(mock.conn.expectations, e) - mock.conn.active = e - return mock.conn -} - -// WithArgs expectation should be called with given arguments. -// Works with Exec and Query expectations -func (c *conn) WithArgs(args ...driver.Value) Mock { - eq, ok := c.active.(*expectedQuery) - if !ok { - ee, ok := c.active.(*expectedExec) - if !ok { - panic(fmt.Sprintf("arguments may be expected only with query based expectations, current is %T", c.active)) +func (c *Sqlmock) next() (e expectation) { + for _, e = range c.expected { + if !e.fulfilled() { + return } - ee.args = args - } else { - eq.args = args } - return c + return nil // all expectations were fulfilled } -// WillReturnResult expectation will return a Result. -// Works only with Exec expectations -func (c *conn) WillReturnResult(result driver.Result) Mock { - eq, ok := c.active.(*expectedExec) - if !ok { - panic(fmt.Sprintf("driver.result may be returned only by exec expectations, current is %T", c.active)) - } - eq.result = result - return c +// ExpectClose queues an expectation for this database +// action to be triggered. the *ExpectedClose allows +// to mock database response +func (c *Sqlmock) ExpectClose() *ExpectedClose { + e := &ExpectedClose{} + c.expected = append(c.expected, e) + return e } -// WillReturnRows expectation will return Rows. -// Works only with Query expectations -func (c *conn) WillReturnRows(rows driver.Rows) Mock { - eq, ok := c.active.(*expectedQuery) - if !ok { - panic(fmt.Sprintf("driver.rows may be returned only by query expectations, current is %T", c.active)) +// Close a mock database driver connection. It may or may not +// be called depending on the sircumstances, but if it is called +// there must be an *ExpectedClose expectation satisfied. +// meets http://golang.org/pkg/database/sql/driver/#Conn interface +func (c *Sqlmock) Close() error { + c.drv.Lock() + defer c.drv.Unlock() + + c.opened-- + if c.opened == 0 { + delete(c.drv.conns, c.dsn) + } + e := c.next() + if e == nil { + return fmt.Errorf("all expectations were already fulfilled, call to database Close was not expected") + } + + t, ok := e.(*ExpectedClose) + if !ok { + return fmt.Errorf("call to database Close, was not expected, next expectation is %T as %+v", e, e) + } + t.triggered = true + return t.err +} + +// ExpectationsWereMet checks whether all queued expectations +// were met in order. If any of them was not met - an error is returned. +func (c *Sqlmock) ExpectationsWereMet() error { + for _, e := range c.expected { + if !e.fulfilled() { + return fmt.Errorf("there is a remaining expectation %T which was not matched yet", e) + } + } + return nil +} + +// Begin meets http://golang.org/pkg/database/sql/driver/#Conn interface +func (c *Sqlmock) Begin() (driver.Tx, error) { + e := c.next() + if e == nil { + return nil, fmt.Errorf("all expectations were already fulfilled, call to begin transaction was not expected") + } + + t, ok := e.(*ExpectedBegin) + if !ok { + return nil, fmt.Errorf("call to begin transaction, was not expected, next expectation is %T as %+v", e, e) + } + t.triggered = true + return c, t.err +} + +// ExpectBegin expects *sql.DB.Begin to be called. +// the *ExpectedBegin allows to mock database response +func (c *Sqlmock) ExpectBegin() *ExpectedBegin { + e := &ExpectedBegin{} + c.expected = append(c.expected, e) + return e +} + +// Exec meets http://golang.org/pkg/database/sql/driver/#Execer +func (c *Sqlmock) Exec(query string, args []driver.Value) (res driver.Result, err error) { + e := c.next() + query = stripQuery(query) + if e == nil { + return nil, fmt.Errorf("all expectations were already fulfilled, call to exec '%s' query with args %+v was not expected", query, args) + } + + t, ok := e.(*ExpectedExec) + if !ok { + return nil, fmt.Errorf("call to exec query '%s' with args %+v, was not expected, next expectation is %T as %+v", query, args, e, e) + } + + t.triggered = true + if t.err != nil { + return nil, t.err // mocked to return error + } + + if t.result == nil { + return nil, fmt.Errorf("exec query '%s' with args %+v, must return a database/sql/driver.result, but it was not set for expectation %T as %+v", query, args, t, t) + } + + defer argMatcherErrorHandler(&err) // converts panic to error in case of reflect value type mismatch + + if !t.queryMatches(query) { + return nil, fmt.Errorf("exec query '%s', does not match regex '%s'", query, t.sqlRegex.String()) + } + + if !t.argsMatches(args) { + return nil, fmt.Errorf("exec query '%s', args %+v does not match expected %+v", query, args, t.args) + } + + return t.result, err +} + +// ExpectExec expects Exec() to be called with sql query +// which match sqlRegexStr given regexp. +// the *ExpectedExec allows to mock database response +func (c *Sqlmock) ExpectExec(sqlRegexStr string) *ExpectedExec { + e := &ExpectedExec{} + e.sqlRegex = regexp.MustCompile(sqlRegexStr) + c.expected = append(c.expected, e) + return e +} + +// Prepare meets http://golang.org/pkg/database/sql/driver/#Conn interface +func (c *Sqlmock) Prepare(query string) (driver.Stmt, error) { + e := c.next() + + query = stripQuery(query) + if e == nil { + return nil, fmt.Errorf("all expectations were already fulfilled, call to Prepare '%s' query was not expected", query) + } + t, ok := e.(*ExpectedPrepare) + if !ok { + return nil, fmt.Errorf("call to Prepare stetement with query '%s', was not expected, next expectation is %T as %+v", query, e, e) + } + + t.triggered = true + if t.err != nil { + return nil, t.err // mocked to return error + } + + return &statement{c, query, t.closeErr}, nil +} + +// ExpectPrepare expects Prepare() to be called with sql query +// which match sqlRegexStr given regexp. +// the *ExpectedPrepare allows to mock database response. +// Note that you may expect Query() or Exec() on the *ExpectedPrepare +// statement to prevent repeating sqlRegexStr +func (c *Sqlmock) ExpectPrepare(sqlRegexStr string) *ExpectedPrepare { + e := &ExpectedPrepare{sqlRegex: regexp.MustCompile(sqlRegexStr), mock: c} + c.expected = append(c.expected, e) + return e +} + +// Query meets http://golang.org/pkg/database/sql/driver/#Queryer +func (c *Sqlmock) Query(query string, args []driver.Value) (rw driver.Rows, err error) { + e := c.next() + query = stripQuery(query) + if e == nil { + return nil, fmt.Errorf("all expectations were already fulfilled, call to query '%s' with args %+v was not expected", query, args) + } + + t, ok := e.(*ExpectedQuery) + if !ok { + return nil, fmt.Errorf("call to query '%s' with args %+v, was not expected, next expectation is %T as %+v", query, args, e, e) + } + + t.triggered = true + if t.err != nil { + return nil, t.err // mocked to return error + } + + if t.rows == nil { + return nil, fmt.Errorf("query '%s' with args %+v, must return a database/sql/driver.rows, but it was not set for expectation %T as %+v", query, args, t, t) + } + + defer argMatcherErrorHandler(&err) // converts panic to error in case of reflect value type mismatch + + if !t.queryMatches(query) { + return nil, fmt.Errorf("query '%s', does not match regex [%s]", query, t.sqlRegex.String()) + } + + if !t.argsMatches(args) { + return nil, fmt.Errorf("query '%s', args %+v does not match expected %+v", query, args, t.args) + } + + return t.rows, err +} + +// ExpectQuery expects Query() or QueryRow() to be called with sql query +// which match sqlRegexStr given regexp. +// the *ExpectedQuery allows to mock database response. +func (c *Sqlmock) ExpectQuery(sqlRegexStr string) *ExpectedQuery { + e := &ExpectedQuery{} + e.sqlRegex = regexp.MustCompile(sqlRegexStr) + c.expected = append(c.expected, e) + return e +} + +// ExpectCommit expects *sql.Tx.Commit to be called. +// the *ExpectedCommit allows to mock database response +func (c *Sqlmock) ExpectCommit() *ExpectedCommit { + e := &ExpectedCommit{} + c.expected = append(c.expected, e) + return e +} + +// ExpectRollback expects *sql.Tx.Rollback to be called. +// the *ExpectedRollback allows to mock database response +func (c *Sqlmock) ExpectRollback() *ExpectedRollback { + e := &ExpectedRollback{} + c.expected = append(c.expected, e) + return e +} + +// Commit meets http://golang.org/pkg/database/sql/driver/#Tx +func (c *Sqlmock) Commit() error { + e := c.next() + if e == nil { + return fmt.Errorf("all expectations were already fulfilled, call to commit transaction was not expected") + } + + t, ok := e.(*ExpectedCommit) + if !ok { + return fmt.Errorf("call to commit transaction, was not expected, next expectation was %v", e) + } + t.triggered = true + return t.err +} + +// Rollback meets http://golang.org/pkg/database/sql/driver/#Tx +func (c *Sqlmock) Rollback() error { + e := c.next() + if e == nil { + return fmt.Errorf("all expectations were already fulfilled, call to rollback transaction was not expected") + } + + t, ok := e.(*ExpectedRollback) + if !ok { + return fmt.Errorf("call to rollback transaction, was not expected, next expectation was %v", e) + } + t.triggered = true + return t.err +} + +func argMatcherErrorHandler(errp *error) { + if e := recover(); e != nil { + if se, ok := e.(*reflect.ValueError); ok { // catch reflect error, failed type conversion + *errp = fmt.Errorf("Failed to compare query arguments: %s", se) + } else { + panic(e) // overwise panic + } } - eq.rows = rows - return c } diff --git a/sqlmock_test.go b/sqlmock_test.go index d488034..e36a225 100644 --- a/sqlmock_test.go +++ b/sqlmock_test.go @@ -7,12 +7,55 @@ import ( "time" ) +func cancelOrder(db *sql.DB, orderID int) error { + tx, _ := db.Begin() + _, _ = tx.Query("SELECT * FROM orders {0} FOR UPDATE", orderID) + _ = tx.Rollback() + return nil +} + +func Example() { + // Open new mock database + db, mock, err := New() + if err != nil { + fmt.Println("error creating mock database") + return + } + // columns to be used for result + columns := []string{"id", "status"} + // expect transaction begin + mock.ExpectBegin() + // expect query to fetch order, match it with regexp + mock.ExpectQuery("SELECT (.+) FROM orders (.+) FOR UPDATE"). + WithArgs(1). + WillReturnRows(NewRows(columns).AddRow(1, 1)) + // expect transaction rollback, since order status is "cancelled" + mock.ExpectRollback() + + // run the cancel order function + someOrderID := 1 + // call a function which executes expected database operations + err = cancelOrder(db, someOrderID) + if err != nil { + fmt.Printf("unexpected error: %s", err) + return + } + + // ensure all expectations have been met + if err = mock.ExpectationsWereMet(); err != nil { + fmt.Printf("unmet expectation error: %s", err) + } + // Output: +} + func TestIssue14EscapeSQL(t *testing.T) { - db, err := New() + t.Parallel() + db, mock, err := New() if err != nil { t.Errorf("an error '%s' was not expected when opening a stub database connection", err) } - ExpectExec("INSERT INTO mytable\\(a, b\\)"). + defer db.Close() + mock.ExpectExec("INSERT INTO mytable\\(a, b\\)"). WithArgs("A", "B"). WillReturnResult(NewResult(1, 1)) @@ -21,37 +64,40 @@ func TestIssue14EscapeSQL(t *testing.T) { t.Errorf("error '%s' was not expected, while inserting a row", err) } - err = db.Close() - if err != nil { - t.Errorf("error '%s' was not expected while closing the database", err) + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expections: %s", err) } } // test the case when db is not triggered and expectations // are not asserted on close func TestIssue4(t *testing.T) { - db, err := New() + t.Parallel() + db, mock, err := New() if err != nil { t.Errorf("an error '%s' was not expected when opening a stub database connection", err) } - ExpectQuery("some sql query which will not be called"). + defer db.Close() + + mock.ExpectQuery("some sql query which will not be called"). WillReturnRows(NewRows([]string{"id"})) - err = db.Close() - if err == nil { - t.Errorf("Was expecting an error, since expected query was not matched") + if err := mock.ExpectationsWereMet(); err == nil { + t.Errorf("was expecting an error since query was not triggered") } } func TestMockQuery(t *testing.T) { - db, err := sql.Open("mock", "") + t.Parallel() + db, mock, err := New() if err != nil { t.Errorf("an error '%s' was not expected when opening a stub database connection", err) } + defer db.Close() rs := NewRows([]string{"id", "title"}).FromCSVString("5,hello world") - ExpectQuery("SELECT (.+) FROM articles WHERE id = ?"). + mock.ExpectQuery("SELECT (.+) FROM articles WHERE id = ?"). WithArgs(5). WillReturnRows(rs) @@ -59,11 +105,13 @@ func TestMockQuery(t *testing.T) { if err != nil { t.Errorf("error '%s' was not expected while retrieving mock rows", err) } + defer func() { if er := rows.Close(); er != nil { t.Error("Unexpected error while trying to close rows") } }() + if !rows.Next() { t.Error("it must have had one row as result, but got empty result set instead") } @@ -84,16 +132,18 @@ func TestMockQuery(t *testing.T) { t.Errorf("expected mocked title to be 'hello world', but got '%s' instead", title) } - if err = db.Close(); err != nil { - t.Errorf("error '%s' was not expected while closing the database", err) + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expections: %s", err) } } func TestMockQueryTypes(t *testing.T) { - db, err := sql.Open("mock", "") + t.Parallel() + db, mock, err := New() if err != nil { t.Errorf("an error '%s' was not expected when opening a stub database connection", err) } + defer db.Close() columns := []string{"id", "timestamp", "sold"} @@ -101,7 +151,7 @@ func TestMockQueryTypes(t *testing.T) { rs := NewRows(columns) rs.AddRow(5, timestamp, true) - ExpectQuery("SELECT (.+) FROM sales WHERE id = ?"). + mock.ExpectQuery("SELECT (.+) FROM sales WHERE id = ?"). WithArgs(5). WillReturnRows(rs) @@ -139,20 +189,22 @@ func TestMockQueryTypes(t *testing.T) { t.Errorf("expected mocked boolean to be true, but got %v instead", sold) } - if err = db.Close(); err != nil { - t.Errorf("error '%s' was not expected while closing the database", err) + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expections: %s", err) } } func TestTransactionExpectations(t *testing.T) { - db, err := sql.Open("mock", "") + t.Parallel() + db, mock, err := New() if err != nil { t.Errorf("an error '%s' was not expected when opening a stub database connection", err) } + defer db.Close() // begin and commit - ExpectBegin() - ExpectCommit() + mock.ExpectBegin() + mock.ExpectCommit() tx, err := db.Begin() if err != nil { @@ -165,8 +217,8 @@ func TestTransactionExpectations(t *testing.T) { } // begin and rollback - ExpectBegin() - ExpectRollback() + mock.ExpectBegin() + mock.ExpectRollback() tx, err = db.Begin() if err != nil { @@ -179,25 +231,28 @@ func TestTransactionExpectations(t *testing.T) { } // begin with an error - ExpectBegin().WillReturnError(fmt.Errorf("some err")) + mock.ExpectBegin().WillReturnError(fmt.Errorf("some err")) tx, err = db.Begin() if err == nil { t.Error("an error was expected when beginning a transaction, but got none") } - if err = db.Close(); err != nil { - t.Errorf("error '%s' was not expected while closing the database", err) + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expections: %s", err) } } func TestPrepareExpectations(t *testing.T) { - db, err := sql.Open("mock", "") + t.Parallel() + db, mock, err := New() if err != nil { t.Errorf("an error '%s' was not expected when opening a stub database connection", err) } + defer db.Close() + + mock.ExpectPrepare("SELECT (.+) FROM articles WHERE id = ?") - // no expectations, w/o ExpectPrepare() stmt, err := db.Prepare("SELECT (.+) FROM articles WHERE id = ?") if err != nil { t.Errorf("error '%s' was not expected while creating a prepared statement", err) @@ -211,36 +266,19 @@ func TestPrepareExpectations(t *testing.T) { var title string rs := NewRows([]string{"id", "title"}).FromCSVString("5,hello world") - ExpectQuery("SELECT (.+) FROM articles WHERE id = ?"). + mock.ExpectQuery("SELECT (.+) FROM articles WHERE id = ?"). WithArgs(5). WillReturnRows(rs) - stmt, err = db.Prepare("SELECT (.+) FROM articles WHERE id = ?") - if err != nil { - t.Errorf("error '%s' was not expected while creating a prepared statement", err) - } - if stmt == nil { - t.Errorf("stmt was expected while creating a prepared statement") - } - err = stmt.QueryRow(5).Scan(&id, &title) if err != nil { t.Errorf("error '%s' was not expected while retrieving mock rows", err) } - // expect normal result - ExpectPrepare() - stmt, err = db.Prepare("SELECT (.+) FROM articles WHERE id = ?") - if err != nil { - t.Errorf("error '%s' was not expected while creating a prepared statement", err) - } - if stmt == nil { - t.Errorf("stmt was expected while creating a prepared statement") - } + mock.ExpectPrepare("SELECT (.+) FROM articles WHERE id = ?"). + WillReturnError(fmt.Errorf("Some DB error occurred")) - // expect error result - ExpectPrepare().WillReturnError(fmt.Errorf("Some DB error occurred")) - stmt, err = db.Prepare("SELECT (.+) FROM articles WHERE id = ?") + stmt, err = db.Prepare("SELECT id FROM articles WHERE id = ?") if err == nil { t.Error("error was expected while creating a prepared statement") } @@ -248,35 +286,38 @@ func TestPrepareExpectations(t *testing.T) { t.Errorf("stmt was not expected while creating a prepared statement returning error") } - if err = db.Close(); err != nil { - t.Errorf("error '%s' was not expected while closing the database", err) + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expections: %s", err) } } func TestPreparedQueryExecutions(t *testing.T) { - db, err := sql.Open("mock", "") + t.Parallel() + db, mock, err := New() if err != nil { t.Errorf("an error '%s' was not expected when opening a stub database connection", err) } + defer db.Close() + + mock.ExpectPrepare("SELECT (.+) FROM articles WHERE id = ?") rs1 := NewRows([]string{"id", "title"}).FromCSVString("5,hello world") - ExpectQuery("SELECT (.+) FROM articles WHERE id = ?"). + mock.ExpectQuery("SELECT (.+) FROM articles WHERE id = ?"). WithArgs(5). WillReturnRows(rs1) rs2 := NewRows([]string{"id", "title"}).FromCSVString("2,whoop") - ExpectQuery("SELECT (.+) FROM articles WHERE id = ?"). + mock.ExpectQuery("SELECT (.+) FROM articles WHERE id = ?"). WithArgs(2). WillReturnRows(rs2) - stmt, err := db.Prepare("SELECT (.+) FROM articles WHERE id = ?") + stmt, err := db.Prepare("SELECT id, title FROM articles WHERE id = ?") if err != nil { t.Errorf("error '%s' was not expected while creating a prepared statement", err) } var id int var title string - err = stmt.QueryRow(5).Scan(&id, &title) if err != nil { t.Errorf("error '%s' was not expected querying row from statement and scanning", err) @@ -303,18 +344,21 @@ func TestPreparedQueryExecutions(t *testing.T) { t.Errorf("expected mocked title to be 'whoop', but got '%s' instead", title) } - if err = db.Close(); err != nil { - t.Errorf("error '%s' was not expected while closing the database", err) + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expections: %s", err) } } func TestUnexpectedOperations(t *testing.T) { - db, err := sql.Open("mock", "") + t.Parallel() + db, mock, err := New() if err != nil { t.Errorf("an error '%s' was not expected when opening a stub database connection", err) } + defer db.Close() - stmt, err := db.Prepare("SELECT (.+) FROM articles WHERE id = ?") + mock.ExpectPrepare("SELECT (.+) FROM articles WHERE id = ?") + stmt, err := db.Prepare("SELECT id, title FROM articles WHERE id = ?") if err != nil { t.Errorf("error '%s' was not expected while creating a prepared statement", err) } @@ -327,39 +371,35 @@ func TestUnexpectedOperations(t *testing.T) { t.Error("error was expected querying row, since there was no such expectation") } - ExpectRollback() + mock.ExpectRollback() - err = db.Close() - if err == nil { - t.Error("error was expected while closing the database, expectation was not fulfilled", err) + if err := mock.ExpectationsWereMet(); err == nil { + t.Errorf("was expecting an error since query was not triggered") } } func TestWrongExpectations(t *testing.T) { - db, err := sql.Open("mock", "") + t.Parallel() + db, mock, err := New() if err != nil { t.Errorf("an error '%s' was not expected when opening a stub database connection", err) } + defer db.Close() - ExpectBegin() + mock.ExpectBegin() rs1 := NewRows([]string{"id", "title"}).FromCSVString("5,hello world") - ExpectQuery("SELECT (.+) FROM articles WHERE id = ?"). + mock.ExpectQuery("SELECT (.+) FROM articles WHERE id = ?"). WithArgs(5). WillReturnRows(rs1) - ExpectCommit().WillReturnError(fmt.Errorf("deadlock occured")) - ExpectRollback() // won't be triggered - - stmt, err := db.Prepare("SELECT (.+) FROM articles WHERE id = ? FOR UPDATE") - if err != nil { - t.Errorf("error '%s' was not expected while creating a prepared statement", err) - } + mock.ExpectCommit().WillReturnError(fmt.Errorf("deadlock occured")) + mock.ExpectRollback() // won't be triggered var id int var title string - err = stmt.QueryRow(5).Scan(&id, &title) + err = db.QueryRow("SELECT id, title FROM articles WHERE id = ? FOR UPDATE", 5).Scan(&id, &title) if err == nil { t.Error("error was expected while querying row, since there begin transaction expectation is not fulfilled") } @@ -370,7 +410,7 @@ func TestWrongExpectations(t *testing.T) { t.Errorf("an error '%s' was not expected when beginning a transaction", err) } - err = stmt.QueryRow(5).Scan(&id, &title) + err = db.QueryRow("SELECT id, title FROM articles WHERE id = ? FOR UPDATE", 5).Scan(&id, &title) if err != nil { t.Errorf("error '%s' was not expected while querying row, since transaction was started", err) } @@ -380,20 +420,21 @@ func TestWrongExpectations(t *testing.T) { t.Error("a deadlock error was expected when commiting a transaction", err) } - err = db.Close() - if err == nil { - t.Error("error was expected while closing the database, expectation was not fulfilled", err) + if err := mock.ExpectationsWereMet(); err == nil { + t.Errorf("was expecting an error since query was not triggered") } } func TestExecExpectations(t *testing.T) { - db, err := sql.Open("mock", "") + t.Parallel() + db, mock, err := New() if err != nil { t.Errorf("an error '%s' was not expected when opening a stub database connection", err) } + defer db.Close() result := NewResult(1, 1) - ExpectExec("^INSERT INTO articles"). + mock.ExpectExec("^INSERT INTO articles"). WithArgs("hello"). WillReturnResult(result) @@ -420,22 +461,24 @@ func TestExecExpectations(t *testing.T) { t.Errorf("expected affected rows to be 1, but got %d instead", affected) } - if err = db.Close(); err != nil { - t.Errorf("error '%s' was not expected while closing the database", err) + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expections: %s", err) } } func TestRowBuilderAndNilTypes(t *testing.T) { - db, err := sql.Open("mock", "") + t.Parallel() + db, mock, err := New() if err != nil { t.Errorf("an error '%s' was not expected when opening a stub database connection", err) } + defer db.Close() rs := NewRows([]string{"id", "active", "created", "status"}). AddRow(1, true, time.Now(), 5). AddRow(2, false, nil, nil) - ExpectQuery("SELECT (.+) FROM sales").WillReturnRows(rs) + mock.ExpectQuery("SELECT (.+) FROM sales").WillReturnRows(rs) rows, err := db.Query("SELECT * FROM sales") if err != nil { @@ -510,20 +553,22 @@ func TestRowBuilderAndNilTypes(t *testing.T) { t.Errorf("expected 'status' to be invalid, but it %+v is not", status) } - if err = db.Close(); err != nil { - t.Errorf("error '%s' was not expected while closing the database", err) + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expections: %s", err) } } func TestArgumentReflectValueTypeError(t *testing.T) { - db, err := sql.Open("mock", "") + t.Parallel() + db, mock, err := New() if err != nil { t.Errorf("an error '%s' was not expected when opening a stub database connection", err) } + defer db.Close() rs := NewRows([]string{"id"}).AddRow(1) - ExpectQuery("SELECT (.+) FROM sales").WithArgs(5.5).WillReturnRows(rs) + mock.ExpectQuery("SELECT (.+) FROM sales").WithArgs(5.5).WillReturnRows(rs) _, err = db.Query("SELECT * FROM sales WHERE x = ?", 5) if err == nil { diff --git a/statement.go b/statement.go index 1e4af08..f84b094 100644 --- a/statement.go +++ b/statement.go @@ -5,12 +5,13 @@ import ( ) type statement struct { - conn *conn + conn *Sqlmock query string + err error } func (stmt *statement) Close() error { - return nil + return stmt.err } func (stmt *statement) NumInput() int { diff --git a/transaction.go b/transaction.go deleted file mode 100644 index be59a6b..0000000 --- a/transaction.go +++ /dev/null @@ -1,37 +0,0 @@ -package sqlmock - -import ( - "fmt" -) - -type transaction struct { - conn *conn -} - -func (tx *transaction) Commit() error { - e := tx.conn.next() - if e == nil { - return fmt.Errorf("all expectations were already fulfilled, call to commit transaction was not expected") - } - - etc, ok := e.(*expectedCommit) - if !ok { - return fmt.Errorf("call to commit transaction, was not expected, next expectation was %v", e) - } - etc.triggered = true - return etc.err -} - -func (tx *transaction) Rollback() error { - e := tx.conn.next() - if e == nil { - return fmt.Errorf("all expectations were already fulfilled, call to rollback transaction was not expected") - } - - etr, ok := e.(*expectedRollback) - if !ok { - return fmt.Errorf("call to rollback transaction, was not expected, next expectation was %v", e) - } - etr.triggered = true - return etr.err -} diff --git a/util.go b/util.go index 070e8b4..072e380 100644 --- a/util.go +++ b/util.go @@ -5,11 +5,7 @@ import ( "strings" ) -var re *regexp.Regexp - -func init() { - re = regexp.MustCompile("\\s+") -} +var re = regexp.MustCompile("\\s+") // strip out new lines and trim spaces func stripQuery(q string) (s string) {