You've already forked go-sqlxmock
							
							
				mirror of
				https://github.com/zhashkevych/go-sqlxmock.git
				synced 2025-10-30 23:27:38 +02:00 
			
		
		
		
	initial commit
This commit is contained in:
		
							
								
								
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
										Normal file
									
								
							| @@ -0,0 +1 @@ | ||||
| /*.test | ||||
							
								
								
									
										10
									
								
								README.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										10
									
								
								README.md
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,10 @@ | ||||
|  | ||||
| db = mock.Open("test", "") | ||||
|  | ||||
| db.ExpectTransactionBegin() | ||||
| db.ExpectTransactionBegin().WillReturnError("some error") | ||||
| db.ExpectQuery("SELECT bla").With(5, 8, "stat").WillReturnNone() | ||||
| db.ExpectExec("UPDATE tbl SET").With(5, "val").WillReturnResult(res /* sql.Result */) | ||||
| db.ExpectExec("INSERT INTO bla").With(5, 8, "stat").WillReturnResult(res /* sql.Result */) | ||||
| db.ExpectQuery("SELECT bla").With(5, 8, "stat").WillReturnRows() | ||||
|  | ||||
							
								
								
									
										101
									
								
								expectations.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										101
									
								
								expectations.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,101 @@ | ||||
| package sqlmock | ||||
|  | ||||
| import ( | ||||
| 	"database/sql/driver" | ||||
| 	"reflect" | ||||
| 	"regexp" | ||||
| ) | ||||
|  | ||||
| type expectation interface { | ||||
| 	fulfilled() bool | ||||
| 	setError(err error) | ||||
| } | ||||
|  | ||||
| // common expectation | ||||
|  | ||||
| type commonExpectation struct { | ||||
| 	triggered bool | ||||
| 	err       error | ||||
| } | ||||
|  | ||||
| func (e *commonExpectation) fulfilled() bool { | ||||
| 	return e.triggered | ||||
| } | ||||
|  | ||||
| func (e *commonExpectation) setError(err error) { | ||||
| 	e.err = err | ||||
| } | ||||
|  | ||||
| // query based expectation | ||||
| type queryBasedExpectation struct { | ||||
| 	commonExpectation | ||||
| 	sqlRegex *regexp.Regexp | ||||
| 	args     []driver.Value | ||||
| } | ||||
|  | ||||
| func (e *queryBasedExpectation) queryMatches(sql string) bool { | ||||
| 	return e.sqlRegex.MatchString(sql) | ||||
| } | ||||
|  | ||||
| func (e *queryBasedExpectation) argsMatches(args []driver.Value) bool { | ||||
| 	if len(args) != len(e.args) { | ||||
| 		return false | ||||
| 	} | ||||
| 	for k, v := range e.args { | ||||
| 		vi := reflect.ValueOf(v) | ||||
| 		ai := reflect.ValueOf(args[k]) | ||||
| 		switch vi.Kind() { | ||||
| 		case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: | ||||
| 			if vi.Int() != ai.Int() { | ||||
| 				return false | ||||
| 			} | ||||
| 		case reflect.Float32, reflect.Float64: | ||||
| 			if vi.Float() != ai.Float() { | ||||
| 				return false | ||||
| 			} | ||||
| 		case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: | ||||
| 			if vi.Uint() != ai.Uint() { | ||||
| 				return false | ||||
| 			} | ||||
| 		case reflect.String: | ||||
| 			if vi.String() != ai.String() { | ||||
| 				return false | ||||
| 			} | ||||
| 		default: | ||||
| 			// compare types like time.Time based on type only | ||||
| 			if vi.Kind() != ai.Kind() { | ||||
| 				return false | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| 	return true | ||||
| } | ||||
|  | ||||
| // begin transaction | ||||
| type expectedBegin struct { | ||||
| 	commonExpectation | ||||
| } | ||||
|  | ||||
| // tx commit | ||||
| type expectedCommit struct { | ||||
| 	commonExpectation | ||||
| } | ||||
|  | ||||
| // tx rollback | ||||
| type expectedRollback struct { | ||||
| 	commonExpectation | ||||
| } | ||||
|  | ||||
| // query expectation | ||||
| type expectedQuery struct { | ||||
| 	queryBasedExpectation | ||||
|  | ||||
| 	rows driver.Rows | ||||
| } | ||||
|  | ||||
| // exec query expectation | ||||
| type expectedExec struct { | ||||
| 	queryBasedExpectation | ||||
|  | ||||
| 	result driver.Result | ||||
| } | ||||
							
								
								
									
										48
									
								
								expectations_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										48
									
								
								expectations_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,48 @@ | ||||
| package sqlmock | ||||
|  | ||||
| import ( | ||||
| 	"database/sql/driver" | ||||
| 	"testing" | ||||
| 	"time" | ||||
| ) | ||||
|  | ||||
| func TestQueryExpectationArgComparison(t *testing.T) { | ||||
| 	e := &queryBasedExpectation{} | ||||
| 	e.args = []driver.Value{5, "str"} | ||||
|  | ||||
| 	against := []driver.Value{5} | ||||
|  | ||||
| 	if e.argsMatches(against) { | ||||
| 		t.Error("Arguments should not match, since the size is not the same") | ||||
| 	} | ||||
|  | ||||
| 	against = []driver.Value{3, "str"} | ||||
| 	if e.argsMatches(against) { | ||||
| 		t.Error("Arguments should not match, since the first argument (int value) is different") | ||||
| 	} | ||||
|  | ||||
| 	against = []driver.Value{5, "st"} | ||||
| 	if e.argsMatches(against) { | ||||
| 		t.Error("Arguments should not match, since the second argument (string value) is different") | ||||
| 	} | ||||
|  | ||||
| 	against = []driver.Value{5, "str"} | ||||
| 	if !e.argsMatches(against) { | ||||
| 		t.Error("Arguments should match, but it did not") | ||||
| 	} | ||||
|  | ||||
| 	e.args = []driver.Value{5, time.Now()} | ||||
|  | ||||
| 	const longForm = "Jan 2, 2006 at 3:04pm (MST)" | ||||
| 	tm, _ := time.Parse(longForm, "Feb 3, 2013 at 7:54pm (PST)") | ||||
|  | ||||
| 	against = []driver.Value{5, tm} | ||||
| 	if !e.argsMatches(against) { | ||||
| 		t.Error("Arguments should match (time will be compared only by type), but it did not") | ||||
| 	} | ||||
|  | ||||
| 	against = []driver.Value{5, 7899000} | ||||
| 	if e.argsMatches(against) { | ||||
| 		t.Error("Arguments should not match, but it did") | ||||
| 	} | ||||
| } | ||||
							
								
								
									
										21
									
								
								result.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										21
									
								
								result.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,21 @@ | ||||
| package sqlmock | ||||
|  | ||||
| type Result struct { | ||||
| 	lastInsertId int64 | ||||
| 	rowsAffected int64 | ||||
| } | ||||
|  | ||||
| func NewResult(lastInsertId int64, rowsAffected int64) *Result { | ||||
| 	return &Result{ | ||||
| 		lastInsertId, | ||||
| 		rowsAffected, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (res *Result) LastInsertId() (int64, error) { | ||||
| 	return res.lastInsertId, nil | ||||
| } | ||||
|  | ||||
| func (res *Result) RowsAffected() (int64, error) { | ||||
| 	return res.rowsAffected, nil | ||||
| } | ||||
							
								
								
									
										62
									
								
								rows.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										62
									
								
								rows.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,62 @@ | ||||
| package sqlmock | ||||
|  | ||||
| import ( | ||||
| 	"database/sql/driver" | ||||
| 	"encoding/csv" | ||||
| 	"io" | ||||
| 	"strings" | ||||
| ) | ||||
|  | ||||
| type rows struct { | ||||
| 	cols []string | ||||
| 	rows [][]driver.Value | ||||
| 	pos  int | ||||
| } | ||||
|  | ||||
| func (r *rows) Columns() []string { | ||||
| 	return r.cols | ||||
| } | ||||
|  | ||||
| func (r *rows) Close() error { | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (r *rows) Err() error { | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (r *rows) Next(dest []driver.Value) error { | ||||
| 	r.pos++ | ||||
| 	if r.pos > len(r.rows) { | ||||
| 		return io.EOF // per interface spec | ||||
| 	} | ||||
|  | ||||
| 	for i, col := range r.rows[r.pos-1] { | ||||
| 		dest[i] = col | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func RowsFromCSVString(columns []string, s string) driver.Rows { | ||||
| 	rs := &rows{} | ||||
| 	rs.cols = columns | ||||
|  | ||||
| 	r := strings.NewReader(strings.TrimSpace(s)) | ||||
| 	csvReader := csv.NewReader(r) | ||||
|  | ||||
| 	for { | ||||
| 		r, err := csvReader.Read() | ||||
| 		if err != nil || r == nil { | ||||
| 			break | ||||
| 		} | ||||
|  | ||||
| 		row := make([]driver.Value, len(columns)) | ||||
| 		for i, v := range r { | ||||
| 			v := strings.TrimSpace(v) | ||||
| 			row[i] = v | ||||
| 		} | ||||
| 		rs.rows = append(rs.rows, row) | ||||
| 	} | ||||
| 	return rs | ||||
| } | ||||
							
								
								
									
										213
									
								
								sqlmock.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										213
									
								
								sqlmock.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,213 @@ | ||||
| package sqlmock | ||||
|  | ||||
| import ( | ||||
| 	"database/sql" | ||||
| 	"database/sql/driver" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"regexp" | ||||
| ) | ||||
|  | ||||
| var mock *mockDriver | ||||
|  | ||||
| type Mock interface { | ||||
| 	WithArgs(...driver.Value) Mock | ||||
| 	WillReturnError(error) Mock | ||||
| 	WillReturnRows(driver.Rows) Mock | ||||
| 	WillReturnResult(driver.Result) Mock | ||||
| } | ||||
|  | ||||
| type mockDriver struct { | ||||
| 	conn *conn | ||||
| } | ||||
|  | ||||
| func (d *mockDriver) Open(dsn string) (driver.Conn, error) { | ||||
| 	return mock.conn, nil | ||||
| } | ||||
|  | ||||
| func init() { | ||||
| 	mock = &mockDriver{&conn{}} | ||||
| 	sql.Register("mock", mock) | ||||
| } | ||||
|  | ||||
| type conn struct { | ||||
| 	expectations []expectation | ||||
| 	active       expectation | ||||
| } | ||||
|  | ||||
| func (c *conn) Close() (err error) { | ||||
| 	for _, e := range mock.conn.expectations { | ||||
| 		if !e.fulfilled() { | ||||
| 			err = errors.New(fmt.Sprintf("There is expectation %+v which was not matched yet", e)) | ||||
| 			break | ||||
| 		} | ||||
| 	} | ||||
| 	mock.conn.expectations = []expectation{} | ||||
| 	mock.conn.active = nil | ||||
| 	return err | ||||
| } | ||||
|  | ||||
| func ExpectBegin() Mock { | ||||
| 	e := &expectedBegin{} | ||||
| 	mock.conn.expectations = append(mock.conn.expectations, e) | ||||
| 	mock.conn.active = e | ||||
| 	return mock.conn | ||||
| } | ||||
|  | ||||
| func ExpectCommit() Mock { | ||||
| 	e := &expectedCommit{} | ||||
| 	mock.conn.expectations = append(mock.conn.expectations, e) | ||||
| 	mock.conn.active = e | ||||
| 	return mock.conn | ||||
| } | ||||
|  | ||||
| func ExpectRollback() Mock { | ||||
| 	e := &expectedRollback{} | ||||
| 	mock.conn.expectations = append(mock.conn.expectations, e) | ||||
| 	mock.conn.active = e | ||||
| 	return mock.conn | ||||
| } | ||||
|  | ||||
| func (c *conn) WillReturnError(err error) Mock { | ||||
| 	c.active.setError(err) | ||||
| 	return c | ||||
| } | ||||
|  | ||||
| func (c *conn) Begin() (driver.Tx, error) { | ||||
| 	e := c.next() | ||||
| 	if e == nil { | ||||
| 		return nil, errors.New("All expectations were already fulfilled, call to Begin transaction was not expected") | ||||
| 	} | ||||
|  | ||||
| 	etb, ok := e.(*expectedBegin) | ||||
| 	if !ok { | ||||
| 		return nil, errors.New(fmt.Sprintf("Call to Begin transaction, was not expected, next expectation is %v", e)) | ||||
| 	} | ||||
| 	etb.triggered = true | ||||
| 	return &transaction{c}, etb.err | ||||
| } | ||||
|  | ||||
| // get next unfulfilled expectation | ||||
| func (c *conn) next() (e expectation) { | ||||
| 	for _, e = range c.expectations { | ||||
| 		if !e.fulfilled() { | ||||
| 			return | ||||
| 		} | ||||
| 	} | ||||
| 	return nil // all expectations were fulfilled | ||||
| } | ||||
|  | ||||
| func (c *conn) Exec(query string, args []driver.Value) (driver.Result, error) { | ||||
| 	e := c.next() | ||||
| 	if e == nil { | ||||
| 		return nil, errors.New(fmt.Sprintf("All expectations were already fulfilled, call to Exec '%s' query with args [%v] was not expected", query, args)) | ||||
| 	} | ||||
|  | ||||
| 	eq, ok := e.(*expectedExec) | ||||
| 	if !ok { | ||||
| 		return nil, errors.New(fmt.Sprintf("Call to Exec query '%s' with args [%v], was not expected, next expectation is %v", query, args, e)) | ||||
| 	} | ||||
|  | ||||
| 	eq.triggered = true | ||||
| 	if eq.err != nil { | ||||
| 		return nil, eq.err // mocked to return error | ||||
| 	} | ||||
|  | ||||
| 	if eq.result == nil { | ||||
| 		return nil, errors.New(fmt.Sprintf("Exec query '%s' with args [%v], must return a database/sql/driver.Result, but it was not set for expectation %v", query, args, eq)) | ||||
| 	} | ||||
|  | ||||
| 	if !eq.queryMatches(query) { | ||||
| 		return nil, errors.New(fmt.Sprintf("Exec query '%s', does not match regex [%s]", query, eq.sqlRegex.String())) | ||||
| 	} | ||||
|  | ||||
| 	if !eq.argsMatches(args) { | ||||
| 		return nil, errors.New(fmt.Sprintf("Exec query '%s', args [%v] does not match expected [%v]", query, args, eq.args)) | ||||
| 	} | ||||
|  | ||||
| 	return eq.result, nil | ||||
| } | ||||
|  | ||||
| func ExpectExec(sqlRegexStr string) Mock { | ||||
| 	e := &expectedExec{} | ||||
| 	e.sqlRegex = regexp.MustCompile(sqlRegexStr) | ||||
| 	mock.conn.expectations = append(mock.conn.expectations, e) | ||||
| 	mock.conn.active = e | ||||
| 	return mock.conn | ||||
| } | ||||
|  | ||||
| func ExpectQuery(sqlRegexStr string) Mock { | ||||
| 	e := &expectedQuery{} | ||||
| 	e.sqlRegex = regexp.MustCompile(sqlRegexStr) | ||||
|  | ||||
| 	mock.conn.expectations = append(mock.conn.expectations, e) | ||||
| 	mock.conn.active = e | ||||
| 	return mock.conn | ||||
| } | ||||
|  | ||||
| func (c *conn) WithArgs(args ...driver.Value) Mock { | ||||
| 	eq, ok := c.active.(*expectedQuery) | ||||
| 	if !ok { | ||||
| 		ee, ok := c.active.(*expectedExec) | ||||
| 		if !ok { | ||||
| 			panic(fmt.Sprintf("Arguments may be expected only with query based expectations, current is %T", c.active)) | ||||
| 		} | ||||
| 		ee.args = args | ||||
| 	} else { | ||||
| 		eq.args = args | ||||
| 	} | ||||
| 	return c | ||||
| } | ||||
|  | ||||
| func (c *conn) WillReturnResult(result driver.Result) Mock { | ||||
| 	eq, ok := c.active.(*expectedExec) | ||||
| 	if !ok { | ||||
| 		panic(fmt.Sprintf("driver.Result may be returned only by Exec expectations, current is %v", c.active)) | ||||
| 	} | ||||
| 	eq.result = result | ||||
| 	return c | ||||
| } | ||||
|  | ||||
| func (c *conn) WillReturnRows(rows driver.Rows) Mock { | ||||
| 	eq, ok := c.active.(*expectedQuery) | ||||
| 	if !ok { | ||||
| 		panic(fmt.Sprintf("driver.Rows may be returned only by Query expectations, current is %v", c.active)) | ||||
| 	} | ||||
| 	eq.rows = rows | ||||
| 	return c | ||||
| } | ||||
|  | ||||
| func (c *conn) Prepare(query string) (driver.Stmt, error) { | ||||
| 	return &statement{c, query}, nil | ||||
| } | ||||
|  | ||||
| func (c *conn) Query(query string, args []driver.Value) (driver.Rows, error) { | ||||
| 	e := c.next() | ||||
| 	if e == nil { | ||||
| 		return nil, errors.New(fmt.Sprintf("All expectations were already fulfilled, call to Query '%s' with args [%v] was not expected", query, args)) | ||||
| 	} | ||||
|  | ||||
| 	eq, ok := e.(*expectedQuery) | ||||
| 	if !ok { | ||||
| 		return nil, errors.New(fmt.Sprintf("Call to Query '%s' with args [%v], was not expected, next expectation is %v", query, args, e)) | ||||
| 	} | ||||
|  | ||||
| 	eq.triggered = true | ||||
| 	if eq.err != nil { | ||||
| 		return nil, eq.err // mocked to return error | ||||
| 	} | ||||
|  | ||||
| 	if eq.rows == nil { | ||||
| 		return nil, errors.New(fmt.Sprintf("Query '%s' with args [%v], must return a database/sql/driver.Rows, but it was not set for expectation %v", query, args, eq)) | ||||
| 	} | ||||
|  | ||||
| 	if !eq.queryMatches(query) { | ||||
| 		return nil, errors.New(fmt.Sprintf("Query '%s', does not match regex [%s]", query, eq.sqlRegex.String())) | ||||
| 	} | ||||
|  | ||||
| 	if !eq.argsMatches(args) { | ||||
| 		return nil, errors.New(fmt.Sprintf("Query '%s', args [%v] does not match expected [%v]", query, args, eq.args)) | ||||
| 	} | ||||
|  | ||||
| 	return eq.rows, nil | ||||
| } | ||||
							
								
								
									
										268
									
								
								sqlmock_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										268
									
								
								sqlmock_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,268 @@ | ||||
| package sqlmock | ||||
|  | ||||
| import ( | ||||
| 	"database/sql" | ||||
| 	"errors" | ||||
| 	"testing" | ||||
| ) | ||||
|  | ||||
| func TestMockQuery(t *testing.T) { | ||||
| 	db, err := sql.Open("mock", "") | ||||
| 	if err != nil { | ||||
| 		t.Errorf("An error '%s' was not expected when opening a stub database connection", err) | ||||
| 	} | ||||
|  | ||||
| 	rs := RowsFromCSVString([]string{"id", "title"}, "5,hello world") | ||||
|  | ||||
| 	ExpectQuery("SELECT (.+) FROM articles WHERE id = ?"). | ||||
| 		WithArgs(5). | ||||
| 		WillReturnRows(rs) | ||||
|  | ||||
| 	rows, err := db.Query("SELECT (.+) FROM articles WHERE id = ?", 5) | ||||
| 	if err != nil { | ||||
| 		t.Errorf("Error '%s' was not expected while retrieving mock rows", err) | ||||
| 	} | ||||
| 	defer rows.Close() | ||||
| 	if !rows.Next() { | ||||
| 		t.Error("It must have had one row as result, but got empty result set instead") | ||||
| 	} | ||||
|  | ||||
| 	var id int | ||||
| 	var title string | ||||
|  | ||||
| 	err = rows.Scan(&id, &title) | ||||
| 	if err != nil { | ||||
| 		t.Errorf("Error '%s' was not expected while trying to scan row", err) | ||||
| 	} | ||||
|  | ||||
| 	if id != 5 { | ||||
| 		t.Errorf("Expected mocked id to be 5, but got %d instead", id) | ||||
| 	} | ||||
|  | ||||
| 	if title != "hello world" { | ||||
| 		t.Errorf("Expected mocked title to be 'hello world', but got '%s' instead", title) | ||||
| 	} | ||||
|  | ||||
| 	if err = db.Close(); err != nil { | ||||
| 		t.Errorf("Error '%s' was not expected while closing the database", err) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestTransactionExpectations(t *testing.T) { | ||||
| 	db, err := sql.Open("mock", "") | ||||
| 	if err != nil { | ||||
| 		t.Errorf("An error '%s' was not expected when opening a stub database connection", err) | ||||
| 	} | ||||
|  | ||||
| 	// begin and commit | ||||
| 	ExpectBegin() | ||||
| 	ExpectCommit() | ||||
|  | ||||
| 	tx, err := db.Begin() | ||||
| 	if err != nil { | ||||
| 		t.Errorf("An error '%s' was not expected when beginning a transaction", err) | ||||
| 	} | ||||
|  | ||||
| 	err = tx.Commit() | ||||
| 	if err != nil { | ||||
| 		t.Errorf("An error '%s' was not expected when commiting a transaction", err) | ||||
| 	} | ||||
|  | ||||
| 	// begin and rollback | ||||
| 	ExpectBegin() | ||||
| 	ExpectRollback() | ||||
|  | ||||
| 	tx, err = db.Begin() | ||||
| 	if err != nil { | ||||
| 		t.Errorf("An error '%s' was not expected when beginning a transaction", err) | ||||
| 	} | ||||
|  | ||||
| 	err = tx.Rollback() | ||||
| 	if err != nil { | ||||
| 		t.Errorf("An error '%s' was not expected when rolling back a transaction", err) | ||||
| 	} | ||||
|  | ||||
| 	// begin with an error | ||||
| 	ExpectBegin().WillReturnError(errors.New("Some err")) | ||||
|  | ||||
| 	tx, err = db.Begin() | ||||
| 	if err == nil { | ||||
| 		t.Error("An error was expected when beginning a transaction, but got none") | ||||
| 	} | ||||
|  | ||||
| 	if err = db.Close(); err != nil { | ||||
| 		t.Errorf("Error '%s' was not expected while closing the database", err) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestPreparedQueryExecutions(t *testing.T) { | ||||
| 	db, err := sql.Open("mock", "") | ||||
| 	if err != nil { | ||||
| 		t.Errorf("An error '%s' was not expected when opening a stub database connection", err) | ||||
| 	} | ||||
|  | ||||
| 	rs1 := RowsFromCSVString([]string{"id", "title"}, "5,hello world") | ||||
| 	ExpectQuery("SELECT (.+) FROM articles WHERE id = ?"). | ||||
| 		WithArgs(5). | ||||
| 		WillReturnRows(rs1) | ||||
|  | ||||
| 	rs2 := RowsFromCSVString([]string{"id", "title"}, "2,whoop") | ||||
| 	ExpectQuery("SELECT (.+) FROM articles WHERE id = ?"). | ||||
| 		WithArgs(2). | ||||
| 		WillReturnRows(rs2) | ||||
|  | ||||
| 	stmt, err := db.Prepare("SELECT (.+) FROM articles WHERE id = ?") | ||||
| 	if err != nil { | ||||
| 		t.Errorf("Error '%s' was not expected while creating a prepared statement", err) | ||||
| 	} | ||||
|  | ||||
| 	var id int | ||||
| 	var title string | ||||
|  | ||||
| 	err = stmt.QueryRow(5).Scan(&id, &title) | ||||
| 	if err != nil { | ||||
| 		t.Errorf("Error '%s' was not expected querying row from statement and scanning", err) | ||||
| 	} | ||||
|  | ||||
| 	if id != 5 { | ||||
| 		t.Errorf("Expected mocked id to be 5, but got %d instead", id) | ||||
| 	} | ||||
|  | ||||
| 	if title != "hello world" { | ||||
| 		t.Errorf("Expected mocked title to be 'hello world', but got '%s' instead", title) | ||||
| 	} | ||||
|  | ||||
| 	err = stmt.QueryRow(2).Scan(&id, &title) | ||||
| 	if err != nil { | ||||
| 		t.Errorf("Error '%s' was not expected querying row from statement and scanning", err) | ||||
| 	} | ||||
|  | ||||
| 	if id != 2 { | ||||
| 		t.Errorf("Expected mocked id to be 2, but got %d instead", id) | ||||
| 	} | ||||
|  | ||||
| 	if title != "whoop" { | ||||
| 		t.Errorf("Expected mocked title to be 'whoop', but got '%s' instead", title) | ||||
| 	} | ||||
|  | ||||
| 	if err = db.Close(); err != nil { | ||||
| 		t.Errorf("Error '%s' was not expected while closing the database", err) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestUnexpectedOperations(t *testing.T) { | ||||
| 	db, err := sql.Open("mock", "") | ||||
| 	if err != nil { | ||||
| 		t.Errorf("An error '%s' was not expected when opening a stub database connection", err) | ||||
| 	} | ||||
|  | ||||
| 	stmt, err := db.Prepare("SELECT (.+) FROM articles WHERE id = ?") | ||||
| 	if err != nil { | ||||
| 		t.Errorf("Error '%s' was not expected while creating a prepared statement", err) | ||||
| 	} | ||||
|  | ||||
| 	var id int | ||||
| 	var title string | ||||
|  | ||||
| 	err = stmt.QueryRow(5).Scan(&id, &title) | ||||
| 	if err == nil { | ||||
| 		t.Error("Error was expected querying row, since there was no such expectation") | ||||
| 	} | ||||
|  | ||||
| 	ExpectRollback() | ||||
|  | ||||
| 	err = db.Close() | ||||
| 	if err == nil { | ||||
| 		t.Error("Error was expected while closing the database, expectation was not fulfilled", err) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestWrongUnexpectations(t *testing.T) { | ||||
| 	db, err := sql.Open("mock", "") | ||||
| 	if err != nil { | ||||
| 		t.Errorf("An error '%s' was not expected when opening a stub database connection", err) | ||||
| 	} | ||||
|  | ||||
| 	ExpectBegin() | ||||
|  | ||||
| 	rs1 := RowsFromCSVString([]string{"id", "title"}, "5,hello world") | ||||
| 	ExpectQuery("SELECT (.+) FROM articles WHERE id = ?"). | ||||
| 		WithArgs(5). | ||||
| 		WillReturnRows(rs1) | ||||
|  | ||||
| 	ExpectCommit().WillReturnError(errors.New("Deadlock occured")) | ||||
| 	ExpectRollback() // won't be triggered | ||||
|  | ||||
| 	stmt, err := db.Prepare("SELECT (.+) FROM articles WHERE id = ? FOR UPDATE") | ||||
| 	if err != nil { | ||||
| 		t.Errorf("Error '%s' was not expected while creating a prepared statement", err) | ||||
| 	} | ||||
|  | ||||
| 	var id int | ||||
| 	var title string | ||||
|  | ||||
| 	err = stmt.QueryRow(5).Scan(&id, &title) | ||||
| 	if err == nil { | ||||
| 		t.Error("Error was expected while querying row, since there Begin transaction expectation is not fulfilled") | ||||
| 	} | ||||
|  | ||||
| 	// lets go around and start transaction | ||||
| 	tx, err := db.Begin() | ||||
| 	if err != nil { | ||||
| 		t.Errorf("An error '%s' was not expected when beginning a transaction", err) | ||||
| 	} | ||||
|  | ||||
| 	err = stmt.QueryRow(5).Scan(&id, &title) | ||||
| 	if err != nil { | ||||
| 		t.Errorf("Error '%s' was not expected while querying row, since transaction was started", err) | ||||
| 	} | ||||
|  | ||||
| 	err = tx.Commit() | ||||
| 	if err == nil { | ||||
| 		t.Error("A deadlock error was expected when commiting a transaction", err) | ||||
| 	} | ||||
|  | ||||
| 	err = db.Close() | ||||
| 	if err == nil { | ||||
| 		t.Error("Error was expected while closing the database, expectation was not fulfilled", err) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestExecExpectations(t *testing.T) { | ||||
| 	db, err := sql.Open("mock", "") | ||||
| 	if err != nil { | ||||
| 		t.Errorf("An error '%s' was not expected when opening a stub database connection", err) | ||||
| 	} | ||||
|  | ||||
| 	result := NewResult(1, 1) | ||||
| 	ExpectExec("^INSERT INTO articles"). | ||||
| 		WithArgs("hello"). | ||||
| 		WillReturnResult(result) | ||||
|  | ||||
| 	res, err := db.Exec("INSERT INTO articles (title) VALUES (?)", "hello") | ||||
| 	if err != nil { | ||||
| 		t.Errorf("Error '%s' was not expected, while inserting a row", err) | ||||
| 	} | ||||
|  | ||||
| 	id, err := res.LastInsertId() | ||||
| 	if err != nil { | ||||
| 		t.Errorf("Error '%s' was not expected, while getting a last insert id", err) | ||||
| 	} | ||||
|  | ||||
| 	affected, err := res.RowsAffected() | ||||
| 	if err != nil { | ||||
| 		t.Errorf("Error '%s' was not expected, while getting affected rows", err) | ||||
| 	} | ||||
|  | ||||
| 	if id != 1 { | ||||
| 		t.Errorf("Expected last insert id to be 1, but got %d instead", id) | ||||
| 	} | ||||
|  | ||||
| 	if affected != 1 { | ||||
| 		t.Errorf("Expected affected rows to be 1, but got %d instead", affected) | ||||
| 	} | ||||
|  | ||||
| 	if err = db.Close(); err != nil { | ||||
| 		t.Errorf("Error '%s' was not expected while closing the database", err) | ||||
| 	} | ||||
| } | ||||
							
								
								
									
										27
									
								
								statement.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										27
									
								
								statement.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,27 @@ | ||||
| package sqlmock | ||||
|  | ||||
| import ( | ||||
| 	"database/sql/driver" | ||||
| ) | ||||
|  | ||||
| type statement struct { | ||||
| 	conn  *conn | ||||
| 	query string | ||||
| } | ||||
|  | ||||
| func (stmt *statement) Close() error { | ||||
| 	stmt.conn = nil | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (stmt *statement) NumInput() int { | ||||
| 	return -1 | ||||
| } | ||||
|  | ||||
| func (stmt *statement) Exec(args []driver.Value) (driver.Result, error) { | ||||
| 	return stmt.conn.Exec(stmt.query, args) | ||||
| } | ||||
|  | ||||
| func (stmt *statement) Query(args []driver.Value) (driver.Rows, error) { | ||||
| 	return stmt.conn.Query(stmt.query, args) | ||||
| } | ||||
							
								
								
									
										38
									
								
								transaction.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										38
									
								
								transaction.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,38 @@ | ||||
| package sqlmock | ||||
|  | ||||
| import ( | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| ) | ||||
|  | ||||
| type transaction struct { | ||||
| 	conn *conn | ||||
| } | ||||
|  | ||||
| func (tx *transaction) Commit() error { | ||||
| 	e := tx.conn.next() | ||||
| 	if e == nil { | ||||
| 		return errors.New("All expectations were already fulfilled, call to Commit transaction was not expected") | ||||
| 	} | ||||
|  | ||||
| 	etc, ok := e.(*expectedCommit) | ||||
| 	if !ok { | ||||
| 		return errors.New(fmt.Sprintf("Call to Commit transaction, was not expected, next expectation was %v", e)) | ||||
| 	} | ||||
| 	etc.triggered = true | ||||
| 	return etc.err | ||||
| } | ||||
|  | ||||
| func (tx *transaction) Rollback() error { | ||||
| 	e := tx.conn.next() | ||||
| 	if e == nil { | ||||
| 		return errors.New("All expectations were already fulfilled, call to Rollback transaction was not expected") | ||||
| 	} | ||||
|  | ||||
| 	etr, ok := e.(*expectedRollback) | ||||
| 	if !ok { | ||||
| 		return errors.New(fmt.Sprintf("Call to Rollback transaction, was not expected, next expectation was %v", e)) | ||||
| 	} | ||||
| 	etr.triggered = true | ||||
| 	return etr.err | ||||
| } | ||||
		Reference in New Issue
	
	Block a user