diff --git a/result.go b/result.go index 645b313..e6113e8 100644 --- a/result.go +++ b/result.go @@ -1,14 +1,14 @@ package sqlmock import ( - "database/sql/driver" + "database/sql/driver" ) // Result satisfies sql driver Result, which // holds last insert id and rows affected // by Exec queries type result struct { - insertID int64 + insertID int64 rowsAffected int64 } diff --git a/rows.go b/rows.go index 3000d6a..705281d 100644 --- a/rows.go +++ b/rows.go @@ -40,17 +40,24 @@ func (r *rows) Next(dest []driver.Value) error { return nil } -func RowFromInterface(columns []string, values ...interface{}) driver.Rows { - rs := &rows{} - rs.cols = columns +func (r *rows) AddRow(values ...interface{}) { + if len(values) != len(r.cols) { + panic("Expected number of values to match number of columns") + } - row := make([]driver.Value, len(columns)) + row := make([]driver.Value, len(r.cols)) for i, v := range values { row[i] = v } - rs.rows = append(rs.rows, row) + r.rows = append(r.rows, row) +} +// NewRows allows Rows to be created manually to use +// any of the types sql/driver.Value supports +func NewRows(columns []string) *rows { + rs := &rows{} + rs.cols = columns return rs } diff --git a/sqlmock_test.go b/sqlmock_test.go index f7d5450..3f798f5 100644 --- a/sqlmock_test.go +++ b/sqlmock_test.go @@ -4,6 +4,7 @@ import ( "database/sql" "fmt" "testing" + "time" ) func TestMockQuery(t *testing.T) { @@ -48,6 +49,57 @@ func TestMockQuery(t *testing.T) { } } +func TestMockQueryTypes(t *testing.T) { + db, err := sql.Open("mock", "") + if err != nil { + t.Errorf("an error '%s' was not expected when opening a stub database connection", err) + } + + columns := []string{"id", "timestamp", "sold"} + + timestamp := time.Now() + rs := NewRows(columns) + rs.AddRow(5, timestamp, true) + + ExpectQuery("SELECT (.+) FROM sales WHERE id = ?"). + WithArgs(5). + WillReturnRows(rs) + + rows, err := db.Query("SELECT (.+) FROM sales WHERE id = ?", 5) + if err != nil { + t.Errorf("error '%s' was not expected while retrieving mock rows", err) + } + defer rows.Close() + if !rows.Next() { + t.Error("it must have had one row as result, but got empty result set instead") + } + + var id int + var time time.Time + var sold bool + + err = rows.Scan(&id, &time, &sold) + if err != nil { + t.Errorf("error '%s' was not expected while trying to scan row", err) + } + + if id != 5 { + t.Errorf("expected mocked id to be 5, but got %d instead", id) + } + + if time != timestamp { + t.Errorf("expected mocked time to be %s, but got '%s' instead", timestamp, time) + } + + if sold != true { + t.Errorf("expected mocked boolean to be true, but got %v instead", sold) + } + + if err = db.Close(); err != nil { + t.Errorf("error '%s' was not expected while closing the database", err) + } +} + func TestTransactionExpectations(t *testing.T) { db, err := sql.Open("mock", "") if err != nil {