From 3e67393335d957281bb0030e4fd42635ce2c133f Mon Sep 17 00:00:00 2001 From: gedi Date: Wed, 5 Feb 2014 16:21:07 +0200 Subject: [PATCH] initial commit --- .gitignore | 1 + README.md | 10 ++ expectations.go | 101 ++++++++++++++++ expectations_test.go | 48 ++++++++ result.go | 21 ++++ rows.go | 62 ++++++++++ sqlmock.go | 213 ++++++++++++++++++++++++++++++++++ sqlmock_test.go | 268 +++++++++++++++++++++++++++++++++++++++++++ statement.go | 27 +++++ transaction.go | 38 ++++++ 10 files changed, 789 insertions(+) create mode 100644 .gitignore create mode 100644 README.md create mode 100644 expectations.go create mode 100644 expectations_test.go create mode 100644 result.go create mode 100644 rows.go create mode 100644 sqlmock.go create mode 100644 sqlmock_test.go create mode 100644 statement.go create mode 100644 transaction.go diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..8493d1d --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +/*.test diff --git a/README.md b/README.md new file mode 100644 index 0000000..0eeaa95 --- /dev/null +++ b/README.md @@ -0,0 +1,10 @@ + +db = mock.Open("test", "") + +db.ExpectTransactionBegin() +db.ExpectTransactionBegin().WillReturnError("some error") +db.ExpectQuery("SELECT bla").With(5, 8, "stat").WillReturnNone() +db.ExpectExec("UPDATE tbl SET").With(5, "val").WillReturnResult(res /* sql.Result */) +db.ExpectExec("INSERT INTO bla").With(5, 8, "stat").WillReturnResult(res /* sql.Result */) +db.ExpectQuery("SELECT bla").With(5, 8, "stat").WillReturnRows() + diff --git a/expectations.go b/expectations.go new file mode 100644 index 0000000..7d1daf0 --- /dev/null +++ b/expectations.go @@ -0,0 +1,101 @@ +package sqlmock + +import ( + "database/sql/driver" + "reflect" + "regexp" +) + +type expectation interface { + fulfilled() bool + setError(err error) +} + +// common expectation + +type commonExpectation struct { + triggered bool + err error +} + +func (e *commonExpectation) fulfilled() bool { + return e.triggered +} + +func (e *commonExpectation) setError(err error) { + e.err = err +} + +// query based expectation +type queryBasedExpectation struct { + commonExpectation + sqlRegex *regexp.Regexp + args []driver.Value +} + +func (e *queryBasedExpectation) queryMatches(sql string) bool { + return e.sqlRegex.MatchString(sql) +} + +func (e *queryBasedExpectation) argsMatches(args []driver.Value) bool { + if len(args) != len(e.args) { + return false + } + for k, v := range e.args { + vi := reflect.ValueOf(v) + ai := reflect.ValueOf(args[k]) + switch vi.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + if vi.Int() != ai.Int() { + return false + } + case reflect.Float32, reflect.Float64: + if vi.Float() != ai.Float() { + return false + } + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + if vi.Uint() != ai.Uint() { + return false + } + case reflect.String: + if vi.String() != ai.String() { + return false + } + default: + // compare types like time.Time based on type only + if vi.Kind() != ai.Kind() { + return false + } + } + } + return true +} + +// begin transaction +type expectedBegin struct { + commonExpectation +} + +// tx commit +type expectedCommit struct { + commonExpectation +} + +// tx rollback +type expectedRollback struct { + commonExpectation +} + +// query expectation +type expectedQuery struct { + queryBasedExpectation + + rows driver.Rows +} + +// exec query expectation +type expectedExec struct { + queryBasedExpectation + + result driver.Result +} diff --git a/expectations_test.go b/expectations_test.go new file mode 100644 index 0000000..e32c8d1 --- /dev/null +++ b/expectations_test.go @@ -0,0 +1,48 @@ +package sqlmock + +import ( + "database/sql/driver" + "testing" + "time" +) + +func TestQueryExpectationArgComparison(t *testing.T) { + e := &queryBasedExpectation{} + e.args = []driver.Value{5, "str"} + + against := []driver.Value{5} + + if e.argsMatches(against) { + t.Error("Arguments should not match, since the size is not the same") + } + + against = []driver.Value{3, "str"} + if e.argsMatches(against) { + t.Error("Arguments should not match, since the first argument (int value) is different") + } + + against = []driver.Value{5, "st"} + if e.argsMatches(against) { + t.Error("Arguments should not match, since the second argument (string value) is different") + } + + against = []driver.Value{5, "str"} + if !e.argsMatches(against) { + t.Error("Arguments should match, but it did not") + } + + e.args = []driver.Value{5, time.Now()} + + const longForm = "Jan 2, 2006 at 3:04pm (MST)" + tm, _ := time.Parse(longForm, "Feb 3, 2013 at 7:54pm (PST)") + + against = []driver.Value{5, tm} + if !e.argsMatches(against) { + t.Error("Arguments should match (time will be compared only by type), but it did not") + } + + against = []driver.Value{5, 7899000} + if e.argsMatches(against) { + t.Error("Arguments should not match, but it did") + } +} diff --git a/result.go b/result.go new file mode 100644 index 0000000..7aed686 --- /dev/null +++ b/result.go @@ -0,0 +1,21 @@ +package sqlmock + +type Result struct { + lastInsertId int64 + rowsAffected int64 +} + +func NewResult(lastInsertId int64, rowsAffected int64) *Result { + return &Result{ + lastInsertId, + rowsAffected, + } +} + +func (res *Result) LastInsertId() (int64, error) { + return res.lastInsertId, nil +} + +func (res *Result) RowsAffected() (int64, error) { + return res.rowsAffected, nil +} diff --git a/rows.go b/rows.go new file mode 100644 index 0000000..55502a2 --- /dev/null +++ b/rows.go @@ -0,0 +1,62 @@ +package sqlmock + +import ( + "database/sql/driver" + "encoding/csv" + "io" + "strings" +) + +type rows struct { + cols []string + rows [][]driver.Value + pos int +} + +func (r *rows) Columns() []string { + return r.cols +} + +func (r *rows) Close() error { + return nil +} + +func (r *rows) Err() error { + return nil +} + +func (r *rows) Next(dest []driver.Value) error { + r.pos++ + if r.pos > len(r.rows) { + return io.EOF // per interface spec + } + + for i, col := range r.rows[r.pos-1] { + dest[i] = col + } + + return nil +} + +func RowsFromCSVString(columns []string, s string) driver.Rows { + rs := &rows{} + rs.cols = columns + + r := strings.NewReader(strings.TrimSpace(s)) + csvReader := csv.NewReader(r) + + for { + r, err := csvReader.Read() + if err != nil || r == nil { + break + } + + row := make([]driver.Value, len(columns)) + for i, v := range r { + v := strings.TrimSpace(v) + row[i] = v + } + rs.rows = append(rs.rows, row) + } + return rs +} diff --git a/sqlmock.go b/sqlmock.go new file mode 100644 index 0000000..e22505d --- /dev/null +++ b/sqlmock.go @@ -0,0 +1,213 @@ +package sqlmock + +import ( + "database/sql" + "database/sql/driver" + "errors" + "fmt" + "regexp" +) + +var mock *mockDriver + +type Mock interface { + WithArgs(...driver.Value) Mock + WillReturnError(error) Mock + WillReturnRows(driver.Rows) Mock + WillReturnResult(driver.Result) Mock +} + +type mockDriver struct { + conn *conn +} + +func (d *mockDriver) Open(dsn string) (driver.Conn, error) { + return mock.conn, nil +} + +func init() { + mock = &mockDriver{&conn{}} + sql.Register("mock", mock) +} + +type conn struct { + expectations []expectation + active expectation +} + +func (c *conn) Close() (err error) { + for _, e := range mock.conn.expectations { + if !e.fulfilled() { + err = errors.New(fmt.Sprintf("There is expectation %+v which was not matched yet", e)) + break + } + } + mock.conn.expectations = []expectation{} + mock.conn.active = nil + return err +} + +func ExpectBegin() Mock { + e := &expectedBegin{} + mock.conn.expectations = append(mock.conn.expectations, e) + mock.conn.active = e + return mock.conn +} + +func ExpectCommit() Mock { + e := &expectedCommit{} + mock.conn.expectations = append(mock.conn.expectations, e) + mock.conn.active = e + return mock.conn +} + +func ExpectRollback() Mock { + e := &expectedRollback{} + mock.conn.expectations = append(mock.conn.expectations, e) + mock.conn.active = e + return mock.conn +} + +func (c *conn) WillReturnError(err error) Mock { + c.active.setError(err) + return c +} + +func (c *conn) Begin() (driver.Tx, error) { + e := c.next() + if e == nil { + return nil, errors.New("All expectations were already fulfilled, call to Begin transaction was not expected") + } + + etb, ok := e.(*expectedBegin) + if !ok { + return nil, errors.New(fmt.Sprintf("Call to Begin transaction, was not expected, next expectation is %v", e)) + } + etb.triggered = true + return &transaction{c}, etb.err +} + +// get next unfulfilled expectation +func (c *conn) next() (e expectation) { + for _, e = range c.expectations { + if !e.fulfilled() { + return + } + } + return nil // all expectations were fulfilled +} + +func (c *conn) Exec(query string, args []driver.Value) (driver.Result, error) { + e := c.next() + if e == nil { + return nil, errors.New(fmt.Sprintf("All expectations were already fulfilled, call to Exec '%s' query with args [%v] was not expected", query, args)) + } + + eq, ok := e.(*expectedExec) + if !ok { + return nil, errors.New(fmt.Sprintf("Call to Exec query '%s' with args [%v], was not expected, next expectation is %v", query, args, e)) + } + + eq.triggered = true + if eq.err != nil { + return nil, eq.err // mocked to return error + } + + if eq.result == nil { + return nil, errors.New(fmt.Sprintf("Exec query '%s' with args [%v], must return a database/sql/driver.Result, but it was not set for expectation %v", query, args, eq)) + } + + if !eq.queryMatches(query) { + return nil, errors.New(fmt.Sprintf("Exec query '%s', does not match regex [%s]", query, eq.sqlRegex.String())) + } + + if !eq.argsMatches(args) { + return nil, errors.New(fmt.Sprintf("Exec query '%s', args [%v] does not match expected [%v]", query, args, eq.args)) + } + + return eq.result, nil +} + +func ExpectExec(sqlRegexStr string) Mock { + e := &expectedExec{} + e.sqlRegex = regexp.MustCompile(sqlRegexStr) + mock.conn.expectations = append(mock.conn.expectations, e) + mock.conn.active = e + return mock.conn +} + +func ExpectQuery(sqlRegexStr string) Mock { + e := &expectedQuery{} + e.sqlRegex = regexp.MustCompile(sqlRegexStr) + + mock.conn.expectations = append(mock.conn.expectations, e) + mock.conn.active = e + return mock.conn +} + +func (c *conn) WithArgs(args ...driver.Value) Mock { + eq, ok := c.active.(*expectedQuery) + if !ok { + ee, ok := c.active.(*expectedExec) + if !ok { + panic(fmt.Sprintf("Arguments may be expected only with query based expectations, current is %T", c.active)) + } + ee.args = args + } else { + eq.args = args + } + return c +} + +func (c *conn) WillReturnResult(result driver.Result) Mock { + eq, ok := c.active.(*expectedExec) + if !ok { + panic(fmt.Sprintf("driver.Result may be returned only by Exec expectations, current is %v", c.active)) + } + eq.result = result + return c +} + +func (c *conn) WillReturnRows(rows driver.Rows) Mock { + eq, ok := c.active.(*expectedQuery) + if !ok { + panic(fmt.Sprintf("driver.Rows may be returned only by Query expectations, current is %v", c.active)) + } + eq.rows = rows + return c +} + +func (c *conn) Prepare(query string) (driver.Stmt, error) { + return &statement{c, query}, nil +} + +func (c *conn) Query(query string, args []driver.Value) (driver.Rows, error) { + e := c.next() + if e == nil { + return nil, errors.New(fmt.Sprintf("All expectations were already fulfilled, call to Query '%s' with args [%v] was not expected", query, args)) + } + + eq, ok := e.(*expectedQuery) + if !ok { + return nil, errors.New(fmt.Sprintf("Call to Query '%s' with args [%v], was not expected, next expectation is %v", query, args, e)) + } + + eq.triggered = true + if eq.err != nil { + return nil, eq.err // mocked to return error + } + + if eq.rows == nil { + return nil, errors.New(fmt.Sprintf("Query '%s' with args [%v], must return a database/sql/driver.Rows, but it was not set for expectation %v", query, args, eq)) + } + + if !eq.queryMatches(query) { + return nil, errors.New(fmt.Sprintf("Query '%s', does not match regex [%s]", query, eq.sqlRegex.String())) + } + + if !eq.argsMatches(args) { + return nil, errors.New(fmt.Sprintf("Query '%s', args [%v] does not match expected [%v]", query, args, eq.args)) + } + + return eq.rows, nil +} diff --git a/sqlmock_test.go b/sqlmock_test.go new file mode 100644 index 0000000..2a26faa --- /dev/null +++ b/sqlmock_test.go @@ -0,0 +1,268 @@ +package sqlmock + +import ( + "database/sql" + "errors" + "testing" +) + +func TestMockQuery(t *testing.T) { + db, err := sql.Open("mock", "") + if err != nil { + t.Errorf("An error '%s' was not expected when opening a stub database connection", err) + } + + rs := RowsFromCSVString([]string{"id", "title"}, "5,hello world") + + ExpectQuery("SELECT (.+) FROM articles WHERE id = ?"). + WithArgs(5). + WillReturnRows(rs) + + rows, err := db.Query("SELECT (.+) FROM articles WHERE id = ?", 5) + if err != nil { + t.Errorf("Error '%s' was not expected while retrieving mock rows", err) + } + defer rows.Close() + if !rows.Next() { + t.Error("It must have had one row as result, but got empty result set instead") + } + + var id int + var title string + + err = rows.Scan(&id, &title) + if err != nil { + t.Errorf("Error '%s' was not expected while trying to scan row", err) + } + + if id != 5 { + t.Errorf("Expected mocked id to be 5, but got %d instead", id) + } + + if title != "hello world" { + t.Errorf("Expected mocked title to be 'hello world', but got '%s' instead", title) + } + + if err = db.Close(); err != nil { + t.Errorf("Error '%s' was not expected while closing the database", err) + } +} + +func TestTransactionExpectations(t *testing.T) { + db, err := sql.Open("mock", "") + if err != nil { + t.Errorf("An error '%s' was not expected when opening a stub database connection", err) + } + + // begin and commit + ExpectBegin() + ExpectCommit() + + tx, err := db.Begin() + if err != nil { + t.Errorf("An error '%s' was not expected when beginning a transaction", err) + } + + err = tx.Commit() + if err != nil { + t.Errorf("An error '%s' was not expected when commiting a transaction", err) + } + + // begin and rollback + ExpectBegin() + ExpectRollback() + + tx, err = db.Begin() + if err != nil { + t.Errorf("An error '%s' was not expected when beginning a transaction", err) + } + + err = tx.Rollback() + if err != nil { + t.Errorf("An error '%s' was not expected when rolling back a transaction", err) + } + + // begin with an error + ExpectBegin().WillReturnError(errors.New("Some err")) + + tx, err = db.Begin() + if err == nil { + t.Error("An error was expected when beginning a transaction, but got none") + } + + if err = db.Close(); err != nil { + t.Errorf("Error '%s' was not expected while closing the database", err) + } +} + +func TestPreparedQueryExecutions(t *testing.T) { + db, err := sql.Open("mock", "") + if err != nil { + t.Errorf("An error '%s' was not expected when opening a stub database connection", err) + } + + rs1 := RowsFromCSVString([]string{"id", "title"}, "5,hello world") + ExpectQuery("SELECT (.+) FROM articles WHERE id = ?"). + WithArgs(5). + WillReturnRows(rs1) + + rs2 := RowsFromCSVString([]string{"id", "title"}, "2,whoop") + ExpectQuery("SELECT (.+) FROM articles WHERE id = ?"). + WithArgs(2). + WillReturnRows(rs2) + + stmt, err := db.Prepare("SELECT (.+) FROM articles WHERE id = ?") + if err != nil { + t.Errorf("Error '%s' was not expected while creating a prepared statement", err) + } + + var id int + var title string + + err = stmt.QueryRow(5).Scan(&id, &title) + if err != nil { + t.Errorf("Error '%s' was not expected querying row from statement and scanning", err) + } + + if id != 5 { + t.Errorf("Expected mocked id to be 5, but got %d instead", id) + } + + if title != "hello world" { + t.Errorf("Expected mocked title to be 'hello world', but got '%s' instead", title) + } + + err = stmt.QueryRow(2).Scan(&id, &title) + if err != nil { + t.Errorf("Error '%s' was not expected querying row from statement and scanning", err) + } + + if id != 2 { + t.Errorf("Expected mocked id to be 2, but got %d instead", id) + } + + if title != "whoop" { + t.Errorf("Expected mocked title to be 'whoop', but got '%s' instead", title) + } + + if err = db.Close(); err != nil { + t.Errorf("Error '%s' was not expected while closing the database", err) + } +} + +func TestUnexpectedOperations(t *testing.T) { + db, err := sql.Open("mock", "") + if err != nil { + t.Errorf("An error '%s' was not expected when opening a stub database connection", err) + } + + stmt, err := db.Prepare("SELECT (.+) FROM articles WHERE id = ?") + if err != nil { + t.Errorf("Error '%s' was not expected while creating a prepared statement", err) + } + + var id int + var title string + + err = stmt.QueryRow(5).Scan(&id, &title) + if err == nil { + t.Error("Error was expected querying row, since there was no such expectation") + } + + ExpectRollback() + + err = db.Close() + if err == nil { + t.Error("Error was expected while closing the database, expectation was not fulfilled", err) + } +} + +func TestWrongUnexpectations(t *testing.T) { + db, err := sql.Open("mock", "") + if err != nil { + t.Errorf("An error '%s' was not expected when opening a stub database connection", err) + } + + ExpectBegin() + + rs1 := RowsFromCSVString([]string{"id", "title"}, "5,hello world") + ExpectQuery("SELECT (.+) FROM articles WHERE id = ?"). + WithArgs(5). + WillReturnRows(rs1) + + ExpectCommit().WillReturnError(errors.New("Deadlock occured")) + ExpectRollback() // won't be triggered + + stmt, err := db.Prepare("SELECT (.+) FROM articles WHERE id = ? FOR UPDATE") + if err != nil { + t.Errorf("Error '%s' was not expected while creating a prepared statement", err) + } + + var id int + var title string + + err = stmt.QueryRow(5).Scan(&id, &title) + if err == nil { + t.Error("Error was expected while querying row, since there Begin transaction expectation is not fulfilled") + } + + // lets go around and start transaction + tx, err := db.Begin() + if err != nil { + t.Errorf("An error '%s' was not expected when beginning a transaction", err) + } + + err = stmt.QueryRow(5).Scan(&id, &title) + if err != nil { + t.Errorf("Error '%s' was not expected while querying row, since transaction was started", err) + } + + err = tx.Commit() + if err == nil { + t.Error("A deadlock error was expected when commiting a transaction", err) + } + + err = db.Close() + if err == nil { + t.Error("Error was expected while closing the database, expectation was not fulfilled", err) + } +} + +func TestExecExpectations(t *testing.T) { + db, err := sql.Open("mock", "") + if err != nil { + t.Errorf("An error '%s' was not expected when opening a stub database connection", err) + } + + result := NewResult(1, 1) + ExpectExec("^INSERT INTO articles"). + WithArgs("hello"). + WillReturnResult(result) + + res, err := db.Exec("INSERT INTO articles (title) VALUES (?)", "hello") + if err != nil { + t.Errorf("Error '%s' was not expected, while inserting a row", err) + } + + id, err := res.LastInsertId() + if err != nil { + t.Errorf("Error '%s' was not expected, while getting a last insert id", err) + } + + affected, err := res.RowsAffected() + if err != nil { + t.Errorf("Error '%s' was not expected, while getting affected rows", err) + } + + if id != 1 { + t.Errorf("Expected last insert id to be 1, but got %d instead", id) + } + + if affected != 1 { + t.Errorf("Expected affected rows to be 1, but got %d instead", affected) + } + + if err = db.Close(); err != nil { + t.Errorf("Error '%s' was not expected while closing the database", err) + } +} diff --git a/statement.go b/statement.go new file mode 100644 index 0000000..d862b23 --- /dev/null +++ b/statement.go @@ -0,0 +1,27 @@ +package sqlmock + +import ( + "database/sql/driver" +) + +type statement struct { + conn *conn + query string +} + +func (stmt *statement) Close() error { + stmt.conn = nil + return nil +} + +func (stmt *statement) NumInput() int { + return -1 +} + +func (stmt *statement) Exec(args []driver.Value) (driver.Result, error) { + return stmt.conn.Exec(stmt.query, args) +} + +func (stmt *statement) Query(args []driver.Value) (driver.Rows, error) { + return stmt.conn.Query(stmt.query, args) +} diff --git a/transaction.go b/transaction.go new file mode 100644 index 0000000..76a0afe --- /dev/null +++ b/transaction.go @@ -0,0 +1,38 @@ +package sqlmock + +import ( + "errors" + "fmt" +) + +type transaction struct { + conn *conn +} + +func (tx *transaction) Commit() error { + e := tx.conn.next() + if e == nil { + return errors.New("All expectations were already fulfilled, call to Commit transaction was not expected") + } + + etc, ok := e.(*expectedCommit) + if !ok { + return errors.New(fmt.Sprintf("Call to Commit transaction, was not expected, next expectation was %v", e)) + } + etc.triggered = true + return etc.err +} + +func (tx *transaction) Rollback() error { + e := tx.conn.next() + if e == nil { + return errors.New("All expectations were already fulfilled, call to Rollback transaction was not expected") + } + + etr, ok := e.(*expectedRollback) + if !ok { + return errors.New(fmt.Sprintf("Call to Rollback transaction, was not expected, next expectation was %v", e)) + } + etr.triggered = true + return etr.err +}