1
0
mirror of https://github.com/zhashkevych/go-sqlxmock.git synced 2024-11-24 08:12:13 +02:00
go-sqlxmock/rows.go
2019-02-28 10:47:32 +00:00

177 lines
3.8 KiB
Go

package sqlmock
import (
"database/sql/driver"
"encoding/csv"
"fmt"
"io"
"strings"
)
// CSVColumnParser is a function which converts trimmed csv
// column string to a []byte representation. Currently
// transforms NULL to nil
var CSVColumnParser = func(s string) []byte {
switch {
case strings.ToLower(s) == "null":
return nil
}
return []byte(s)
}
type rowSets struct {
sets []*Rows
pos int
ex *ExpectedQuery
}
func (rs *rowSets) Columns() []string {
return rs.sets[rs.pos].cols
}
func (rs *rowSets) Close() error {
rs.ex.rowsWereClosed = true
return rs.sets[rs.pos].closeErr
}
// advances to next row
func (rs *rowSets) Next(dest []driver.Value) error {
r := rs.sets[rs.pos]
r.pos++
if r.pos > len(r.rows) {
return io.EOF // per interface spec
}
for i, col := range r.rows[r.pos-1] {
dest[i] = col
}
return r.nextErr[r.pos-1]
}
// transforms to debuggable printable string
func (rs *rowSets) String() string {
if rs.empty() {
return "with empty rows"
}
msg := "should return rows:\n"
if len(rs.sets) == 1 {
for n, row := range rs.sets[0].rows {
msg += fmt.Sprintf(" row %d - %+v\n", n, row)
}
return strings.TrimSpace(msg)
}
for i, set := range rs.sets {
msg += fmt.Sprintf(" result set: %d\n", i)
for n, row := range set.rows {
msg += fmt.Sprintf(" row %d - %+v\n", n, row)
}
}
return strings.TrimSpace(msg)
}
func (rs *rowSets) empty() bool {
for _, set := range rs.sets {
if len(set.rows) > 0 {
return false
}
}
return true
}
// Rows is a mocked collection of rows to
// return for Query result
type Rows struct {
converter driver.ValueConverter
cols []string
rows [][]driver.Value
pos int
nextErr map[int]error
closeErr error
}
// NewRows allows Rows to be created from a
// sql driver.Value slice or from the CSV string and
// to be used as sql driver.Rows.
// Use Sqlmock.NewRows instead if using a custom converter
func NewRows(columns []string) *Rows {
return &Rows{
cols: columns,
nextErr: make(map[int]error),
converter: driver.DefaultParameterConverter,
}
}
// CloseError allows to set an error
// which will be returned by rows.Close
// function.
//
// The close error will be triggered only in cases
// when rows.Next() EOF was not yet reached, that is
// a default sql library behavior
func (r *Rows) CloseError(err error) *Rows {
r.closeErr = err
return r
}
// RowError allows to set an error
// which will be returned when a given
// row number is read
func (r *Rows) RowError(row int, err error) *Rows {
r.nextErr[row] = err
return r
}
// AddRow composed from database driver.Value slice
// return the same instance to perform subsequent actions.
// Note that the number of values must match the number
// of columns
func (r *Rows) AddRow(values ...driver.Value) *Rows {
if len(values) != len(r.cols) {
panic("Expected number of values to match number of columns")
}
row := make([]driver.Value, len(r.cols))
for i, v := range values {
// Convert user-friendly values (such as int or driver.Valuer)
// to database/sql native value (driver.Value such as int64)
var err error
v, err = r.converter.ConvertValue(v)
if err != nil {
panic(fmt.Errorf(
"row #%d, column #%d (%q) type %T: %s",
len(r.rows)+1, i, r.cols[i], values[i], err,
))
}
row[i] = v
}
r.rows = append(r.rows, row)
return r
}
// FromCSVString build rows from csv string.
// return the same instance to perform subsequent actions.
// Note that the number of values must match the number
// of columns
func (r *Rows) FromCSVString(s string) *Rows {
res := strings.NewReader(strings.TrimSpace(s))
csvReader := csv.NewReader(res)
for {
res, err := csvReader.Read()
if err != nil || res == nil {
break
}
row := make([]driver.Value, len(r.cols))
for i, v := range res {
row[i] = CSVColumnParser(strings.TrimSpace(v))
}
r.rows = append(r.rows, row)
}
return r
}