1
0
mirror of https://github.com/DATA-DOG/go-sqlmock.git synced 2025-07-17 01:22:23 +02:00

Merge pull request #21 from DATA-DOG/refactor

concurrency support, closes #20 and closes #9 and closes #15
This commit is contained in:
Gediminas Morkevicius
2015-08-28 10:37:42 +03:00
25 changed files with 2026 additions and 1213 deletions

4
.gitignore vendored
View File

@ -1 +1,3 @@
/*.test /examples/blog/blog
/examples/orders/orders
/examples/basic/basic

View File

@ -1,4 +1,5 @@
language: go language: go
sudo: false
go: go:
- 1.2 - 1.2
- 1.3 - 1.3
@ -7,10 +8,5 @@ go:
- tip - tip
script: script:
- go get github.com/kisielk/errcheck
- go get ./...
- go test -v ./... - go test -v ./...
- go test -race ./... - go test -race ./...
- errcheck github.com/DATA-DOG/go-sqlmock

View File

@ -1,6 +1,6 @@
The three clause BSD license (http://en.wikipedia.org/wiki/BSD_licenses) The three clause BSD license (http://en.wikipedia.org/wiki/BSD_licenses)
Copyright (c) 2013, DataDog.lt team Copyright (c) 2013-2015, DataDog.lt team
All rights reserved. All rights reserved.
Redistribution and use in source and binary forms, with or without Redistribution and use in source and binary forms, with or without

399
README.md
View File

@ -5,336 +5,153 @@
This is a **mock** driver as **database/sql/driver** which is very flexible and pragmatic to This is a **mock** driver as **database/sql/driver** which is very flexible and pragmatic to
manage and mock expected queries. All the expectations should be met and all queries and actions manage and mock expected queries. All the expectations should be met and all queries and actions
triggered should be mocked in order to pass a test. triggered should be mocked in order to pass a test. The package has no 3rd party dependencies.
**NOTE:** regarding major issues #20 and #9 the **api** has changed to support concurrency and more than
one database connection.
If you need an old version, checkout **go-sqlmock** at gopkg.in:
go get gopkg.in/DATA-DOG/go-sqlmock.v0
Otherwise use the **v1** branch from master which should be stable afterwards, because all the issues which
were known will be fixed in this version.
## Install ## Install
go get github.com/DATA-DOG/go-sqlmock go get gopkg.in/DATA-DOG/go-sqlmock.v1
## Use it with pleasure Or take an older version:
An example of some database interaction which you may want to test: go get gopkg.in/DATA-DOG/go-sqlmock.v0
## Documentation and Examples
Visit [godoc](http://godoc.org/github.com/DATA-DOG/go-sqlmock) for general examples and public api reference.
See **.travis.yml** for supported **go** versions.
Different use case, is to functionally test with a real database - [go-txdb](https://github.com/DATA-DOG/go-txdb)
all database related actions are isolated within a single transaction so the database can remain in the same state.
See implementation examples:
- [blog API server](https://github.com/DATA-DOG/go-sqlmock/tree/master/examples/blog)
- [the same orders example](https://github.com/DATA-DOG/go-sqlmock/tree/master/examples/orders)
### Something you may want to test
``` go ``` go
package main package main
import ( import "database/sql"
"database/sql"
_ "github.com/go-sql-driver/mysql"
"github.com/kisielk/sqlstruct"
"fmt"
"log"
)
const ORDER_PENDING = 0 func recordStats(db *sql.DB, userID, productID int64) (err error) {
const ORDER_CANCELLED = 1 tx, err := db.Begin()
if err != nil {
return
}
type User struct { defer func() {
Id int `sql:"id"` switch err {
Username string `sql:"username"` case nil:
Balance float64 `sql:"balance"` err = tx.Commit()
} default:
tx.Rollback()
}
}()
type Order struct { if _, err = tx.Exec("UPDATE products SET views = views + 1"); err != nil {
Id int `sql:"id"` return
Value float64 `sql:"value"` }
ReservedFee float64 `sql:"reserved_fee"` if _, err = tx.Exec("INSERT INTO product_viewers (user_id, product_id) VALUES (?, ?)", userID, productID); err != nil {
Status int `sql:"status"` return
} }
return
func cancelOrder(id int, db *sql.DB) (err error) {
tx, err := db.Begin()
if err != nil {
return
}
var order Order
var user User
sql := fmt.Sprintf(`
SELECT %s, %s
FROM orders AS o
INNER JOIN users AS u ON o.buyer_id = u.id
WHERE o.id = ?
FOR UPDATE`,
sqlstruct.ColumnsAliased(order, "o"),
sqlstruct.ColumnsAliased(user, "u"))
// fetch order to cancel
rows, err := tx.Query(sql, id)
if err != nil {
tx.Rollback()
return
}
defer rows.Close()
// no rows, nothing to do
if !rows.Next() {
tx.Rollback()
return
}
// read order
err = sqlstruct.ScanAliased(&order, rows, "o")
if err != nil {
tx.Rollback()
return
}
// ensure order status
if order.Status != ORDER_PENDING {
tx.Rollback()
return
}
// read user
err = sqlstruct.ScanAliased(&user, rows, "u")
if err != nil {
tx.Rollback()
return
}
rows.Close() // manually close before other prepared statements
// refund order value
sql = "UPDATE users SET balance = balance + ? WHERE id = ?"
refundStmt, err := tx.Prepare(sql)
if err != nil {
tx.Rollback()
return
}
defer refundStmt.Close()
_, err = refundStmt.Exec(order.Value + order.ReservedFee, user.Id)
if err != nil {
tx.Rollback()
return
}
// update order status
order.Status = ORDER_CANCELLED
sql = "UPDATE orders SET status = ?, updated = NOW() WHERE id = ?"
orderUpdStmt, err := tx.Prepare(sql)
if err != nil {
tx.Rollback()
return
}
defer orderUpdStmt.Close()
_, err = orderUpdStmt.Exec(order.Status, order.Id)
if err != nil {
tx.Rollback()
return
}
return tx.Commit()
} }
func main() { func main() {
db, err := sql.Open("mysql", "root:nimda@/test") // @NOTE: the real connection is not required for tests
if err != nil { db, err := sql.Open("mysql", "root@/blog")
log.Fatal(err) if err != nil {
} panic(err)
defer db.Close() }
err = cancelOrder(1, db) defer db.Close()
if err != nil {
log.Fatal(err) if err = recordStats(db, 1 /*some user id*/, 5 /*some product id*/); err != nil {
} panic(err)
}
} }
``` ```
And the clean nice test: ### Tests with sqlmock
``` go ``` go
package main package main
import ( import (
"database/sql" "fmt"
"github.com/DATA-DOG/go-sqlmock" "testing"
"testing"
"fmt" "github.com/DATA-DOG/go-sqlmock"
) )
// will test that order with a different status, cannot be cancelled // a successful case
func TestShouldNotCancelOrderWithNonPendingStatus(t *testing.T) { func TestShouldUpdateStats(t *testing.T) {
// open database stub db, mock, err := sqlmock.New()
db, err := sqlmock.New() if err != nil {
if err != nil { t.Fatalf("an error '%s' was not expected when opening a stub database connection", err)
t.Errorf("An error '%s' was not expected when opening a stub database connection", err) }
} defer db.Close()
// columns are prefixed with "o" since we used sqlstruct to generate them mock.ExpectBegin()
columns := []string{"o_id", "o_status"} mock.ExpectExec("UPDATE products").WillReturnResult(sqlmock.NewResult(1, 1))
// expect transaction begin mock.ExpectExec("INSERT INTO product_viewers").WithArgs(2, 3).WillReturnResult(sqlmock.NewResult(1, 1))
sqlmock.ExpectBegin() mock.ExpectCommit()
// expect query to fetch order and user, match it with regexp
sqlmock.ExpectQuery("SELECT (.+) FROM orders AS o INNER JOIN users AS u (.+) FOR UPDATE").
WithArgs(1).
WillReturnRows(sqlmock.NewRows(columns).FromCSVString("1,1"))
// expect transaction rollback, since order status is "cancelled"
sqlmock.ExpectRollback()
// run the cancel order function // now we execute our method
err = cancelOrder(1, db) if err = recordStats(db, 2, 3); err != nil {
if err != nil { t.Errorf("error was not expected while updating stats: %s", err)
t.Errorf("Expected no error, but got %s instead", err) }
}
// db.Close() ensures that all expectations have been met // we make sure that all expectations were met
if err = db.Close(); err != nil { if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("Error '%s' was not expected while closing the database", err) t.Errorf("there were unfulfilled expections: %s", err)
} }
} }
// will test order cancellation // a failing test case
func TestShouldRefundUserWhenOrderIsCancelled(t *testing.T) { func TestShouldRollbackStatUpdatesOnFailure(t *testing.T) {
// open database stub db, mock, err := sqlmock.New()
db, err := sqlmock.New() if err != nil {
if err != nil { t.Fatalf("an error '%s' was not expected when opening a stub database connection", err)
t.Errorf("An error '%s' was not expected when opening a stub database connection", err) }
} defer db.Close()
// columns are prefixed with "o" since we used sqlstruct to generate them mock.ExpectBegin()
columns := []string{"o_id", "o_status", "o_value", "o_reserved_fee", "u_id", "u_balance"} mock.ExpectExec("UPDATE products").WillReturnResult(sqlmock.NewResult(1, 1))
// expect transaction begin mock.ExpectExec("INSERT INTO product_viewers").
sqlmock.ExpectBegin() WithArgs(2, 3).
// expect query to fetch order and user, match it with regexp WillReturnError(fmt.Errorf("some error"))
sqlmock.ExpectQuery("SELECT (.+) FROM orders AS o INNER JOIN users AS u (.+) FOR UPDATE"). mock.ExpectRollback()
WithArgs(1).
WillReturnRows(sqlmock.NewRows(columns).AddRow(1, 0, 25.75, 3.25, 2, 10.00))
// expect user balance update
sqlmock.ExpectExec("UPDATE users SET balance").
WithArgs(25.75 + 3.25, 2). // refund amount, user id
WillReturnResult(sqlmock.NewResult(0, 1)) // no insert id, 1 affected row
// expect order status update
sqlmock.ExpectExec("UPDATE orders SET status").
WithArgs(ORDER_CANCELLED, 1). // status, id
WillReturnResult(sqlmock.NewResult(0, 1)) // no insert id, 1 affected row
// expect a transaction commit
sqlmock.ExpectCommit()
// run the cancel order function // now we execute our method
err = cancelOrder(1, db) if err = recordStats(db, 2, 3); err == nil {
if err != nil { t.Errorf("was expecting an error, but there was none")
t.Errorf("Expected no error, but got %s instead", err) }
}
// db.Close() ensures that all expectations have been met // we make sure that all expectations were met
if err = db.Close(); err != nil { if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("Error '%s' was not expected while closing the database", err) t.Errorf("there were unfulfilled expections: %s", err)
} }
} }
// will test order cancellation
func TestShouldRollbackOnError(t *testing.T) {
// open database stub
db, err := sqlmock.New()
if err != nil {
t.Errorf("An error '%s' was not expected when opening a stub database connection", err)
}
// expect transaction begin
sqlmock.ExpectBegin()
// expect query to fetch order and user, match it with regexp
sqlmock.ExpectQuery("SELECT (.+) FROM orders AS o INNER JOIN users AS u (.+) FOR UPDATE").
WithArgs(1).
WillReturnError(fmt.Errorf("Some error"))
// should rollback since error was returned from query execution
sqlmock.ExpectRollback()
// run the cancel order function
err = cancelOrder(1, db)
// error should return back
if err == nil {
t.Error("Expected error, but got none")
}
// db.Close() ensures that all expectations have been met
if err = db.Close(); err != nil {
t.Errorf("Error '%s' was not expected while closing the database", err)
}
}
```
## Expectations
All **Expect** methods return a **Mock** interface which allow you to describe
expectations in more details: return an error, expect specific arguments, return rows and so on.
**NOTE:** that if you call **WithArgs** on a non query based expectation, it will panic
A **Mock** interface:
``` go
type Mock interface {
WithArgs(...driver.Value) Mock
WillReturnError(error) Mock
WillReturnRows(driver.Rows) Mock
WillReturnResult(driver.Result) Mock
}
```
As an example we can expect a transaction commit and simulate an error for it:
``` go
sqlmock.ExpectCommit().WillReturnError(fmt.Errorf("Deadlock occured"))
```
In same fashion, we can expect queries to match arguments. If there are any, it must be matched.
Instead of result we can return error.
``` go
sqlmock.ExpectQuery("SELECT (.*) FROM orders").
WithArgs("string value").
WillReturnRows(sqlmock.NewRows([]string{"col"}).AddRow("val"))
```
**NOTE:** it matches a regular expression. Some regex special characters must be escaped if you want to match them.
For example if we want to match a subselect:
``` go
sqlmock.ExpectQuery("SELECT (.*) FROM orders WHERE id IN \\(SELECT id FROM finished WHERE status = 1\\)").
WithArgs("string value").
WillReturnRows(sqlmock.NewRows([]string{"col"}).AddRow("val"))
```
**WithArgs** expectation, compares values based on their type, for usual values like **string, float, int**
it matches the actual value. Types like **time** are compared only by type. Other types might require different ways
to compare them correctly, this may be improved.
You can build rows either from CSV string or from interface values:
**Rows** interface, which satisfies sql driver.Rows:
``` go
type Rows interface {
AddRow(...driver.Value) Rows
FromCSVString(s string) Rows
Next([]driver.Value) error
Columns() []string
Close() error
}
```
Example for to build rows:
``` go
rs := sqlmock.NewRows([]string{"column1", "column2"}).
FromCSVString("one,1\ntwo,2").
AddRow("three", 3)
```
**Prepare** will ignore other expectations if ExpectPrepare not set. When set, can expect normal result or simulate an error:
``` go
rs := sqlmock.ExpectPrepare().
WillReturnError(fmt.Errorf("Query prepare failed"))
``` ```
## Run tests ## Run tests
go test go test -race
## Documentation
Visit [godoc](http://godoc.org/github.com/DATA-DOG/go-sqlmock)
See **.travis.yml** for supported **go** versions
Different use case, is to functionally test with a real database - [go-txdb](https://github.com/DATA-DOG/go-txdb)
all database related actions are isolated within a single transaction so the database can remain in the same state.
## Changes ## Changes
- **2015-08-27** - **v1** api change, concurrency support, all known issues fixed.
- **2014-08-16** instead of **panic** during reflect type mismatch when comparing query arguments - now return error - **2014-08-16** instead of **panic** during reflect type mismatch when comparing query arguments - now return error
- **2014-08-14** added **sqlmock.NewErrorResult** which gives an option to return driver.Result with errors for - **2014-08-14** added **sqlmock.NewErrorResult** which gives an option to return driver.Result with errors for
interface methods, see [issue](https://github.com/DATA-DOG/go-sqlmock/issues/5) interface methods, see [issue](https://github.com/DATA-DOG/go-sqlmock/issues/5)

View File

@ -1,151 +0,0 @@
package sqlmock
import (
"database/sql/driver"
"fmt"
"reflect"
)
type conn struct {
expectations []expectation
active expectation
}
// Close a mock database driver connection. It should
// be always called to ensure that all expectations
// were met successfully. Returns error if there is any
func (c *conn) Close() (err error) {
for _, e := range mock.conn.expectations {
if !e.fulfilled() {
err = fmt.Errorf("there is a remaining expectation %T which was not matched yet", e)
break
}
}
mock.conn.expectations = []expectation{}
mock.conn.active = nil
return err
}
func (c *conn) Begin() (driver.Tx, error) {
e := c.next()
if e == nil {
return nil, fmt.Errorf("all expectations were already fulfilled, call to begin transaction was not expected")
}
etb, ok := e.(*expectedBegin)
if !ok {
return nil, fmt.Errorf("call to begin transaction, was not expected, next expectation is %T as %+v", e, e)
}
etb.triggered = true
return &transaction{c}, etb.err
}
// get next unfulfilled expectation
func (c *conn) next() (e expectation) {
for _, e = range c.expectations {
if !e.fulfilled() {
return
}
}
return nil // all expectations were fulfilled
}
func (c *conn) Exec(query string, args []driver.Value) (res driver.Result, err error) {
e := c.next()
query = stripQuery(query)
if e == nil {
return nil, fmt.Errorf("all expectations were already fulfilled, call to exec '%s' query with args %+v was not expected", query, args)
}
eq, ok := e.(*expectedExec)
if !ok {
return nil, fmt.Errorf("call to exec query '%s' with args %+v, was not expected, next expectation is %T as %+v", query, args, e, e)
}
eq.triggered = true
defer argMatcherErrorHandler(&err) // converts panic to error in case of reflect value type mismatch
if !eq.queryMatches(query) {
return nil, fmt.Errorf("exec query '%s', does not match regex '%s'", query, eq.sqlRegex.String())
}
if !eq.argsMatches(args) {
return nil, fmt.Errorf("exec query '%s', args %+v does not match expected %+v", query, args, eq.args)
}
if eq.err != nil {
return nil, eq.err // mocked to return error
}
if eq.result == nil {
return nil, fmt.Errorf("exec query '%s' with args %+v, must return a database/sql/driver.result, but it was not set for expectation %T as %+v", query, args, eq, eq)
}
return eq.result, err
}
func (c *conn) Prepare(query string) (driver.Stmt, error) {
e := c.next()
// for backwards compatibility, ignore when Prepare not expected
if e == nil {
return &statement{mock.conn, stripQuery(query)}, nil
}
eq, ok := e.(*expectedPrepare)
if !ok {
return &statement{mock.conn, stripQuery(query)}, nil
}
eq.triggered = true
if eq.err != nil {
return nil, eq.err // mocked to return error
}
return &statement{mock.conn, stripQuery(query)}, nil
}
func (c *conn) Query(query string, args []driver.Value) (rw driver.Rows, err error) {
e := c.next()
query = stripQuery(query)
if e == nil {
return nil, fmt.Errorf("all expectations were already fulfilled, call to query '%s' with args %+v was not expected", query, args)
}
eq, ok := e.(*expectedQuery)
if !ok {
return nil, fmt.Errorf("call to query '%s' with args %+v, was not expected, next expectation is %T as %+v", query, args, e, e)
}
eq.triggered = true
defer argMatcherErrorHandler(&err) // converts panic to error in case of reflect value type mismatch
if !eq.queryMatches(query) {
return nil, fmt.Errorf("query '%s', does not match regex [%s]", query, eq.sqlRegex.String())
}
if !eq.argsMatches(args) {
return nil, fmt.Errorf("query '%s', args %+v does not match expected %+v", query, args, eq.args)
}
if eq.err != nil {
return nil, eq.err // mocked to return error
}
if eq.rows == nil {
return nil, fmt.Errorf("query '%s' with args %+v, must return a database/sql/driver.rows, but it was not set for expectation %T as %+v", query, args, eq, eq)
}
return eq.rows, err
}
func argMatcherErrorHandler(errp *error) {
if e := recover(); e != nil {
if se, ok := e.(*reflect.ValueError); ok { // catch reflect error, failed type conversion
*errp = fmt.Errorf("Failed to compare query arguments: %s", se)
} else {
panic(e) // overwise panic
}
}
}

View File

@ -1,378 +0,0 @@
package sqlmock
import (
"database/sql/driver"
"errors"
"regexp"
"testing"
)
func TestExecNoExpectations(t *testing.T) {
c := &conn{
expectations: []expectation{
&expectedExec{
queryBasedExpectation: queryBasedExpectation{
commonExpectation: commonExpectation{
triggered: true,
err: errors.New("WillReturnError"),
},
sqlRegex: regexp.MustCompile(regexp.QuoteMeta("otherquery")),
args: []driver.Value{456},
},
},
},
}
res, err := c.Exec("query", []driver.Value{123})
if res != nil {
t.Error("Result should be nil")
}
if err == nil {
t.Error("error should not be nil")
}
pattern := regexp.MustCompile(regexp.QuoteMeta("all expectations were already fulfilled, call to exec"))
if !pattern.MatchString(err.Error()) {
t.Errorf("error should match expected error message (actual: %s)", err.Error())
}
}
func TestExecExpectationMismatch(t *testing.T) {
c := &conn{
expectations: []expectation{
&expectedQuery{
queryBasedExpectation: queryBasedExpectation{
commonExpectation: commonExpectation{
err: errors.New("WillReturnError"),
},
sqlRegex: regexp.MustCompile(regexp.QuoteMeta("otherquery")),
args: []driver.Value{456},
},
},
},
}
res, err := c.Exec("query", []driver.Value{123})
if res != nil {
t.Error("Result should be nil")
}
if err == nil {
t.Error("error should not be nil")
}
pattern := regexp.MustCompile(regexp.QuoteMeta("was not expected, next expectation is"))
if !pattern.MatchString(err.Error()) {
t.Errorf("error should match expected error message (actual: %s)", err.Error())
}
}
func TestExecQueryMismatch(t *testing.T) {
c := &conn{
expectations: []expectation{
&expectedExec{
queryBasedExpectation: queryBasedExpectation{
commonExpectation: commonExpectation{
err: errors.New("WillReturnError"),
},
sqlRegex: regexp.MustCompile(regexp.QuoteMeta("otherquery")),
args: []driver.Value{456},
},
},
},
}
res, err := c.Exec("query", []driver.Value{123})
if res != nil {
t.Error("Result should be nil")
}
if err == nil {
t.Error("error should not be nil")
}
pattern := regexp.MustCompile(regexp.QuoteMeta("does not match regex"))
if !pattern.MatchString(err.Error()) {
t.Errorf("error should match expected error message (actual: %s)", err.Error())
}
}
func TestExecArgsMismatch(t *testing.T) {
c := &conn{
expectations: []expectation{
&expectedExec{
queryBasedExpectation: queryBasedExpectation{
commonExpectation: commonExpectation{
err: errors.New("WillReturnError"),
},
sqlRegex: regexp.MustCompile(regexp.QuoteMeta("query")),
args: []driver.Value{456},
},
},
},
}
res, err := c.Exec("query", []driver.Value{123})
if res != nil {
t.Error("Result should be nil")
}
if err == nil {
t.Error("error should not be nil")
}
pattern := regexp.MustCompile(regexp.QuoteMeta("does not match expected"))
if !pattern.MatchString(err.Error()) {
t.Errorf("error should match expected error message (actual: %s)", err.Error())
}
}
func TestExecWillReturnError(t *testing.T) {
c := &conn{
expectations: []expectation{
&expectedExec{
queryBasedExpectation: queryBasedExpectation{
commonExpectation: commonExpectation{
err: errors.New("WillReturnError"),
},
sqlRegex: regexp.MustCompile(regexp.QuoteMeta("query")),
},
},
},
}
res, err := c.Exec("query", []driver.Value{123})
if res != nil {
t.Error("Result should be nil")
}
if err == nil {
t.Error("error should not be nil")
}
if err.Error() != "WillReturnError" {
t.Errorf("error should match expected error message (actual: %s)", err.Error())
}
}
func TestExecMissingResult(t *testing.T) {
c := &conn{
expectations: []expectation{
&expectedExec{
queryBasedExpectation: queryBasedExpectation{
commonExpectation: commonExpectation{},
sqlRegex: regexp.MustCompile(regexp.QuoteMeta("query")),
args: []driver.Value{123},
},
},
},
}
res, err := c.Exec("query", []driver.Value{123})
if res != nil {
t.Error("Result should be nil")
}
if err == nil {
t.Error("error should not be nil")
}
pattern := regexp.MustCompile(regexp.QuoteMeta("must return a database/sql/driver.result, but it was not set for expectation"))
if !pattern.MatchString(err.Error()) {
t.Errorf("error should match expected error message (actual: %s)", err.Error())
}
}
func TestExec(t *testing.T) {
expectedResult := driver.Result(&result{})
c := &conn{
expectations: []expectation{
&expectedExec{
queryBasedExpectation: queryBasedExpectation{
commonExpectation: commonExpectation{},
sqlRegex: regexp.MustCompile(regexp.QuoteMeta("query")),
args: []driver.Value{123},
},
result: expectedResult,
},
},
}
res, err := c.Exec("query", []driver.Value{123})
if res == nil {
t.Error("Result should not be nil")
}
if res != expectedResult {
t.Errorf("Result should match expected Result (actual %+v)", res)
}
if err != nil {
t.Errorf("error should be nil (actual %s)", err.Error())
}
}
func TestQueryNoExpectations(t *testing.T) {
c := &conn{
expectations: []expectation{
&expectedQuery{
queryBasedExpectation: queryBasedExpectation{
commonExpectation: commonExpectation{
triggered: true,
err: errors.New("WillReturnError"),
},
sqlRegex: regexp.MustCompile(regexp.QuoteMeta("otherquery")),
args: []driver.Value{456},
},
},
},
}
res, err := c.Query("query", []driver.Value{123})
if res != nil {
t.Error("Rows should be nil")
}
if err == nil {
t.Error("error should not be nil")
}
pattern := regexp.MustCompile(regexp.QuoteMeta("all expectations were already fulfilled, call to query"))
if !pattern.MatchString(err.Error()) {
t.Errorf("error should match expected error message (actual: %s)", err.Error())
}
}
func TestQueryExpectationMismatch(t *testing.T) {
c := &conn{
expectations: []expectation{
&expectedExec{
queryBasedExpectation: queryBasedExpectation{
commonExpectation: commonExpectation{
err: errors.New("WillReturnError"),
},
sqlRegex: regexp.MustCompile(regexp.QuoteMeta("otherquery")),
args: []driver.Value{456},
},
},
},
}
res, err := c.Query("query", []driver.Value{123})
if res != nil {
t.Error("Rows should be nil")
}
if err == nil {
t.Error("error should not be nil")
}
pattern := regexp.MustCompile(regexp.QuoteMeta("was not expected, next expectation is"))
if !pattern.MatchString(err.Error()) {
t.Errorf("error should match expected error message (actual: %s)", err.Error())
}
}
func TestQueryQueryMismatch(t *testing.T) {
c := &conn{
expectations: []expectation{
&expectedQuery{
queryBasedExpectation: queryBasedExpectation{
commonExpectation: commonExpectation{
err: errors.New("WillReturnError"),
},
sqlRegex: regexp.MustCompile(regexp.QuoteMeta("otherquery")),
args: []driver.Value{456},
},
},
},
}
res, err := c.Query("query", []driver.Value{123})
if res != nil {
t.Error("Rows should be nil")
}
if err == nil {
t.Error("error should not be nil")
}
pattern := regexp.MustCompile(regexp.QuoteMeta("does not match regex"))
if !pattern.MatchString(err.Error()) {
t.Errorf("error should match expected error message (actual: %s)", err.Error())
}
}
func TestQueryArgsMismatch(t *testing.T) {
c := &conn{
expectations: []expectation{
&expectedQuery{
queryBasedExpectation: queryBasedExpectation{
commonExpectation: commonExpectation{
err: errors.New("WillReturnError"),
},
sqlRegex: regexp.MustCompile(regexp.QuoteMeta("query")),
args: []driver.Value{456},
},
},
},
}
res, err := c.Query("query", []driver.Value{123})
if res != nil {
t.Error("Rows should be nil")
}
if err == nil {
t.Error("error should not be nil")
}
pattern := regexp.MustCompile(regexp.QuoteMeta("does not match expected"))
if !pattern.MatchString(err.Error()) {
t.Errorf("error should match expected error message (actual: %s)", err.Error())
}
}
func TestQueryWillReturnError(t *testing.T) {
c := &conn{
expectations: []expectation{
&expectedQuery{
queryBasedExpectation: queryBasedExpectation{
commonExpectation: commonExpectation{
err: errors.New("WillReturnError"),
},
sqlRegex: regexp.MustCompile(regexp.QuoteMeta("query")),
},
},
},
}
res, err := c.Query("query", []driver.Value{123})
if res != nil {
t.Error("Rows should be nil")
}
if err == nil {
t.Error("error should not be nil")
}
if err.Error() != "WillReturnError" {
t.Errorf("error should match expected error message (actual: %s)", err.Error())
}
}
func TestQueryMissingRows(t *testing.T) {
c := &conn{
expectations: []expectation{
&expectedQuery{
queryBasedExpectation: queryBasedExpectation{
commonExpectation: commonExpectation{},
sqlRegex: regexp.MustCompile(regexp.QuoteMeta("query")),
args: []driver.Value{123},
},
},
},
}
res, err := c.Query("query", []driver.Value{123})
if res != nil {
t.Error("Rows should be nil")
}
if err == nil {
t.Error("error should not be nil")
}
pattern := regexp.MustCompile(regexp.QuoteMeta("must return a database/sql/driver.rows, but it was not set for expectation"))
if !pattern.MatchString(err.Error()) {
t.Errorf("error should match expected error message (actual: %s)", err.Error())
}
}
func TestQuery(t *testing.T) {
expectedRows := driver.Rows(&rows{})
c := &conn{
expectations: []expectation{
&expectedQuery{
queryBasedExpectation: queryBasedExpectation{
commonExpectation: commonExpectation{},
sqlRegex: regexp.MustCompile(regexp.QuoteMeta("query")),
args: []driver.Value{123},
},
rows: expectedRows,
},
},
}
rows, err := c.Query("query", []driver.Value{123})
if rows == nil {
t.Error("Rows should not be nil")
}
if rows != expectedRows {
t.Errorf("Rows should match expected Rows (actual %+v)", rows)
}
if err != nil {
t.Errorf("error should be nil (actual %s)", err.Error())
}
}

56
driver.go Normal file
View File

@ -0,0 +1,56 @@
package sqlmock
import (
"database/sql"
"database/sql/driver"
"fmt"
"sync"
)
var pool *mockDriver
func init() {
pool = &mockDriver{
conns: make(map[string]*Sqlmock),
}
sql.Register("sqlmock", pool)
}
type mockDriver struct {
sync.Mutex
counter int
conns map[string]*Sqlmock
}
func (d *mockDriver) Open(dsn string) (driver.Conn, error) {
d.Lock()
defer d.Unlock()
c, ok := d.conns[dsn]
if !ok {
return c, fmt.Errorf("expected a connection to be available, but it is not")
}
c.opened++
return c, nil
}
// New creates sqlmock database connection
// and a mock to manage expectations.
// Pings db so that all expectations could be
// asserted.
func New() (db *sql.DB, mock *Sqlmock, err error) {
pool.Lock()
dsn := fmt.Sprintf("sqlmock_db_%d", pool.counter)
pool.counter++
mock = &Sqlmock{dsn: dsn, drv: pool, MatchExpectationsInOrder: true}
pool.conns[dsn] = mock
pool.Unlock()
db, err = sql.Open("sqlmock", dsn)
if err != nil {
return
}
return db, mock, db.Ping()
}

83
driver_test.go Normal file
View File

@ -0,0 +1,83 @@
package sqlmock
import (
"fmt"
"testing"
)
func ExampleNew() {
db, mock, err := New()
if err != nil {
fmt.Println("expected no error, but got:", err)
return
}
defer db.Close()
// now we can expect operations performed on db
mock.ExpectBegin().WillReturnError(fmt.Errorf("an error will occur on db.Begin() call"))
}
func TestShouldOpenConnectionIssue15(t *testing.T) {
db, mock, err := New()
if err != nil {
t.Errorf("expected no error, but got: %s", err)
}
if len(pool.conns) != 1 {
t.Errorf("expected 1 connection in pool, but there is: %d", len(pool.conns))
}
if mock.opened != 1 {
t.Errorf("expected 1 connection on mock to be opened, but there is: %d", mock.opened)
}
// defer so the rows gets closed first
defer func() {
if mock.opened != 0 {
t.Errorf("expected no connections on mock to be opened, but there is: %d", mock.opened)
}
}()
mock.ExpectQuery("SELECT").WillReturnRows(NewRows([]string{"one", "two"}).AddRow("val1", "val2"))
rows, err := db.Query("SELECT")
if err != nil {
t.Errorf("unexpected error: %s", err)
}
defer rows.Close()
mock.ExpectExec("UPDATE").WillReturnResult(NewResult(1, 1))
if _, err = db.Exec("UPDATE"); err != nil {
t.Errorf("unexpected error: %s", err)
}
// now there should be two connections open
if mock.opened != 2 {
t.Errorf("expected 2 connection on mock to be opened, but there is: %d", mock.opened)
}
mock.ExpectClose()
if err = db.Close(); err != nil {
t.Errorf("expected no error on close, but got: %s", err)
}
// one is still reserved for rows
if mock.opened != 1 {
t.Errorf("expected 1 connection on mock to be still reserved for rows, but there is: %d", mock.opened)
}
}
func TestTwoOpenConnectionsOnTheSameDSN(t *testing.T) {
db, mock, err := New()
if err != nil {
t.Errorf("expected no error, but got: %s", err)
}
db2, mock2, err := New()
if len(pool.conns) != 2 {
t.Errorf("expected 2 connection in pool, but there is: %d", len(pool.conns))
}
if db == db2 {
t.Errorf("expected not the same database instance, but it is the same")
}
if mock == mock2 {
t.Errorf("expected not the same mock instance, but it is the same")
}
}

40
examples/basic/basic.go Normal file
View File

@ -0,0 +1,40 @@
package main
import "database/sql"
func recordStats(db *sql.DB, userID, productID int64) (err error) {
tx, err := db.Begin()
if err != nil {
return
}
defer func() {
switch err {
case nil:
err = tx.Commit()
default:
tx.Rollback()
}
}()
if _, err = tx.Exec("UPDATE products SET views = views + 1"); err != nil {
return
}
if _, err = tx.Exec("INSERT INTO product_viewers (user_id, product_id) VALUES (?, ?)", userID, productID); err != nil {
return
}
return
}
func main() {
// @NOTE: the real connection is not required for tests
db, err := sql.Open("mysql", "root@/blog")
if err != nil {
panic(err)
}
defer db.Close()
if err = recordStats(db, 1 /*some user id*/, 5 /*some product id*/); err != nil {
panic(err)
}
}

View File

@ -0,0 +1,58 @@
package main
import (
"fmt"
"testing"
"github.com/DATA-DOG/go-sqlmock"
)
// a successful case
func TestShouldUpdateStats(t *testing.T) {
db, mock, err := sqlmock.New()
if err != nil {
t.Fatalf("an error '%s' was not expected when opening a stub database connection", err)
}
defer db.Close()
mock.ExpectBegin()
mock.ExpectExec("UPDATE products").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectExec("INSERT INTO product_viewers").WithArgs(2, 3).WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
// now we execute our method
if err = recordStats(db, 2, 3); err != nil {
t.Errorf("error was not expected while updating stats: %s", err)
}
// we make sure that all expectations were met
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("there were unfulfilled expections: %s", err)
}
}
// a failing test case
func TestShouldRollbackStatUpdatesOnFailure(t *testing.T) {
db, mock, err := sqlmock.New()
if err != nil {
t.Fatalf("an error '%s' was not expected when opening a stub database connection", err)
}
defer db.Close()
mock.ExpectBegin()
mock.ExpectExec("UPDATE products").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectExec("INSERT INTO product_viewers").
WithArgs(2, 3).
WillReturnError(fmt.Errorf("some error"))
mock.ExpectRollback()
// now we execute our method
if err = recordStats(db, 2, 3); err == nil {
t.Errorf("was expecting an error, but there was none")
}
// we make sure that all expectations were met
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("there were unfulfilled expections: %s", err)
}
}

81
examples/blog/blog.go Normal file
View File

@ -0,0 +1,81 @@
package main
import (
"database/sql"
"encoding/json"
"net/http"
)
type api struct {
db *sql.DB
}
type post struct {
ID int
Title string
Body string
}
func (a *api) posts(w http.ResponseWriter, r *http.Request) {
rows, err := a.db.Query("SELECT id, title, body FROM posts")
if err != nil {
a.fail(w, "failed to fetch posts: "+err.Error(), 500)
return
}
defer rows.Close()
var posts []*post
for rows.Next() {
p := &post{}
if err := rows.Scan(&p.ID, &p.Title, &p.Body); err != nil {
a.fail(w, "failed to scan post: "+err.Error(), 500)
return
}
posts = append(posts, p)
}
if rows.Err() != nil {
a.fail(w, "failed to read all posts: "+rows.Err().Error(), 500)
return
}
data := struct {
Posts []*post
}{posts}
a.ok(w, data)
}
func main() {
// @NOTE: the real connection is not required for tests
db, err := sql.Open("mysql", "root@/blog")
if err != nil {
panic(err)
}
app := &api{db: db}
http.HandleFunc("/posts", app.posts)
http.ListenAndServe(":8080", nil)
}
func (a *api) fail(w http.ResponseWriter, msg string, status int) {
w.Header().Set("Content-Type", "application/json")
data := struct {
Error string
}{Error: msg}
resp, _ := json.Marshal(data)
w.WriteHeader(status)
w.Write(resp)
}
func (a *api) ok(w http.ResponseWriter, data interface{}) {
w.Header().Set("Content-Type", "application/json")
resp, err := json.Marshal(data)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
a.fail(w, "oops something evil has happened", 500)
return
}
w.Write(resp)
}

102
examples/blog/blog_test.go Normal file
View File

@ -0,0 +1,102 @@
package main
import (
"bytes"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"testing"
"github.com/DATA-DOG/go-sqlmock"
)
func (a *api) assertJSON(actual []byte, data interface{}, t *testing.T) {
expected, err := json.Marshal(data)
if err != nil {
t.Fatalf("an error '%s' was not expected when marshaling expected json data", err)
}
if bytes.Compare(expected, actual) != 0 {
t.Errorf("the expected json: %s is different from actual %s", expected, actual)
}
}
func TestShouldGetPosts(t *testing.T) {
db, mock, err := sqlmock.New()
if err != nil {
t.Fatalf("an error '%s' was not expected when opening a stub database connection", err)
}
defer db.Close()
// create app with mocked db, request and response to test
app := &api{db}
req, err := http.NewRequest("GET", "http://localhost/posts", nil)
if err != nil {
t.Fatalf("an error '%s' was not expected while creating request", err)
}
w := httptest.NewRecorder()
// before we actually execute our api function, we need to expect required DB actions
rows := sqlmock.NewRows([]string{"id", "title", "body"}).
AddRow(1, "post 1", "hello").
AddRow(2, "post 2", "world")
mock.ExpectQuery("^SELECT (.+) FROM posts$").WillReturnRows(rows)
// now we execute our request
app.posts(w, req)
if w.Code != 200 {
t.Fatalf("expected status code to be 200, but got: %d", w.Code)
}
data := struct {
Posts []*post
}{Posts: []*post{
{ID: 1, Title: "post 1", Body: "hello"},
{ID: 2, Title: "post 2", Body: "world"},
}}
app.assertJSON(w.Body.Bytes(), data, t)
// we make sure that all expectations were met
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("there were unfulfilled expections: %s", err)
}
}
func TestShouldRespondWithErrorOnFailure(t *testing.T) {
db, mock, err := sqlmock.New()
if err != nil {
t.Fatalf("an error '%s' was not expected when opening a stub database connection", err)
}
defer db.Close()
// create app with mocked db, request and response to test
app := &api{db}
req, err := http.NewRequest("GET", "http://localhost/posts", nil)
if err != nil {
t.Fatalf("an error '%s' was not expected while creating request", err)
}
w := httptest.NewRecorder()
// before we actually execute our api function, we need to expect required DB actions
mock.ExpectQuery("^SELECT (.+) FROM posts$").WillReturnError(fmt.Errorf("some error"))
// now we execute our request
app.posts(w, req)
if w.Code != 500 {
t.Fatalf("expected status code to be 500, but got: %d", w.Code)
}
data := struct {
Error string
}{"failed to fetch posts: some error"}
app.assertJSON(w.Body.Bytes(), data, t)
// we make sure that all expectations were met
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("there were unfulfilled expections: %s", err)
}
}

1
examples/doc.go Normal file
View File

@ -0,0 +1 @@
package examples

121
examples/orders/orders.go Normal file
View File

@ -0,0 +1,121 @@
package main
import (
"database/sql"
"fmt"
"log"
"github.com/kisielk/sqlstruct"
)
const ORDER_PENDING = 0
const ORDER_CANCELLED = 1
type User struct {
Id int `sql:"id"`
Username string `sql:"username"`
Balance float64 `sql:"balance"`
}
type Order struct {
Id int `sql:"id"`
Value float64 `sql:"value"`
ReservedFee float64 `sql:"reserved_fee"`
Status int `sql:"status"`
}
func cancelOrder(id int, db *sql.DB) (err error) {
tx, err := db.Begin()
if err != nil {
return
}
var order Order
var user User
sql := fmt.Sprintf(`
SELECT %s, %s
FROM orders AS o
INNER JOIN users AS u ON o.buyer_id = u.id
WHERE o.id = ?
FOR UPDATE`,
sqlstruct.ColumnsAliased(order, "o"),
sqlstruct.ColumnsAliased(user, "u"))
// fetch order to cancel
rows, err := tx.Query(sql, id)
if err != nil {
tx.Rollback()
return
}
defer rows.Close()
// no rows, nothing to do
if !rows.Next() {
tx.Rollback()
return
}
// read order
err = sqlstruct.ScanAliased(&order, rows, "o")
if err != nil {
tx.Rollback()
return
}
// ensure order status
if order.Status != ORDER_PENDING {
tx.Rollback()
return
}
// read user
err = sqlstruct.ScanAliased(&user, rows, "u")
if err != nil {
tx.Rollback()
return
}
rows.Close() // manually close before other prepared statements
// refund order value
sql = "UPDATE users SET balance = balance + ? WHERE id = ?"
refundStmt, err := tx.Prepare(sql)
if err != nil {
tx.Rollback()
return
}
defer refundStmt.Close()
_, err = refundStmt.Exec(order.Value+order.ReservedFee, user.Id)
if err != nil {
tx.Rollback()
return
}
// update order status
order.Status = ORDER_CANCELLED
sql = "UPDATE orders SET status = ?, updated = NOW() WHERE id = ?"
orderUpdStmt, err := tx.Prepare(sql)
if err != nil {
tx.Rollback()
return
}
defer orderUpdStmt.Close()
_, err = orderUpdStmt.Exec(order.Status, order.Id)
if err != nil {
tx.Rollback()
return
}
return tx.Commit()
}
func main() {
// @NOTE: the real connection is not required for tests
db, err := sql.Open("mysql", "root:@/orders")
if err != nil {
log.Fatal(err)
}
defer db.Close()
err = cancelOrder(1, db)
if err != nil {
log.Fatal(err)
}
}

View File

@ -0,0 +1,108 @@
package main
import (
"fmt"
"testing"
"github.com/DATA-DOG/go-sqlmock"
)
// will test that order with a different status, cannot be cancelled
func TestShouldNotCancelOrderWithNonPendingStatus(t *testing.T) {
// open database stub
db, mock, err := sqlmock.New()
if err != nil {
t.Errorf("An error '%s' was not expected when opening a stub database connection", err)
}
defer db.Close()
// columns are prefixed with "o" since we used sqlstruct to generate them
columns := []string{"o_id", "o_status"}
// expect transaction begin
mock.ExpectBegin()
// expect query to fetch order and user, match it with regexp
mock.ExpectQuery("SELECT (.+) FROM orders AS o INNER JOIN users AS u (.+) FOR UPDATE").
WithArgs(1).
WillReturnRows(sqlmock.NewRows(columns).FromCSVString("1,1"))
// expect transaction rollback, since order status is "cancelled"
mock.ExpectRollback()
// run the cancel order function
err = cancelOrder(1, db)
if err != nil {
t.Errorf("Expected no error, but got %s instead", err)
}
// we make sure that all expectations were met
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("there were unfulfilled expections: %s", err)
}
}
// will test order cancellation
func TestShouldRefundUserWhenOrderIsCancelled(t *testing.T) {
// open database stub
db, mock, err := sqlmock.New()
if err != nil {
t.Errorf("An error '%s' was not expected when opening a stub database connection", err)
}
defer db.Close()
// columns are prefixed with "o" since we used sqlstruct to generate them
columns := []string{"o_id", "o_status", "o_value", "o_reserved_fee", "u_id", "u_balance"}
// expect transaction begin
mock.ExpectBegin()
// expect query to fetch order and user, match it with regexp
mock.ExpectQuery("SELECT (.+) FROM orders AS o INNER JOIN users AS u (.+) FOR UPDATE").
WithArgs(1).
WillReturnRows(sqlmock.NewRows(columns).AddRow(1, 0, 25.75, 3.25, 2, 10.00))
// expect user balance update
mock.ExpectPrepare("UPDATE users SET balance").ExpectExec().
WithArgs(25.75+3.25, 2). // refund amount, user id
WillReturnResult(sqlmock.NewResult(0, 1)) // no insert id, 1 affected row
// expect order status update
mock.ExpectPrepare("UPDATE orders SET status").ExpectExec().
WithArgs(ORDER_CANCELLED, 1). // status, id
WillReturnResult(sqlmock.NewResult(0, 1)) // no insert id, 1 affected row
// expect a transaction commit
mock.ExpectCommit()
// run the cancel order function
err = cancelOrder(1, db)
if err != nil {
t.Errorf("Expected no error, but got %s instead", err)
}
// we make sure that all expectations were met
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("there were unfulfilled expections: %s", err)
}
}
// will test order cancellation
func TestShouldRollbackOnError(t *testing.T) {
// open database stub
db, mock, err := sqlmock.New()
if err != nil {
t.Errorf("An error '%s' was not expected when opening a stub database connection", err)
}
defer db.Close()
// expect transaction begin
mock.ExpectBegin()
// expect query to fetch order and user, match it with regexp
mock.ExpectQuery("SELECT (.+) FROM orders AS o INNER JOIN users AS u (.+) FOR UPDATE").
WithArgs(1).
WillReturnError(fmt.Errorf("Some error"))
// should rollback since error was returned from query execution
mock.ExpectRollback()
// run the cancel order function
err = cancelOrder(1, db)
// error should return back
if err == nil {
t.Error("Expected error, but got none")
}
// we make sure that all expectations were met
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("there were unfulfilled expections: %s", err)
}
}

View File

@ -2,12 +2,16 @@ package sqlmock
import ( import (
"database/sql/driver" "database/sql/driver"
"fmt"
"reflect" "reflect"
"regexp" "regexp"
"strings"
"sync"
) )
// Argument interface allows to match // Argument interface allows to match
// any argument in specific way // any argument in specific way when used with
// ExpectedQuery and ExpectedExec expectations.
type Argument interface { type Argument interface {
Match(driver.Value) bool Match(driver.Value) bool
} }
@ -15,12 +19,15 @@ type Argument interface {
// an expectation interface // an expectation interface
type expectation interface { type expectation interface {
fulfilled() bool fulfilled() bool
setError(err error) Lock()
Unlock()
String() string
} }
// common expectation struct // common expectation struct
// satisfies the expectation interface // satisfies the expectation interface
type commonExpectation struct { type commonExpectation struct {
sync.Mutex
triggered bool triggered bool
err error err error
} }
@ -29,8 +36,267 @@ func (e *commonExpectation) fulfilled() bool {
return e.triggered return e.triggered
} }
func (e *commonExpectation) setError(err error) { // ExpectedClose is used to manage *sql.DB.Close expectation
// returned by *Sqlmock.ExpectClose.
type ExpectedClose struct {
commonExpectation
}
// WillReturnError allows to set an error for *sql.DB.Close action
func (e *ExpectedClose) WillReturnError(err error) *ExpectedClose {
e.err = err e.err = err
return e
}
// String returns string representation
func (e *ExpectedClose) String() string {
msg := "ExpectedClose => expecting database Close"
if e.err != nil {
msg += fmt.Sprintf(", which should return error: %s", e.err)
}
return msg
}
// ExpectedBegin is used to manage *sql.DB.Begin expectation
// returned by *Sqlmock.ExpectBegin.
type ExpectedBegin struct {
commonExpectation
}
// WillReturnError allows to set an error for *sql.DB.Begin action
func (e *ExpectedBegin) WillReturnError(err error) *ExpectedBegin {
e.err = err
return e
}
// String returns string representation
func (e *ExpectedBegin) String() string {
msg := "ExpectedBegin => expecting database transaction Begin"
if e.err != nil {
msg += fmt.Sprintf(", which should return error: %s", e.err)
}
return msg
}
// ExpectedCommit is used to manage *sql.Tx.Commit expectation
// returned by *Sqlmock.ExpectCommit.
type ExpectedCommit struct {
commonExpectation
}
// WillReturnError allows to set an error for *sql.Tx.Close action
func (e *ExpectedCommit) WillReturnError(err error) *ExpectedCommit {
e.err = err
return e
}
// String returns string representation
func (e *ExpectedCommit) String() string {
msg := "ExpectedCommit => expecting transaction Commit"
if e.err != nil {
msg += fmt.Sprintf(", which should return error: %s", e.err)
}
return msg
}
// ExpectedRollback is used to manage *sql.Tx.Rollback expectation
// returned by *Sqlmock.ExpectRollback.
type ExpectedRollback struct {
commonExpectation
}
// WillReturnError allows to set an error for *sql.Tx.Rollback action
func (e *ExpectedRollback) WillReturnError(err error) *ExpectedRollback {
e.err = err
return e
}
// String returns string representation
func (e *ExpectedRollback) String() string {
msg := "ExpectedRollback => expecting transaction Rollback"
if e.err != nil {
msg += fmt.Sprintf(", which should return error: %s", e.err)
}
return msg
}
// ExpectedQuery is used to manage *sql.DB.Query, *dql.DB.QueryRow, *sql.Tx.Query,
// *sql.Tx.QueryRow, *sql.Stmt.Query or *sql.Stmt.QueryRow expectations.
// Returned by *Sqlmock.ExpectQuery.
type ExpectedQuery struct {
queryBasedExpectation
rows driver.Rows
}
// WithArgs will match given expected args to actual database query arguments.
// if at least one argument does not match, it will return an error. For specific
// arguments an sqlmock.Argument interface can be used to match an argument.
func (e *ExpectedQuery) WithArgs(args ...driver.Value) *ExpectedQuery {
e.args = args
return e
}
// WillReturnError allows to set an error for expected database query
func (e *ExpectedQuery) WillReturnError(err error) *ExpectedQuery {
e.err = err
return e
}
// WillReturnRows specifies the set of resulting rows that will be returned
// by the triggered query
func (e *ExpectedQuery) WillReturnRows(rows driver.Rows) *ExpectedQuery {
e.rows = rows
return e
}
// String returns string representation
func (e *ExpectedQuery) String() string {
msg := "ExpectedQuery => expecting Query or QueryRow which:"
msg += "\n - matches sql: '" + e.sqlRegex.String() + "'"
if len(e.args) == 0 {
msg += "\n - is without arguments"
} else {
msg += "\n - is with arguments:\n"
for i, arg := range e.args {
msg += fmt.Sprintf(" %d - %+v\n", i, arg)
}
msg = strings.TrimSpace(msg)
}
if e.rows != nil {
msg += "\n - should return rows:\n"
rs, _ := e.rows.(*rows)
for i, row := range rs.rows {
msg += fmt.Sprintf(" %d - %+v\n", i, row)
}
msg = strings.TrimSpace(msg)
}
if e.err != nil {
msg += fmt.Sprintf("\n - should return error: %s", e.err)
}
return msg
}
// ExpectedExec is used to manage *sql.DB.Exec, *sql.Tx.Exec or *sql.Stmt.Exec expectations.
// Returned by *Sqlmock.ExpectExec.
type ExpectedExec struct {
queryBasedExpectation
result driver.Result
}
// WithArgs will match given expected args to actual database exec operation arguments.
// if at least one argument does not match, it will return an error. For specific
// arguments an sqlmock.Argument interface can be used to match an argument.
func (e *ExpectedExec) WithArgs(args ...driver.Value) *ExpectedExec {
e.args = args
return e
}
// WillReturnError allows to set an error for expected database exec action
func (e *ExpectedExec) WillReturnError(err error) *ExpectedExec {
e.err = err
return e
}
// String returns string representation
func (e *ExpectedExec) String() string {
msg := "ExpectedExec => expecting Exec which:"
msg += "\n - matches sql: '" + e.sqlRegex.String() + "'"
if len(e.args) == 0 {
msg += "\n - is without arguments"
} else {
msg += "\n - is with arguments:\n"
var margs []string
for i, arg := range e.args {
margs = append(margs, fmt.Sprintf(" %d - %+v", i, arg))
}
msg += strings.Join(margs, "\n")
}
if e.result != nil {
res, _ := e.result.(*result)
msg += "\n - should return Result having:"
msg += fmt.Sprintf("\n LastInsertId: %d", res.insertID)
msg += fmt.Sprintf("\n RowsAffected: %d", res.rowsAffected)
if res.err != nil {
msg += fmt.Sprintf("\n Error: %s", res.err)
}
}
if e.err != nil {
msg += fmt.Sprintf("\n - should return error: %s", e.err)
}
return msg
}
// WillReturnResult arranges for an expected Exec() to return a particular
// result, there is sqlmock.NewResult(lastInsertID int64, affectedRows int64) method
// to build a corresponding result. Or if actions needs to be tested against errors
// sqlmock.NewErrorResult(err error) to return a given error.
func (e *ExpectedExec) WillReturnResult(result driver.Result) *ExpectedExec {
e.result = result
return e
}
// ExpectedPrepare is used to manage *sql.DB.Prepare or *sql.Tx.Prepare expectations.
// Returned by *Sqlmock.ExpectPrepare.
type ExpectedPrepare struct {
commonExpectation
mock *Sqlmock
sqlRegex *regexp.Regexp
statement driver.Stmt
closeErr error
}
// WillReturnError allows to set an error for the expected *sql.DB.Prepare or *sql.Tx.Prepare action.
func (e *ExpectedPrepare) WillReturnError(err error) *ExpectedPrepare {
e.err = err
return e
}
// WillReturnCloseError allows to set an error for this prapared statement Close action
func (e *ExpectedPrepare) WillReturnCloseError(err error) *ExpectedPrepare {
e.closeErr = err
return e
}
// ExpectQuery allows to expect Query() or QueryRow() on this prepared statement.
// this method is convenient in order to prevent duplicating sql query string matching.
func (e *ExpectedPrepare) ExpectQuery() *ExpectedQuery {
eq := &ExpectedQuery{}
eq.sqlRegex = e.sqlRegex
e.mock.expected = append(e.mock.expected, eq)
return eq
}
// ExpectExec allows to expect Exec() on this prepared statement.
// this method is convenient in order to prevent duplicating sql query string matching.
func (e *ExpectedPrepare) ExpectExec() *ExpectedExec {
eq := &ExpectedExec{}
eq.sqlRegex = e.sqlRegex
e.mock.expected = append(e.mock.expected, eq)
return eq
}
// String returns string representation
func (e *ExpectedPrepare) String() string {
msg := "ExpectedPrepare => expecting Prepare statement which:"
msg += "\n - matches sql: '" + e.sqlRegex.String() + "'"
if e.err != nil {
msg += fmt.Sprintf("\n - should return error: %s", e.err)
}
if e.closeErr != nil {
msg += fmt.Sprintf("\n - should return error on Close: %s", e.closeErr)
}
return msg
} }
// query based expectation // query based expectation
@ -41,6 +307,19 @@ type queryBasedExpectation struct {
args []driver.Value args []driver.Value
} }
func (e *queryBasedExpectation) attemptMatch(sql string, args []driver.Value) (ret bool) {
if !e.queryMatches(sql) {
return
}
defer recover() // ignore panic since we attempt a match
if e.argsMatches(args) {
return true
}
return
}
func (e *queryBasedExpectation) queryMatches(sql string) bool { func (e *queryBasedExpectation) queryMatches(sql string) bool {
return e.sqlRegex.MatchString(sql) return e.sqlRegex.MatchString(sql)
} }
@ -88,39 +367,3 @@ func (e *queryBasedExpectation) argsMatches(args []driver.Value) bool {
} }
return true return true
} }
// begin transaction
type expectedBegin struct {
commonExpectation
}
// tx commit
type expectedCommit struct {
commonExpectation
}
// tx rollback
type expectedRollback struct {
commonExpectation
}
// query expectation
type expectedQuery struct {
queryBasedExpectation
rows driver.Rows
}
// exec query expectation
type expectedExec struct {
queryBasedExpectation
result driver.Result
}
// Prepare expectation
type expectedPrepare struct {
commonExpectation
statement driver.Stmt
}

View File

@ -2,6 +2,7 @@ package sqlmock
import ( import (
"database/sql/driver" "database/sql/driver"
"fmt"
"regexp" "regexp"
"testing" "testing"
"time" "time"
@ -60,7 +61,7 @@ func TestQueryExpectationArgComparison(t *testing.T) {
} }
func TestQueryExpectationSqlMatch(t *testing.T) { func TestQueryExpectationSqlMatch(t *testing.T) {
e := &expectedExec{} e := &ExpectedExec{}
e.sqlRegex = regexp.MustCompile("SELECT x FROM") e.sqlRegex = regexp.MustCompile("SELECT x FROM")
if !e.queryMatches("SELECT x FROM someting") { if !e.queryMatches("SELECT x FROM someting") {
t.Errorf("Sql must have matched the query") t.Errorf("Sql must have matched the query")
@ -71,3 +72,13 @@ func TestQueryExpectationSqlMatch(t *testing.T) {
t.Errorf("Sql must have matched the query") t.Errorf("Sql must have matched the query")
} }
} }
func ExampleExpectExec() {
db, mock, _ := New()
result := NewErrorResult(fmt.Errorf("some error"))
mock.ExpectExec("^INSERT (.+)").WillReturnResult(result)
res, _ := db.Exec("INSERT something")
_, err := res.LastInsertId()
fmt.Println(err)
// Output: some error
}

View File

@ -5,6 +5,32 @@ import (
"testing" "testing"
) )
// used for examples
var mock = &Sqlmock{}
func ExampleNewErrorResult() {
db, mock, _ := New()
result := NewErrorResult(fmt.Errorf("some error"))
mock.ExpectExec("^INSERT (.+)").WillReturnResult(result)
res, _ := db.Exec("INSERT something")
_, err := res.LastInsertId()
fmt.Println(err)
// Output: some error
}
func ExampleNewResult() {
var lastInsertID, affected int64
result := NewResult(lastInsertID, affected)
mock.ExpectExec("^INSERT (.+)").WillReturnResult(result)
fmt.Println(mock.ExpectationsWereMet())
// Output: there is a remaining expectation which was not matched: ExpectedExec => expecting Exec which:
// - matches sql: '^INSERT (.+)'
// - is without arguments
// - should return Result having:
// LastInsertId: 0
// RowsAffected: 0
}
func TestShouldReturnValidSqlDriverResult(t *testing.T) { func TestShouldReturnValidSqlDriverResult(t *testing.T) {
result := NewResult(1, 2) result := NewResult(1, 2)
id, err := result.LastInsertId() id, err := result.LastInsertId()

106
rows.go
View File

@ -7,19 +7,56 @@ import (
"strings" "strings"
) )
// CSVColumnParser is a function which converts trimmed csv
// column string to a []byte representation. currently
// transforms NULL to nil
var CSVColumnParser = func(s string) []byte {
switch {
case strings.ToLower(s) == "null":
return nil
}
return []byte(s)
}
// Rows interface allows to construct rows // Rows interface allows to construct rows
// which also satisfies database/sql/driver.Rows interface // which also satisfies database/sql/driver.Rows interface
type Rows interface { type Rows interface {
driver.Rows // composed interface, supports sql driver.Rows // composed interface, supports sql driver.Rows
AddRow(...driver.Value) Rows driver.Rows
// AddRow composed from database driver.Value slice
// return the same instance to perform subsequent actions.
// Note that the number of values must match the number
// of columns
AddRow(columns ...driver.Value) Rows
// FromCSVString build rows from csv string.
// return the same instance to perform subsequent actions.
// Note that the number of values must match the number
// of columns
FromCSVString(s string) Rows FromCSVString(s string) Rows
// RowError allows to set an error
// which will be returned when a given
// row number is read
RowError(row int, err error) Rows
// CloseError allows to set an error
// which will be returned by rows.Close
// function.
//
// The close error will be triggered only in cases
// when rows.Next() EOF was not yet reached, that is
// a default sql library behavior
CloseError(err error) Rows
} }
// a struct which implements database/sql/driver.Rows
type rows struct { type rows struct {
cols []string cols []string
rows [][]driver.Value rows [][]driver.Value
pos int pos int
nextErr map[int]error
closeErr error
} }
func (r *rows) Columns() []string { func (r *rows) Columns() []string {
@ -27,11 +64,7 @@ func (r *rows) Columns() []string {
} }
func (r *rows) Close() error { func (r *rows) Close() error {
return nil return r.closeErr
}
func (r *rows) Err() error {
return nil
} }
// advances to next row // advances to next row
@ -45,19 +78,26 @@ func (r *rows) Next(dest []driver.Value) error {
dest[i] = col dest[i] = col
} }
return nil return r.nextErr[r.pos-1]
} }
// NewRows allows Rows to be created from a group of // NewRows allows Rows to be created from a
// sql driver.Value or from the CSV string and // sql driver.Value slice or from the CSV string and
// to be used as sql driver.Rows // to be used as sql driver.Rows
func NewRows(columns []string) Rows { func NewRows(columns []string) Rows {
return &rows{cols: columns} return &rows{cols: columns, nextErr: make(map[int]error)}
}
func (r *rows) CloseError(err error) Rows {
r.closeErr = err
return r
}
func (r *rows) RowError(row int, err error) Rows {
r.nextErr[row] = err
return r
} }
// AddRow adds a row which is built from arguments
// in the same column order, returns sql driver.Rows
// compatible interface
func (r *rows) AddRow(values ...driver.Value) Rows { func (r *rows) AddRow(values ...driver.Value) Rows {
if len(values) != len(r.cols) { if len(values) != len(r.cols) {
panic("Expected number of values to match number of columns") panic("Expected number of values to match number of columns")
@ -72,8 +112,6 @@ func (r *rows) AddRow(values ...driver.Value) Rows {
return r return r
} }
// FromCSVString adds rows from CSV string.
// Returns sql driver.Rows compatible interface
func (r *rows) FromCSVString(s string) Rows { func (r *rows) FromCSVString(s string) Rows {
res := strings.NewReader(strings.TrimSpace(s)) res := strings.NewReader(strings.TrimSpace(s))
csvReader := csv.NewReader(res) csvReader := csv.NewReader(res)
@ -86,35 +124,9 @@ func (r *rows) FromCSVString(s string) Rows {
row := make([]driver.Value, len(r.cols)) row := make([]driver.Value, len(r.cols))
for i, v := range res { for i, v := range res {
row[i] = []byte(strings.TrimSpace(v)) row[i] = CSVColumnParser(strings.TrimSpace(v))
} }
r.rows = append(r.rows, row) r.rows = append(r.rows, row)
} }
return r return r
} }
// RowsFromCSVString creates Rows from CSV string
// to be used for mocked queries. Returns sql driver Rows interface
// ** DEPRECATED ** will be removed in the future, use Rows.FromCSVString
func RowsFromCSVString(columns []string, s string) driver.Rows {
rs := &rows{}
rs.cols = columns
r := strings.NewReader(strings.TrimSpace(s))
csvReader := csv.NewReader(r)
for {
r, err := csvReader.Read()
if err != nil || r == nil {
break
}
row := make([]driver.Value, len(columns))
for i, v := range r {
v := strings.TrimSpace(v)
row[i] = []byte(v)
}
rs.rows = append(rs.rows, row)
}
return rs
}

248
rows_test.go Normal file
View File

@ -0,0 +1,248 @@
package sqlmock
import (
"database/sql"
"fmt"
"testing"
)
func ExampleRows() {
db, mock, err := New()
if err != nil {
fmt.Println("failed to open sqlmock database:", err)
}
defer db.Close()
rows := NewRows([]string{"id", "title"}).
AddRow(1, "one").
AddRow(2, "two")
mock.ExpectQuery("SELECT").WillReturnRows(rows)
rs, _ := db.Query("SELECT")
defer rs.Close()
for rs.Next() {
var id int
var title string
rs.Scan(&id, &title)
fmt.Println("scanned id:", id, "and title:", title)
}
if rs.Err() != nil {
fmt.Println("got rows error:", rs.Err())
}
// Output: scanned id: 1 and title: one
// scanned id: 2 and title: two
}
func ExampleRows_rowError() {
db, mock, err := New()
if err != nil {
fmt.Println("failed to open sqlmock database:", err)
}
defer db.Close()
rows := NewRows([]string{"id", "title"}).
AddRow(0, "one").
AddRow(1, "two").
RowError(1, fmt.Errorf("row error"))
mock.ExpectQuery("SELECT").WillReturnRows(rows)
rs, _ := db.Query("SELECT")
defer rs.Close()
for rs.Next() {
var id int
var title string
rs.Scan(&id, &title)
fmt.Println("scanned id:", id, "and title:", title)
}
if rs.Err() != nil {
fmt.Println("got rows error:", rs.Err())
}
// Output: scanned id: 0 and title: one
// got rows error: row error
}
func ExampleRows_closeError() {
db, mock, err := New()
if err != nil {
fmt.Println("failed to open sqlmock database:", err)
}
defer db.Close()
rows := NewRows([]string{"id", "title"}).CloseError(fmt.Errorf("close error"))
mock.ExpectQuery("SELECT").WillReturnRows(rows)
rs, _ := db.Query("SELECT")
// Note: that close will return error only before rows EOF
// that is a default sql package behavior. If you run rs.Next()
// it will handle the error internally and return nil bellow
if err := rs.Close(); err != nil {
fmt.Println("got error:", err)
}
// Output: got error: close error
}
func TestAllowsToSetRowsErrors(t *testing.T) {
t.Parallel()
db, mock, err := New()
if err != nil {
t.Fatalf("an error '%s' was not expected when opening a stub database connection", err)
}
defer db.Close()
rows := NewRows([]string{"id", "title"}).
AddRow(0, "one").
AddRow(1, "two").
RowError(1, fmt.Errorf("error"))
mock.ExpectQuery("SELECT").WillReturnRows(rows)
rs, err := db.Query("SELECT")
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
defer rs.Close()
if !rs.Next() {
t.Fatal("expected the first row to be available")
}
if rs.Err() != nil {
t.Fatalf("unexpected error: %s", rs.Err())
}
if rs.Next() {
t.Fatal("was not expecting the second row, since there should be an error")
}
if rs.Err() == nil {
t.Fatal("expected an error, but got none")
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Fatal(err)
}
}
func TestRowsCloseError(t *testing.T) {
t.Parallel()
db, mock, err := New()
if err != nil {
t.Fatalf("an error '%s' was not expected when opening a stub database connection", err)
}
defer db.Close()
rows := NewRows([]string{"id"}).CloseError(fmt.Errorf("close error"))
mock.ExpectQuery("SELECT").WillReturnRows(rows)
rs, err := db.Query("SELECT")
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
if err := rs.Close(); err == nil {
t.Fatal("expected a close error")
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Fatal(err)
}
}
func TestQuerySingleRow(t *testing.T) {
t.Parallel()
db, mock, err := New()
if err != nil {
t.Fatalf("an error '%s' was not expected when opening a stub database connection", err)
}
defer db.Close()
rows := NewRows([]string{"id"}).
AddRow(1).
AddRow(2)
mock.ExpectQuery("SELECT").WillReturnRows(rows)
var id int
if err := db.QueryRow("SELECT").Scan(&id); err != nil {
t.Fatalf("unexpected error: %s", err)
}
mock.ExpectQuery("SELECT").WillReturnRows(NewRows([]string{"id"}))
if err := db.QueryRow("SELECT").Scan(&id); err != sql.ErrNoRows {
t.Fatal("expected sql no rows error")
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Fatal(err)
}
}
func TestRowsScanError(t *testing.T) {
t.Parallel()
db, mock, err := New()
if err != nil {
t.Fatalf("an error '%s' was not expected when opening a stub database connection", err)
}
defer db.Close()
r := NewRows([]string{"col1", "col2"}).AddRow("one", "two").AddRow("one", nil)
mock.ExpectQuery("SELECT").WillReturnRows(r)
rs, err := db.Query("SELECT")
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
defer rs.Close()
var one, two string
if !rs.Next() || rs.Err() != nil || rs.Scan(&one, &two) != nil {
t.Fatal("unexpected error on first row scan")
}
if !rs.Next() || rs.Err() != nil {
t.Fatal("unexpected error on second row read")
}
err = rs.Scan(&one, &two)
if err == nil {
t.Fatal("expected an error for scan, but got none")
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Fatal(err)
}
}
func TestCSVRowParser(t *testing.T) {
t.Parallel()
rs := NewRows([]string{"col1", "col2"}).FromCSVString("a,NULL")
db, mock, err := New()
if err != nil {
t.Fatalf("an error '%s' was not expected when opening a stub database connection", err)
}
defer db.Close()
mock.ExpectQuery("SELECT").WillReturnRows(rs)
rw, err := db.Query("SELECT")
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
defer rw.Close()
var col1 string
var col2 []byte
rw.Next()
if err = rw.Scan(&col1, &col2); err != nil {
t.Fatalf("unexpected error: %s", err)
}
if col1 != "a" {
t.Fatalf("expected col1 to be 'a', but got [%T]:%+v", col1, col1)
}
if col2 != nil {
t.Fatalf("expected col2 to be nil, but got [%T]:%+v", col2, col2)
}
}

View File

@ -1,195 +1,444 @@
/* /*
Package sqlmock provides sql driver mock connecection, which allows to test database, Package sqlmock provides sql driver connection, which allows to test database
create expectations and ensure the correct execution flow of any database operations. interactions by expected calls and simulate their results or errors.
It hooks into Go standard library's database/sql package.
The package provides convenient methods to mock database queries, transactions and It does not require any modifications to your source code in order to test
expect the right execution flow, compare query arguments or even return error instead and mock database operations. It does not even require a real database in order
to simulate failures. See the example bellow, which illustrates how convenient it is to test your application.
to work with:
package main
import (
"database/sql"
"github.com/DATA-DOG/go-sqlmock"
"testing"
"fmt"
)
// will test that order with a different status, cannot be cancelled
func TestShouldNotCancelOrderWithNonPendingStatus(t *testing.T) {
// open database stub
db, err := sql.Open("mock", "")
if err != nil {
t.Errorf("An error '%s' was not expected when opening a stub database connection", err)
}
// columns to be used for result
columns := []string{"id", "status"}
// expect transaction begin
sqlmock.ExpectBegin()
// expect query to fetch order, match it with regexp
sqlmock.ExpectQuery("SELECT (.+) FROM orders (.+) FOR UPDATE").
WithArgs(1).
WillReturnRows(sqlmock.NewRows(columns).FromCSVString("1,1"))
// expect transaction rollback, since order status is "cancelled"
sqlmock.ExpectRollback()
// run the cancel order function
someOrderId := 1
// call a function which executes expected database operations
err = cancelOrder(someOrderId, db)
if err != nil {
t.Errorf("Expected no error, but got %s instead", err)
}
// db.Close() ensures that all expectations have been met
if err = db.Close(); err != nil {
t.Errorf("Error '%s' was not expected while closing the database", err)
}
}
The driver allows to mock any sql driver method behavior. Concurrent actions
are also supported.
*/ */
package sqlmock package sqlmock
import ( import (
"database/sql"
"database/sql/driver" "database/sql/driver"
"fmt" "fmt"
"reflect"
"regexp" "regexp"
) )
var mock *mockDriver // Sqlmock type satisfies required sql.driver interfaces
// to simulate actual database and also serves to
// create expectations for any kind of database action
// in order to mock and test real database behavior.
type Sqlmock struct {
// Mock interface defines a mock which is returned // MatchExpectationsInOrder gives an option whether to match all
// by any expectation and can be detailed further // expectations in the order they were set or not.
// with the methods this interface provides //
type Mock interface { // By default it is set to - true. But if you use goroutines
WithArgs(...driver.Value) Mock // to parallelize your query executation, that option may
WillReturnError(error) Mock // be handy.
WillReturnRows(driver.Rows) Mock MatchExpectationsInOrder bool
WillReturnResult(driver.Result) Mock
dsn string
opened int
drv *mockDriver
expected []expectation
} }
type mockDriver struct { // ExpectClose queues an expectation for this database
conn *conn // action to be triggered. the *ExpectedClose allows
// to mock database response
func (c *Sqlmock) ExpectClose() *ExpectedClose {
e := &ExpectedClose{}
c.expected = append(c.expected, e)
return e
} }
func (d *mockDriver) Open(dsn string) (driver.Conn, error) { // Close a mock database driver connection. It may or may not
return mock.conn, nil // be called depending on the sircumstances, but if it is called
} // there must be an *ExpectedClose expectation satisfied.
// meets http://golang.org/pkg/database/sql/driver/#Conn interface
func (c *Sqlmock) Close() error {
c.drv.Lock()
defer c.drv.Unlock()
func init() { c.opened--
mock = &mockDriver{&conn{}} if c.opened == 0 {
sql.Register("mock", mock) delete(c.drv.conns, c.dsn)
}
// New creates sqlmock database connection
// and pings it so that all expectations could be
// asserted on Close.
func New() (db *sql.DB, err error) {
db, err = sql.Open("mock", "")
if err != nil {
return
} }
// ensure open connection, otherwise Close does not assert expectations
return db, db.Ping()
}
// ExpectBegin expects transaction to be started var expected *ExpectedClose
func ExpectBegin() Mock { var fulfilled int
e := &expectedBegin{} var ok bool
mock.conn.expectations = append(mock.conn.expectations, e) for _, next := range c.expected {
mock.conn.active = e next.Lock()
return mock.conn if next.fulfilled() {
} next.Unlock()
fulfilled++
// ExpectCommit expects transaction to be commited continue
func ExpectCommit() Mock { }
e := &expectedCommit{}
mock.conn.expectations = append(mock.conn.expectations, e) if expected, ok = next.(*ExpectedClose); ok {
mock.conn.active = e break
return mock.conn }
}
next.Unlock()
// ExpectRollback expects transaction to be rolled back if c.MatchExpectationsInOrder {
func ExpectRollback() Mock { return fmt.Errorf("call to database Close, was not expected, next expectation is: %s", next)
e := &expectedRollback{}
mock.conn.expectations = append(mock.conn.expectations, e)
mock.conn.active = e
return mock.conn
}
// ExpectPrepare expects Query to be prepared
func ExpectPrepare() Mock {
e := &expectedPrepare{}
mock.conn.expectations = append(mock.conn.expectations, e)
mock.conn.active = e
return mock.conn
}
// WillReturnError the expectation will return an error
func (c *conn) WillReturnError(err error) Mock {
c.active.setError(err)
return c
}
// ExpectExec expects database Exec to be triggered, which will match
// the given query string as a regular expression
func ExpectExec(sqlRegexStr string) Mock {
e := &expectedExec{}
e.sqlRegex = regexp.MustCompile(sqlRegexStr)
mock.conn.expectations = append(mock.conn.expectations, e)
mock.conn.active = e
return mock.conn
}
// ExpectQuery database Query to be triggered, which will match
// the given query string as a regular expression
func ExpectQuery(sqlRegexStr string) Mock {
e := &expectedQuery{}
e.sqlRegex = regexp.MustCompile(sqlRegexStr)
mock.conn.expectations = append(mock.conn.expectations, e)
mock.conn.active = e
return mock.conn
}
// WithArgs expectation should be called with given arguments.
// Works with Exec and Query expectations
func (c *conn) WithArgs(args ...driver.Value) Mock {
eq, ok := c.active.(*expectedQuery)
if !ok {
ee, ok := c.active.(*expectedExec)
if !ok {
panic(fmt.Sprintf("arguments may be expected only with query based expectations, current is %T", c.active))
} }
ee.args = args
} else {
eq.args = args
} }
return c
if expected == nil {
msg := "call to database Close was not expected"
if fulfilled == len(c.expected) {
msg = "all expectations were already fulfilled, " + msg
}
return fmt.Errorf(msg)
}
expected.triggered = true
expected.Unlock()
return expected.err
} }
// WillReturnResult expectation will return a Result. // ExpectationsWereMet checks whether all queued expectations
// Works only with Exec expectations // were met in order. If any of them was not met - an error is returned.
func (c *conn) WillReturnResult(result driver.Result) Mock { func (c *Sqlmock) ExpectationsWereMet() error {
eq, ok := c.active.(*expectedExec) for _, e := range c.expected {
if !ok { if !e.fulfilled() {
panic(fmt.Sprintf("driver.result may be returned only by exec expectations, current is %T", c.active)) return fmt.Errorf("there is a remaining expectation which was not matched: %s", e)
}
} }
eq.result = result return nil
return c
} }
// WillReturnRows expectation will return Rows. // Begin meets http://golang.org/pkg/database/sql/driver/#Conn interface
// Works only with Query expectations func (c *Sqlmock) Begin() (driver.Tx, error) {
func (c *conn) WillReturnRows(rows driver.Rows) Mock { var expected *ExpectedBegin
eq, ok := c.active.(*expectedQuery) var ok bool
if !ok { var fulfilled int
panic(fmt.Sprintf("driver.rows may be returned only by query expectations, current is %T", c.active)) for _, next := range c.expected {
next.Lock()
if next.fulfilled() {
next.Unlock()
fulfilled++
continue
}
if expected, ok = next.(*ExpectedBegin); ok {
break
}
next.Unlock()
if c.MatchExpectationsInOrder {
return nil, fmt.Errorf("call to database transaction Begin, was not expected, next expectation is: %s", next)
}
} }
eq.rows = rows if expected == nil {
return c msg := "call to database transaction Begin was not expected"
if fulfilled == len(c.expected) {
msg = "all expectations were already fulfilled, " + msg
}
return nil, fmt.Errorf(msg)
}
expected.triggered = true
expected.Unlock()
return c, expected.err
}
// ExpectBegin expects *sql.DB.Begin to be called.
// the *ExpectedBegin allows to mock database response
func (c *Sqlmock) ExpectBegin() *ExpectedBegin {
e := &ExpectedBegin{}
c.expected = append(c.expected, e)
return e
}
// Exec meets http://golang.org/pkg/database/sql/driver/#Execer
func (c *Sqlmock) Exec(query string, args []driver.Value) (res driver.Result, err error) {
query = stripQuery(query)
var expected *ExpectedExec
var fulfilled int
var ok bool
for _, next := range c.expected {
next.Lock()
if next.fulfilled() {
next.Unlock()
fulfilled++
continue
}
if c.MatchExpectationsInOrder {
if expected, ok = next.(*ExpectedExec); ok {
break
}
next.Unlock()
return nil, fmt.Errorf("call to exec query '%s' with args %+v, was not expected, next expectation is: %s", query, args, next)
}
if exec, ok := next.(*ExpectedExec); ok {
if exec.attemptMatch(query, args) {
expected = exec
break
}
}
next.Unlock()
}
if expected == nil {
msg := "call to exec '%s' query with args %+v was not expected"
if fulfilled == len(c.expected) {
msg = "all expectations were already fulfilled, " + msg
}
return nil, fmt.Errorf(msg, query, args)
}
defer expected.Unlock()
expected.triggered = true
// converts panic to error in case of reflect value type mismatch
defer func(errp *error, exp *ExpectedExec, q string, a []driver.Value) {
if e := recover(); e != nil {
if se, ok := e.(*reflect.ValueError); ok { // catch reflect error, failed type conversion
msg := "exec query \"%s\", args \"%+v\" failed to match with error \"%s\" expectation: %s"
*errp = fmt.Errorf(msg, q, a, se, exp)
} else {
panic(e) // overwise if unknown error panic
}
}
}(&err, expected, query, args)
if !expected.queryMatches(query) {
return nil, fmt.Errorf("exec query '%s', does not match regex '%s'", query, expected.sqlRegex.String())
}
if !expected.argsMatches(args) {
return nil, fmt.Errorf("exec query '%s', args %+v does not match expected %+v", query, args, expected.args)
}
if expected.err != nil {
return nil, expected.err // mocked to return error
}
if expected.result == nil {
return nil, fmt.Errorf("exec query '%s' with args %+v, must return a database/sql/driver.result, but it was not set for expectation %T as %+v", query, args, expected, expected)
}
return expected.result, err
}
// ExpectExec expects Exec() to be called with sql query
// which match sqlRegexStr given regexp.
// the *ExpectedExec allows to mock database response
func (c *Sqlmock) ExpectExec(sqlRegexStr string) *ExpectedExec {
e := &ExpectedExec{}
e.sqlRegex = regexp.MustCompile(sqlRegexStr)
c.expected = append(c.expected, e)
return e
}
// Prepare meets http://golang.org/pkg/database/sql/driver/#Conn interface
func (c *Sqlmock) Prepare(query string) (driver.Stmt, error) {
var expected *ExpectedPrepare
var fulfilled int
var ok bool
for _, next := range c.expected {
next.Lock()
if next.fulfilled() {
next.Unlock()
fulfilled++
continue
}
if expected, ok = next.(*ExpectedPrepare); ok {
break
}
next.Unlock()
if c.MatchExpectationsInOrder {
return nil, fmt.Errorf("call to Prepare stetement with query '%s', was not expected, next expectation is: %s", query, next)
}
}
query = stripQuery(query)
if expected == nil {
msg := "call to Prepare '%s' query was not expected"
if fulfilled == len(c.expected) {
msg = "all expectations were already fulfilled, " + msg
}
return nil, fmt.Errorf(msg, query)
}
expected.triggered = true
expected.Unlock()
return &statement{c, query, expected.closeErr}, expected.err
}
// ExpectPrepare expects Prepare() to be called with sql query
// which match sqlRegexStr given regexp.
// the *ExpectedPrepare allows to mock database response.
// Note that you may expect Query() or Exec() on the *ExpectedPrepare
// statement to prevent repeating sqlRegexStr
func (c *Sqlmock) ExpectPrepare(sqlRegexStr string) *ExpectedPrepare {
e := &ExpectedPrepare{sqlRegex: regexp.MustCompile(sqlRegexStr), mock: c}
c.expected = append(c.expected, e)
return e
}
// Query meets http://golang.org/pkg/database/sql/driver/#Queryer
func (c *Sqlmock) Query(query string, args []driver.Value) (rw driver.Rows, err error) {
query = stripQuery(query)
var expected *ExpectedQuery
var fulfilled int
var ok bool
for _, next := range c.expected {
next.Lock()
if next.fulfilled() {
next.Unlock()
fulfilled++
continue
}
if c.MatchExpectationsInOrder {
if expected, ok = next.(*ExpectedQuery); ok {
break
}
next.Unlock()
return nil, fmt.Errorf("call to query '%s' with args %+v, was not expected, next expectation is: %s", query, args, next)
}
if qr, ok := next.(*ExpectedQuery); ok {
if qr.attemptMatch(query, args) {
expected = qr
break
}
}
next.Unlock()
}
if expected == nil {
msg := "call to query '%s' with args %+v was not expected"
if fulfilled == len(c.expected) {
msg = "all expectations were already fulfilled, " + msg
}
return nil, fmt.Errorf(msg, query, args)
}
defer expected.Unlock()
expected.triggered = true
// converts panic to error in case of reflect value type mismatch
defer func(errp *error, exp *ExpectedQuery, q string, a []driver.Value) {
if e := recover(); e != nil {
if se, ok := e.(*reflect.ValueError); ok { // catch reflect error, failed type conversion
msg := "query \"%s\", args \"%+v\" failed to match with error \"%s\" expectation: %s"
*errp = fmt.Errorf(msg, q, a, se, exp)
} else {
panic(e) // overwise if unknown error panic
}
}
}(&err, expected, query, args)
if !expected.queryMatches(query) {
return nil, fmt.Errorf("query '%s', does not match regex [%s]", query, expected.sqlRegex.String())
}
if !expected.argsMatches(args) {
return nil, fmt.Errorf("query '%s', args %+v does not match expected %+v", query, args, expected.args)
}
if expected.err != nil {
return nil, expected.err // mocked to return error
}
if expected.rows == nil {
return nil, fmt.Errorf("query '%s' with args %+v, must return a database/sql/driver.rows, but it was not set for expectation %T as %+v", query, args, expected, expected)
}
return expected.rows, err
}
// ExpectQuery expects Query() or QueryRow() to be called with sql query
// which match sqlRegexStr given regexp.
// the *ExpectedQuery allows to mock database response.
func (c *Sqlmock) ExpectQuery(sqlRegexStr string) *ExpectedQuery {
e := &ExpectedQuery{}
e.sqlRegex = regexp.MustCompile(sqlRegexStr)
c.expected = append(c.expected, e)
return e
}
// ExpectCommit expects *sql.Tx.Commit to be called.
// the *ExpectedCommit allows to mock database response
func (c *Sqlmock) ExpectCommit() *ExpectedCommit {
e := &ExpectedCommit{}
c.expected = append(c.expected, e)
return e
}
// ExpectRollback expects *sql.Tx.Rollback to be called.
// the *ExpectedRollback allows to mock database response
func (c *Sqlmock) ExpectRollback() *ExpectedRollback {
e := &ExpectedRollback{}
c.expected = append(c.expected, e)
return e
}
// Commit meets http://golang.org/pkg/database/sql/driver/#Tx
func (c *Sqlmock) Commit() error {
var expected *ExpectedCommit
var fulfilled int
var ok bool
for _, next := range c.expected {
next.Lock()
if next.fulfilled() {
next.Unlock()
fulfilled++
continue
}
if expected, ok = next.(*ExpectedCommit); ok {
break
}
next.Unlock()
if c.MatchExpectationsInOrder {
return fmt.Errorf("call to commit transaction, was not expected, next expectation is: %s", next)
}
}
if expected == nil {
msg := "call to commit transaction was not expected"
if fulfilled == len(c.expected) {
msg = "all expectations were already fulfilled, " + msg
}
return fmt.Errorf(msg)
}
expected.triggered = true
expected.Unlock()
return expected.err
}
// Rollback meets http://golang.org/pkg/database/sql/driver/#Tx
func (c *Sqlmock) Rollback() error {
var expected *ExpectedRollback
var fulfilled int
var ok bool
for _, next := range c.expected {
next.Lock()
if next.fulfilled() {
next.Unlock()
fulfilled++
continue
}
if expected, ok = next.(*ExpectedRollback); ok {
break
}
next.Unlock()
if c.MatchExpectationsInOrder {
return fmt.Errorf("call to rollback transaction, was not expected, next expectation is: %s", next)
}
}
if expected == nil {
msg := "call to rollback transaction was not expected"
if fulfilled == len(c.expected) {
msg = "all expectations were already fulfilled, " + msg
}
return fmt.Errorf(msg)
}
expected.triggered = true
expected.Unlock()
return expected.err
} }

View File

@ -3,16 +3,60 @@ package sqlmock
import ( import (
"database/sql" "database/sql"
"fmt" "fmt"
"sync"
"testing" "testing"
"time" "time"
) )
func cancelOrder(db *sql.DB, orderID int) error {
tx, _ := db.Begin()
_, _ = tx.Query("SELECT * FROM orders {0} FOR UPDATE", orderID)
_ = tx.Rollback()
return nil
}
func Example() {
// Open new mock database
db, mock, err := New()
if err != nil {
fmt.Println("error creating mock database")
return
}
// columns to be used for result
columns := []string{"id", "status"}
// expect transaction begin
mock.ExpectBegin()
// expect query to fetch order, match it with regexp
mock.ExpectQuery("SELECT (.+) FROM orders (.+) FOR UPDATE").
WithArgs(1).
WillReturnRows(NewRows(columns).AddRow(1, 1))
// expect transaction rollback, since order status is "cancelled"
mock.ExpectRollback()
// run the cancel order function
someOrderID := 1
// call a function which executes expected database operations
err = cancelOrder(db, someOrderID)
if err != nil {
fmt.Printf("unexpected error: %s", err)
return
}
// ensure all expectations have been met
if err = mock.ExpectationsWereMet(); err != nil {
fmt.Printf("unmet expectation error: %s", err)
}
// Output:
}
func TestIssue14EscapeSQL(t *testing.T) { func TestIssue14EscapeSQL(t *testing.T) {
db, err := New() t.Parallel()
db, mock, err := New()
if err != nil { if err != nil {
t.Errorf("an error '%s' was not expected when opening a stub database connection", err) t.Errorf("an error '%s' was not expected when opening a stub database connection", err)
} }
ExpectExec("INSERT INTO mytable\\(a, b\\)"). defer db.Close()
mock.ExpectExec("INSERT INTO mytable\\(a, b\\)").
WithArgs("A", "B"). WithArgs("A", "B").
WillReturnResult(NewResult(1, 1)) WillReturnResult(NewResult(1, 1))
@ -21,37 +65,40 @@ func TestIssue14EscapeSQL(t *testing.T) {
t.Errorf("error '%s' was not expected, while inserting a row", err) t.Errorf("error '%s' was not expected, while inserting a row", err)
} }
err = db.Close() if err := mock.ExpectationsWereMet(); err != nil {
if err != nil { t.Errorf("there were unfulfilled expections: %s", err)
t.Errorf("error '%s' was not expected while closing the database", err)
} }
} }
// test the case when db is not triggered and expectations // test the case when db is not triggered and expectations
// are not asserted on close // are not asserted on close
func TestIssue4(t *testing.T) { func TestIssue4(t *testing.T) {
db, err := New() t.Parallel()
db, mock, err := New()
if err != nil { if err != nil {
t.Errorf("an error '%s' was not expected when opening a stub database connection", err) t.Errorf("an error '%s' was not expected when opening a stub database connection", err)
} }
ExpectQuery("some sql query which will not be called"). defer db.Close()
mock.ExpectQuery("some sql query which will not be called").
WillReturnRows(NewRows([]string{"id"})) WillReturnRows(NewRows([]string{"id"}))
err = db.Close() if err := mock.ExpectationsWereMet(); err == nil {
if err == nil { t.Errorf("was expecting an error since query was not triggered")
t.Errorf("Was expecting an error, since expected query was not matched")
} }
} }
func TestMockQuery(t *testing.T) { func TestMockQuery(t *testing.T) {
db, err := sql.Open("mock", "") t.Parallel()
db, mock, err := New()
if err != nil { if err != nil {
t.Errorf("an error '%s' was not expected when opening a stub database connection", err) t.Errorf("an error '%s' was not expected when opening a stub database connection", err)
} }
defer db.Close()
rs := NewRows([]string{"id", "title"}).FromCSVString("5,hello world") rs := NewRows([]string{"id", "title"}).FromCSVString("5,hello world")
ExpectQuery("SELECT (.+) FROM articles WHERE id = ?"). mock.ExpectQuery("SELECT (.+) FROM articles WHERE id = ?").
WithArgs(5). WithArgs(5).
WillReturnRows(rs) WillReturnRows(rs)
@ -59,11 +106,13 @@ func TestMockQuery(t *testing.T) {
if err != nil { if err != nil {
t.Errorf("error '%s' was not expected while retrieving mock rows", err) t.Errorf("error '%s' was not expected while retrieving mock rows", err)
} }
defer func() { defer func() {
if er := rows.Close(); er != nil { if er := rows.Close(); er != nil {
t.Error("Unexpected error while trying to close rows") t.Error("Unexpected error while trying to close rows")
} }
}() }()
if !rows.Next() { if !rows.Next() {
t.Error("it must have had one row as result, but got empty result set instead") t.Error("it must have had one row as result, but got empty result set instead")
} }
@ -84,16 +133,18 @@ func TestMockQuery(t *testing.T) {
t.Errorf("expected mocked title to be 'hello world', but got '%s' instead", title) t.Errorf("expected mocked title to be 'hello world', but got '%s' instead", title)
} }
if err = db.Close(); err != nil { if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("error '%s' was not expected while closing the database", err) t.Errorf("there were unfulfilled expections: %s", err)
} }
} }
func TestMockQueryTypes(t *testing.T) { func TestMockQueryTypes(t *testing.T) {
db, err := sql.Open("mock", "") t.Parallel()
db, mock, err := New()
if err != nil { if err != nil {
t.Errorf("an error '%s' was not expected when opening a stub database connection", err) t.Errorf("an error '%s' was not expected when opening a stub database connection", err)
} }
defer db.Close()
columns := []string{"id", "timestamp", "sold"} columns := []string{"id", "timestamp", "sold"}
@ -101,7 +152,7 @@ func TestMockQueryTypes(t *testing.T) {
rs := NewRows(columns) rs := NewRows(columns)
rs.AddRow(5, timestamp, true) rs.AddRow(5, timestamp, true)
ExpectQuery("SELECT (.+) FROM sales WHERE id = ?"). mock.ExpectQuery("SELECT (.+) FROM sales WHERE id = ?").
WithArgs(5). WithArgs(5).
WillReturnRows(rs) WillReturnRows(rs)
@ -139,20 +190,22 @@ func TestMockQueryTypes(t *testing.T) {
t.Errorf("expected mocked boolean to be true, but got %v instead", sold) t.Errorf("expected mocked boolean to be true, but got %v instead", sold)
} }
if err = db.Close(); err != nil { if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("error '%s' was not expected while closing the database", err) t.Errorf("there were unfulfilled expections: %s", err)
} }
} }
func TestTransactionExpectations(t *testing.T) { func TestTransactionExpectations(t *testing.T) {
db, err := sql.Open("mock", "") t.Parallel()
db, mock, err := New()
if err != nil { if err != nil {
t.Errorf("an error '%s' was not expected when opening a stub database connection", err) t.Errorf("an error '%s' was not expected when opening a stub database connection", err)
} }
defer db.Close()
// begin and commit // begin and commit
ExpectBegin() mock.ExpectBegin()
ExpectCommit() mock.ExpectCommit()
tx, err := db.Begin() tx, err := db.Begin()
if err != nil { if err != nil {
@ -165,8 +218,8 @@ func TestTransactionExpectations(t *testing.T) {
} }
// begin and rollback // begin and rollback
ExpectBegin() mock.ExpectBegin()
ExpectRollback() mock.ExpectRollback()
tx, err = db.Begin() tx, err = db.Begin()
if err != nil { if err != nil {
@ -179,25 +232,28 @@ func TestTransactionExpectations(t *testing.T) {
} }
// begin with an error // begin with an error
ExpectBegin().WillReturnError(fmt.Errorf("some err")) mock.ExpectBegin().WillReturnError(fmt.Errorf("some err"))
tx, err = db.Begin() tx, err = db.Begin()
if err == nil { if err == nil {
t.Error("an error was expected when beginning a transaction, but got none") t.Error("an error was expected when beginning a transaction, but got none")
} }
if err = db.Close(); err != nil { if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("error '%s' was not expected while closing the database", err) t.Errorf("there were unfulfilled expections: %s", err)
} }
} }
func TestPrepareExpectations(t *testing.T) { func TestPrepareExpectations(t *testing.T) {
db, err := sql.Open("mock", "") t.Parallel()
db, mock, err := New()
if err != nil { if err != nil {
t.Errorf("an error '%s' was not expected when opening a stub database connection", err) t.Errorf("an error '%s' was not expected when opening a stub database connection", err)
} }
defer db.Close()
mock.ExpectPrepare("SELECT (.+) FROM articles WHERE id = ?")
// no expectations, w/o ExpectPrepare()
stmt, err := db.Prepare("SELECT (.+) FROM articles WHERE id = ?") stmt, err := db.Prepare("SELECT (.+) FROM articles WHERE id = ?")
if err != nil { if err != nil {
t.Errorf("error '%s' was not expected while creating a prepared statement", err) t.Errorf("error '%s' was not expected while creating a prepared statement", err)
@ -211,36 +267,19 @@ func TestPrepareExpectations(t *testing.T) {
var title string var title string
rs := NewRows([]string{"id", "title"}).FromCSVString("5,hello world") rs := NewRows([]string{"id", "title"}).FromCSVString("5,hello world")
ExpectQuery("SELECT (.+) FROM articles WHERE id = ?"). mock.ExpectQuery("SELECT (.+) FROM articles WHERE id = ?").
WithArgs(5). WithArgs(5).
WillReturnRows(rs) WillReturnRows(rs)
stmt, err = db.Prepare("SELECT (.+) FROM articles WHERE id = ?")
if err != nil {
t.Errorf("error '%s' was not expected while creating a prepared statement", err)
}
if stmt == nil {
t.Errorf("stmt was expected while creating a prepared statement")
}
err = stmt.QueryRow(5).Scan(&id, &title) err = stmt.QueryRow(5).Scan(&id, &title)
if err != nil { if err != nil {
t.Errorf("error '%s' was not expected while retrieving mock rows", err) t.Errorf("error '%s' was not expected while retrieving mock rows", err)
} }
// expect normal result mock.ExpectPrepare("SELECT (.+) FROM articles WHERE id = ?").
ExpectPrepare() WillReturnError(fmt.Errorf("Some DB error occurred"))
stmt, err = db.Prepare("SELECT (.+) FROM articles WHERE id = ?")
if err != nil {
t.Errorf("error '%s' was not expected while creating a prepared statement", err)
}
if stmt == nil {
t.Errorf("stmt was expected while creating a prepared statement")
}
// expect error result stmt, err = db.Prepare("SELECT id FROM articles WHERE id = ?")
ExpectPrepare().WillReturnError(fmt.Errorf("Some DB error occurred"))
stmt, err = db.Prepare("SELECT (.+) FROM articles WHERE id = ?")
if err == nil { if err == nil {
t.Error("error was expected while creating a prepared statement") t.Error("error was expected while creating a prepared statement")
} }
@ -248,35 +287,38 @@ func TestPrepareExpectations(t *testing.T) {
t.Errorf("stmt was not expected while creating a prepared statement returning error") t.Errorf("stmt was not expected while creating a prepared statement returning error")
} }
if err = db.Close(); err != nil { if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("error '%s' was not expected while closing the database", err) t.Errorf("there were unfulfilled expections: %s", err)
} }
} }
func TestPreparedQueryExecutions(t *testing.T) { func TestPreparedQueryExecutions(t *testing.T) {
db, err := sql.Open("mock", "") t.Parallel()
db, mock, err := New()
if err != nil { if err != nil {
t.Errorf("an error '%s' was not expected when opening a stub database connection", err) t.Errorf("an error '%s' was not expected when opening a stub database connection", err)
} }
defer db.Close()
mock.ExpectPrepare("SELECT (.+) FROM articles WHERE id = ?")
rs1 := NewRows([]string{"id", "title"}).FromCSVString("5,hello world") rs1 := NewRows([]string{"id", "title"}).FromCSVString("5,hello world")
ExpectQuery("SELECT (.+) FROM articles WHERE id = ?"). mock.ExpectQuery("SELECT (.+) FROM articles WHERE id = ?").
WithArgs(5). WithArgs(5).
WillReturnRows(rs1) WillReturnRows(rs1)
rs2 := NewRows([]string{"id", "title"}).FromCSVString("2,whoop") rs2 := NewRows([]string{"id", "title"}).FromCSVString("2,whoop")
ExpectQuery("SELECT (.+) FROM articles WHERE id = ?"). mock.ExpectQuery("SELECT (.+) FROM articles WHERE id = ?").
WithArgs(2). WithArgs(2).
WillReturnRows(rs2) WillReturnRows(rs2)
stmt, err := db.Prepare("SELECT (.+) FROM articles WHERE id = ?") stmt, err := db.Prepare("SELECT id, title FROM articles WHERE id = ?")
if err != nil { if err != nil {
t.Errorf("error '%s' was not expected while creating a prepared statement", err) t.Errorf("error '%s' was not expected while creating a prepared statement", err)
} }
var id int var id int
var title string var title string
err = stmt.QueryRow(5).Scan(&id, &title) err = stmt.QueryRow(5).Scan(&id, &title)
if err != nil { if err != nil {
t.Errorf("error '%s' was not expected querying row from statement and scanning", err) t.Errorf("error '%s' was not expected querying row from statement and scanning", err)
@ -303,18 +345,21 @@ func TestPreparedQueryExecutions(t *testing.T) {
t.Errorf("expected mocked title to be 'whoop', but got '%s' instead", title) t.Errorf("expected mocked title to be 'whoop', but got '%s' instead", title)
} }
if err = db.Close(); err != nil { if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("error '%s' was not expected while closing the database", err) t.Errorf("there were unfulfilled expections: %s", err)
} }
} }
func TestUnexpectedOperations(t *testing.T) { func TestUnexpectedOperations(t *testing.T) {
db, err := sql.Open("mock", "") t.Parallel()
db, mock, err := New()
if err != nil { if err != nil {
t.Errorf("an error '%s' was not expected when opening a stub database connection", err) t.Errorf("an error '%s' was not expected when opening a stub database connection", err)
} }
defer db.Close()
stmt, err := db.Prepare("SELECT (.+) FROM articles WHERE id = ?") mock.ExpectPrepare("SELECT (.+) FROM articles WHERE id = ?")
stmt, err := db.Prepare("SELECT id, title FROM articles WHERE id = ?")
if err != nil { if err != nil {
t.Errorf("error '%s' was not expected while creating a prepared statement", err) t.Errorf("error '%s' was not expected while creating a prepared statement", err)
} }
@ -327,39 +372,35 @@ func TestUnexpectedOperations(t *testing.T) {
t.Error("error was expected querying row, since there was no such expectation") t.Error("error was expected querying row, since there was no such expectation")
} }
ExpectRollback() mock.ExpectRollback()
err = db.Close() if err := mock.ExpectationsWereMet(); err == nil {
if err == nil { t.Errorf("was expecting an error since query was not triggered")
t.Error("error was expected while closing the database, expectation was not fulfilled", err)
} }
} }
func TestWrongExpectations(t *testing.T) { func TestWrongExpectations(t *testing.T) {
db, err := sql.Open("mock", "") t.Parallel()
db, mock, err := New()
if err != nil { if err != nil {
t.Errorf("an error '%s' was not expected when opening a stub database connection", err) t.Errorf("an error '%s' was not expected when opening a stub database connection", err)
} }
defer db.Close()
ExpectBegin() mock.ExpectBegin()
rs1 := NewRows([]string{"id", "title"}).FromCSVString("5,hello world") rs1 := NewRows([]string{"id", "title"}).FromCSVString("5,hello world")
ExpectQuery("SELECT (.+) FROM articles WHERE id = ?"). mock.ExpectQuery("SELECT (.+) FROM articles WHERE id = ?").
WithArgs(5). WithArgs(5).
WillReturnRows(rs1) WillReturnRows(rs1)
ExpectCommit().WillReturnError(fmt.Errorf("deadlock occured")) mock.ExpectCommit().WillReturnError(fmt.Errorf("deadlock occured"))
ExpectRollback() // won't be triggered mock.ExpectRollback() // won't be triggered
stmt, err := db.Prepare("SELECT (.+) FROM articles WHERE id = ? FOR UPDATE")
if err != nil {
t.Errorf("error '%s' was not expected while creating a prepared statement", err)
}
var id int var id int
var title string var title string
err = stmt.QueryRow(5).Scan(&id, &title) err = db.QueryRow("SELECT id, title FROM articles WHERE id = ? FOR UPDATE", 5).Scan(&id, &title)
if err == nil { if err == nil {
t.Error("error was expected while querying row, since there begin transaction expectation is not fulfilled") t.Error("error was expected while querying row, since there begin transaction expectation is not fulfilled")
} }
@ -370,7 +411,7 @@ func TestWrongExpectations(t *testing.T) {
t.Errorf("an error '%s' was not expected when beginning a transaction", err) t.Errorf("an error '%s' was not expected when beginning a transaction", err)
} }
err = stmt.QueryRow(5).Scan(&id, &title) err = db.QueryRow("SELECT id, title FROM articles WHERE id = ? FOR UPDATE", 5).Scan(&id, &title)
if err != nil { if err != nil {
t.Errorf("error '%s' was not expected while querying row, since transaction was started", err) t.Errorf("error '%s' was not expected while querying row, since transaction was started", err)
} }
@ -380,20 +421,21 @@ func TestWrongExpectations(t *testing.T) {
t.Error("a deadlock error was expected when commiting a transaction", err) t.Error("a deadlock error was expected when commiting a transaction", err)
} }
err = db.Close() if err := mock.ExpectationsWereMet(); err == nil {
if err == nil { t.Errorf("was expecting an error since query was not triggered")
t.Error("error was expected while closing the database, expectation was not fulfilled", err)
} }
} }
func TestExecExpectations(t *testing.T) { func TestExecExpectations(t *testing.T) {
db, err := sql.Open("mock", "") t.Parallel()
db, mock, err := New()
if err != nil { if err != nil {
t.Errorf("an error '%s' was not expected when opening a stub database connection", err) t.Errorf("an error '%s' was not expected when opening a stub database connection", err)
} }
defer db.Close()
result := NewResult(1, 1) result := NewResult(1, 1)
ExpectExec("^INSERT INTO articles"). mock.ExpectExec("^INSERT INTO articles").
WithArgs("hello"). WithArgs("hello").
WillReturnResult(result) WillReturnResult(result)
@ -420,22 +462,24 @@ func TestExecExpectations(t *testing.T) {
t.Errorf("expected affected rows to be 1, but got %d instead", affected) t.Errorf("expected affected rows to be 1, but got %d instead", affected)
} }
if err = db.Close(); err != nil { if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("error '%s' was not expected while closing the database", err) t.Errorf("there were unfulfilled expections: %s", err)
} }
} }
func TestRowBuilderAndNilTypes(t *testing.T) { func TestRowBuilderAndNilTypes(t *testing.T) {
db, err := sql.Open("mock", "") t.Parallel()
db, mock, err := New()
if err != nil { if err != nil {
t.Errorf("an error '%s' was not expected when opening a stub database connection", err) t.Errorf("an error '%s' was not expected when opening a stub database connection", err)
} }
defer db.Close()
rs := NewRows([]string{"id", "active", "created", "status"}). rs := NewRows([]string{"id", "active", "created", "status"}).
AddRow(1, true, time.Now(), 5). AddRow(1, true, time.Now(), 5).
AddRow(2, false, nil, nil) AddRow(2, false, nil, nil)
ExpectQuery("SELECT (.+) FROM sales").WillReturnRows(rs) mock.ExpectQuery("SELECT (.+) FROM sales").WillReturnRows(rs)
rows, err := db.Query("SELECT * FROM sales") rows, err := db.Query("SELECT * FROM sales")
if err != nil { if err != nil {
@ -510,23 +554,107 @@ func TestRowBuilderAndNilTypes(t *testing.T) {
t.Errorf("expected 'status' to be invalid, but it %+v is not", status) t.Errorf("expected 'status' to be invalid, but it %+v is not", status)
} }
if err = db.Close(); err != nil { if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("error '%s' was not expected while closing the database", err) t.Errorf("there were unfulfilled expections: %s", err)
} }
} }
func TestArgumentReflectValueTypeError(t *testing.T) { func TestArgumentReflectValueTypeError(t *testing.T) {
db, err := sql.Open("mock", "") t.Parallel()
db, mock, err := New()
if err != nil { if err != nil {
t.Errorf("an error '%s' was not expected when opening a stub database connection", err) t.Errorf("an error '%s' was not expected when opening a stub database connection", err)
} }
defer db.Close()
rs := NewRows([]string{"id"}).AddRow(1) rs := NewRows([]string{"id"}).AddRow(1)
ExpectQuery("SELECT (.+) FROM sales").WithArgs(5.5).WillReturnRows(rs) mock.ExpectQuery("SELECT (.+) FROM sales").WithArgs(5.5).WillReturnRows(rs)
_, err = db.Query("SELECT * FROM sales WHERE x = ?", 5) _, err = db.Query("SELECT * FROM sales WHERE x = ?", 5)
if err == nil { if err == nil {
t.Error("Expected error, but got none") t.Error("Expected error, but got none")
} }
} }
func TestGoroutineExecutionWithUnorderedExpectationMatching(t *testing.T) {
t.Parallel()
db, mock, err := New()
if err != nil {
t.Errorf("an error '%s' was not expected when opening a stub database connection", err)
}
defer db.Close()
// note this line is important for unordered expectation matching
mock.MatchExpectationsInOrder = false
result := NewResult(1, 1)
mock.ExpectExec("^UPDATE one").WithArgs("one").WillReturnResult(result)
mock.ExpectExec("^UPDATE two").WithArgs("one", "two").WillReturnResult(result)
mock.ExpectExec("^UPDATE three").WithArgs("one", "two", "three").WillReturnResult(result)
var wg sync.WaitGroup
queries := map[string][]interface{}{
"one": []interface{}{"one"},
"two": []interface{}{"one", "two"},
"three": []interface{}{"one", "two", "three"},
}
wg.Add(len(queries))
for table, args := range queries {
go func(tbl string, a []interface{}) {
if _, err := db.Exec("UPDATE "+tbl, a...); err != nil {
t.Errorf("error was not expected: %s", err)
}
wg.Done()
}(table, args)
}
wg.Wait()
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("there were unfulfilled expections: %s", err)
}
}
func ExampleSqlmock_goroutines() {
db, mock, err := New()
if err != nil {
fmt.Println("failed to open sqlmock database:", err)
}
defer db.Close()
// note this line is important for unordered expectation matching
mock.MatchExpectationsInOrder = false
result := NewResult(1, 1)
mock.ExpectExec("^UPDATE one").WithArgs("one").WillReturnResult(result)
mock.ExpectExec("^UPDATE two").WithArgs("one", "two").WillReturnResult(result)
mock.ExpectExec("^UPDATE three").WithArgs("one", "two", "three").WillReturnResult(result)
var wg sync.WaitGroup
queries := map[string][]interface{}{
"one": []interface{}{"one"},
"two": []interface{}{"one", "two"},
"three": []interface{}{"one", "two", "three"},
}
wg.Add(len(queries))
for table, args := range queries {
go func(tbl string, a []interface{}) {
if _, err := db.Exec("UPDATE "+tbl, a...); err != nil {
fmt.Println("error was not expected:", err)
}
wg.Done()
}(table, args)
}
wg.Wait()
if err := mock.ExpectationsWereMet(); err != nil {
fmt.Println("there were unfulfilled expections:", err)
}
// Output:
}

View File

@ -5,12 +5,13 @@ import (
) )
type statement struct { type statement struct {
conn *conn conn *Sqlmock
query string query string
err error
} }
func (stmt *statement) Close() error { func (stmt *statement) Close() error {
return nil return stmt.err
} }
func (stmt *statement) NumInput() int { func (stmt *statement) NumInput() int {

View File

@ -1,37 +0,0 @@
package sqlmock
import (
"fmt"
)
type transaction struct {
conn *conn
}
func (tx *transaction) Commit() error {
e := tx.conn.next()
if e == nil {
return fmt.Errorf("all expectations were already fulfilled, call to commit transaction was not expected")
}
etc, ok := e.(*expectedCommit)
if !ok {
return fmt.Errorf("call to commit transaction, was not expected, next expectation was %v", e)
}
etc.triggered = true
return etc.err
}
func (tx *transaction) Rollback() error {
e := tx.conn.next()
if e == nil {
return fmt.Errorf("all expectations were already fulfilled, call to rollback transaction was not expected")
}
etr, ok := e.(*expectedRollback)
if !ok {
return fmt.Errorf("call to rollback transaction, was not expected, next expectation was %v", e)
}
etr.triggered = true
return etr.err
}

View File

@ -5,11 +5,7 @@ import (
"strings" "strings"
) )
var re *regexp.Regexp var re = regexp.MustCompile("\\s+")
func init() {
re = regexp.MustCompile("\\s+")
}
// strip out new lines and trim spaces // strip out new lines and trim spaces
func stripQuery(q string) (s string) { func stripQuery(q string) (s string) {