mirror of
https://github.com/DATA-DOG/go-sqlmock.git
synced 2024-12-10 10:10:05 +02:00
Merge pull request #21 from DATA-DOG/refactor
concurrency support, closes #20 and closes #9 and closes #15
This commit is contained in:
commit
711064c51d
4
.gitignore
vendored
4
.gitignore
vendored
@ -1 +1,3 @@
|
||||
/*.test
|
||||
/examples/blog/blog
|
||||
/examples/orders/orders
|
||||
/examples/basic/basic
|
||||
|
@ -1,4 +1,5 @@
|
||||
language: go
|
||||
sudo: false
|
||||
go:
|
||||
- 1.2
|
||||
- 1.3
|
||||
@ -7,10 +8,5 @@ go:
|
||||
- tip
|
||||
|
||||
script:
|
||||
- go get github.com/kisielk/errcheck
|
||||
- go get ./...
|
||||
|
||||
- go test -v ./...
|
||||
- go test -race ./...
|
||||
- errcheck github.com/DATA-DOG/go-sqlmock
|
||||
|
||||
|
2
LICENSE
2
LICENSE
@ -1,6 +1,6 @@
|
||||
The three clause BSD license (http://en.wikipedia.org/wiki/BSD_licenses)
|
||||
|
||||
Copyright (c) 2013, DataDog.lt team
|
||||
Copyright (c) 2013-2015, DataDog.lt team
|
||||
All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
|
399
README.md
399
README.md
@ -5,336 +5,153 @@
|
||||
|
||||
This is a **mock** driver as **database/sql/driver** which is very flexible and pragmatic to
|
||||
manage and mock expected queries. All the expectations should be met and all queries and actions
|
||||
triggered should be mocked in order to pass a test.
|
||||
triggered should be mocked in order to pass a test. The package has no 3rd party dependencies.
|
||||
|
||||
**NOTE:** regarding major issues #20 and #9 the **api** has changed to support concurrency and more than
|
||||
one database connection.
|
||||
|
||||
If you need an old version, checkout **go-sqlmock** at gopkg.in:
|
||||
|
||||
go get gopkg.in/DATA-DOG/go-sqlmock.v0
|
||||
|
||||
Otherwise use the **v1** branch from master which should be stable afterwards, because all the issues which
|
||||
were known will be fixed in this version.
|
||||
|
||||
## Install
|
||||
|
||||
go get github.com/DATA-DOG/go-sqlmock
|
||||
go get gopkg.in/DATA-DOG/go-sqlmock.v1
|
||||
|
||||
## Use it with pleasure
|
||||
Or take an older version:
|
||||
|
||||
An example of some database interaction which you may want to test:
|
||||
go get gopkg.in/DATA-DOG/go-sqlmock.v0
|
||||
|
||||
## Documentation and Examples
|
||||
|
||||
Visit [godoc](http://godoc.org/github.com/DATA-DOG/go-sqlmock) for general examples and public api reference.
|
||||
See **.travis.yml** for supported **go** versions.
|
||||
Different use case, is to functionally test with a real database - [go-txdb](https://github.com/DATA-DOG/go-txdb)
|
||||
all database related actions are isolated within a single transaction so the database can remain in the same state.
|
||||
|
||||
See implementation examples:
|
||||
|
||||
- [blog API server](https://github.com/DATA-DOG/go-sqlmock/tree/master/examples/blog)
|
||||
- [the same orders example](https://github.com/DATA-DOG/go-sqlmock/tree/master/examples/orders)
|
||||
|
||||
### Something you may want to test
|
||||
|
||||
``` go
|
||||
package main
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
_ "github.com/go-sql-driver/mysql"
|
||||
"github.com/kisielk/sqlstruct"
|
||||
"fmt"
|
||||
"log"
|
||||
)
|
||||
import "database/sql"
|
||||
|
||||
const ORDER_PENDING = 0
|
||||
const ORDER_CANCELLED = 1
|
||||
func recordStats(db *sql.DB, userID, productID int64) (err error) {
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
type User struct {
|
||||
Id int `sql:"id"`
|
||||
Username string `sql:"username"`
|
||||
Balance float64 `sql:"balance"`
|
||||
}
|
||||
defer func() {
|
||||
switch err {
|
||||
case nil:
|
||||
err = tx.Commit()
|
||||
default:
|
||||
tx.Rollback()
|
||||
}
|
||||
}()
|
||||
|
||||
type Order struct {
|
||||
Id int `sql:"id"`
|
||||
Value float64 `sql:"value"`
|
||||
ReservedFee float64 `sql:"reserved_fee"`
|
||||
Status int `sql:"status"`
|
||||
}
|
||||
|
||||
func cancelOrder(id int, db *sql.DB) (err error) {
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
var order Order
|
||||
var user User
|
||||
sql := fmt.Sprintf(`
|
||||
SELECT %s, %s
|
||||
FROM orders AS o
|
||||
INNER JOIN users AS u ON o.buyer_id = u.id
|
||||
WHERE o.id = ?
|
||||
FOR UPDATE`,
|
||||
sqlstruct.ColumnsAliased(order, "o"),
|
||||
sqlstruct.ColumnsAliased(user, "u"))
|
||||
|
||||
// fetch order to cancel
|
||||
rows, err := tx.Query(sql, id)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
return
|
||||
}
|
||||
|
||||
defer rows.Close()
|
||||
// no rows, nothing to do
|
||||
if !rows.Next() {
|
||||
tx.Rollback()
|
||||
return
|
||||
}
|
||||
|
||||
// read order
|
||||
err = sqlstruct.ScanAliased(&order, rows, "o")
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
return
|
||||
}
|
||||
|
||||
// ensure order status
|
||||
if order.Status != ORDER_PENDING {
|
||||
tx.Rollback()
|
||||
return
|
||||
}
|
||||
|
||||
// read user
|
||||
err = sqlstruct.ScanAliased(&user, rows, "u")
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
return
|
||||
}
|
||||
rows.Close() // manually close before other prepared statements
|
||||
|
||||
// refund order value
|
||||
sql = "UPDATE users SET balance = balance + ? WHERE id = ?"
|
||||
refundStmt, err := tx.Prepare(sql)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
return
|
||||
}
|
||||
defer refundStmt.Close()
|
||||
_, err = refundStmt.Exec(order.Value + order.ReservedFee, user.Id)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
return
|
||||
}
|
||||
|
||||
// update order status
|
||||
order.Status = ORDER_CANCELLED
|
||||
sql = "UPDATE orders SET status = ?, updated = NOW() WHERE id = ?"
|
||||
orderUpdStmt, err := tx.Prepare(sql)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
return
|
||||
}
|
||||
defer orderUpdStmt.Close()
|
||||
_, err = orderUpdStmt.Exec(order.Status, order.Id)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
return
|
||||
}
|
||||
return tx.Commit()
|
||||
if _, err = tx.Exec("UPDATE products SET views = views + 1"); err != nil {
|
||||
return
|
||||
}
|
||||
if _, err = tx.Exec("INSERT INTO product_viewers (user_id, product_id) VALUES (?, ?)", userID, productID); err != nil {
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func main() {
|
||||
db, err := sql.Open("mysql", "root:nimda@/test")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
err = cancelOrder(1, db)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
// @NOTE: the real connection is not required for tests
|
||||
db, err := sql.Open("mysql", "root@/blog")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
if err = recordStats(db, 1 /*some user id*/, 5 /*some product id*/); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
And the clean nice test:
|
||||
### Tests with sqlmock
|
||||
|
||||
``` go
|
||||
package main
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
"testing"
|
||||
"fmt"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
)
|
||||
|
||||
// will test that order with a different status, cannot be cancelled
|
||||
func TestShouldNotCancelOrderWithNonPendingStatus(t *testing.T) {
|
||||
// open database stub
|
||||
db, err := sqlmock.New()
|
||||
if err != nil {
|
||||
t.Errorf("An error '%s' was not expected when opening a stub database connection", err)
|
||||
}
|
||||
// a successful case
|
||||
func TestShouldUpdateStats(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
if err != nil {
|
||||
t.Fatalf("an error '%s' was not expected when opening a stub database connection", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
// columns are prefixed with "o" since we used sqlstruct to generate them
|
||||
columns := []string{"o_id", "o_status"}
|
||||
// expect transaction begin
|
||||
sqlmock.ExpectBegin()
|
||||
// expect query to fetch order and user, match it with regexp
|
||||
sqlmock.ExpectQuery("SELECT (.+) FROM orders AS o INNER JOIN users AS u (.+) FOR UPDATE").
|
||||
WithArgs(1).
|
||||
WillReturnRows(sqlmock.NewRows(columns).FromCSVString("1,1"))
|
||||
// expect transaction rollback, since order status is "cancelled"
|
||||
sqlmock.ExpectRollback()
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectExec("UPDATE products").WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
mock.ExpectExec("INSERT INTO product_viewers").WithArgs(2, 3).WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
mock.ExpectCommit()
|
||||
|
||||
// run the cancel order function
|
||||
err = cancelOrder(1, db)
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error, but got %s instead", err)
|
||||
}
|
||||
// db.Close() ensures that all expectations have been met
|
||||
if err = db.Close(); err != nil {
|
||||
t.Errorf("Error '%s' was not expected while closing the database", err)
|
||||
}
|
||||
// now we execute our method
|
||||
if err = recordStats(db, 2, 3); err != nil {
|
||||
t.Errorf("error was not expected while updating stats: %s", err)
|
||||
}
|
||||
|
||||
// we make sure that all expectations were met
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("there were unfulfilled expections: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
// will test order cancellation
|
||||
func TestShouldRefundUserWhenOrderIsCancelled(t *testing.T) {
|
||||
// open database stub
|
||||
db, err := sqlmock.New()
|
||||
if err != nil {
|
||||
t.Errorf("An error '%s' was not expected when opening a stub database connection", err)
|
||||
}
|
||||
// a failing test case
|
||||
func TestShouldRollbackStatUpdatesOnFailure(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
if err != nil {
|
||||
t.Fatalf("an error '%s' was not expected when opening a stub database connection", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
// columns are prefixed with "o" since we used sqlstruct to generate them
|
||||
columns := []string{"o_id", "o_status", "o_value", "o_reserved_fee", "u_id", "u_balance"}
|
||||
// expect transaction begin
|
||||
sqlmock.ExpectBegin()
|
||||
// expect query to fetch order and user, match it with regexp
|
||||
sqlmock.ExpectQuery("SELECT (.+) FROM orders AS o INNER JOIN users AS u (.+) FOR UPDATE").
|
||||
WithArgs(1).
|
||||
WillReturnRows(sqlmock.NewRows(columns).AddRow(1, 0, 25.75, 3.25, 2, 10.00))
|
||||
// expect user balance update
|
||||
sqlmock.ExpectExec("UPDATE users SET balance").
|
||||
WithArgs(25.75 + 3.25, 2). // refund amount, user id
|
||||
WillReturnResult(sqlmock.NewResult(0, 1)) // no insert id, 1 affected row
|
||||
// expect order status update
|
||||
sqlmock.ExpectExec("UPDATE orders SET status").
|
||||
WithArgs(ORDER_CANCELLED, 1). // status, id
|
||||
WillReturnResult(sqlmock.NewResult(0, 1)) // no insert id, 1 affected row
|
||||
// expect a transaction commit
|
||||
sqlmock.ExpectCommit()
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectExec("UPDATE products").WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
mock.ExpectExec("INSERT INTO product_viewers").
|
||||
WithArgs(2, 3).
|
||||
WillReturnError(fmt.Errorf("some error"))
|
||||
mock.ExpectRollback()
|
||||
|
||||
// run the cancel order function
|
||||
err = cancelOrder(1, db)
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error, but got %s instead", err)
|
||||
}
|
||||
// db.Close() ensures that all expectations have been met
|
||||
if err = db.Close(); err != nil {
|
||||
t.Errorf("Error '%s' was not expected while closing the database", err)
|
||||
}
|
||||
// now we execute our method
|
||||
if err = recordStats(db, 2, 3); err == nil {
|
||||
t.Errorf("was expecting an error, but there was none")
|
||||
}
|
||||
|
||||
// we make sure that all expectations were met
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("there were unfulfilled expections: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
// will test order cancellation
|
||||
func TestShouldRollbackOnError(t *testing.T) {
|
||||
// open database stub
|
||||
db, err := sqlmock.New()
|
||||
if err != nil {
|
||||
t.Errorf("An error '%s' was not expected when opening a stub database connection", err)
|
||||
}
|
||||
|
||||
// expect transaction begin
|
||||
sqlmock.ExpectBegin()
|
||||
// expect query to fetch order and user, match it with regexp
|
||||
sqlmock.ExpectQuery("SELECT (.+) FROM orders AS o INNER JOIN users AS u (.+) FOR UPDATE").
|
||||
WithArgs(1).
|
||||
WillReturnError(fmt.Errorf("Some error"))
|
||||
// should rollback since error was returned from query execution
|
||||
sqlmock.ExpectRollback()
|
||||
|
||||
// run the cancel order function
|
||||
err = cancelOrder(1, db)
|
||||
// error should return back
|
||||
if err == nil {
|
||||
t.Error("Expected error, but got none")
|
||||
}
|
||||
// db.Close() ensures that all expectations have been met
|
||||
if err = db.Close(); err != nil {
|
||||
t.Errorf("Error '%s' was not expected while closing the database", err)
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Expectations
|
||||
|
||||
All **Expect** methods return a **Mock** interface which allow you to describe
|
||||
expectations in more details: return an error, expect specific arguments, return rows and so on.
|
||||
**NOTE:** that if you call **WithArgs** on a non query based expectation, it will panic
|
||||
|
||||
A **Mock** interface:
|
||||
|
||||
``` go
|
||||
type Mock interface {
|
||||
WithArgs(...driver.Value) Mock
|
||||
WillReturnError(error) Mock
|
||||
WillReturnRows(driver.Rows) Mock
|
||||
WillReturnResult(driver.Result) Mock
|
||||
}
|
||||
```
|
||||
|
||||
As an example we can expect a transaction commit and simulate an error for it:
|
||||
|
||||
``` go
|
||||
sqlmock.ExpectCommit().WillReturnError(fmt.Errorf("Deadlock occured"))
|
||||
```
|
||||
|
||||
In same fashion, we can expect queries to match arguments. If there are any, it must be matched.
|
||||
Instead of result we can return error.
|
||||
|
||||
``` go
|
||||
sqlmock.ExpectQuery("SELECT (.*) FROM orders").
|
||||
WithArgs("string value").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"col"}).AddRow("val"))
|
||||
```
|
||||
|
||||
**NOTE:** it matches a regular expression. Some regex special characters must be escaped if you want to match them.
|
||||
For example if we want to match a subselect:
|
||||
|
||||
``` go
|
||||
sqlmock.ExpectQuery("SELECT (.*) FROM orders WHERE id IN \\(SELECT id FROM finished WHERE status = 1\\)").
|
||||
WithArgs("string value").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"col"}).AddRow("val"))
|
||||
```
|
||||
|
||||
**WithArgs** expectation, compares values based on their type, for usual values like **string, float, int**
|
||||
it matches the actual value. Types like **time** are compared only by type. Other types might require different ways
|
||||
to compare them correctly, this may be improved.
|
||||
|
||||
You can build rows either from CSV string or from interface values:
|
||||
|
||||
**Rows** interface, which satisfies sql driver.Rows:
|
||||
|
||||
``` go
|
||||
type Rows interface {
|
||||
AddRow(...driver.Value) Rows
|
||||
FromCSVString(s string) Rows
|
||||
Next([]driver.Value) error
|
||||
Columns() []string
|
||||
Close() error
|
||||
}
|
||||
```
|
||||
|
||||
Example for to build rows:
|
||||
|
||||
``` go
|
||||
rs := sqlmock.NewRows([]string{"column1", "column2"}).
|
||||
FromCSVString("one,1\ntwo,2").
|
||||
AddRow("three", 3)
|
||||
```
|
||||
|
||||
**Prepare** will ignore other expectations if ExpectPrepare not set. When set, can expect normal result or simulate an error:
|
||||
|
||||
``` go
|
||||
rs := sqlmock.ExpectPrepare().
|
||||
WillReturnError(fmt.Errorf("Query prepare failed"))
|
||||
```
|
||||
|
||||
## Run tests
|
||||
|
||||
go test
|
||||
|
||||
## Documentation
|
||||
|
||||
Visit [godoc](http://godoc.org/github.com/DATA-DOG/go-sqlmock)
|
||||
See **.travis.yml** for supported **go** versions
|
||||
Different use case, is to functionally test with a real database - [go-txdb](https://github.com/DATA-DOG/go-txdb)
|
||||
all database related actions are isolated within a single transaction so the database can remain in the same state.
|
||||
go test -race
|
||||
|
||||
## Changes
|
||||
|
||||
- **2015-08-27** - **v1** api change, concurrency support, all known issues fixed.
|
||||
- **2014-08-16** instead of **panic** during reflect type mismatch when comparing query arguments - now return error
|
||||
- **2014-08-14** added **sqlmock.NewErrorResult** which gives an option to return driver.Result with errors for
|
||||
interface methods, see [issue](https://github.com/DATA-DOG/go-sqlmock/issues/5)
|
||||
|
151
connection.go
151
connection.go
@ -1,151 +0,0 @@
|
||||
package sqlmock
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"fmt"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
type conn struct {
|
||||
expectations []expectation
|
||||
active expectation
|
||||
}
|
||||
|
||||
// Close a mock database driver connection. It should
|
||||
// be always called to ensure that all expectations
|
||||
// were met successfully. Returns error if there is any
|
||||
func (c *conn) Close() (err error) {
|
||||
for _, e := range mock.conn.expectations {
|
||||
if !e.fulfilled() {
|
||||
err = fmt.Errorf("there is a remaining expectation %T which was not matched yet", e)
|
||||
break
|
||||
}
|
||||
}
|
||||
mock.conn.expectations = []expectation{}
|
||||
mock.conn.active = nil
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *conn) Begin() (driver.Tx, error) {
|
||||
e := c.next()
|
||||
if e == nil {
|
||||
return nil, fmt.Errorf("all expectations were already fulfilled, call to begin transaction was not expected")
|
||||
}
|
||||
|
||||
etb, ok := e.(*expectedBegin)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("call to begin transaction, was not expected, next expectation is %T as %+v", e, e)
|
||||
}
|
||||
etb.triggered = true
|
||||
return &transaction{c}, etb.err
|
||||
}
|
||||
|
||||
// get next unfulfilled expectation
|
||||
func (c *conn) next() (e expectation) {
|
||||
for _, e = range c.expectations {
|
||||
if !e.fulfilled() {
|
||||
return
|
||||
}
|
||||
}
|
||||
return nil // all expectations were fulfilled
|
||||
}
|
||||
|
||||
func (c *conn) Exec(query string, args []driver.Value) (res driver.Result, err error) {
|
||||
e := c.next()
|
||||
query = stripQuery(query)
|
||||
if e == nil {
|
||||
return nil, fmt.Errorf("all expectations were already fulfilled, call to exec '%s' query with args %+v was not expected", query, args)
|
||||
}
|
||||
|
||||
eq, ok := e.(*expectedExec)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("call to exec query '%s' with args %+v, was not expected, next expectation is %T as %+v", query, args, e, e)
|
||||
}
|
||||
|
||||
eq.triggered = true
|
||||
|
||||
defer argMatcherErrorHandler(&err) // converts panic to error in case of reflect value type mismatch
|
||||
|
||||
if !eq.queryMatches(query) {
|
||||
return nil, fmt.Errorf("exec query '%s', does not match regex '%s'", query, eq.sqlRegex.String())
|
||||
}
|
||||
|
||||
if !eq.argsMatches(args) {
|
||||
return nil, fmt.Errorf("exec query '%s', args %+v does not match expected %+v", query, args, eq.args)
|
||||
}
|
||||
|
||||
if eq.err != nil {
|
||||
return nil, eq.err // mocked to return error
|
||||
}
|
||||
|
||||
if eq.result == nil {
|
||||
return nil, fmt.Errorf("exec query '%s' with args %+v, must return a database/sql/driver.result, but it was not set for expectation %T as %+v", query, args, eq, eq)
|
||||
}
|
||||
|
||||
return eq.result, err
|
||||
}
|
||||
|
||||
func (c *conn) Prepare(query string) (driver.Stmt, error) {
|
||||
e := c.next()
|
||||
|
||||
// for backwards compatibility, ignore when Prepare not expected
|
||||
if e == nil {
|
||||
return &statement{mock.conn, stripQuery(query)}, nil
|
||||
}
|
||||
eq, ok := e.(*expectedPrepare)
|
||||
if !ok {
|
||||
return &statement{mock.conn, stripQuery(query)}, nil
|
||||
}
|
||||
|
||||
eq.triggered = true
|
||||
if eq.err != nil {
|
||||
return nil, eq.err // mocked to return error
|
||||
}
|
||||
|
||||
return &statement{mock.conn, stripQuery(query)}, nil
|
||||
}
|
||||
|
||||
func (c *conn) Query(query string, args []driver.Value) (rw driver.Rows, err error) {
|
||||
e := c.next()
|
||||
query = stripQuery(query)
|
||||
if e == nil {
|
||||
return nil, fmt.Errorf("all expectations were already fulfilled, call to query '%s' with args %+v was not expected", query, args)
|
||||
}
|
||||
|
||||
eq, ok := e.(*expectedQuery)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("call to query '%s' with args %+v, was not expected, next expectation is %T as %+v", query, args, e, e)
|
||||
}
|
||||
|
||||
eq.triggered = true
|
||||
|
||||
defer argMatcherErrorHandler(&err) // converts panic to error in case of reflect value type mismatch
|
||||
|
||||
if !eq.queryMatches(query) {
|
||||
return nil, fmt.Errorf("query '%s', does not match regex [%s]", query, eq.sqlRegex.String())
|
||||
}
|
||||
|
||||
if !eq.argsMatches(args) {
|
||||
return nil, fmt.Errorf("query '%s', args %+v does not match expected %+v", query, args, eq.args)
|
||||
}
|
||||
|
||||
if eq.err != nil {
|
||||
return nil, eq.err // mocked to return error
|
||||
}
|
||||
|
||||
if eq.rows == nil {
|
||||
return nil, fmt.Errorf("query '%s' with args %+v, must return a database/sql/driver.rows, but it was not set for expectation %T as %+v", query, args, eq, eq)
|
||||
}
|
||||
|
||||
return eq.rows, err
|
||||
}
|
||||
|
||||
func argMatcherErrorHandler(errp *error) {
|
||||
if e := recover(); e != nil {
|
||||
if se, ok := e.(*reflect.ValueError); ok { // catch reflect error, failed type conversion
|
||||
*errp = fmt.Errorf("Failed to compare query arguments: %s", se)
|
||||
} else {
|
||||
panic(e) // overwise panic
|
||||
}
|
||||
}
|
||||
}
|
@ -1,378 +0,0 @@
|
||||
package sqlmock
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"errors"
|
||||
"regexp"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestExecNoExpectations(t *testing.T) {
|
||||
c := &conn{
|
||||
expectations: []expectation{
|
||||
&expectedExec{
|
||||
queryBasedExpectation: queryBasedExpectation{
|
||||
commonExpectation: commonExpectation{
|
||||
triggered: true,
|
||||
err: errors.New("WillReturnError"),
|
||||
},
|
||||
sqlRegex: regexp.MustCompile(regexp.QuoteMeta("otherquery")),
|
||||
args: []driver.Value{456},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
res, err := c.Exec("query", []driver.Value{123})
|
||||
if res != nil {
|
||||
t.Error("Result should be nil")
|
||||
}
|
||||
if err == nil {
|
||||
t.Error("error should not be nil")
|
||||
}
|
||||
pattern := regexp.MustCompile(regexp.QuoteMeta("all expectations were already fulfilled, call to exec"))
|
||||
if !pattern.MatchString(err.Error()) {
|
||||
t.Errorf("error should match expected error message (actual: %s)", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecExpectationMismatch(t *testing.T) {
|
||||
c := &conn{
|
||||
expectations: []expectation{
|
||||
&expectedQuery{
|
||||
queryBasedExpectation: queryBasedExpectation{
|
||||
commonExpectation: commonExpectation{
|
||||
err: errors.New("WillReturnError"),
|
||||
},
|
||||
sqlRegex: regexp.MustCompile(regexp.QuoteMeta("otherquery")),
|
||||
args: []driver.Value{456},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
res, err := c.Exec("query", []driver.Value{123})
|
||||
if res != nil {
|
||||
t.Error("Result should be nil")
|
||||
}
|
||||
if err == nil {
|
||||
t.Error("error should not be nil")
|
||||
}
|
||||
pattern := regexp.MustCompile(regexp.QuoteMeta("was not expected, next expectation is"))
|
||||
if !pattern.MatchString(err.Error()) {
|
||||
t.Errorf("error should match expected error message (actual: %s)", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecQueryMismatch(t *testing.T) {
|
||||
c := &conn{
|
||||
expectations: []expectation{
|
||||
&expectedExec{
|
||||
queryBasedExpectation: queryBasedExpectation{
|
||||
commonExpectation: commonExpectation{
|
||||
err: errors.New("WillReturnError"),
|
||||
},
|
||||
sqlRegex: regexp.MustCompile(regexp.QuoteMeta("otherquery")),
|
||||
args: []driver.Value{456},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
res, err := c.Exec("query", []driver.Value{123})
|
||||
if res != nil {
|
||||
t.Error("Result should be nil")
|
||||
}
|
||||
if err == nil {
|
||||
t.Error("error should not be nil")
|
||||
}
|
||||
pattern := regexp.MustCompile(regexp.QuoteMeta("does not match regex"))
|
||||
if !pattern.MatchString(err.Error()) {
|
||||
t.Errorf("error should match expected error message (actual: %s)", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecArgsMismatch(t *testing.T) {
|
||||
c := &conn{
|
||||
expectations: []expectation{
|
||||
&expectedExec{
|
||||
queryBasedExpectation: queryBasedExpectation{
|
||||
commonExpectation: commonExpectation{
|
||||
err: errors.New("WillReturnError"),
|
||||
},
|
||||
sqlRegex: regexp.MustCompile(regexp.QuoteMeta("query")),
|
||||
args: []driver.Value{456},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
res, err := c.Exec("query", []driver.Value{123})
|
||||
if res != nil {
|
||||
t.Error("Result should be nil")
|
||||
}
|
||||
if err == nil {
|
||||
t.Error("error should not be nil")
|
||||
}
|
||||
pattern := regexp.MustCompile(regexp.QuoteMeta("does not match expected"))
|
||||
if !pattern.MatchString(err.Error()) {
|
||||
t.Errorf("error should match expected error message (actual: %s)", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecWillReturnError(t *testing.T) {
|
||||
c := &conn{
|
||||
expectations: []expectation{
|
||||
&expectedExec{
|
||||
queryBasedExpectation: queryBasedExpectation{
|
||||
commonExpectation: commonExpectation{
|
||||
err: errors.New("WillReturnError"),
|
||||
},
|
||||
sqlRegex: regexp.MustCompile(regexp.QuoteMeta("query")),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
res, err := c.Exec("query", []driver.Value{123})
|
||||
if res != nil {
|
||||
t.Error("Result should be nil")
|
||||
}
|
||||
if err == nil {
|
||||
t.Error("error should not be nil")
|
||||
}
|
||||
if err.Error() != "WillReturnError" {
|
||||
t.Errorf("error should match expected error message (actual: %s)", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecMissingResult(t *testing.T) {
|
||||
c := &conn{
|
||||
expectations: []expectation{
|
||||
&expectedExec{
|
||||
queryBasedExpectation: queryBasedExpectation{
|
||||
commonExpectation: commonExpectation{},
|
||||
sqlRegex: regexp.MustCompile(regexp.QuoteMeta("query")),
|
||||
args: []driver.Value{123},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
res, err := c.Exec("query", []driver.Value{123})
|
||||
if res != nil {
|
||||
t.Error("Result should be nil")
|
||||
}
|
||||
if err == nil {
|
||||
t.Error("error should not be nil")
|
||||
}
|
||||
pattern := regexp.MustCompile(regexp.QuoteMeta("must return a database/sql/driver.result, but it was not set for expectation"))
|
||||
if !pattern.MatchString(err.Error()) {
|
||||
t.Errorf("error should match expected error message (actual: %s)", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func TestExec(t *testing.T) {
|
||||
expectedResult := driver.Result(&result{})
|
||||
c := &conn{
|
||||
expectations: []expectation{
|
||||
&expectedExec{
|
||||
queryBasedExpectation: queryBasedExpectation{
|
||||
commonExpectation: commonExpectation{},
|
||||
sqlRegex: regexp.MustCompile(regexp.QuoteMeta("query")),
|
||||
args: []driver.Value{123},
|
||||
},
|
||||
result: expectedResult,
|
||||
},
|
||||
},
|
||||
}
|
||||
res, err := c.Exec("query", []driver.Value{123})
|
||||
if res == nil {
|
||||
t.Error("Result should not be nil")
|
||||
}
|
||||
if res != expectedResult {
|
||||
t.Errorf("Result should match expected Result (actual %+v)", res)
|
||||
}
|
||||
if err != nil {
|
||||
t.Errorf("error should be nil (actual %s)", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func TestQueryNoExpectations(t *testing.T) {
|
||||
c := &conn{
|
||||
expectations: []expectation{
|
||||
&expectedQuery{
|
||||
queryBasedExpectation: queryBasedExpectation{
|
||||
commonExpectation: commonExpectation{
|
||||
triggered: true,
|
||||
err: errors.New("WillReturnError"),
|
||||
},
|
||||
sqlRegex: regexp.MustCompile(regexp.QuoteMeta("otherquery")),
|
||||
args: []driver.Value{456},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
res, err := c.Query("query", []driver.Value{123})
|
||||
if res != nil {
|
||||
t.Error("Rows should be nil")
|
||||
}
|
||||
if err == nil {
|
||||
t.Error("error should not be nil")
|
||||
}
|
||||
pattern := regexp.MustCompile(regexp.QuoteMeta("all expectations were already fulfilled, call to query"))
|
||||
if !pattern.MatchString(err.Error()) {
|
||||
t.Errorf("error should match expected error message (actual: %s)", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func TestQueryExpectationMismatch(t *testing.T) {
|
||||
c := &conn{
|
||||
expectations: []expectation{
|
||||
&expectedExec{
|
||||
queryBasedExpectation: queryBasedExpectation{
|
||||
commonExpectation: commonExpectation{
|
||||
err: errors.New("WillReturnError"),
|
||||
},
|
||||
sqlRegex: regexp.MustCompile(regexp.QuoteMeta("otherquery")),
|
||||
args: []driver.Value{456},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
res, err := c.Query("query", []driver.Value{123})
|
||||
if res != nil {
|
||||
t.Error("Rows should be nil")
|
||||
}
|
||||
if err == nil {
|
||||
t.Error("error should not be nil")
|
||||
}
|
||||
pattern := regexp.MustCompile(regexp.QuoteMeta("was not expected, next expectation is"))
|
||||
if !pattern.MatchString(err.Error()) {
|
||||
t.Errorf("error should match expected error message (actual: %s)", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func TestQueryQueryMismatch(t *testing.T) {
|
||||
c := &conn{
|
||||
expectations: []expectation{
|
||||
&expectedQuery{
|
||||
queryBasedExpectation: queryBasedExpectation{
|
||||
commonExpectation: commonExpectation{
|
||||
err: errors.New("WillReturnError"),
|
||||
},
|
||||
sqlRegex: regexp.MustCompile(regexp.QuoteMeta("otherquery")),
|
||||
args: []driver.Value{456},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
res, err := c.Query("query", []driver.Value{123})
|
||||
if res != nil {
|
||||
t.Error("Rows should be nil")
|
||||
}
|
||||
if err == nil {
|
||||
t.Error("error should not be nil")
|
||||
}
|
||||
pattern := regexp.MustCompile(regexp.QuoteMeta("does not match regex"))
|
||||
if !pattern.MatchString(err.Error()) {
|
||||
t.Errorf("error should match expected error message (actual: %s)", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func TestQueryArgsMismatch(t *testing.T) {
|
||||
c := &conn{
|
||||
expectations: []expectation{
|
||||
&expectedQuery{
|
||||
queryBasedExpectation: queryBasedExpectation{
|
||||
commonExpectation: commonExpectation{
|
||||
err: errors.New("WillReturnError"),
|
||||
},
|
||||
sqlRegex: regexp.MustCompile(regexp.QuoteMeta("query")),
|
||||
args: []driver.Value{456},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
res, err := c.Query("query", []driver.Value{123})
|
||||
if res != nil {
|
||||
t.Error("Rows should be nil")
|
||||
}
|
||||
if err == nil {
|
||||
t.Error("error should not be nil")
|
||||
}
|
||||
pattern := regexp.MustCompile(regexp.QuoteMeta("does not match expected"))
|
||||
if !pattern.MatchString(err.Error()) {
|
||||
t.Errorf("error should match expected error message (actual: %s)", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func TestQueryWillReturnError(t *testing.T) {
|
||||
c := &conn{
|
||||
expectations: []expectation{
|
||||
&expectedQuery{
|
||||
queryBasedExpectation: queryBasedExpectation{
|
||||
commonExpectation: commonExpectation{
|
||||
err: errors.New("WillReturnError"),
|
||||
},
|
||||
sqlRegex: regexp.MustCompile(regexp.QuoteMeta("query")),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
res, err := c.Query("query", []driver.Value{123})
|
||||
if res != nil {
|
||||
t.Error("Rows should be nil")
|
||||
}
|
||||
if err == nil {
|
||||
t.Error("error should not be nil")
|
||||
}
|
||||
if err.Error() != "WillReturnError" {
|
||||
t.Errorf("error should match expected error message (actual: %s)", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func TestQueryMissingRows(t *testing.T) {
|
||||
c := &conn{
|
||||
expectations: []expectation{
|
||||
&expectedQuery{
|
||||
queryBasedExpectation: queryBasedExpectation{
|
||||
commonExpectation: commonExpectation{},
|
||||
sqlRegex: regexp.MustCompile(regexp.QuoteMeta("query")),
|
||||
args: []driver.Value{123},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
res, err := c.Query("query", []driver.Value{123})
|
||||
if res != nil {
|
||||
t.Error("Rows should be nil")
|
||||
}
|
||||
if err == nil {
|
||||
t.Error("error should not be nil")
|
||||
}
|
||||
pattern := regexp.MustCompile(regexp.QuoteMeta("must return a database/sql/driver.rows, but it was not set for expectation"))
|
||||
if !pattern.MatchString(err.Error()) {
|
||||
t.Errorf("error should match expected error message (actual: %s)", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func TestQuery(t *testing.T) {
|
||||
expectedRows := driver.Rows(&rows{})
|
||||
c := &conn{
|
||||
expectations: []expectation{
|
||||
&expectedQuery{
|
||||
queryBasedExpectation: queryBasedExpectation{
|
||||
commonExpectation: commonExpectation{},
|
||||
sqlRegex: regexp.MustCompile(regexp.QuoteMeta("query")),
|
||||
args: []driver.Value{123},
|
||||
},
|
||||
rows: expectedRows,
|
||||
},
|
||||
},
|
||||
}
|
||||
rows, err := c.Query("query", []driver.Value{123})
|
||||
if rows == nil {
|
||||
t.Error("Rows should not be nil")
|
||||
}
|
||||
if rows != expectedRows {
|
||||
t.Errorf("Rows should match expected Rows (actual %+v)", rows)
|
||||
}
|
||||
if err != nil {
|
||||
t.Errorf("error should be nil (actual %s)", err.Error())
|
||||
}
|
||||
}
|
56
driver.go
Normal file
56
driver.go
Normal file
@ -0,0 +1,56 @@
|
||||
package sqlmock
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"fmt"
|
||||
"sync"
|
||||
)
|
||||
|
||||
var pool *mockDriver
|
||||
|
||||
func init() {
|
||||
pool = &mockDriver{
|
||||
conns: make(map[string]*Sqlmock),
|
||||
}
|
||||
sql.Register("sqlmock", pool)
|
||||
}
|
||||
|
||||
type mockDriver struct {
|
||||
sync.Mutex
|
||||
counter int
|
||||
conns map[string]*Sqlmock
|
||||
}
|
||||
|
||||
func (d *mockDriver) Open(dsn string) (driver.Conn, error) {
|
||||
d.Lock()
|
||||
defer d.Unlock()
|
||||
|
||||
c, ok := d.conns[dsn]
|
||||
if !ok {
|
||||
return c, fmt.Errorf("expected a connection to be available, but it is not")
|
||||
}
|
||||
|
||||
c.opened++
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// New creates sqlmock database connection
|
||||
// and a mock to manage expectations.
|
||||
// Pings db so that all expectations could be
|
||||
// asserted.
|
||||
func New() (db *sql.DB, mock *Sqlmock, err error) {
|
||||
pool.Lock()
|
||||
dsn := fmt.Sprintf("sqlmock_db_%d", pool.counter)
|
||||
pool.counter++
|
||||
|
||||
mock = &Sqlmock{dsn: dsn, drv: pool, MatchExpectationsInOrder: true}
|
||||
pool.conns[dsn] = mock
|
||||
pool.Unlock()
|
||||
|
||||
db, err = sql.Open("sqlmock", dsn)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return db, mock, db.Ping()
|
||||
}
|
83
driver_test.go
Normal file
83
driver_test.go
Normal file
@ -0,0 +1,83 @@
|
||||
package sqlmock
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func ExampleNew() {
|
||||
db, mock, err := New()
|
||||
if err != nil {
|
||||
fmt.Println("expected no error, but got:", err)
|
||||
return
|
||||
}
|
||||
defer db.Close()
|
||||
// now we can expect operations performed on db
|
||||
mock.ExpectBegin().WillReturnError(fmt.Errorf("an error will occur on db.Begin() call"))
|
||||
}
|
||||
|
||||
func TestShouldOpenConnectionIssue15(t *testing.T) {
|
||||
db, mock, err := New()
|
||||
if err != nil {
|
||||
t.Errorf("expected no error, but got: %s", err)
|
||||
}
|
||||
if len(pool.conns) != 1 {
|
||||
t.Errorf("expected 1 connection in pool, but there is: %d", len(pool.conns))
|
||||
}
|
||||
|
||||
if mock.opened != 1 {
|
||||
t.Errorf("expected 1 connection on mock to be opened, but there is: %d", mock.opened)
|
||||
}
|
||||
|
||||
// defer so the rows gets closed first
|
||||
defer func() {
|
||||
if mock.opened != 0 {
|
||||
t.Errorf("expected no connections on mock to be opened, but there is: %d", mock.opened)
|
||||
}
|
||||
}()
|
||||
|
||||
mock.ExpectQuery("SELECT").WillReturnRows(NewRows([]string{"one", "two"}).AddRow("val1", "val2"))
|
||||
rows, err := db.Query("SELECT")
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %s", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
mock.ExpectExec("UPDATE").WillReturnResult(NewResult(1, 1))
|
||||
if _, err = db.Exec("UPDATE"); err != nil {
|
||||
t.Errorf("unexpected error: %s", err)
|
||||
}
|
||||
|
||||
// now there should be two connections open
|
||||
if mock.opened != 2 {
|
||||
t.Errorf("expected 2 connection on mock to be opened, but there is: %d", mock.opened)
|
||||
}
|
||||
|
||||
mock.ExpectClose()
|
||||
if err = db.Close(); err != nil {
|
||||
t.Errorf("expected no error on close, but got: %s", err)
|
||||
}
|
||||
|
||||
// one is still reserved for rows
|
||||
if mock.opened != 1 {
|
||||
t.Errorf("expected 1 connection on mock to be still reserved for rows, but there is: %d", mock.opened)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTwoOpenConnectionsOnTheSameDSN(t *testing.T) {
|
||||
db, mock, err := New()
|
||||
if err != nil {
|
||||
t.Errorf("expected no error, but got: %s", err)
|
||||
}
|
||||
db2, mock2, err := New()
|
||||
if len(pool.conns) != 2 {
|
||||
t.Errorf("expected 2 connection in pool, but there is: %d", len(pool.conns))
|
||||
}
|
||||
|
||||
if db == db2 {
|
||||
t.Errorf("expected not the same database instance, but it is the same")
|
||||
}
|
||||
if mock == mock2 {
|
||||
t.Errorf("expected not the same mock instance, but it is the same")
|
||||
}
|
||||
}
|
40
examples/basic/basic.go
Normal file
40
examples/basic/basic.go
Normal file
@ -0,0 +1,40 @@
|
||||
package main
|
||||
|
||||
import "database/sql"
|
||||
|
||||
func recordStats(db *sql.DB, userID, productID int64) (err error) {
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
defer func() {
|
||||
switch err {
|
||||
case nil:
|
||||
err = tx.Commit()
|
||||
default:
|
||||
tx.Rollback()
|
||||
}
|
||||
}()
|
||||
|
||||
if _, err = tx.Exec("UPDATE products SET views = views + 1"); err != nil {
|
||||
return
|
||||
}
|
||||
if _, err = tx.Exec("INSERT INTO product_viewers (user_id, product_id) VALUES (?, ?)", userID, productID); err != nil {
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func main() {
|
||||
// @NOTE: the real connection is not required for tests
|
||||
db, err := sql.Open("mysql", "root@/blog")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
if err = recordStats(db, 1 /*some user id*/, 5 /*some product id*/); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
58
examples/basic/basic_test.go
Normal file
58
examples/basic/basic_test.go
Normal file
@ -0,0 +1,58 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
)
|
||||
|
||||
// a successful case
|
||||
func TestShouldUpdateStats(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
if err != nil {
|
||||
t.Fatalf("an error '%s' was not expected when opening a stub database connection", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectExec("UPDATE products").WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
mock.ExpectExec("INSERT INTO product_viewers").WithArgs(2, 3).WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
mock.ExpectCommit()
|
||||
|
||||
// now we execute our method
|
||||
if err = recordStats(db, 2, 3); err != nil {
|
||||
t.Errorf("error was not expected while updating stats: %s", err)
|
||||
}
|
||||
|
||||
// we make sure that all expectations were met
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("there were unfulfilled expections: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
// a failing test case
|
||||
func TestShouldRollbackStatUpdatesOnFailure(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
if err != nil {
|
||||
t.Fatalf("an error '%s' was not expected when opening a stub database connection", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectExec("UPDATE products").WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
mock.ExpectExec("INSERT INTO product_viewers").
|
||||
WithArgs(2, 3).
|
||||
WillReturnError(fmt.Errorf("some error"))
|
||||
mock.ExpectRollback()
|
||||
|
||||
// now we execute our method
|
||||
if err = recordStats(db, 2, 3); err == nil {
|
||||
t.Errorf("was expecting an error, but there was none")
|
||||
}
|
||||
|
||||
// we make sure that all expectations were met
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("there were unfulfilled expections: %s", err)
|
||||
}
|
||||
}
|
81
examples/blog/blog.go
Normal file
81
examples/blog/blog.go
Normal file
@ -0,0 +1,81 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type api struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
type post struct {
|
||||
ID int
|
||||
Title string
|
||||
Body string
|
||||
}
|
||||
|
||||
func (a *api) posts(w http.ResponseWriter, r *http.Request) {
|
||||
rows, err := a.db.Query("SELECT id, title, body FROM posts")
|
||||
if err != nil {
|
||||
a.fail(w, "failed to fetch posts: "+err.Error(), 500)
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var posts []*post
|
||||
for rows.Next() {
|
||||
p := &post{}
|
||||
if err := rows.Scan(&p.ID, &p.Title, &p.Body); err != nil {
|
||||
a.fail(w, "failed to scan post: "+err.Error(), 500)
|
||||
return
|
||||
}
|
||||
posts = append(posts, p)
|
||||
}
|
||||
if rows.Err() != nil {
|
||||
a.fail(w, "failed to read all posts: "+rows.Err().Error(), 500)
|
||||
return
|
||||
}
|
||||
|
||||
data := struct {
|
||||
Posts []*post
|
||||
}{posts}
|
||||
|
||||
a.ok(w, data)
|
||||
}
|
||||
|
||||
func main() {
|
||||
// @NOTE: the real connection is not required for tests
|
||||
db, err := sql.Open("mysql", "root@/blog")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
app := &api{db: db}
|
||||
http.HandleFunc("/posts", app.posts)
|
||||
http.ListenAndServe(":8080", nil)
|
||||
}
|
||||
|
||||
func (a *api) fail(w http.ResponseWriter, msg string, status int) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
data := struct {
|
||||
Error string
|
||||
}{Error: msg}
|
||||
|
||||
resp, _ := json.Marshal(data)
|
||||
w.WriteHeader(status)
|
||||
w.Write(resp)
|
||||
}
|
||||
|
||||
func (a *api) ok(w http.ResponseWriter, data interface{}) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
a.fail(w, "oops something evil has happened", 500)
|
||||
return
|
||||
}
|
||||
w.Write(resp)
|
||||
}
|
102
examples/blog/blog_test.go
Normal file
102
examples/blog/blog_test.go
Normal file
@ -0,0 +1,102 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
)
|
||||
|
||||
func (a *api) assertJSON(actual []byte, data interface{}, t *testing.T) {
|
||||
expected, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
t.Fatalf("an error '%s' was not expected when marshaling expected json data", err)
|
||||
}
|
||||
|
||||
if bytes.Compare(expected, actual) != 0 {
|
||||
t.Errorf("the expected json: %s is different from actual %s", expected, actual)
|
||||
}
|
||||
}
|
||||
|
||||
func TestShouldGetPosts(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
if err != nil {
|
||||
t.Fatalf("an error '%s' was not expected when opening a stub database connection", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
// create app with mocked db, request and response to test
|
||||
app := &api{db}
|
||||
req, err := http.NewRequest("GET", "http://localhost/posts", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("an error '%s' was not expected while creating request", err)
|
||||
}
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
// before we actually execute our api function, we need to expect required DB actions
|
||||
rows := sqlmock.NewRows([]string{"id", "title", "body"}).
|
||||
AddRow(1, "post 1", "hello").
|
||||
AddRow(2, "post 2", "world")
|
||||
|
||||
mock.ExpectQuery("^SELECT (.+) FROM posts$").WillReturnRows(rows)
|
||||
|
||||
// now we execute our request
|
||||
app.posts(w, req)
|
||||
|
||||
if w.Code != 200 {
|
||||
t.Fatalf("expected status code to be 200, but got: %d", w.Code)
|
||||
}
|
||||
|
||||
data := struct {
|
||||
Posts []*post
|
||||
}{Posts: []*post{
|
||||
{ID: 1, Title: "post 1", Body: "hello"},
|
||||
{ID: 2, Title: "post 2", Body: "world"},
|
||||
}}
|
||||
app.assertJSON(w.Body.Bytes(), data, t)
|
||||
|
||||
// we make sure that all expectations were met
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("there were unfulfilled expections: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestShouldRespondWithErrorOnFailure(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
if err != nil {
|
||||
t.Fatalf("an error '%s' was not expected when opening a stub database connection", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
// create app with mocked db, request and response to test
|
||||
app := &api{db}
|
||||
req, err := http.NewRequest("GET", "http://localhost/posts", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("an error '%s' was not expected while creating request", err)
|
||||
}
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
// before we actually execute our api function, we need to expect required DB actions
|
||||
mock.ExpectQuery("^SELECT (.+) FROM posts$").WillReturnError(fmt.Errorf("some error"))
|
||||
|
||||
// now we execute our request
|
||||
app.posts(w, req)
|
||||
|
||||
if w.Code != 500 {
|
||||
t.Fatalf("expected status code to be 500, but got: %d", w.Code)
|
||||
}
|
||||
|
||||
data := struct {
|
||||
Error string
|
||||
}{"failed to fetch posts: some error"}
|
||||
app.assertJSON(w.Body.Bytes(), data, t)
|
||||
|
||||
// we make sure that all expectations were met
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("there were unfulfilled expections: %s", err)
|
||||
}
|
||||
}
|
1
examples/doc.go
Normal file
1
examples/doc.go
Normal file
@ -0,0 +1 @@
|
||||
package examples
|
121
examples/orders/orders.go
Normal file
121
examples/orders/orders.go
Normal file
@ -0,0 +1,121 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"log"
|
||||
|
||||
"github.com/kisielk/sqlstruct"
|
||||
)
|
||||
|
||||
const ORDER_PENDING = 0
|
||||
const ORDER_CANCELLED = 1
|
||||
|
||||
type User struct {
|
||||
Id int `sql:"id"`
|
||||
Username string `sql:"username"`
|
||||
Balance float64 `sql:"balance"`
|
||||
}
|
||||
|
||||
type Order struct {
|
||||
Id int `sql:"id"`
|
||||
Value float64 `sql:"value"`
|
||||
ReservedFee float64 `sql:"reserved_fee"`
|
||||
Status int `sql:"status"`
|
||||
}
|
||||
|
||||
func cancelOrder(id int, db *sql.DB) (err error) {
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
var order Order
|
||||
var user User
|
||||
sql := fmt.Sprintf(`
|
||||
SELECT %s, %s
|
||||
FROM orders AS o
|
||||
INNER JOIN users AS u ON o.buyer_id = u.id
|
||||
WHERE o.id = ?
|
||||
FOR UPDATE`,
|
||||
sqlstruct.ColumnsAliased(order, "o"),
|
||||
sqlstruct.ColumnsAliased(user, "u"))
|
||||
|
||||
// fetch order to cancel
|
||||
rows, err := tx.Query(sql, id)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
return
|
||||
}
|
||||
|
||||
defer rows.Close()
|
||||
// no rows, nothing to do
|
||||
if !rows.Next() {
|
||||
tx.Rollback()
|
||||
return
|
||||
}
|
||||
|
||||
// read order
|
||||
err = sqlstruct.ScanAliased(&order, rows, "o")
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
return
|
||||
}
|
||||
|
||||
// ensure order status
|
||||
if order.Status != ORDER_PENDING {
|
||||
tx.Rollback()
|
||||
return
|
||||
}
|
||||
|
||||
// read user
|
||||
err = sqlstruct.ScanAliased(&user, rows, "u")
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
return
|
||||
}
|
||||
rows.Close() // manually close before other prepared statements
|
||||
|
||||
// refund order value
|
||||
sql = "UPDATE users SET balance = balance + ? WHERE id = ?"
|
||||
refundStmt, err := tx.Prepare(sql)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
return
|
||||
}
|
||||
defer refundStmt.Close()
|
||||
_, err = refundStmt.Exec(order.Value+order.ReservedFee, user.Id)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
return
|
||||
}
|
||||
|
||||
// update order status
|
||||
order.Status = ORDER_CANCELLED
|
||||
sql = "UPDATE orders SET status = ?, updated = NOW() WHERE id = ?"
|
||||
orderUpdStmt, err := tx.Prepare(sql)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
return
|
||||
}
|
||||
defer orderUpdStmt.Close()
|
||||
_, err = orderUpdStmt.Exec(order.Status, order.Id)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
return
|
||||
}
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func main() {
|
||||
// @NOTE: the real connection is not required for tests
|
||||
db, err := sql.Open("mysql", "root:@/orders")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
err = cancelOrder(1, db)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
108
examples/orders/orders_test.go
Normal file
108
examples/orders/orders_test.go
Normal file
@ -0,0 +1,108 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
)
|
||||
|
||||
// will test that order with a different status, cannot be cancelled
|
||||
func TestShouldNotCancelOrderWithNonPendingStatus(t *testing.T) {
|
||||
// open database stub
|
||||
db, mock, err := sqlmock.New()
|
||||
if err != nil {
|
||||
t.Errorf("An error '%s' was not expected when opening a stub database connection", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
// columns are prefixed with "o" since we used sqlstruct to generate them
|
||||
columns := []string{"o_id", "o_status"}
|
||||
// expect transaction begin
|
||||
mock.ExpectBegin()
|
||||
// expect query to fetch order and user, match it with regexp
|
||||
mock.ExpectQuery("SELECT (.+) FROM orders AS o INNER JOIN users AS u (.+) FOR UPDATE").
|
||||
WithArgs(1).
|
||||
WillReturnRows(sqlmock.NewRows(columns).FromCSVString("1,1"))
|
||||
// expect transaction rollback, since order status is "cancelled"
|
||||
mock.ExpectRollback()
|
||||
|
||||
// run the cancel order function
|
||||
err = cancelOrder(1, db)
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error, but got %s instead", err)
|
||||
}
|
||||
// we make sure that all expectations were met
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("there were unfulfilled expections: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
// will test order cancellation
|
||||
func TestShouldRefundUserWhenOrderIsCancelled(t *testing.T) {
|
||||
// open database stub
|
||||
db, mock, err := sqlmock.New()
|
||||
if err != nil {
|
||||
t.Errorf("An error '%s' was not expected when opening a stub database connection", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
// columns are prefixed with "o" since we used sqlstruct to generate them
|
||||
columns := []string{"o_id", "o_status", "o_value", "o_reserved_fee", "u_id", "u_balance"}
|
||||
// expect transaction begin
|
||||
mock.ExpectBegin()
|
||||
// expect query to fetch order and user, match it with regexp
|
||||
mock.ExpectQuery("SELECT (.+) FROM orders AS o INNER JOIN users AS u (.+) FOR UPDATE").
|
||||
WithArgs(1).
|
||||
WillReturnRows(sqlmock.NewRows(columns).AddRow(1, 0, 25.75, 3.25, 2, 10.00))
|
||||
// expect user balance update
|
||||
mock.ExpectPrepare("UPDATE users SET balance").ExpectExec().
|
||||
WithArgs(25.75+3.25, 2). // refund amount, user id
|
||||
WillReturnResult(sqlmock.NewResult(0, 1)) // no insert id, 1 affected row
|
||||
// expect order status update
|
||||
mock.ExpectPrepare("UPDATE orders SET status").ExpectExec().
|
||||
WithArgs(ORDER_CANCELLED, 1). // status, id
|
||||
WillReturnResult(sqlmock.NewResult(0, 1)) // no insert id, 1 affected row
|
||||
// expect a transaction commit
|
||||
mock.ExpectCommit()
|
||||
|
||||
// run the cancel order function
|
||||
err = cancelOrder(1, db)
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error, but got %s instead", err)
|
||||
}
|
||||
// we make sure that all expectations were met
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("there were unfulfilled expections: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
// will test order cancellation
|
||||
func TestShouldRollbackOnError(t *testing.T) {
|
||||
// open database stub
|
||||
db, mock, err := sqlmock.New()
|
||||
if err != nil {
|
||||
t.Errorf("An error '%s' was not expected when opening a stub database connection", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
// expect transaction begin
|
||||
mock.ExpectBegin()
|
||||
// expect query to fetch order and user, match it with regexp
|
||||
mock.ExpectQuery("SELECT (.+) FROM orders AS o INNER JOIN users AS u (.+) FOR UPDATE").
|
||||
WithArgs(1).
|
||||
WillReturnError(fmt.Errorf("Some error"))
|
||||
// should rollback since error was returned from query execution
|
||||
mock.ExpectRollback()
|
||||
|
||||
// run the cancel order function
|
||||
err = cancelOrder(1, db)
|
||||
// error should return back
|
||||
if err == nil {
|
||||
t.Error("Expected error, but got none")
|
||||
}
|
||||
// we make sure that all expectations were met
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("there were unfulfilled expections: %s", err)
|
||||
}
|
||||
}
|
321
expectations.go
321
expectations.go
@ -2,12 +2,16 @@ package sqlmock
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// Argument interface allows to match
|
||||
// any argument in specific way
|
||||
// any argument in specific way when used with
|
||||
// ExpectedQuery and ExpectedExec expectations.
|
||||
type Argument interface {
|
||||
Match(driver.Value) bool
|
||||
}
|
||||
@ -15,12 +19,15 @@ type Argument interface {
|
||||
// an expectation interface
|
||||
type expectation interface {
|
||||
fulfilled() bool
|
||||
setError(err error)
|
||||
Lock()
|
||||
Unlock()
|
||||
String() string
|
||||
}
|
||||
|
||||
// common expectation struct
|
||||
// satisfies the expectation interface
|
||||
type commonExpectation struct {
|
||||
sync.Mutex
|
||||
triggered bool
|
||||
err error
|
||||
}
|
||||
@ -29,8 +36,267 @@ func (e *commonExpectation) fulfilled() bool {
|
||||
return e.triggered
|
||||
}
|
||||
|
||||
func (e *commonExpectation) setError(err error) {
|
||||
// ExpectedClose is used to manage *sql.DB.Close expectation
|
||||
// returned by *Sqlmock.ExpectClose.
|
||||
type ExpectedClose struct {
|
||||
commonExpectation
|
||||
}
|
||||
|
||||
// WillReturnError allows to set an error for *sql.DB.Close action
|
||||
func (e *ExpectedClose) WillReturnError(err error) *ExpectedClose {
|
||||
e.err = err
|
||||
return e
|
||||
}
|
||||
|
||||
// String returns string representation
|
||||
func (e *ExpectedClose) String() string {
|
||||
msg := "ExpectedClose => expecting database Close"
|
||||
if e.err != nil {
|
||||
msg += fmt.Sprintf(", which should return error: %s", e.err)
|
||||
}
|
||||
return msg
|
||||
}
|
||||
|
||||
// ExpectedBegin is used to manage *sql.DB.Begin expectation
|
||||
// returned by *Sqlmock.ExpectBegin.
|
||||
type ExpectedBegin struct {
|
||||
commonExpectation
|
||||
}
|
||||
|
||||
// WillReturnError allows to set an error for *sql.DB.Begin action
|
||||
func (e *ExpectedBegin) WillReturnError(err error) *ExpectedBegin {
|
||||
e.err = err
|
||||
return e
|
||||
}
|
||||
|
||||
// String returns string representation
|
||||
func (e *ExpectedBegin) String() string {
|
||||
msg := "ExpectedBegin => expecting database transaction Begin"
|
||||
if e.err != nil {
|
||||
msg += fmt.Sprintf(", which should return error: %s", e.err)
|
||||
}
|
||||
return msg
|
||||
}
|
||||
|
||||
// ExpectedCommit is used to manage *sql.Tx.Commit expectation
|
||||
// returned by *Sqlmock.ExpectCommit.
|
||||
type ExpectedCommit struct {
|
||||
commonExpectation
|
||||
}
|
||||
|
||||
// WillReturnError allows to set an error for *sql.Tx.Close action
|
||||
func (e *ExpectedCommit) WillReturnError(err error) *ExpectedCommit {
|
||||
e.err = err
|
||||
return e
|
||||
}
|
||||
|
||||
// String returns string representation
|
||||
func (e *ExpectedCommit) String() string {
|
||||
msg := "ExpectedCommit => expecting transaction Commit"
|
||||
if e.err != nil {
|
||||
msg += fmt.Sprintf(", which should return error: %s", e.err)
|
||||
}
|
||||
return msg
|
||||
}
|
||||
|
||||
// ExpectedRollback is used to manage *sql.Tx.Rollback expectation
|
||||
// returned by *Sqlmock.ExpectRollback.
|
||||
type ExpectedRollback struct {
|
||||
commonExpectation
|
||||
}
|
||||
|
||||
// WillReturnError allows to set an error for *sql.Tx.Rollback action
|
||||
func (e *ExpectedRollback) WillReturnError(err error) *ExpectedRollback {
|
||||
e.err = err
|
||||
return e
|
||||
}
|
||||
|
||||
// String returns string representation
|
||||
func (e *ExpectedRollback) String() string {
|
||||
msg := "ExpectedRollback => expecting transaction Rollback"
|
||||
if e.err != nil {
|
||||
msg += fmt.Sprintf(", which should return error: %s", e.err)
|
||||
}
|
||||
return msg
|
||||
}
|
||||
|
||||
// ExpectedQuery is used to manage *sql.DB.Query, *dql.DB.QueryRow, *sql.Tx.Query,
|
||||
// *sql.Tx.QueryRow, *sql.Stmt.Query or *sql.Stmt.QueryRow expectations.
|
||||
// Returned by *Sqlmock.ExpectQuery.
|
||||
type ExpectedQuery struct {
|
||||
queryBasedExpectation
|
||||
rows driver.Rows
|
||||
}
|
||||
|
||||
// WithArgs will match given expected args to actual database query arguments.
|
||||
// if at least one argument does not match, it will return an error. For specific
|
||||
// arguments an sqlmock.Argument interface can be used to match an argument.
|
||||
func (e *ExpectedQuery) WithArgs(args ...driver.Value) *ExpectedQuery {
|
||||
e.args = args
|
||||
return e
|
||||
}
|
||||
|
||||
// WillReturnError allows to set an error for expected database query
|
||||
func (e *ExpectedQuery) WillReturnError(err error) *ExpectedQuery {
|
||||
e.err = err
|
||||
return e
|
||||
}
|
||||
|
||||
// WillReturnRows specifies the set of resulting rows that will be returned
|
||||
// by the triggered query
|
||||
func (e *ExpectedQuery) WillReturnRows(rows driver.Rows) *ExpectedQuery {
|
||||
e.rows = rows
|
||||
return e
|
||||
}
|
||||
|
||||
// String returns string representation
|
||||
func (e *ExpectedQuery) String() string {
|
||||
msg := "ExpectedQuery => expecting Query or QueryRow which:"
|
||||
msg += "\n - matches sql: '" + e.sqlRegex.String() + "'"
|
||||
|
||||
if len(e.args) == 0 {
|
||||
msg += "\n - is without arguments"
|
||||
} else {
|
||||
msg += "\n - is with arguments:\n"
|
||||
for i, arg := range e.args {
|
||||
msg += fmt.Sprintf(" %d - %+v\n", i, arg)
|
||||
}
|
||||
msg = strings.TrimSpace(msg)
|
||||
}
|
||||
|
||||
if e.rows != nil {
|
||||
msg += "\n - should return rows:\n"
|
||||
rs, _ := e.rows.(*rows)
|
||||
for i, row := range rs.rows {
|
||||
msg += fmt.Sprintf(" %d - %+v\n", i, row)
|
||||
}
|
||||
msg = strings.TrimSpace(msg)
|
||||
}
|
||||
|
||||
if e.err != nil {
|
||||
msg += fmt.Sprintf("\n - should return error: %s", e.err)
|
||||
}
|
||||
|
||||
return msg
|
||||
}
|
||||
|
||||
// ExpectedExec is used to manage *sql.DB.Exec, *sql.Tx.Exec or *sql.Stmt.Exec expectations.
|
||||
// Returned by *Sqlmock.ExpectExec.
|
||||
type ExpectedExec struct {
|
||||
queryBasedExpectation
|
||||
result driver.Result
|
||||
}
|
||||
|
||||
// WithArgs will match given expected args to actual database exec operation arguments.
|
||||
// if at least one argument does not match, it will return an error. For specific
|
||||
// arguments an sqlmock.Argument interface can be used to match an argument.
|
||||
func (e *ExpectedExec) WithArgs(args ...driver.Value) *ExpectedExec {
|
||||
e.args = args
|
||||
return e
|
||||
}
|
||||
|
||||
// WillReturnError allows to set an error for expected database exec action
|
||||
func (e *ExpectedExec) WillReturnError(err error) *ExpectedExec {
|
||||
e.err = err
|
||||
return e
|
||||
}
|
||||
|
||||
// String returns string representation
|
||||
func (e *ExpectedExec) String() string {
|
||||
msg := "ExpectedExec => expecting Exec which:"
|
||||
msg += "\n - matches sql: '" + e.sqlRegex.String() + "'"
|
||||
|
||||
if len(e.args) == 0 {
|
||||
msg += "\n - is without arguments"
|
||||
} else {
|
||||
msg += "\n - is with arguments:\n"
|
||||
var margs []string
|
||||
for i, arg := range e.args {
|
||||
margs = append(margs, fmt.Sprintf(" %d - %+v", i, arg))
|
||||
}
|
||||
msg += strings.Join(margs, "\n")
|
||||
}
|
||||
|
||||
if e.result != nil {
|
||||
res, _ := e.result.(*result)
|
||||
msg += "\n - should return Result having:"
|
||||
msg += fmt.Sprintf("\n LastInsertId: %d", res.insertID)
|
||||
msg += fmt.Sprintf("\n RowsAffected: %d", res.rowsAffected)
|
||||
if res.err != nil {
|
||||
msg += fmt.Sprintf("\n Error: %s", res.err)
|
||||
}
|
||||
}
|
||||
|
||||
if e.err != nil {
|
||||
msg += fmt.Sprintf("\n - should return error: %s", e.err)
|
||||
}
|
||||
|
||||
return msg
|
||||
}
|
||||
|
||||
// WillReturnResult arranges for an expected Exec() to return a particular
|
||||
// result, there is sqlmock.NewResult(lastInsertID int64, affectedRows int64) method
|
||||
// to build a corresponding result. Or if actions needs to be tested against errors
|
||||
// sqlmock.NewErrorResult(err error) to return a given error.
|
||||
func (e *ExpectedExec) WillReturnResult(result driver.Result) *ExpectedExec {
|
||||
e.result = result
|
||||
return e
|
||||
}
|
||||
|
||||
// ExpectedPrepare is used to manage *sql.DB.Prepare or *sql.Tx.Prepare expectations.
|
||||
// Returned by *Sqlmock.ExpectPrepare.
|
||||
type ExpectedPrepare struct {
|
||||
commonExpectation
|
||||
mock *Sqlmock
|
||||
sqlRegex *regexp.Regexp
|
||||
statement driver.Stmt
|
||||
closeErr error
|
||||
}
|
||||
|
||||
// WillReturnError allows to set an error for the expected *sql.DB.Prepare or *sql.Tx.Prepare action.
|
||||
func (e *ExpectedPrepare) WillReturnError(err error) *ExpectedPrepare {
|
||||
e.err = err
|
||||
return e
|
||||
}
|
||||
|
||||
// WillReturnCloseError allows to set an error for this prapared statement Close action
|
||||
func (e *ExpectedPrepare) WillReturnCloseError(err error) *ExpectedPrepare {
|
||||
e.closeErr = err
|
||||
return e
|
||||
}
|
||||
|
||||
// ExpectQuery allows to expect Query() or QueryRow() on this prepared statement.
|
||||
// this method is convenient in order to prevent duplicating sql query string matching.
|
||||
func (e *ExpectedPrepare) ExpectQuery() *ExpectedQuery {
|
||||
eq := &ExpectedQuery{}
|
||||
eq.sqlRegex = e.sqlRegex
|
||||
e.mock.expected = append(e.mock.expected, eq)
|
||||
return eq
|
||||
}
|
||||
|
||||
// ExpectExec allows to expect Exec() on this prepared statement.
|
||||
// this method is convenient in order to prevent duplicating sql query string matching.
|
||||
func (e *ExpectedPrepare) ExpectExec() *ExpectedExec {
|
||||
eq := &ExpectedExec{}
|
||||
eq.sqlRegex = e.sqlRegex
|
||||
e.mock.expected = append(e.mock.expected, eq)
|
||||
return eq
|
||||
}
|
||||
|
||||
// String returns string representation
|
||||
func (e *ExpectedPrepare) String() string {
|
||||
msg := "ExpectedPrepare => expecting Prepare statement which:"
|
||||
msg += "\n - matches sql: '" + e.sqlRegex.String() + "'"
|
||||
|
||||
if e.err != nil {
|
||||
msg += fmt.Sprintf("\n - should return error: %s", e.err)
|
||||
}
|
||||
|
||||
if e.closeErr != nil {
|
||||
msg += fmt.Sprintf("\n - should return error on Close: %s", e.closeErr)
|
||||
}
|
||||
|
||||
return msg
|
||||
}
|
||||
|
||||
// query based expectation
|
||||
@ -41,6 +307,19 @@ type queryBasedExpectation struct {
|
||||
args []driver.Value
|
||||
}
|
||||
|
||||
func (e *queryBasedExpectation) attemptMatch(sql string, args []driver.Value) (ret bool) {
|
||||
if !e.queryMatches(sql) {
|
||||
return
|
||||
}
|
||||
|
||||
defer recover() // ignore panic since we attempt a match
|
||||
|
||||
if e.argsMatches(args) {
|
||||
return true
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (e *queryBasedExpectation) queryMatches(sql string) bool {
|
||||
return e.sqlRegex.MatchString(sql)
|
||||
}
|
||||
@ -88,39 +367,3 @@ func (e *queryBasedExpectation) argsMatches(args []driver.Value) bool {
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// begin transaction
|
||||
type expectedBegin struct {
|
||||
commonExpectation
|
||||
}
|
||||
|
||||
// tx commit
|
||||
type expectedCommit struct {
|
||||
commonExpectation
|
||||
}
|
||||
|
||||
// tx rollback
|
||||
type expectedRollback struct {
|
||||
commonExpectation
|
||||
}
|
||||
|
||||
// query expectation
|
||||
type expectedQuery struct {
|
||||
queryBasedExpectation
|
||||
|
||||
rows driver.Rows
|
||||
}
|
||||
|
||||
// exec query expectation
|
||||
type expectedExec struct {
|
||||
queryBasedExpectation
|
||||
|
||||
result driver.Result
|
||||
}
|
||||
|
||||
// Prepare expectation
|
||||
type expectedPrepare struct {
|
||||
commonExpectation
|
||||
|
||||
statement driver.Stmt
|
||||
}
|
||||
|
@ -2,6 +2,7 @@ package sqlmock
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"testing"
|
||||
"time"
|
||||
@ -60,7 +61,7 @@ func TestQueryExpectationArgComparison(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestQueryExpectationSqlMatch(t *testing.T) {
|
||||
e := &expectedExec{}
|
||||
e := &ExpectedExec{}
|
||||
e.sqlRegex = regexp.MustCompile("SELECT x FROM")
|
||||
if !e.queryMatches("SELECT x FROM someting") {
|
||||
t.Errorf("Sql must have matched the query")
|
||||
@ -71,3 +72,13 @@ func TestQueryExpectationSqlMatch(t *testing.T) {
|
||||
t.Errorf("Sql must have matched the query")
|
||||
}
|
||||
}
|
||||
|
||||
func ExampleExpectExec() {
|
||||
db, mock, _ := New()
|
||||
result := NewErrorResult(fmt.Errorf("some error"))
|
||||
mock.ExpectExec("^INSERT (.+)").WillReturnResult(result)
|
||||
res, _ := db.Exec("INSERT something")
|
||||
_, err := res.LastInsertId()
|
||||
fmt.Println(err)
|
||||
// Output: some error
|
||||
}
|
||||
|
@ -5,6 +5,32 @@ import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
// used for examples
|
||||
var mock = &Sqlmock{}
|
||||
|
||||
func ExampleNewErrorResult() {
|
||||
db, mock, _ := New()
|
||||
result := NewErrorResult(fmt.Errorf("some error"))
|
||||
mock.ExpectExec("^INSERT (.+)").WillReturnResult(result)
|
||||
res, _ := db.Exec("INSERT something")
|
||||
_, err := res.LastInsertId()
|
||||
fmt.Println(err)
|
||||
// Output: some error
|
||||
}
|
||||
|
||||
func ExampleNewResult() {
|
||||
var lastInsertID, affected int64
|
||||
result := NewResult(lastInsertID, affected)
|
||||
mock.ExpectExec("^INSERT (.+)").WillReturnResult(result)
|
||||
fmt.Println(mock.ExpectationsWereMet())
|
||||
// Output: there is a remaining expectation which was not matched: ExpectedExec => expecting Exec which:
|
||||
// - matches sql: '^INSERT (.+)'
|
||||
// - is without arguments
|
||||
// - should return Result having:
|
||||
// LastInsertId: 0
|
||||
// RowsAffected: 0
|
||||
}
|
||||
|
||||
func TestShouldReturnValidSqlDriverResult(t *testing.T) {
|
||||
result := NewResult(1, 2)
|
||||
id, err := result.LastInsertId()
|
||||
|
106
rows.go
106
rows.go
@ -7,19 +7,56 @@ import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
// CSVColumnParser is a function which converts trimmed csv
|
||||
// column string to a []byte representation. currently
|
||||
// transforms NULL to nil
|
||||
var CSVColumnParser = func(s string) []byte {
|
||||
switch {
|
||||
case strings.ToLower(s) == "null":
|
||||
return nil
|
||||
}
|
||||
return []byte(s)
|
||||
}
|
||||
|
||||
// Rows interface allows to construct rows
|
||||
// which also satisfies database/sql/driver.Rows interface
|
||||
type Rows interface {
|
||||
driver.Rows // composed interface, supports sql driver.Rows
|
||||
AddRow(...driver.Value) Rows
|
||||
// composed interface, supports sql driver.Rows
|
||||
driver.Rows
|
||||
|
||||
// AddRow composed from database driver.Value slice
|
||||
// return the same instance to perform subsequent actions.
|
||||
// Note that the number of values must match the number
|
||||
// of columns
|
||||
AddRow(columns ...driver.Value) Rows
|
||||
|
||||
// FromCSVString build rows from csv string.
|
||||
// return the same instance to perform subsequent actions.
|
||||
// Note that the number of values must match the number
|
||||
// of columns
|
||||
FromCSVString(s string) Rows
|
||||
|
||||
// RowError allows to set an error
|
||||
// which will be returned when a given
|
||||
// row number is read
|
||||
RowError(row int, err error) Rows
|
||||
|
||||
// CloseError allows to set an error
|
||||
// which will be returned by rows.Close
|
||||
// function.
|
||||
//
|
||||
// The close error will be triggered only in cases
|
||||
// when rows.Next() EOF was not yet reached, that is
|
||||
// a default sql library behavior
|
||||
CloseError(err error) Rows
|
||||
}
|
||||
|
||||
// a struct which implements database/sql/driver.Rows
|
||||
type rows struct {
|
||||
cols []string
|
||||
rows [][]driver.Value
|
||||
pos int
|
||||
cols []string
|
||||
rows [][]driver.Value
|
||||
pos int
|
||||
nextErr map[int]error
|
||||
closeErr error
|
||||
}
|
||||
|
||||
func (r *rows) Columns() []string {
|
||||
@ -27,11 +64,7 @@ func (r *rows) Columns() []string {
|
||||
}
|
||||
|
||||
func (r *rows) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *rows) Err() error {
|
||||
return nil
|
||||
return r.closeErr
|
||||
}
|
||||
|
||||
// advances to next row
|
||||
@ -45,19 +78,26 @@ func (r *rows) Next(dest []driver.Value) error {
|
||||
dest[i] = col
|
||||
}
|
||||
|
||||
return nil
|
||||
return r.nextErr[r.pos-1]
|
||||
}
|
||||
|
||||
// NewRows allows Rows to be created from a group of
|
||||
// sql driver.Value or from the CSV string and
|
||||
// NewRows allows Rows to be created from a
|
||||
// sql driver.Value slice or from the CSV string and
|
||||
// to be used as sql driver.Rows
|
||||
func NewRows(columns []string) Rows {
|
||||
return &rows{cols: columns}
|
||||
return &rows{cols: columns, nextErr: make(map[int]error)}
|
||||
}
|
||||
|
||||
func (r *rows) CloseError(err error) Rows {
|
||||
r.closeErr = err
|
||||
return r
|
||||
}
|
||||
|
||||
func (r *rows) RowError(row int, err error) Rows {
|
||||
r.nextErr[row] = err
|
||||
return r
|
||||
}
|
||||
|
||||
// AddRow adds a row which is built from arguments
|
||||
// in the same column order, returns sql driver.Rows
|
||||
// compatible interface
|
||||
func (r *rows) AddRow(values ...driver.Value) Rows {
|
||||
if len(values) != len(r.cols) {
|
||||
panic("Expected number of values to match number of columns")
|
||||
@ -72,8 +112,6 @@ func (r *rows) AddRow(values ...driver.Value) Rows {
|
||||
return r
|
||||
}
|
||||
|
||||
// FromCSVString adds rows from CSV string.
|
||||
// Returns sql driver.Rows compatible interface
|
||||
func (r *rows) FromCSVString(s string) Rows {
|
||||
res := strings.NewReader(strings.TrimSpace(s))
|
||||
csvReader := csv.NewReader(res)
|
||||
@ -86,35 +124,9 @@ func (r *rows) FromCSVString(s string) Rows {
|
||||
|
||||
row := make([]driver.Value, len(r.cols))
|
||||
for i, v := range res {
|
||||
row[i] = []byte(strings.TrimSpace(v))
|
||||
row[i] = CSVColumnParser(strings.TrimSpace(v))
|
||||
}
|
||||
r.rows = append(r.rows, row)
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
// RowsFromCSVString creates Rows from CSV string
|
||||
// to be used for mocked queries. Returns sql driver Rows interface
|
||||
// ** DEPRECATED ** will be removed in the future, use Rows.FromCSVString
|
||||
func RowsFromCSVString(columns []string, s string) driver.Rows {
|
||||
rs := &rows{}
|
||||
rs.cols = columns
|
||||
|
||||
r := strings.NewReader(strings.TrimSpace(s))
|
||||
csvReader := csv.NewReader(r)
|
||||
|
||||
for {
|
||||
r, err := csvReader.Read()
|
||||
if err != nil || r == nil {
|
||||
break
|
||||
}
|
||||
|
||||
row := make([]driver.Value, len(columns))
|
||||
for i, v := range r {
|
||||
v := strings.TrimSpace(v)
|
||||
row[i] = []byte(v)
|
||||
}
|
||||
rs.rows = append(rs.rows, row)
|
||||
}
|
||||
return rs
|
||||
}
|
||||
|
248
rows_test.go
Normal file
248
rows_test.go
Normal file
@ -0,0 +1,248 @@
|
||||
package sqlmock
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func ExampleRows() {
|
||||
db, mock, err := New()
|
||||
if err != nil {
|
||||
fmt.Println("failed to open sqlmock database:", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
rows := NewRows([]string{"id", "title"}).
|
||||
AddRow(1, "one").
|
||||
AddRow(2, "two")
|
||||
|
||||
mock.ExpectQuery("SELECT").WillReturnRows(rows)
|
||||
|
||||
rs, _ := db.Query("SELECT")
|
||||
defer rs.Close()
|
||||
|
||||
for rs.Next() {
|
||||
var id int
|
||||
var title string
|
||||
rs.Scan(&id, &title)
|
||||
fmt.Println("scanned id:", id, "and title:", title)
|
||||
}
|
||||
|
||||
if rs.Err() != nil {
|
||||
fmt.Println("got rows error:", rs.Err())
|
||||
}
|
||||
// Output: scanned id: 1 and title: one
|
||||
// scanned id: 2 and title: two
|
||||
}
|
||||
|
||||
func ExampleRows_rowError() {
|
||||
db, mock, err := New()
|
||||
if err != nil {
|
||||
fmt.Println("failed to open sqlmock database:", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
rows := NewRows([]string{"id", "title"}).
|
||||
AddRow(0, "one").
|
||||
AddRow(1, "two").
|
||||
RowError(1, fmt.Errorf("row error"))
|
||||
mock.ExpectQuery("SELECT").WillReturnRows(rows)
|
||||
|
||||
rs, _ := db.Query("SELECT")
|
||||
defer rs.Close()
|
||||
|
||||
for rs.Next() {
|
||||
var id int
|
||||
var title string
|
||||
rs.Scan(&id, &title)
|
||||
fmt.Println("scanned id:", id, "and title:", title)
|
||||
}
|
||||
|
||||
if rs.Err() != nil {
|
||||
fmt.Println("got rows error:", rs.Err())
|
||||
}
|
||||
// Output: scanned id: 0 and title: one
|
||||
// got rows error: row error
|
||||
}
|
||||
|
||||
func ExampleRows_closeError() {
|
||||
db, mock, err := New()
|
||||
if err != nil {
|
||||
fmt.Println("failed to open sqlmock database:", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
rows := NewRows([]string{"id", "title"}).CloseError(fmt.Errorf("close error"))
|
||||
mock.ExpectQuery("SELECT").WillReturnRows(rows)
|
||||
|
||||
rs, _ := db.Query("SELECT")
|
||||
|
||||
// Note: that close will return error only before rows EOF
|
||||
// that is a default sql package behavior. If you run rs.Next()
|
||||
// it will handle the error internally and return nil bellow
|
||||
if err := rs.Close(); err != nil {
|
||||
fmt.Println("got error:", err)
|
||||
}
|
||||
|
||||
// Output: got error: close error
|
||||
}
|
||||
|
||||
func TestAllowsToSetRowsErrors(t *testing.T) {
|
||||
t.Parallel()
|
||||
db, mock, err := New()
|
||||
if err != nil {
|
||||
t.Fatalf("an error '%s' was not expected when opening a stub database connection", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
rows := NewRows([]string{"id", "title"}).
|
||||
AddRow(0, "one").
|
||||
AddRow(1, "two").
|
||||
RowError(1, fmt.Errorf("error"))
|
||||
mock.ExpectQuery("SELECT").WillReturnRows(rows)
|
||||
|
||||
rs, err := db.Query("SELECT")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
defer rs.Close()
|
||||
|
||||
if !rs.Next() {
|
||||
t.Fatal("expected the first row to be available")
|
||||
}
|
||||
if rs.Err() != nil {
|
||||
t.Fatalf("unexpected error: %s", rs.Err())
|
||||
}
|
||||
|
||||
if rs.Next() {
|
||||
t.Fatal("was not expecting the second row, since there should be an error")
|
||||
}
|
||||
if rs.Err() == nil {
|
||||
t.Fatal("expected an error, but got none")
|
||||
}
|
||||
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRowsCloseError(t *testing.T) {
|
||||
t.Parallel()
|
||||
db, mock, err := New()
|
||||
if err != nil {
|
||||
t.Fatalf("an error '%s' was not expected when opening a stub database connection", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
rows := NewRows([]string{"id"}).CloseError(fmt.Errorf("close error"))
|
||||
mock.ExpectQuery("SELECT").WillReturnRows(rows)
|
||||
|
||||
rs, err := db.Query("SELECT")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
|
||||
if err := rs.Close(); err == nil {
|
||||
t.Fatal("expected a close error")
|
||||
}
|
||||
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestQuerySingleRow(t *testing.T) {
|
||||
t.Parallel()
|
||||
db, mock, err := New()
|
||||
if err != nil {
|
||||
t.Fatalf("an error '%s' was not expected when opening a stub database connection", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
rows := NewRows([]string{"id"}).
|
||||
AddRow(1).
|
||||
AddRow(2)
|
||||
mock.ExpectQuery("SELECT").WillReturnRows(rows)
|
||||
|
||||
var id int
|
||||
if err := db.QueryRow("SELECT").Scan(&id); err != nil {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
|
||||
mock.ExpectQuery("SELECT").WillReturnRows(NewRows([]string{"id"}))
|
||||
if err := db.QueryRow("SELECT").Scan(&id); err != sql.ErrNoRows {
|
||||
t.Fatal("expected sql no rows error")
|
||||
}
|
||||
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRowsScanError(t *testing.T) {
|
||||
t.Parallel()
|
||||
db, mock, err := New()
|
||||
if err != nil {
|
||||
t.Fatalf("an error '%s' was not expected when opening a stub database connection", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
r := NewRows([]string{"col1", "col2"}).AddRow("one", "two").AddRow("one", nil)
|
||||
mock.ExpectQuery("SELECT").WillReturnRows(r)
|
||||
|
||||
rs, err := db.Query("SELECT")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
defer rs.Close()
|
||||
|
||||
var one, two string
|
||||
if !rs.Next() || rs.Err() != nil || rs.Scan(&one, &two) != nil {
|
||||
t.Fatal("unexpected error on first row scan")
|
||||
}
|
||||
|
||||
if !rs.Next() || rs.Err() != nil {
|
||||
t.Fatal("unexpected error on second row read")
|
||||
}
|
||||
|
||||
err = rs.Scan(&one, &two)
|
||||
if err == nil {
|
||||
t.Fatal("expected an error for scan, but got none")
|
||||
}
|
||||
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCSVRowParser(t *testing.T) {
|
||||
t.Parallel()
|
||||
rs := NewRows([]string{"col1", "col2"}).FromCSVString("a,NULL")
|
||||
db, mock, err := New()
|
||||
if err != nil {
|
||||
t.Fatalf("an error '%s' was not expected when opening a stub database connection", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
mock.ExpectQuery("SELECT").WillReturnRows(rs)
|
||||
|
||||
rw, err := db.Query("SELECT")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
defer rw.Close()
|
||||
var col1 string
|
||||
var col2 []byte
|
||||
|
||||
rw.Next()
|
||||
if err = rw.Scan(&col1, &col2); err != nil {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
if col1 != "a" {
|
||||
t.Fatalf("expected col1 to be 'a', but got [%T]:%+v", col1, col1)
|
||||
}
|
||||
if col2 != nil {
|
||||
t.Fatalf("expected col2 to be nil, but got [%T]:%+v", col2, col2)
|
||||
}
|
||||
}
|
581
sqlmock.go
581
sqlmock.go
@ -1,195 +1,444 @@
|
||||
/*
|
||||
Package sqlmock provides sql driver mock connecection, which allows to test database,
|
||||
create expectations and ensure the correct execution flow of any database operations.
|
||||
It hooks into Go standard library's database/sql package.
|
||||
Package sqlmock provides sql driver connection, which allows to test database
|
||||
interactions by expected calls and simulate their results or errors.
|
||||
|
||||
The package provides convenient methods to mock database queries, transactions and
|
||||
expect the right execution flow, compare query arguments or even return error instead
|
||||
to simulate failures. See the example bellow, which illustrates how convenient it is
|
||||
to work with:
|
||||
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
"testing"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// will test that order with a different status, cannot be cancelled
|
||||
func TestShouldNotCancelOrderWithNonPendingStatus(t *testing.T) {
|
||||
// open database stub
|
||||
db, err := sql.Open("mock", "")
|
||||
if err != nil {
|
||||
t.Errorf("An error '%s' was not expected when opening a stub database connection", err)
|
||||
}
|
||||
|
||||
// columns to be used for result
|
||||
columns := []string{"id", "status"}
|
||||
// expect transaction begin
|
||||
sqlmock.ExpectBegin()
|
||||
// expect query to fetch order, match it with regexp
|
||||
sqlmock.ExpectQuery("SELECT (.+) FROM orders (.+) FOR UPDATE").
|
||||
WithArgs(1).
|
||||
WillReturnRows(sqlmock.NewRows(columns).FromCSVString("1,1"))
|
||||
// expect transaction rollback, since order status is "cancelled"
|
||||
sqlmock.ExpectRollback()
|
||||
|
||||
// run the cancel order function
|
||||
someOrderId := 1
|
||||
// call a function which executes expected database operations
|
||||
err = cancelOrder(someOrderId, db)
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error, but got %s instead", err)
|
||||
}
|
||||
// db.Close() ensures that all expectations have been met
|
||||
if err = db.Close(); err != nil {
|
||||
t.Errorf("Error '%s' was not expected while closing the database", err)
|
||||
}
|
||||
}
|
||||
It does not require any modifications to your source code in order to test
|
||||
and mock database operations. It does not even require a real database in order
|
||||
to test your application.
|
||||
|
||||
The driver allows to mock any sql driver method behavior. Concurrent actions
|
||||
are also supported.
|
||||
*/
|
||||
package sqlmock
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"regexp"
|
||||
)
|
||||
|
||||
var mock *mockDriver
|
||||
// Sqlmock type satisfies required sql.driver interfaces
|
||||
// to simulate actual database and also serves to
|
||||
// create expectations for any kind of database action
|
||||
// in order to mock and test real database behavior.
|
||||
type Sqlmock struct {
|
||||
|
||||
// Mock interface defines a mock which is returned
|
||||
// by any expectation and can be detailed further
|
||||
// with the methods this interface provides
|
||||
type Mock interface {
|
||||
WithArgs(...driver.Value) Mock
|
||||
WillReturnError(error) Mock
|
||||
WillReturnRows(driver.Rows) Mock
|
||||
WillReturnResult(driver.Result) Mock
|
||||
// MatchExpectationsInOrder gives an option whether to match all
|
||||
// expectations in the order they were set or not.
|
||||
//
|
||||
// By default it is set to - true. But if you use goroutines
|
||||
// to parallelize your query executation, that option may
|
||||
// be handy.
|
||||
MatchExpectationsInOrder bool
|
||||
|
||||
dsn string
|
||||
opened int
|
||||
drv *mockDriver
|
||||
|
||||
expected []expectation
|
||||
}
|
||||
|
||||
type mockDriver struct {
|
||||
conn *conn
|
||||
// ExpectClose queues an expectation for this database
|
||||
// action to be triggered. the *ExpectedClose allows
|
||||
// to mock database response
|
||||
func (c *Sqlmock) ExpectClose() *ExpectedClose {
|
||||
e := &ExpectedClose{}
|
||||
c.expected = append(c.expected, e)
|
||||
return e
|
||||
}
|
||||
|
||||
func (d *mockDriver) Open(dsn string) (driver.Conn, error) {
|
||||
return mock.conn, nil
|
||||
}
|
||||
// Close a mock database driver connection. It may or may not
|
||||
// be called depending on the sircumstances, but if it is called
|
||||
// there must be an *ExpectedClose expectation satisfied.
|
||||
// meets http://golang.org/pkg/database/sql/driver/#Conn interface
|
||||
func (c *Sqlmock) Close() error {
|
||||
c.drv.Lock()
|
||||
defer c.drv.Unlock()
|
||||
|
||||
func init() {
|
||||
mock = &mockDriver{&conn{}}
|
||||
sql.Register("mock", mock)
|
||||
}
|
||||
|
||||
// New creates sqlmock database connection
|
||||
// and pings it so that all expectations could be
|
||||
// asserted on Close.
|
||||
func New() (db *sql.DB, err error) {
|
||||
db, err = sql.Open("mock", "")
|
||||
if err != nil {
|
||||
return
|
||||
c.opened--
|
||||
if c.opened == 0 {
|
||||
delete(c.drv.conns, c.dsn)
|
||||
}
|
||||
// ensure open connection, otherwise Close does not assert expectations
|
||||
return db, db.Ping()
|
||||
}
|
||||
|
||||
// ExpectBegin expects transaction to be started
|
||||
func ExpectBegin() Mock {
|
||||
e := &expectedBegin{}
|
||||
mock.conn.expectations = append(mock.conn.expectations, e)
|
||||
mock.conn.active = e
|
||||
return mock.conn
|
||||
}
|
||||
|
||||
// ExpectCommit expects transaction to be commited
|
||||
func ExpectCommit() Mock {
|
||||
e := &expectedCommit{}
|
||||
mock.conn.expectations = append(mock.conn.expectations, e)
|
||||
mock.conn.active = e
|
||||
return mock.conn
|
||||
}
|
||||
|
||||
// ExpectRollback expects transaction to be rolled back
|
||||
func ExpectRollback() Mock {
|
||||
e := &expectedRollback{}
|
||||
mock.conn.expectations = append(mock.conn.expectations, e)
|
||||
mock.conn.active = e
|
||||
return mock.conn
|
||||
}
|
||||
|
||||
// ExpectPrepare expects Query to be prepared
|
||||
func ExpectPrepare() Mock {
|
||||
e := &expectedPrepare{}
|
||||
mock.conn.expectations = append(mock.conn.expectations, e)
|
||||
mock.conn.active = e
|
||||
return mock.conn
|
||||
}
|
||||
|
||||
// WillReturnError the expectation will return an error
|
||||
func (c *conn) WillReturnError(err error) Mock {
|
||||
c.active.setError(err)
|
||||
return c
|
||||
}
|
||||
|
||||
// ExpectExec expects database Exec to be triggered, which will match
|
||||
// the given query string as a regular expression
|
||||
func ExpectExec(sqlRegexStr string) Mock {
|
||||
e := &expectedExec{}
|
||||
e.sqlRegex = regexp.MustCompile(sqlRegexStr)
|
||||
mock.conn.expectations = append(mock.conn.expectations, e)
|
||||
mock.conn.active = e
|
||||
return mock.conn
|
||||
}
|
||||
|
||||
// ExpectQuery database Query to be triggered, which will match
|
||||
// the given query string as a regular expression
|
||||
func ExpectQuery(sqlRegexStr string) Mock {
|
||||
e := &expectedQuery{}
|
||||
e.sqlRegex = regexp.MustCompile(sqlRegexStr)
|
||||
|
||||
mock.conn.expectations = append(mock.conn.expectations, e)
|
||||
mock.conn.active = e
|
||||
return mock.conn
|
||||
}
|
||||
|
||||
// WithArgs expectation should be called with given arguments.
|
||||
// Works with Exec and Query expectations
|
||||
func (c *conn) WithArgs(args ...driver.Value) Mock {
|
||||
eq, ok := c.active.(*expectedQuery)
|
||||
if !ok {
|
||||
ee, ok := c.active.(*expectedExec)
|
||||
if !ok {
|
||||
panic(fmt.Sprintf("arguments may be expected only with query based expectations, current is %T", c.active))
|
||||
var expected *ExpectedClose
|
||||
var fulfilled int
|
||||
var ok bool
|
||||
for _, next := range c.expected {
|
||||
next.Lock()
|
||||
if next.fulfilled() {
|
||||
next.Unlock()
|
||||
fulfilled++
|
||||
continue
|
||||
}
|
||||
|
||||
if expected, ok = next.(*ExpectedClose); ok {
|
||||
break
|
||||
}
|
||||
|
||||
next.Unlock()
|
||||
if c.MatchExpectationsInOrder {
|
||||
return fmt.Errorf("call to database Close, was not expected, next expectation is: %s", next)
|
||||
}
|
||||
ee.args = args
|
||||
} else {
|
||||
eq.args = args
|
||||
}
|
||||
return c
|
||||
|
||||
if expected == nil {
|
||||
msg := "call to database Close was not expected"
|
||||
if fulfilled == len(c.expected) {
|
||||
msg = "all expectations were already fulfilled, " + msg
|
||||
}
|
||||
return fmt.Errorf(msg)
|
||||
}
|
||||
|
||||
expected.triggered = true
|
||||
expected.Unlock()
|
||||
return expected.err
|
||||
}
|
||||
|
||||
// WillReturnResult expectation will return a Result.
|
||||
// Works only with Exec expectations
|
||||
func (c *conn) WillReturnResult(result driver.Result) Mock {
|
||||
eq, ok := c.active.(*expectedExec)
|
||||
if !ok {
|
||||
panic(fmt.Sprintf("driver.result may be returned only by exec expectations, current is %T", c.active))
|
||||
// ExpectationsWereMet checks whether all queued expectations
|
||||
// were met in order. If any of them was not met - an error is returned.
|
||||
func (c *Sqlmock) ExpectationsWereMet() error {
|
||||
for _, e := range c.expected {
|
||||
if !e.fulfilled() {
|
||||
return fmt.Errorf("there is a remaining expectation which was not matched: %s", e)
|
||||
}
|
||||
}
|
||||
eq.result = result
|
||||
return c
|
||||
return nil
|
||||
}
|
||||
|
||||
// WillReturnRows expectation will return Rows.
|
||||
// Works only with Query expectations
|
||||
func (c *conn) WillReturnRows(rows driver.Rows) Mock {
|
||||
eq, ok := c.active.(*expectedQuery)
|
||||
if !ok {
|
||||
panic(fmt.Sprintf("driver.rows may be returned only by query expectations, current is %T", c.active))
|
||||
// Begin meets http://golang.org/pkg/database/sql/driver/#Conn interface
|
||||
func (c *Sqlmock) Begin() (driver.Tx, error) {
|
||||
var expected *ExpectedBegin
|
||||
var ok bool
|
||||
var fulfilled int
|
||||
for _, next := range c.expected {
|
||||
next.Lock()
|
||||
if next.fulfilled() {
|
||||
next.Unlock()
|
||||
fulfilled++
|
||||
continue
|
||||
}
|
||||
|
||||
if expected, ok = next.(*ExpectedBegin); ok {
|
||||
break
|
||||
}
|
||||
|
||||
next.Unlock()
|
||||
if c.MatchExpectationsInOrder {
|
||||
return nil, fmt.Errorf("call to database transaction Begin, was not expected, next expectation is: %s", next)
|
||||
}
|
||||
}
|
||||
eq.rows = rows
|
||||
return c
|
||||
if expected == nil {
|
||||
msg := "call to database transaction Begin was not expected"
|
||||
if fulfilled == len(c.expected) {
|
||||
msg = "all expectations were already fulfilled, " + msg
|
||||
}
|
||||
return nil, fmt.Errorf(msg)
|
||||
}
|
||||
|
||||
expected.triggered = true
|
||||
expected.Unlock()
|
||||
return c, expected.err
|
||||
}
|
||||
|
||||
// ExpectBegin expects *sql.DB.Begin to be called.
|
||||
// the *ExpectedBegin allows to mock database response
|
||||
func (c *Sqlmock) ExpectBegin() *ExpectedBegin {
|
||||
e := &ExpectedBegin{}
|
||||
c.expected = append(c.expected, e)
|
||||
return e
|
||||
}
|
||||
|
||||
// Exec meets http://golang.org/pkg/database/sql/driver/#Execer
|
||||
func (c *Sqlmock) Exec(query string, args []driver.Value) (res driver.Result, err error) {
|
||||
query = stripQuery(query)
|
||||
var expected *ExpectedExec
|
||||
var fulfilled int
|
||||
var ok bool
|
||||
for _, next := range c.expected {
|
||||
next.Lock()
|
||||
if next.fulfilled() {
|
||||
next.Unlock()
|
||||
fulfilled++
|
||||
continue
|
||||
}
|
||||
|
||||
if c.MatchExpectationsInOrder {
|
||||
if expected, ok = next.(*ExpectedExec); ok {
|
||||
break
|
||||
}
|
||||
next.Unlock()
|
||||
return nil, fmt.Errorf("call to exec query '%s' with args %+v, was not expected, next expectation is: %s", query, args, next)
|
||||
}
|
||||
if exec, ok := next.(*ExpectedExec); ok {
|
||||
if exec.attemptMatch(query, args) {
|
||||
expected = exec
|
||||
break
|
||||
}
|
||||
}
|
||||
next.Unlock()
|
||||
}
|
||||
if expected == nil {
|
||||
msg := "call to exec '%s' query with args %+v was not expected"
|
||||
if fulfilled == len(c.expected) {
|
||||
msg = "all expectations were already fulfilled, " + msg
|
||||
}
|
||||
return nil, fmt.Errorf(msg, query, args)
|
||||
}
|
||||
|
||||
defer expected.Unlock()
|
||||
expected.triggered = true
|
||||
// converts panic to error in case of reflect value type mismatch
|
||||
defer func(errp *error, exp *ExpectedExec, q string, a []driver.Value) {
|
||||
if e := recover(); e != nil {
|
||||
if se, ok := e.(*reflect.ValueError); ok { // catch reflect error, failed type conversion
|
||||
msg := "exec query \"%s\", args \"%+v\" failed to match with error \"%s\" expectation: %s"
|
||||
*errp = fmt.Errorf(msg, q, a, se, exp)
|
||||
} else {
|
||||
panic(e) // overwise if unknown error panic
|
||||
}
|
||||
}
|
||||
}(&err, expected, query, args)
|
||||
|
||||
if !expected.queryMatches(query) {
|
||||
return nil, fmt.Errorf("exec query '%s', does not match regex '%s'", query, expected.sqlRegex.String())
|
||||
}
|
||||
|
||||
if !expected.argsMatches(args) {
|
||||
return nil, fmt.Errorf("exec query '%s', args %+v does not match expected %+v", query, args, expected.args)
|
||||
}
|
||||
|
||||
if expected.err != nil {
|
||||
return nil, expected.err // mocked to return error
|
||||
}
|
||||
|
||||
if expected.result == nil {
|
||||
return nil, fmt.Errorf("exec query '%s' with args %+v, must return a database/sql/driver.result, but it was not set for expectation %T as %+v", query, args, expected, expected)
|
||||
}
|
||||
return expected.result, err
|
||||
}
|
||||
|
||||
// ExpectExec expects Exec() to be called with sql query
|
||||
// which match sqlRegexStr given regexp.
|
||||
// the *ExpectedExec allows to mock database response
|
||||
func (c *Sqlmock) ExpectExec(sqlRegexStr string) *ExpectedExec {
|
||||
e := &ExpectedExec{}
|
||||
e.sqlRegex = regexp.MustCompile(sqlRegexStr)
|
||||
c.expected = append(c.expected, e)
|
||||
return e
|
||||
}
|
||||
|
||||
// Prepare meets http://golang.org/pkg/database/sql/driver/#Conn interface
|
||||
func (c *Sqlmock) Prepare(query string) (driver.Stmt, error) {
|
||||
var expected *ExpectedPrepare
|
||||
var fulfilled int
|
||||
var ok bool
|
||||
for _, next := range c.expected {
|
||||
next.Lock()
|
||||
if next.fulfilled() {
|
||||
next.Unlock()
|
||||
fulfilled++
|
||||
continue
|
||||
}
|
||||
|
||||
if expected, ok = next.(*ExpectedPrepare); ok {
|
||||
break
|
||||
}
|
||||
|
||||
next.Unlock()
|
||||
if c.MatchExpectationsInOrder {
|
||||
return nil, fmt.Errorf("call to Prepare stetement with query '%s', was not expected, next expectation is: %s", query, next)
|
||||
}
|
||||
}
|
||||
|
||||
query = stripQuery(query)
|
||||
if expected == nil {
|
||||
msg := "call to Prepare '%s' query was not expected"
|
||||
if fulfilled == len(c.expected) {
|
||||
msg = "all expectations were already fulfilled, " + msg
|
||||
}
|
||||
return nil, fmt.Errorf(msg, query)
|
||||
}
|
||||
|
||||
expected.triggered = true
|
||||
expected.Unlock()
|
||||
return &statement{c, query, expected.closeErr}, expected.err
|
||||
}
|
||||
|
||||
// ExpectPrepare expects Prepare() to be called with sql query
|
||||
// which match sqlRegexStr given regexp.
|
||||
// the *ExpectedPrepare allows to mock database response.
|
||||
// Note that you may expect Query() or Exec() on the *ExpectedPrepare
|
||||
// statement to prevent repeating sqlRegexStr
|
||||
func (c *Sqlmock) ExpectPrepare(sqlRegexStr string) *ExpectedPrepare {
|
||||
e := &ExpectedPrepare{sqlRegex: regexp.MustCompile(sqlRegexStr), mock: c}
|
||||
c.expected = append(c.expected, e)
|
||||
return e
|
||||
}
|
||||
|
||||
// Query meets http://golang.org/pkg/database/sql/driver/#Queryer
|
||||
func (c *Sqlmock) Query(query string, args []driver.Value) (rw driver.Rows, err error) {
|
||||
query = stripQuery(query)
|
||||
var expected *ExpectedQuery
|
||||
var fulfilled int
|
||||
var ok bool
|
||||
for _, next := range c.expected {
|
||||
next.Lock()
|
||||
if next.fulfilled() {
|
||||
next.Unlock()
|
||||
fulfilled++
|
||||
continue
|
||||
}
|
||||
|
||||
if c.MatchExpectationsInOrder {
|
||||
if expected, ok = next.(*ExpectedQuery); ok {
|
||||
break
|
||||
}
|
||||
next.Unlock()
|
||||
return nil, fmt.Errorf("call to query '%s' with args %+v, was not expected, next expectation is: %s", query, args, next)
|
||||
}
|
||||
if qr, ok := next.(*ExpectedQuery); ok {
|
||||
if qr.attemptMatch(query, args) {
|
||||
expected = qr
|
||||
break
|
||||
}
|
||||
}
|
||||
next.Unlock()
|
||||
}
|
||||
|
||||
if expected == nil {
|
||||
msg := "call to query '%s' with args %+v was not expected"
|
||||
if fulfilled == len(c.expected) {
|
||||
msg = "all expectations were already fulfilled, " + msg
|
||||
}
|
||||
return nil, fmt.Errorf(msg, query, args)
|
||||
}
|
||||
|
||||
defer expected.Unlock()
|
||||
expected.triggered = true
|
||||
// converts panic to error in case of reflect value type mismatch
|
||||
defer func(errp *error, exp *ExpectedQuery, q string, a []driver.Value) {
|
||||
if e := recover(); e != nil {
|
||||
if se, ok := e.(*reflect.ValueError); ok { // catch reflect error, failed type conversion
|
||||
msg := "query \"%s\", args \"%+v\" failed to match with error \"%s\" expectation: %s"
|
||||
*errp = fmt.Errorf(msg, q, a, se, exp)
|
||||
} else {
|
||||
panic(e) // overwise if unknown error panic
|
||||
}
|
||||
}
|
||||
}(&err, expected, query, args)
|
||||
|
||||
if !expected.queryMatches(query) {
|
||||
return nil, fmt.Errorf("query '%s', does not match regex [%s]", query, expected.sqlRegex.String())
|
||||
}
|
||||
|
||||
if !expected.argsMatches(args) {
|
||||
return nil, fmt.Errorf("query '%s', args %+v does not match expected %+v", query, args, expected.args)
|
||||
}
|
||||
|
||||
if expected.err != nil {
|
||||
return nil, expected.err // mocked to return error
|
||||
}
|
||||
|
||||
if expected.rows == nil {
|
||||
return nil, fmt.Errorf("query '%s' with args %+v, must return a database/sql/driver.rows, but it was not set for expectation %T as %+v", query, args, expected, expected)
|
||||
}
|
||||
|
||||
return expected.rows, err
|
||||
}
|
||||
|
||||
// ExpectQuery expects Query() or QueryRow() to be called with sql query
|
||||
// which match sqlRegexStr given regexp.
|
||||
// the *ExpectedQuery allows to mock database response.
|
||||
func (c *Sqlmock) ExpectQuery(sqlRegexStr string) *ExpectedQuery {
|
||||
e := &ExpectedQuery{}
|
||||
e.sqlRegex = regexp.MustCompile(sqlRegexStr)
|
||||
c.expected = append(c.expected, e)
|
||||
return e
|
||||
}
|
||||
|
||||
// ExpectCommit expects *sql.Tx.Commit to be called.
|
||||
// the *ExpectedCommit allows to mock database response
|
||||
func (c *Sqlmock) ExpectCommit() *ExpectedCommit {
|
||||
e := &ExpectedCommit{}
|
||||
c.expected = append(c.expected, e)
|
||||
return e
|
||||
}
|
||||
|
||||
// ExpectRollback expects *sql.Tx.Rollback to be called.
|
||||
// the *ExpectedRollback allows to mock database response
|
||||
func (c *Sqlmock) ExpectRollback() *ExpectedRollback {
|
||||
e := &ExpectedRollback{}
|
||||
c.expected = append(c.expected, e)
|
||||
return e
|
||||
}
|
||||
|
||||
// Commit meets http://golang.org/pkg/database/sql/driver/#Tx
|
||||
func (c *Sqlmock) Commit() error {
|
||||
var expected *ExpectedCommit
|
||||
var fulfilled int
|
||||
var ok bool
|
||||
for _, next := range c.expected {
|
||||
next.Lock()
|
||||
if next.fulfilled() {
|
||||
next.Unlock()
|
||||
fulfilled++
|
||||
continue
|
||||
}
|
||||
|
||||
if expected, ok = next.(*ExpectedCommit); ok {
|
||||
break
|
||||
}
|
||||
|
||||
next.Unlock()
|
||||
if c.MatchExpectationsInOrder {
|
||||
return fmt.Errorf("call to commit transaction, was not expected, next expectation is: %s", next)
|
||||
}
|
||||
}
|
||||
if expected == nil {
|
||||
msg := "call to commit transaction was not expected"
|
||||
if fulfilled == len(c.expected) {
|
||||
msg = "all expectations were already fulfilled, " + msg
|
||||
}
|
||||
return fmt.Errorf(msg)
|
||||
}
|
||||
|
||||
expected.triggered = true
|
||||
expected.Unlock()
|
||||
return expected.err
|
||||
}
|
||||
|
||||
// Rollback meets http://golang.org/pkg/database/sql/driver/#Tx
|
||||
func (c *Sqlmock) Rollback() error {
|
||||
var expected *ExpectedRollback
|
||||
var fulfilled int
|
||||
var ok bool
|
||||
for _, next := range c.expected {
|
||||
next.Lock()
|
||||
if next.fulfilled() {
|
||||
next.Unlock()
|
||||
fulfilled++
|
||||
continue
|
||||
}
|
||||
|
||||
if expected, ok = next.(*ExpectedRollback); ok {
|
||||
break
|
||||
}
|
||||
|
||||
next.Unlock()
|
||||
if c.MatchExpectationsInOrder {
|
||||
return fmt.Errorf("call to rollback transaction, was not expected, next expectation is: %s", next)
|
||||
}
|
||||
}
|
||||
if expected == nil {
|
||||
msg := "call to rollback transaction was not expected"
|
||||
if fulfilled == len(c.expected) {
|
||||
msg = "all expectations were already fulfilled, " + msg
|
||||
}
|
||||
return fmt.Errorf(msg)
|
||||
}
|
||||
|
||||
expected.triggered = true
|
||||
expected.Unlock()
|
||||
return expected.err
|
||||
}
|
||||
|
306
sqlmock_test.go
306
sqlmock_test.go
@ -3,16 +3,60 @@ package sqlmock
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func cancelOrder(db *sql.DB, orderID int) error {
|
||||
tx, _ := db.Begin()
|
||||
_, _ = tx.Query("SELECT * FROM orders {0} FOR UPDATE", orderID)
|
||||
_ = tx.Rollback()
|
||||
return nil
|
||||
}
|
||||
|
||||
func Example() {
|
||||
// Open new mock database
|
||||
db, mock, err := New()
|
||||
if err != nil {
|
||||
fmt.Println("error creating mock database")
|
||||
return
|
||||
}
|
||||
// columns to be used for result
|
||||
columns := []string{"id", "status"}
|
||||
// expect transaction begin
|
||||
mock.ExpectBegin()
|
||||
// expect query to fetch order, match it with regexp
|
||||
mock.ExpectQuery("SELECT (.+) FROM orders (.+) FOR UPDATE").
|
||||
WithArgs(1).
|
||||
WillReturnRows(NewRows(columns).AddRow(1, 1))
|
||||
// expect transaction rollback, since order status is "cancelled"
|
||||
mock.ExpectRollback()
|
||||
|
||||
// run the cancel order function
|
||||
someOrderID := 1
|
||||
// call a function which executes expected database operations
|
||||
err = cancelOrder(db, someOrderID)
|
||||
if err != nil {
|
||||
fmt.Printf("unexpected error: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
// ensure all expectations have been met
|
||||
if err = mock.ExpectationsWereMet(); err != nil {
|
||||
fmt.Printf("unmet expectation error: %s", err)
|
||||
}
|
||||
// Output:
|
||||
}
|
||||
|
||||
func TestIssue14EscapeSQL(t *testing.T) {
|
||||
db, err := New()
|
||||
t.Parallel()
|
||||
db, mock, err := New()
|
||||
if err != nil {
|
||||
t.Errorf("an error '%s' was not expected when opening a stub database connection", err)
|
||||
}
|
||||
ExpectExec("INSERT INTO mytable\\(a, b\\)").
|
||||
defer db.Close()
|
||||
mock.ExpectExec("INSERT INTO mytable\\(a, b\\)").
|
||||
WithArgs("A", "B").
|
||||
WillReturnResult(NewResult(1, 1))
|
||||
|
||||
@ -21,37 +65,40 @@ func TestIssue14EscapeSQL(t *testing.T) {
|
||||
t.Errorf("error '%s' was not expected, while inserting a row", err)
|
||||
}
|
||||
|
||||
err = db.Close()
|
||||
if err != nil {
|
||||
t.Errorf("error '%s' was not expected while closing the database", err)
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("there were unfulfilled expections: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
// test the case when db is not triggered and expectations
|
||||
// are not asserted on close
|
||||
func TestIssue4(t *testing.T) {
|
||||
db, err := New()
|
||||
t.Parallel()
|
||||
db, mock, err := New()
|
||||
if err != nil {
|
||||
t.Errorf("an error '%s' was not expected when opening a stub database connection", err)
|
||||
}
|
||||
ExpectQuery("some sql query which will not be called").
|
||||
defer db.Close()
|
||||
|
||||
mock.ExpectQuery("some sql query which will not be called").
|
||||
WillReturnRows(NewRows([]string{"id"}))
|
||||
|
||||
err = db.Close()
|
||||
if err == nil {
|
||||
t.Errorf("Was expecting an error, since expected query was not matched")
|
||||
if err := mock.ExpectationsWereMet(); err == nil {
|
||||
t.Errorf("was expecting an error since query was not triggered")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMockQuery(t *testing.T) {
|
||||
db, err := sql.Open("mock", "")
|
||||
t.Parallel()
|
||||
db, mock, err := New()
|
||||
if err != nil {
|
||||
t.Errorf("an error '%s' was not expected when opening a stub database connection", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
rs := NewRows([]string{"id", "title"}).FromCSVString("5,hello world")
|
||||
|
||||
ExpectQuery("SELECT (.+) FROM articles WHERE id = ?").
|
||||
mock.ExpectQuery("SELECT (.+) FROM articles WHERE id = ?").
|
||||
WithArgs(5).
|
||||
WillReturnRows(rs)
|
||||
|
||||
@ -59,11 +106,13 @@ func TestMockQuery(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Errorf("error '%s' was not expected while retrieving mock rows", err)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if er := rows.Close(); er != nil {
|
||||
t.Error("Unexpected error while trying to close rows")
|
||||
}
|
||||
}()
|
||||
|
||||
if !rows.Next() {
|
||||
t.Error("it must have had one row as result, but got empty result set instead")
|
||||
}
|
||||
@ -84,16 +133,18 @@ func TestMockQuery(t *testing.T) {
|
||||
t.Errorf("expected mocked title to be 'hello world', but got '%s' instead", title)
|
||||
}
|
||||
|
||||
if err = db.Close(); err != nil {
|
||||
t.Errorf("error '%s' was not expected while closing the database", err)
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("there were unfulfilled expections: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMockQueryTypes(t *testing.T) {
|
||||
db, err := sql.Open("mock", "")
|
||||
t.Parallel()
|
||||
db, mock, err := New()
|
||||
if err != nil {
|
||||
t.Errorf("an error '%s' was not expected when opening a stub database connection", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
columns := []string{"id", "timestamp", "sold"}
|
||||
|
||||
@ -101,7 +152,7 @@ func TestMockQueryTypes(t *testing.T) {
|
||||
rs := NewRows(columns)
|
||||
rs.AddRow(5, timestamp, true)
|
||||
|
||||
ExpectQuery("SELECT (.+) FROM sales WHERE id = ?").
|
||||
mock.ExpectQuery("SELECT (.+) FROM sales WHERE id = ?").
|
||||
WithArgs(5).
|
||||
WillReturnRows(rs)
|
||||
|
||||
@ -139,20 +190,22 @@ func TestMockQueryTypes(t *testing.T) {
|
||||
t.Errorf("expected mocked boolean to be true, but got %v instead", sold)
|
||||
}
|
||||
|
||||
if err = db.Close(); err != nil {
|
||||
t.Errorf("error '%s' was not expected while closing the database", err)
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("there were unfulfilled expections: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTransactionExpectations(t *testing.T) {
|
||||
db, err := sql.Open("mock", "")
|
||||
t.Parallel()
|
||||
db, mock, err := New()
|
||||
if err != nil {
|
||||
t.Errorf("an error '%s' was not expected when opening a stub database connection", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
// begin and commit
|
||||
ExpectBegin()
|
||||
ExpectCommit()
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectCommit()
|
||||
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
@ -165,8 +218,8 @@ func TestTransactionExpectations(t *testing.T) {
|
||||
}
|
||||
|
||||
// begin and rollback
|
||||
ExpectBegin()
|
||||
ExpectRollback()
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectRollback()
|
||||
|
||||
tx, err = db.Begin()
|
||||
if err != nil {
|
||||
@ -179,25 +232,28 @@ func TestTransactionExpectations(t *testing.T) {
|
||||
}
|
||||
|
||||
// begin with an error
|
||||
ExpectBegin().WillReturnError(fmt.Errorf("some err"))
|
||||
mock.ExpectBegin().WillReturnError(fmt.Errorf("some err"))
|
||||
|
||||
tx, err = db.Begin()
|
||||
if err == nil {
|
||||
t.Error("an error was expected when beginning a transaction, but got none")
|
||||
}
|
||||
|
||||
if err = db.Close(); err != nil {
|
||||
t.Errorf("error '%s' was not expected while closing the database", err)
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("there were unfulfilled expections: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrepareExpectations(t *testing.T) {
|
||||
db, err := sql.Open("mock", "")
|
||||
t.Parallel()
|
||||
db, mock, err := New()
|
||||
if err != nil {
|
||||
t.Errorf("an error '%s' was not expected when opening a stub database connection", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
mock.ExpectPrepare("SELECT (.+) FROM articles WHERE id = ?")
|
||||
|
||||
// no expectations, w/o ExpectPrepare()
|
||||
stmt, err := db.Prepare("SELECT (.+) FROM articles WHERE id = ?")
|
||||
if err != nil {
|
||||
t.Errorf("error '%s' was not expected while creating a prepared statement", err)
|
||||
@ -211,36 +267,19 @@ func TestPrepareExpectations(t *testing.T) {
|
||||
var title string
|
||||
rs := NewRows([]string{"id", "title"}).FromCSVString("5,hello world")
|
||||
|
||||
ExpectQuery("SELECT (.+) FROM articles WHERE id = ?").
|
||||
mock.ExpectQuery("SELECT (.+) FROM articles WHERE id = ?").
|
||||
WithArgs(5).
|
||||
WillReturnRows(rs)
|
||||
|
||||
stmt, err = db.Prepare("SELECT (.+) FROM articles WHERE id = ?")
|
||||
if err != nil {
|
||||
t.Errorf("error '%s' was not expected while creating a prepared statement", err)
|
||||
}
|
||||
if stmt == nil {
|
||||
t.Errorf("stmt was expected while creating a prepared statement")
|
||||
}
|
||||
|
||||
err = stmt.QueryRow(5).Scan(&id, &title)
|
||||
if err != nil {
|
||||
t.Errorf("error '%s' was not expected while retrieving mock rows", err)
|
||||
}
|
||||
|
||||
// expect normal result
|
||||
ExpectPrepare()
|
||||
stmt, err = db.Prepare("SELECT (.+) FROM articles WHERE id = ?")
|
||||
if err != nil {
|
||||
t.Errorf("error '%s' was not expected while creating a prepared statement", err)
|
||||
}
|
||||
if stmt == nil {
|
||||
t.Errorf("stmt was expected while creating a prepared statement")
|
||||
}
|
||||
mock.ExpectPrepare("SELECT (.+) FROM articles WHERE id = ?").
|
||||
WillReturnError(fmt.Errorf("Some DB error occurred"))
|
||||
|
||||
// expect error result
|
||||
ExpectPrepare().WillReturnError(fmt.Errorf("Some DB error occurred"))
|
||||
stmt, err = db.Prepare("SELECT (.+) FROM articles WHERE id = ?")
|
||||
stmt, err = db.Prepare("SELECT id FROM articles WHERE id = ?")
|
||||
if err == nil {
|
||||
t.Error("error was expected while creating a prepared statement")
|
||||
}
|
||||
@ -248,35 +287,38 @@ func TestPrepareExpectations(t *testing.T) {
|
||||
t.Errorf("stmt was not expected while creating a prepared statement returning error")
|
||||
}
|
||||
|
||||
if err = db.Close(); err != nil {
|
||||
t.Errorf("error '%s' was not expected while closing the database", err)
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("there were unfulfilled expections: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPreparedQueryExecutions(t *testing.T) {
|
||||
db, err := sql.Open("mock", "")
|
||||
t.Parallel()
|
||||
db, mock, err := New()
|
||||
if err != nil {
|
||||
t.Errorf("an error '%s' was not expected when opening a stub database connection", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
mock.ExpectPrepare("SELECT (.+) FROM articles WHERE id = ?")
|
||||
|
||||
rs1 := NewRows([]string{"id", "title"}).FromCSVString("5,hello world")
|
||||
ExpectQuery("SELECT (.+) FROM articles WHERE id = ?").
|
||||
mock.ExpectQuery("SELECT (.+) FROM articles WHERE id = ?").
|
||||
WithArgs(5).
|
||||
WillReturnRows(rs1)
|
||||
|
||||
rs2 := NewRows([]string{"id", "title"}).FromCSVString("2,whoop")
|
||||
ExpectQuery("SELECT (.+) FROM articles WHERE id = ?").
|
||||
mock.ExpectQuery("SELECT (.+) FROM articles WHERE id = ?").
|
||||
WithArgs(2).
|
||||
WillReturnRows(rs2)
|
||||
|
||||
stmt, err := db.Prepare("SELECT (.+) FROM articles WHERE id = ?")
|
||||
stmt, err := db.Prepare("SELECT id, title FROM articles WHERE id = ?")
|
||||
if err != nil {
|
||||
t.Errorf("error '%s' was not expected while creating a prepared statement", err)
|
||||
}
|
||||
|
||||
var id int
|
||||
var title string
|
||||
|
||||
err = stmt.QueryRow(5).Scan(&id, &title)
|
||||
if err != nil {
|
||||
t.Errorf("error '%s' was not expected querying row from statement and scanning", err)
|
||||
@ -303,18 +345,21 @@ func TestPreparedQueryExecutions(t *testing.T) {
|
||||
t.Errorf("expected mocked title to be 'whoop', but got '%s' instead", title)
|
||||
}
|
||||
|
||||
if err = db.Close(); err != nil {
|
||||
t.Errorf("error '%s' was not expected while closing the database", err)
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("there were unfulfilled expections: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnexpectedOperations(t *testing.T) {
|
||||
db, err := sql.Open("mock", "")
|
||||
t.Parallel()
|
||||
db, mock, err := New()
|
||||
if err != nil {
|
||||
t.Errorf("an error '%s' was not expected when opening a stub database connection", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
stmt, err := db.Prepare("SELECT (.+) FROM articles WHERE id = ?")
|
||||
mock.ExpectPrepare("SELECT (.+) FROM articles WHERE id = ?")
|
||||
stmt, err := db.Prepare("SELECT id, title FROM articles WHERE id = ?")
|
||||
if err != nil {
|
||||
t.Errorf("error '%s' was not expected while creating a prepared statement", err)
|
||||
}
|
||||
@ -327,39 +372,35 @@ func TestUnexpectedOperations(t *testing.T) {
|
||||
t.Error("error was expected querying row, since there was no such expectation")
|
||||
}
|
||||
|
||||
ExpectRollback()
|
||||
mock.ExpectRollback()
|
||||
|
||||
err = db.Close()
|
||||
if err == nil {
|
||||
t.Error("error was expected while closing the database, expectation was not fulfilled", err)
|
||||
if err := mock.ExpectationsWereMet(); err == nil {
|
||||
t.Errorf("was expecting an error since query was not triggered")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWrongExpectations(t *testing.T) {
|
||||
db, err := sql.Open("mock", "")
|
||||
t.Parallel()
|
||||
db, mock, err := New()
|
||||
if err != nil {
|
||||
t.Errorf("an error '%s' was not expected when opening a stub database connection", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
ExpectBegin()
|
||||
mock.ExpectBegin()
|
||||
|
||||
rs1 := NewRows([]string{"id", "title"}).FromCSVString("5,hello world")
|
||||
ExpectQuery("SELECT (.+) FROM articles WHERE id = ?").
|
||||
mock.ExpectQuery("SELECT (.+) FROM articles WHERE id = ?").
|
||||
WithArgs(5).
|
||||
WillReturnRows(rs1)
|
||||
|
||||
ExpectCommit().WillReturnError(fmt.Errorf("deadlock occured"))
|
||||
ExpectRollback() // won't be triggered
|
||||
|
||||
stmt, err := db.Prepare("SELECT (.+) FROM articles WHERE id = ? FOR UPDATE")
|
||||
if err != nil {
|
||||
t.Errorf("error '%s' was not expected while creating a prepared statement", err)
|
||||
}
|
||||
mock.ExpectCommit().WillReturnError(fmt.Errorf("deadlock occured"))
|
||||
mock.ExpectRollback() // won't be triggered
|
||||
|
||||
var id int
|
||||
var title string
|
||||
|
||||
err = stmt.QueryRow(5).Scan(&id, &title)
|
||||
err = db.QueryRow("SELECT id, title FROM articles WHERE id = ? FOR UPDATE", 5).Scan(&id, &title)
|
||||
if err == nil {
|
||||
t.Error("error was expected while querying row, since there begin transaction expectation is not fulfilled")
|
||||
}
|
||||
@ -370,7 +411,7 @@ func TestWrongExpectations(t *testing.T) {
|
||||
t.Errorf("an error '%s' was not expected when beginning a transaction", err)
|
||||
}
|
||||
|
||||
err = stmt.QueryRow(5).Scan(&id, &title)
|
||||
err = db.QueryRow("SELECT id, title FROM articles WHERE id = ? FOR UPDATE", 5).Scan(&id, &title)
|
||||
if err != nil {
|
||||
t.Errorf("error '%s' was not expected while querying row, since transaction was started", err)
|
||||
}
|
||||
@ -380,20 +421,21 @@ func TestWrongExpectations(t *testing.T) {
|
||||
t.Error("a deadlock error was expected when commiting a transaction", err)
|
||||
}
|
||||
|
||||
err = db.Close()
|
||||
if err == nil {
|
||||
t.Error("error was expected while closing the database, expectation was not fulfilled", err)
|
||||
if err := mock.ExpectationsWereMet(); err == nil {
|
||||
t.Errorf("was expecting an error since query was not triggered")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecExpectations(t *testing.T) {
|
||||
db, err := sql.Open("mock", "")
|
||||
t.Parallel()
|
||||
db, mock, err := New()
|
||||
if err != nil {
|
||||
t.Errorf("an error '%s' was not expected when opening a stub database connection", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
result := NewResult(1, 1)
|
||||
ExpectExec("^INSERT INTO articles").
|
||||
mock.ExpectExec("^INSERT INTO articles").
|
||||
WithArgs("hello").
|
||||
WillReturnResult(result)
|
||||
|
||||
@ -420,22 +462,24 @@ func TestExecExpectations(t *testing.T) {
|
||||
t.Errorf("expected affected rows to be 1, but got %d instead", affected)
|
||||
}
|
||||
|
||||
if err = db.Close(); err != nil {
|
||||
t.Errorf("error '%s' was not expected while closing the database", err)
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("there were unfulfilled expections: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRowBuilderAndNilTypes(t *testing.T) {
|
||||
db, err := sql.Open("mock", "")
|
||||
t.Parallel()
|
||||
db, mock, err := New()
|
||||
if err != nil {
|
||||
t.Errorf("an error '%s' was not expected when opening a stub database connection", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
rs := NewRows([]string{"id", "active", "created", "status"}).
|
||||
AddRow(1, true, time.Now(), 5).
|
||||
AddRow(2, false, nil, nil)
|
||||
|
||||
ExpectQuery("SELECT (.+) FROM sales").WillReturnRows(rs)
|
||||
mock.ExpectQuery("SELECT (.+) FROM sales").WillReturnRows(rs)
|
||||
|
||||
rows, err := db.Query("SELECT * FROM sales")
|
||||
if err != nil {
|
||||
@ -510,23 +554,107 @@ func TestRowBuilderAndNilTypes(t *testing.T) {
|
||||
t.Errorf("expected 'status' to be invalid, but it %+v is not", status)
|
||||
}
|
||||
|
||||
if err = db.Close(); err != nil {
|
||||
t.Errorf("error '%s' was not expected while closing the database", err)
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("there were unfulfilled expections: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestArgumentReflectValueTypeError(t *testing.T) {
|
||||
db, err := sql.Open("mock", "")
|
||||
t.Parallel()
|
||||
db, mock, err := New()
|
||||
if err != nil {
|
||||
t.Errorf("an error '%s' was not expected when opening a stub database connection", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
rs := NewRows([]string{"id"}).AddRow(1)
|
||||
|
||||
ExpectQuery("SELECT (.+) FROM sales").WithArgs(5.5).WillReturnRows(rs)
|
||||
mock.ExpectQuery("SELECT (.+) FROM sales").WithArgs(5.5).WillReturnRows(rs)
|
||||
|
||||
_, err = db.Query("SELECT * FROM sales WHERE x = ?", 5)
|
||||
if err == nil {
|
||||
t.Error("Expected error, but got none")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGoroutineExecutionWithUnorderedExpectationMatching(t *testing.T) {
|
||||
t.Parallel()
|
||||
db, mock, err := New()
|
||||
if err != nil {
|
||||
t.Errorf("an error '%s' was not expected when opening a stub database connection", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
// note this line is important for unordered expectation matching
|
||||
mock.MatchExpectationsInOrder = false
|
||||
|
||||
result := NewResult(1, 1)
|
||||
|
||||
mock.ExpectExec("^UPDATE one").WithArgs("one").WillReturnResult(result)
|
||||
mock.ExpectExec("^UPDATE two").WithArgs("one", "two").WillReturnResult(result)
|
||||
mock.ExpectExec("^UPDATE three").WithArgs("one", "two", "three").WillReturnResult(result)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
queries := map[string][]interface{}{
|
||||
"one": []interface{}{"one"},
|
||||
"two": []interface{}{"one", "two"},
|
||||
"three": []interface{}{"one", "two", "three"},
|
||||
}
|
||||
|
||||
wg.Add(len(queries))
|
||||
for table, args := range queries {
|
||||
go func(tbl string, a []interface{}) {
|
||||
if _, err := db.Exec("UPDATE "+tbl, a...); err != nil {
|
||||
t.Errorf("error was not expected: %s", err)
|
||||
}
|
||||
wg.Done()
|
||||
}(table, args)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("there were unfulfilled expections: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func ExampleSqlmock_goroutines() {
|
||||
db, mock, err := New()
|
||||
if err != nil {
|
||||
fmt.Println("failed to open sqlmock database:", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
// note this line is important for unordered expectation matching
|
||||
mock.MatchExpectationsInOrder = false
|
||||
|
||||
result := NewResult(1, 1)
|
||||
|
||||
mock.ExpectExec("^UPDATE one").WithArgs("one").WillReturnResult(result)
|
||||
mock.ExpectExec("^UPDATE two").WithArgs("one", "two").WillReturnResult(result)
|
||||
mock.ExpectExec("^UPDATE three").WithArgs("one", "two", "three").WillReturnResult(result)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
queries := map[string][]interface{}{
|
||||
"one": []interface{}{"one"},
|
||||
"two": []interface{}{"one", "two"},
|
||||
"three": []interface{}{"one", "two", "three"},
|
||||
}
|
||||
|
||||
wg.Add(len(queries))
|
||||
for table, args := range queries {
|
||||
go func(tbl string, a []interface{}) {
|
||||
if _, err := db.Exec("UPDATE "+tbl, a...); err != nil {
|
||||
fmt.Println("error was not expected:", err)
|
||||
}
|
||||
wg.Done()
|
||||
}(table, args)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
fmt.Println("there were unfulfilled expections:", err)
|
||||
}
|
||||
// Output:
|
||||
}
|
||||
|
@ -5,12 +5,13 @@ import (
|
||||
)
|
||||
|
||||
type statement struct {
|
||||
conn *conn
|
||||
conn *Sqlmock
|
||||
query string
|
||||
err error
|
||||
}
|
||||
|
||||
func (stmt *statement) Close() error {
|
||||
return nil
|
||||
return stmt.err
|
||||
}
|
||||
|
||||
func (stmt *statement) NumInput() int {
|
||||
|
@ -1,37 +0,0 @@
|
||||
package sqlmock
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
type transaction struct {
|
||||
conn *conn
|
||||
}
|
||||
|
||||
func (tx *transaction) Commit() error {
|
||||
e := tx.conn.next()
|
||||
if e == nil {
|
||||
return fmt.Errorf("all expectations were already fulfilled, call to commit transaction was not expected")
|
||||
}
|
||||
|
||||
etc, ok := e.(*expectedCommit)
|
||||
if !ok {
|
||||
return fmt.Errorf("call to commit transaction, was not expected, next expectation was %v", e)
|
||||
}
|
||||
etc.triggered = true
|
||||
return etc.err
|
||||
}
|
||||
|
||||
func (tx *transaction) Rollback() error {
|
||||
e := tx.conn.next()
|
||||
if e == nil {
|
||||
return fmt.Errorf("all expectations were already fulfilled, call to rollback transaction was not expected")
|
||||
}
|
||||
|
||||
etr, ok := e.(*expectedRollback)
|
||||
if !ok {
|
||||
return fmt.Errorf("call to rollback transaction, was not expected, next expectation was %v", e)
|
||||
}
|
||||
etr.triggered = true
|
||||
return etr.err
|
||||
}
|
Loading…
Reference in New Issue
Block a user