From 128bf5c539d9ca228e5ae3f2d02770c8ffb6dd91 Mon Sep 17 00:00:00 2001 From: gedi Date: Wed, 8 Feb 2017 17:35:32 +0200 Subject: [PATCH] implements next rows result set support --- expectations.go | 15 +-- ...ore_go18.go => expectations_before_go18.go | 7 ++ arg_matcher_go18.go => expectations_go18.go | 11 +++ rows.go | 93 +++++++++---------- rows_go18.go | 20 ++++ rows_go18_test.go | 92 ++++++++++++++++++ 6 files changed, 178 insertions(+), 60 deletions(-) rename arg_matcher_before_go18.go => expectations_before_go18.go (84%) rename arg_matcher_go18.go => expectations_go18.go (83%) create mode 100644 rows_go18.go create mode 100644 rows_go18_test.go diff --git a/expectations.go b/expectations.go index b19fbc9..adc726e 100644 --- a/expectations.go +++ b/expectations.go @@ -144,13 +144,6 @@ func (e *ExpectedQuery) WillReturnError(err error) *ExpectedQuery { return e } -// WillReturnRows specifies the set of resulting rows that will be returned -// by the triggered query -func (e *ExpectedQuery) WillReturnRows(rows driver.Rows) *ExpectedQuery { - e.rows = rows - return e -} - // WillDelayFor allows to specify duration for which it will delay // result. May be used together with Context func (e *ExpectedQuery) WillDelayFor(duration time.Duration) *ExpectedQuery { @@ -175,9 +168,11 @@ func (e *ExpectedQuery) String() string { if e.rows != nil { msg += "\n - should return rows:\n" - rs, _ := e.rows.(*rows) - for i, row := range rs.rows { - msg += fmt.Sprintf(" %d - %+v\n", i, row) + rs, _ := e.rows.(*rowSets) + for _, set := range rs.sets { + for i, row := range set.rows { + msg += fmt.Sprintf(" %d - %+v\n", i, row) + } } msg = strings.TrimSpace(msg) } diff --git a/arg_matcher_before_go18.go b/expectations_before_go18.go similarity index 84% rename from arg_matcher_before_go18.go rename to expectations_before_go18.go index 52eb369..146f240 100644 --- a/arg_matcher_before_go18.go +++ b/expectations_before_go18.go @@ -8,6 +8,13 @@ import ( "reflect" ) +// WillReturnRows specifies the set of resulting rows that will be returned +// by the triggered query +func (e *ExpectedQuery) WillReturnRows(rows *Rows) *ExpectedQuery { + e.rows = &rowSets{sets: []*Rows{rows}} + return e +} + func (e *queryBasedExpectation) argsMatches(args []namedValue) error { if nil == e.args { return nil diff --git a/arg_matcher_go18.go b/expectations_go18.go similarity index 83% rename from arg_matcher_go18.go rename to expectations_go18.go index 610eac3..29eeb30 100644 --- a/arg_matcher_go18.go +++ b/expectations_go18.go @@ -9,6 +9,17 @@ import ( "reflect" ) +// WillReturnRows specifies the set of resulting rows that will be returned +// by the triggered query +func (e *ExpectedQuery) WillReturnRows(rows ...*Rows) *ExpectedQuery { + sets := make([]*Rows, len(rows)) + for i, r := range rows { + sets[i] = r + } + e.rows = &rowSets{sets: sets} + return e +} + func (e *queryBasedExpectation) argsMatches(args []namedValue) error { if nil == e.args { return nil diff --git a/rows.go b/rows.go index 8b6beb6..43681d4 100644 --- a/rows.go +++ b/rows.go @@ -18,57 +18,22 @@ var CSVColumnParser = func(s string) []byte { return []byte(s) } -// Rows interface allows to construct rows -// which also satisfies database/sql/driver.Rows interface -type Rows interface { - // composed interface, supports sql driver.Rows - driver.Rows - - // AddRow composed from database driver.Value slice - // return the same instance to perform subsequent actions. - // Note that the number of values must match the number - // of columns - AddRow(columns ...driver.Value) Rows - - // FromCSVString build rows from csv string. - // return the same instance to perform subsequent actions. - // Note that the number of values must match the number - // of columns - FromCSVString(s string) Rows - - // RowError allows to set an error - // which will be returned when a given - // row number is read - RowError(row int, err error) Rows - - // CloseError allows to set an error - // which will be returned by rows.Close - // function. - // - // The close error will be triggered only in cases - // when rows.Next() EOF was not yet reached, that is - // a default sql library behavior - CloseError(err error) Rows +type rowSets struct { + sets []*Rows + pos int } -type rows struct { - cols []string - rows [][]driver.Value - pos int - nextErr map[int]error - closeErr error +func (rs *rowSets) Columns() []string { + return rs.sets[rs.pos].cols } -func (r *rows) Columns() []string { - return r.cols -} - -func (r *rows) Close() error { - return r.closeErr +func (rs *rowSets) Close() error { + return rs.sets[rs.pos].closeErr } // advances to next row -func (r *rows) Next(dest []driver.Value) error { +func (rs *rowSets) Next(dest []driver.Value) error { + r := rs.sets[rs.pos] r.pos++ if r.pos > len(r.rows) { return io.EOF // per interface spec @@ -81,24 +46,48 @@ func (r *rows) Next(dest []driver.Value) error { return r.nextErr[r.pos-1] } +// Rows is a mocked collection of rows to +// return for Query result +type Rows struct { + cols []string + rows [][]driver.Value + pos int + nextErr map[int]error + closeErr error +} + // NewRows allows Rows to be created from a // sql driver.Value slice or from the CSV string and // to be used as sql driver.Rows -func NewRows(columns []string) Rows { - return &rows{cols: columns, nextErr: make(map[int]error)} +func NewRows(columns []string) *Rows { + return &Rows{cols: columns, nextErr: make(map[int]error)} } -func (r *rows) CloseError(err error) Rows { +// CloseError allows to set an error +// which will be returned by rows.Close +// function. +// +// The close error will be triggered only in cases +// when rows.Next() EOF was not yet reached, that is +// a default sql library behavior +func (r *Rows) CloseError(err error) *Rows { r.closeErr = err return r } -func (r *rows) RowError(row int, err error) Rows { +// RowError allows to set an error +// which will be returned when a given +// row number is read +func (r *Rows) RowError(row int, err error) *Rows { r.nextErr[row] = err return r } -func (r *rows) AddRow(values ...driver.Value) Rows { +// AddRow composed from database driver.Value slice +// return the same instance to perform subsequent actions. +// Note that the number of values must match the number +// of columns +func (r *Rows) AddRow(values ...driver.Value) *Rows { if len(values) != len(r.cols) { panic("Expected number of values to match number of columns") } @@ -112,7 +101,11 @@ func (r *rows) AddRow(values ...driver.Value) Rows { return r } -func (r *rows) FromCSVString(s string) Rows { +// FromCSVString build rows from csv string. +// return the same instance to perform subsequent actions. +// Note that the number of values must match the number +// of columns +func (r *Rows) FromCSVString(s string) *Rows { res := strings.NewReader(strings.TrimSpace(s)) csvReader := csv.NewReader(res) diff --git a/rows_go18.go b/rows_go18.go new file mode 100644 index 0000000..4ecf84e --- /dev/null +++ b/rows_go18.go @@ -0,0 +1,20 @@ +// +build go1.8 + +package sqlmock + +import "io" + +// Implement the "RowsNextResultSet" interface +func (rs *rowSets) HasNextResultSet() bool { + return rs.pos+1 < len(rs.sets) +} + +// Implement the "RowsNextResultSet" interface +func (rs *rowSets) NextResultSet() error { + if !rs.HasNextResultSet() { + return io.EOF + } + + rs.pos++ + return nil +} diff --git a/rows_go18_test.go b/rows_go18_test.go new file mode 100644 index 0000000..297e7c0 --- /dev/null +++ b/rows_go18_test.go @@ -0,0 +1,92 @@ +// +build go1.8 + +package sqlmock + +import ( + "fmt" + "testing" +) + +func TestQueryMultiRows(t *testing.T) { + t.Parallel() + db, mock, err := New() + if err != nil { + t.Errorf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + + rs1 := NewRows([]string{"id", "title"}).AddRow(5, "hello world") + rs2 := NewRows([]string{"name"}).AddRow("gopher").AddRow("john").AddRow("jane").RowError(2, fmt.Errorf("error")) + + mock.ExpectQuery("SELECT (.+) FROM articles WHERE id = \\?;SELECT name FROM users"). + WithArgs(5). + WillReturnRows(rs1, rs2) + + rows, err := db.Query("SELECT id, title FROM articles WHERE id = ?;SELECT name FROM users", 5) + if err != nil { + t.Errorf("error was not expected, but got: %v", err) + } + defer rows.Close() + + if !rows.Next() { + t.Error("expected a row to be available in first result set") + } + + var id int + var name string + + err = rows.Scan(&id, &name) + if err != nil { + t.Errorf("error was not expected, but got: %v", err) + } + + if id != 5 || name != "hello world" { + t.Errorf("unexpected row values id: %v name: %v", id, name) + } + + if rows.Next() { + t.Error("was not expecting next row in first result set") + } + + if !rows.NextResultSet() { + t.Error("had to have next result set") + } + + if !rows.Next() { + t.Error("expected a row to be available in second result set") + } + + err = rows.Scan(&name) + if err != nil { + t.Errorf("error was not expected, but got: %v", err) + } + + if name != "gopher" { + t.Errorf("unexpected row name: %v", name) + } + + if !rows.Next() { + t.Error("expected a row to be available in second result set") + } + + err = rows.Scan(&name) + if err != nil { + t.Errorf("error was not expected, but got: %v", err) + } + + if name != "john" { + t.Errorf("unexpected row name: %v", name) + } + + if rows.Next() { + t.Error("expected next row to produce error") + } + + if rows.Err() == nil { + t.Error("expected an error, but there was none") + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expections: %s", err) + } +}