1
0
mirror of https://github.com/DATA-DOG/go-sqlmock.git synced 2025-03-23 21:09:19 +02:00

initial commit

This commit is contained in:
gedi 2014-02-05 16:21:07 +02:00
commit 3e67393335
10 changed files with 789 additions and 0 deletions

1
.gitignore vendored Normal file
View File

@ -0,0 +1 @@
/*.test

10
README.md Normal file
View File

@ -0,0 +1,10 @@
db = mock.Open("test", "")
db.ExpectTransactionBegin()
db.ExpectTransactionBegin().WillReturnError("some error")
db.ExpectQuery("SELECT bla").With(5, 8, "stat").WillReturnNone()
db.ExpectExec("UPDATE tbl SET").With(5, "val").WillReturnResult(res /* sql.Result */)
db.ExpectExec("INSERT INTO bla").With(5, 8, "stat").WillReturnResult(res /* sql.Result */)
db.ExpectQuery("SELECT bla").With(5, 8, "stat").WillReturnRows()

101
expectations.go Normal file
View File

@ -0,0 +1,101 @@
package sqlmock
import (
"database/sql/driver"
"reflect"
"regexp"
)
type expectation interface {
fulfilled() bool
setError(err error)
}
// common expectation
type commonExpectation struct {
triggered bool
err error
}
func (e *commonExpectation) fulfilled() bool {
return e.triggered
}
func (e *commonExpectation) setError(err error) {
e.err = err
}
// query based expectation
type queryBasedExpectation struct {
commonExpectation
sqlRegex *regexp.Regexp
args []driver.Value
}
func (e *queryBasedExpectation) queryMatches(sql string) bool {
return e.sqlRegex.MatchString(sql)
}
func (e *queryBasedExpectation) argsMatches(args []driver.Value) bool {
if len(args) != len(e.args) {
return false
}
for k, v := range e.args {
vi := reflect.ValueOf(v)
ai := reflect.ValueOf(args[k])
switch vi.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
if vi.Int() != ai.Int() {
return false
}
case reflect.Float32, reflect.Float64:
if vi.Float() != ai.Float() {
return false
}
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
if vi.Uint() != ai.Uint() {
return false
}
case reflect.String:
if vi.String() != ai.String() {
return false
}
default:
// compare types like time.Time based on type only
if vi.Kind() != ai.Kind() {
return false
}
}
}
return true
}
// begin transaction
type expectedBegin struct {
commonExpectation
}
// tx commit
type expectedCommit struct {
commonExpectation
}
// tx rollback
type expectedRollback struct {
commonExpectation
}
// query expectation
type expectedQuery struct {
queryBasedExpectation
rows driver.Rows
}
// exec query expectation
type expectedExec struct {
queryBasedExpectation
result driver.Result
}

48
expectations_test.go Normal file
View File

@ -0,0 +1,48 @@
package sqlmock
import (
"database/sql/driver"
"testing"
"time"
)
func TestQueryExpectationArgComparison(t *testing.T) {
e := &queryBasedExpectation{}
e.args = []driver.Value{5, "str"}
against := []driver.Value{5}
if e.argsMatches(against) {
t.Error("Arguments should not match, since the size is not the same")
}
against = []driver.Value{3, "str"}
if e.argsMatches(against) {
t.Error("Arguments should not match, since the first argument (int value) is different")
}
against = []driver.Value{5, "st"}
if e.argsMatches(against) {
t.Error("Arguments should not match, since the second argument (string value) is different")
}
against = []driver.Value{5, "str"}
if !e.argsMatches(against) {
t.Error("Arguments should match, but it did not")
}
e.args = []driver.Value{5, time.Now()}
const longForm = "Jan 2, 2006 at 3:04pm (MST)"
tm, _ := time.Parse(longForm, "Feb 3, 2013 at 7:54pm (PST)")
against = []driver.Value{5, tm}
if !e.argsMatches(against) {
t.Error("Arguments should match (time will be compared only by type), but it did not")
}
against = []driver.Value{5, 7899000}
if e.argsMatches(against) {
t.Error("Arguments should not match, but it did")
}
}

21
result.go Normal file
View File

@ -0,0 +1,21 @@
package sqlmock
type Result struct {
lastInsertId int64
rowsAffected int64
}
func NewResult(lastInsertId int64, rowsAffected int64) *Result {
return &Result{
lastInsertId,
rowsAffected,
}
}
func (res *Result) LastInsertId() (int64, error) {
return res.lastInsertId, nil
}
func (res *Result) RowsAffected() (int64, error) {
return res.rowsAffected, nil
}

62
rows.go Normal file
View File

@ -0,0 +1,62 @@
package sqlmock
import (
"database/sql/driver"
"encoding/csv"
"io"
"strings"
)
type rows struct {
cols []string
rows [][]driver.Value
pos int
}
func (r *rows) Columns() []string {
return r.cols
}
func (r *rows) Close() error {
return nil
}
func (r *rows) Err() error {
return nil
}
func (r *rows) Next(dest []driver.Value) error {
r.pos++
if r.pos > len(r.rows) {
return io.EOF // per interface spec
}
for i, col := range r.rows[r.pos-1] {
dest[i] = col
}
return nil
}
func RowsFromCSVString(columns []string, s string) driver.Rows {
rs := &rows{}
rs.cols = columns
r := strings.NewReader(strings.TrimSpace(s))
csvReader := csv.NewReader(r)
for {
r, err := csvReader.Read()
if err != nil || r == nil {
break
}
row := make([]driver.Value, len(columns))
for i, v := range r {
v := strings.TrimSpace(v)
row[i] = v
}
rs.rows = append(rs.rows, row)
}
return rs
}

213
sqlmock.go Normal file
View File

@ -0,0 +1,213 @@
package sqlmock
import (
"database/sql"
"database/sql/driver"
"errors"
"fmt"
"regexp"
)
var mock *mockDriver
type Mock interface {
WithArgs(...driver.Value) Mock
WillReturnError(error) Mock
WillReturnRows(driver.Rows) Mock
WillReturnResult(driver.Result) Mock
}
type mockDriver struct {
conn *conn
}
func (d *mockDriver) Open(dsn string) (driver.Conn, error) {
return mock.conn, nil
}
func init() {
mock = &mockDriver{&conn{}}
sql.Register("mock", mock)
}
type conn struct {
expectations []expectation
active expectation
}
func (c *conn) Close() (err error) {
for _, e := range mock.conn.expectations {
if !e.fulfilled() {
err = errors.New(fmt.Sprintf("There is expectation %+v which was not matched yet", e))
break
}
}
mock.conn.expectations = []expectation{}
mock.conn.active = nil
return err
}
func ExpectBegin() Mock {
e := &expectedBegin{}
mock.conn.expectations = append(mock.conn.expectations, e)
mock.conn.active = e
return mock.conn
}
func ExpectCommit() Mock {
e := &expectedCommit{}
mock.conn.expectations = append(mock.conn.expectations, e)
mock.conn.active = e
return mock.conn
}
func ExpectRollback() Mock {
e := &expectedRollback{}
mock.conn.expectations = append(mock.conn.expectations, e)
mock.conn.active = e
return mock.conn
}
func (c *conn) WillReturnError(err error) Mock {
c.active.setError(err)
return c
}
func (c *conn) Begin() (driver.Tx, error) {
e := c.next()
if e == nil {
return nil, errors.New("All expectations were already fulfilled, call to Begin transaction was not expected")
}
etb, ok := e.(*expectedBegin)
if !ok {
return nil, errors.New(fmt.Sprintf("Call to Begin transaction, was not expected, next expectation is %v", e))
}
etb.triggered = true
return &transaction{c}, etb.err
}
// get next unfulfilled expectation
func (c *conn) next() (e expectation) {
for _, e = range c.expectations {
if !e.fulfilled() {
return
}
}
return nil // all expectations were fulfilled
}
func (c *conn) Exec(query string, args []driver.Value) (driver.Result, error) {
e := c.next()
if e == nil {
return nil, errors.New(fmt.Sprintf("All expectations were already fulfilled, call to Exec '%s' query with args [%v] was not expected", query, args))
}
eq, ok := e.(*expectedExec)
if !ok {
return nil, errors.New(fmt.Sprintf("Call to Exec query '%s' with args [%v], was not expected, next expectation is %v", query, args, e))
}
eq.triggered = true
if eq.err != nil {
return nil, eq.err // mocked to return error
}
if eq.result == nil {
return nil, errors.New(fmt.Sprintf("Exec query '%s' with args [%v], must return a database/sql/driver.Result, but it was not set for expectation %v", query, args, eq))
}
if !eq.queryMatches(query) {
return nil, errors.New(fmt.Sprintf("Exec query '%s', does not match regex [%s]", query, eq.sqlRegex.String()))
}
if !eq.argsMatches(args) {
return nil, errors.New(fmt.Sprintf("Exec query '%s', args [%v] does not match expected [%v]", query, args, eq.args))
}
return eq.result, nil
}
func ExpectExec(sqlRegexStr string) Mock {
e := &expectedExec{}
e.sqlRegex = regexp.MustCompile(sqlRegexStr)
mock.conn.expectations = append(mock.conn.expectations, e)
mock.conn.active = e
return mock.conn
}
func ExpectQuery(sqlRegexStr string) Mock {
e := &expectedQuery{}
e.sqlRegex = regexp.MustCompile(sqlRegexStr)
mock.conn.expectations = append(mock.conn.expectations, e)
mock.conn.active = e
return mock.conn
}
func (c *conn) WithArgs(args ...driver.Value) Mock {
eq, ok := c.active.(*expectedQuery)
if !ok {
ee, ok := c.active.(*expectedExec)
if !ok {
panic(fmt.Sprintf("Arguments may be expected only with query based expectations, current is %T", c.active))
}
ee.args = args
} else {
eq.args = args
}
return c
}
func (c *conn) WillReturnResult(result driver.Result) Mock {
eq, ok := c.active.(*expectedExec)
if !ok {
panic(fmt.Sprintf("driver.Result may be returned only by Exec expectations, current is %v", c.active))
}
eq.result = result
return c
}
func (c *conn) WillReturnRows(rows driver.Rows) Mock {
eq, ok := c.active.(*expectedQuery)
if !ok {
panic(fmt.Sprintf("driver.Rows may be returned only by Query expectations, current is %v", c.active))
}
eq.rows = rows
return c
}
func (c *conn) Prepare(query string) (driver.Stmt, error) {
return &statement{c, query}, nil
}
func (c *conn) Query(query string, args []driver.Value) (driver.Rows, error) {
e := c.next()
if e == nil {
return nil, errors.New(fmt.Sprintf("All expectations were already fulfilled, call to Query '%s' with args [%v] was not expected", query, args))
}
eq, ok := e.(*expectedQuery)
if !ok {
return nil, errors.New(fmt.Sprintf("Call to Query '%s' with args [%v], was not expected, next expectation is %v", query, args, e))
}
eq.triggered = true
if eq.err != nil {
return nil, eq.err // mocked to return error
}
if eq.rows == nil {
return nil, errors.New(fmt.Sprintf("Query '%s' with args [%v], must return a database/sql/driver.Rows, but it was not set for expectation %v", query, args, eq))
}
if !eq.queryMatches(query) {
return nil, errors.New(fmt.Sprintf("Query '%s', does not match regex [%s]", query, eq.sqlRegex.String()))
}
if !eq.argsMatches(args) {
return nil, errors.New(fmt.Sprintf("Query '%s', args [%v] does not match expected [%v]", query, args, eq.args))
}
return eq.rows, nil
}

268
sqlmock_test.go Normal file
View File

@ -0,0 +1,268 @@
package sqlmock
import (
"database/sql"
"errors"
"testing"
)
func TestMockQuery(t *testing.T) {
db, err := sql.Open("mock", "")
if err != nil {
t.Errorf("An error '%s' was not expected when opening a stub database connection", err)
}
rs := RowsFromCSVString([]string{"id", "title"}, "5,hello world")
ExpectQuery("SELECT (.+) FROM articles WHERE id = ?").
WithArgs(5).
WillReturnRows(rs)
rows, err := db.Query("SELECT (.+) FROM articles WHERE id = ?", 5)
if err != nil {
t.Errorf("Error '%s' was not expected while retrieving mock rows", err)
}
defer rows.Close()
if !rows.Next() {
t.Error("It must have had one row as result, but got empty result set instead")
}
var id int
var title string
err = rows.Scan(&id, &title)
if err != nil {
t.Errorf("Error '%s' was not expected while trying to scan row", err)
}
if id != 5 {
t.Errorf("Expected mocked id to be 5, but got %d instead", id)
}
if title != "hello world" {
t.Errorf("Expected mocked title to be 'hello world', but got '%s' instead", title)
}
if err = db.Close(); err != nil {
t.Errorf("Error '%s' was not expected while closing the database", err)
}
}
func TestTransactionExpectations(t *testing.T) {
db, err := sql.Open("mock", "")
if err != nil {
t.Errorf("An error '%s' was not expected when opening a stub database connection", err)
}
// begin and commit
ExpectBegin()
ExpectCommit()
tx, err := db.Begin()
if err != nil {
t.Errorf("An error '%s' was not expected when beginning a transaction", err)
}
err = tx.Commit()
if err != nil {
t.Errorf("An error '%s' was not expected when commiting a transaction", err)
}
// begin and rollback
ExpectBegin()
ExpectRollback()
tx, err = db.Begin()
if err != nil {
t.Errorf("An error '%s' was not expected when beginning a transaction", err)
}
err = tx.Rollback()
if err != nil {
t.Errorf("An error '%s' was not expected when rolling back a transaction", err)
}
// begin with an error
ExpectBegin().WillReturnError(errors.New("Some err"))
tx, err = db.Begin()
if err == nil {
t.Error("An error was expected when beginning a transaction, but got none")
}
if err = db.Close(); err != nil {
t.Errorf("Error '%s' was not expected while closing the database", err)
}
}
func TestPreparedQueryExecutions(t *testing.T) {
db, err := sql.Open("mock", "")
if err != nil {
t.Errorf("An error '%s' was not expected when opening a stub database connection", err)
}
rs1 := RowsFromCSVString([]string{"id", "title"}, "5,hello world")
ExpectQuery("SELECT (.+) FROM articles WHERE id = ?").
WithArgs(5).
WillReturnRows(rs1)
rs2 := RowsFromCSVString([]string{"id", "title"}, "2,whoop")
ExpectQuery("SELECT (.+) FROM articles WHERE id = ?").
WithArgs(2).
WillReturnRows(rs2)
stmt, err := db.Prepare("SELECT (.+) FROM articles WHERE id = ?")
if err != nil {
t.Errorf("Error '%s' was not expected while creating a prepared statement", err)
}
var id int
var title string
err = stmt.QueryRow(5).Scan(&id, &title)
if err != nil {
t.Errorf("Error '%s' was not expected querying row from statement and scanning", err)
}
if id != 5 {
t.Errorf("Expected mocked id to be 5, but got %d instead", id)
}
if title != "hello world" {
t.Errorf("Expected mocked title to be 'hello world', but got '%s' instead", title)
}
err = stmt.QueryRow(2).Scan(&id, &title)
if err != nil {
t.Errorf("Error '%s' was not expected querying row from statement and scanning", err)
}
if id != 2 {
t.Errorf("Expected mocked id to be 2, but got %d instead", id)
}
if title != "whoop" {
t.Errorf("Expected mocked title to be 'whoop', but got '%s' instead", title)
}
if err = db.Close(); err != nil {
t.Errorf("Error '%s' was not expected while closing the database", err)
}
}
func TestUnexpectedOperations(t *testing.T) {
db, err := sql.Open("mock", "")
if err != nil {
t.Errorf("An error '%s' was not expected when opening a stub database connection", err)
}
stmt, err := db.Prepare("SELECT (.+) FROM articles WHERE id = ?")
if err != nil {
t.Errorf("Error '%s' was not expected while creating a prepared statement", err)
}
var id int
var title string
err = stmt.QueryRow(5).Scan(&id, &title)
if err == nil {
t.Error("Error was expected querying row, since there was no such expectation")
}
ExpectRollback()
err = db.Close()
if err == nil {
t.Error("Error was expected while closing the database, expectation was not fulfilled", err)
}
}
func TestWrongUnexpectations(t *testing.T) {
db, err := sql.Open("mock", "")
if err != nil {
t.Errorf("An error '%s' was not expected when opening a stub database connection", err)
}
ExpectBegin()
rs1 := RowsFromCSVString([]string{"id", "title"}, "5,hello world")
ExpectQuery("SELECT (.+) FROM articles WHERE id = ?").
WithArgs(5).
WillReturnRows(rs1)
ExpectCommit().WillReturnError(errors.New("Deadlock occured"))
ExpectRollback() // won't be triggered
stmt, err := db.Prepare("SELECT (.+) FROM articles WHERE id = ? FOR UPDATE")
if err != nil {
t.Errorf("Error '%s' was not expected while creating a prepared statement", err)
}
var id int
var title string
err = stmt.QueryRow(5).Scan(&id, &title)
if err == nil {
t.Error("Error was expected while querying row, since there Begin transaction expectation is not fulfilled")
}
// lets go around and start transaction
tx, err := db.Begin()
if err != nil {
t.Errorf("An error '%s' was not expected when beginning a transaction", err)
}
err = stmt.QueryRow(5).Scan(&id, &title)
if err != nil {
t.Errorf("Error '%s' was not expected while querying row, since transaction was started", err)
}
err = tx.Commit()
if err == nil {
t.Error("A deadlock error was expected when commiting a transaction", err)
}
err = db.Close()
if err == nil {
t.Error("Error was expected while closing the database, expectation was not fulfilled", err)
}
}
func TestExecExpectations(t *testing.T) {
db, err := sql.Open("mock", "")
if err != nil {
t.Errorf("An error '%s' was not expected when opening a stub database connection", err)
}
result := NewResult(1, 1)
ExpectExec("^INSERT INTO articles").
WithArgs("hello").
WillReturnResult(result)
res, err := db.Exec("INSERT INTO articles (title) VALUES (?)", "hello")
if err != nil {
t.Errorf("Error '%s' was not expected, while inserting a row", err)
}
id, err := res.LastInsertId()
if err != nil {
t.Errorf("Error '%s' was not expected, while getting a last insert id", err)
}
affected, err := res.RowsAffected()
if err != nil {
t.Errorf("Error '%s' was not expected, while getting affected rows", err)
}
if id != 1 {
t.Errorf("Expected last insert id to be 1, but got %d instead", id)
}
if affected != 1 {
t.Errorf("Expected affected rows to be 1, but got %d instead", affected)
}
if err = db.Close(); err != nil {
t.Errorf("Error '%s' was not expected while closing the database", err)
}
}

27
statement.go Normal file
View File

@ -0,0 +1,27 @@
package sqlmock
import (
"database/sql/driver"
)
type statement struct {
conn *conn
query string
}
func (stmt *statement) Close() error {
stmt.conn = nil
return nil
}
func (stmt *statement) NumInput() int {
return -1
}
func (stmt *statement) Exec(args []driver.Value) (driver.Result, error) {
return stmt.conn.Exec(stmt.query, args)
}
func (stmt *statement) Query(args []driver.Value) (driver.Rows, error) {
return stmt.conn.Query(stmt.query, args)
}

38
transaction.go Normal file
View File

@ -0,0 +1,38 @@
package sqlmock
import (
"errors"
"fmt"
)
type transaction struct {
conn *conn
}
func (tx *transaction) Commit() error {
e := tx.conn.next()
if e == nil {
return errors.New("All expectations were already fulfilled, call to Commit transaction was not expected")
}
etc, ok := e.(*expectedCommit)
if !ok {
return errors.New(fmt.Sprintf("Call to Commit transaction, was not expected, next expectation was %v", e))
}
etc.triggered = true
return etc.err
}
func (tx *transaction) Rollback() error {
e := tx.conn.next()
if e == nil {
return errors.New("All expectations were already fulfilled, call to Rollback transaction was not expected")
}
etr, ok := e.(*expectedRollback)
if !ok {
return errors.New(fmt.Sprintf("Call to Rollback transaction, was not expected, next expectation was %v", e))
}
etr.triggered = true
return etr.err
}