2014-02-05 17:21:07 +03:00
|
|
|
package sqlmock
|
|
|
|
|
|
|
|
import (
|
|
|
|
"database/sql/driver"
|
|
|
|
"encoding/csv"
|
2017-02-16 22:33:12 +02:00
|
|
|
"fmt"
|
2014-02-05 17:21:07 +03:00
|
|
|
"io"
|
|
|
|
"strings"
|
|
|
|
)
|
|
|
|
|
2015-08-05 10:50:16 +02:00
|
|
|
// CSVColumnParser is a function which converts trimmed csv
|
2015-08-05 12:37:58 +02:00
|
|
|
// column string to a []byte representation. currently
|
|
|
|
// transforms NULL to nil
|
2015-08-05 10:50:16 +02:00
|
|
|
var CSVColumnParser = func(s string) []byte {
|
|
|
|
switch {
|
|
|
|
case strings.ToLower(s) == "null":
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
return []byte(s)
|
|
|
|
}
|
|
|
|
|
2017-02-08 17:35:32 +02:00
|
|
|
type rowSets struct {
|
|
|
|
sets []*Rows
|
|
|
|
pos int
|
2014-02-14 01:14:32 +03:00
|
|
|
}
|
|
|
|
|
2017-02-08 17:35:32 +02:00
|
|
|
func (rs *rowSets) Columns() []string {
|
|
|
|
return rs.sets[rs.pos].cols
|
2014-02-05 17:21:07 +03:00
|
|
|
}
|
|
|
|
|
2017-02-08 17:35:32 +02:00
|
|
|
func (rs *rowSets) Close() error {
|
|
|
|
return rs.sets[rs.pos].closeErr
|
2014-02-05 17:21:07 +03:00
|
|
|
}
|
|
|
|
|
2014-02-07 09:58:27 +03:00
|
|
|
// advances to next row
|
2017-02-08 17:35:32 +02:00
|
|
|
func (rs *rowSets) Next(dest []driver.Value) error {
|
|
|
|
r := rs.sets[rs.pos]
|
2014-02-05 17:21:07 +03:00
|
|
|
r.pos++
|
|
|
|
if r.pos > len(r.rows) {
|
|
|
|
return io.EOF // per interface spec
|
|
|
|
}
|
|
|
|
|
|
|
|
for i, col := range r.rows[r.pos-1] {
|
|
|
|
dest[i] = col
|
|
|
|
}
|
|
|
|
|
2015-08-05 12:37:58 +02:00
|
|
|
return r.nextErr[r.pos-1]
|
2014-02-05 17:21:07 +03:00
|
|
|
}
|
|
|
|
|
2017-02-16 22:33:12 +02:00
|
|
|
// transforms to debuggable printable string
|
|
|
|
func (rs *rowSets) String() string {
|
2017-04-26 08:56:02 +02:00
|
|
|
if rs.empty() {
|
|
|
|
return "with empty rows"
|
|
|
|
}
|
|
|
|
|
2017-02-16 22:33:12 +02:00
|
|
|
msg := "should return rows:\n"
|
|
|
|
if len(rs.sets) == 1 {
|
|
|
|
for n, row := range rs.sets[0].rows {
|
|
|
|
msg += fmt.Sprintf(" row %d - %+v\n", n, row)
|
|
|
|
}
|
|
|
|
return strings.TrimSpace(msg)
|
|
|
|
}
|
|
|
|
for i, set := range rs.sets {
|
|
|
|
msg += fmt.Sprintf(" result set: %d\n", i)
|
|
|
|
for n, row := range set.rows {
|
|
|
|
msg += fmt.Sprintf(" row %d - %+v\n", n, row)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return strings.TrimSpace(msg)
|
|
|
|
}
|
|
|
|
|
2017-04-26 08:56:02 +02:00
|
|
|
func (rs *rowSets) empty() bool {
|
|
|
|
for _, set := range rs.sets {
|
|
|
|
if len(set.rows) > 0 {
|
|
|
|
return false
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return true
|
|
|
|
}
|
|
|
|
|
2017-02-08 17:35:32 +02:00
|
|
|
// Rows is a mocked collection of rows to
|
|
|
|
// return for Query result
|
|
|
|
type Rows struct {
|
|
|
|
cols []string
|
|
|
|
rows [][]driver.Value
|
|
|
|
pos int
|
|
|
|
nextErr map[int]error
|
|
|
|
closeErr error
|
|
|
|
}
|
|
|
|
|
2015-07-17 12:14:30 +02:00
|
|
|
// NewRows allows Rows to be created from a
|
|
|
|
// sql driver.Value slice or from the CSV string and
|
2014-02-14 01:14:32 +03:00
|
|
|
// to be used as sql driver.Rows
|
2017-02-08 17:35:32 +02:00
|
|
|
func NewRows(columns []string) *Rows {
|
|
|
|
return &Rows{cols: columns, nextErr: make(map[int]error)}
|
2015-07-22 15:17:35 +02:00
|
|
|
}
|
|
|
|
|
2017-02-08 17:35:32 +02:00
|
|
|
// CloseError allows to set an error
|
|
|
|
// which will be returned by rows.Close
|
|
|
|
// function.
|
|
|
|
//
|
|
|
|
// The close error will be triggered only in cases
|
|
|
|
// when rows.Next() EOF was not yet reached, that is
|
|
|
|
// a default sql library behavior
|
|
|
|
func (r *Rows) CloseError(err error) *Rows {
|
2015-07-22 15:17:35 +02:00
|
|
|
r.closeErr = err
|
|
|
|
return r
|
|
|
|
}
|
|
|
|
|
2017-02-08 17:35:32 +02:00
|
|
|
// RowError allows to set an error
|
|
|
|
// which will be returned when a given
|
|
|
|
// row number is read
|
|
|
|
func (r *Rows) RowError(row int, err error) *Rows {
|
2015-08-05 12:37:58 +02:00
|
|
|
r.nextErr[row] = err
|
2015-07-22 15:17:35 +02:00
|
|
|
return r
|
2014-02-14 01:14:32 +03:00
|
|
|
}
|
|
|
|
|
2017-02-08 17:35:32 +02:00
|
|
|
// AddRow composed from database driver.Value slice
|
|
|
|
// return the same instance to perform subsequent actions.
|
|
|
|
// Note that the number of values must match the number
|
|
|
|
// of columns
|
|
|
|
func (r *Rows) AddRow(values ...driver.Value) *Rows {
|
2014-02-13 22:59:35 +03:00
|
|
|
if len(values) != len(r.cols) {
|
|
|
|
panic("Expected number of values to match number of columns")
|
|
|
|
}
|
2014-02-13 04:02:35 +03:00
|
|
|
|
2014-02-13 22:59:35 +03:00
|
|
|
row := make([]driver.Value, len(r.cols))
|
2014-02-13 04:02:35 +03:00
|
|
|
for i, v := range values {
|
|
|
|
row[i] = v
|
|
|
|
}
|
|
|
|
|
2014-02-13 22:59:35 +03:00
|
|
|
r.rows = append(r.rows, row)
|
2014-02-14 01:14:32 +03:00
|
|
|
return r
|
2014-02-13 22:59:35 +03:00
|
|
|
}
|
2014-02-13 04:02:35 +03:00
|
|
|
|
2017-02-08 17:35:32 +02:00
|
|
|
// FromCSVString build rows from csv string.
|
|
|
|
// return the same instance to perform subsequent actions.
|
|
|
|
// Note that the number of values must match the number
|
|
|
|
// of columns
|
|
|
|
func (r *Rows) FromCSVString(s string) *Rows {
|
2014-02-14 01:14:32 +03:00
|
|
|
res := strings.NewReader(strings.TrimSpace(s))
|
|
|
|
csvReader := csv.NewReader(res)
|
|
|
|
|
|
|
|
for {
|
|
|
|
res, err := csvReader.Read()
|
|
|
|
if err != nil || res == nil {
|
|
|
|
break
|
|
|
|
}
|
|
|
|
|
|
|
|
row := make([]driver.Value, len(r.cols))
|
|
|
|
for i, v := range res {
|
2015-08-05 10:50:16 +02:00
|
|
|
row[i] = CSVColumnParser(strings.TrimSpace(v))
|
2014-02-14 01:14:32 +03:00
|
|
|
}
|
|
|
|
r.rows = append(r.rows, row)
|
|
|
|
}
|
|
|
|
return r
|
2014-02-13 04:02:35 +03:00
|
|
|
}
|