mirror of
https://github.com/zhashkevych/go-sqlxmock.git
synced 2024-11-24 08:12:13 +02:00
213 lines
5.0 KiB
Go
213 lines
5.0 KiB
Go
package sqlmock
|
|
|
|
import (
|
|
"bytes"
|
|
"database/sql/driver"
|
|
"encoding/csv"
|
|
"fmt"
|
|
"io"
|
|
"strings"
|
|
)
|
|
|
|
const invalidate = "☠☠☠ MEMORY OVERWRITTEN ☠☠☠ "
|
|
|
|
// CSVColumnParser is a function which converts trimmed csv
|
|
// column string to a []byte representation. Currently
|
|
// transforms NULL to nil
|
|
var CSVColumnParser = func(s string) []byte {
|
|
switch {
|
|
case strings.ToLower(s) == "null":
|
|
return nil
|
|
}
|
|
return []byte(s)
|
|
}
|
|
|
|
type rowSets struct {
|
|
sets []*Rows
|
|
pos int
|
|
ex *ExpectedQuery
|
|
raw [][]byte
|
|
}
|
|
|
|
func (rs *rowSets) Columns() []string {
|
|
return rs.sets[rs.pos].cols
|
|
}
|
|
|
|
func (rs *rowSets) Close() error {
|
|
rs.invalidateRaw()
|
|
rs.ex.rowsWereClosed = true
|
|
return rs.sets[rs.pos].closeErr
|
|
}
|
|
|
|
// advances to next row
|
|
func (rs *rowSets) Next(dest []driver.Value) error {
|
|
r := rs.sets[rs.pos]
|
|
r.pos++
|
|
rs.invalidateRaw()
|
|
if r.pos > len(r.rows) {
|
|
return io.EOF // per interface spec
|
|
}
|
|
|
|
for i, col := range r.rows[r.pos-1] {
|
|
if b, ok := rawBytes(col); ok {
|
|
rs.raw = append(rs.raw, b)
|
|
dest[i] = b
|
|
continue
|
|
}
|
|
dest[i] = col
|
|
}
|
|
|
|
return r.nextErr[r.pos-1]
|
|
}
|
|
|
|
// transforms to debuggable printable string
|
|
func (rs *rowSets) String() string {
|
|
if rs.empty() {
|
|
return "with empty rows"
|
|
}
|
|
|
|
msg := "should return rows:\n"
|
|
if len(rs.sets) == 1 {
|
|
for n, row := range rs.sets[0].rows {
|
|
msg += fmt.Sprintf(" row %d - %+v\n", n, row)
|
|
}
|
|
return strings.TrimSpace(msg)
|
|
}
|
|
for i, set := range rs.sets {
|
|
msg += fmt.Sprintf(" result set: %d\n", i)
|
|
for n, row := range set.rows {
|
|
msg += fmt.Sprintf(" row %d - %+v\n", n, row)
|
|
}
|
|
}
|
|
return strings.TrimSpace(msg)
|
|
}
|
|
|
|
func (rs *rowSets) empty() bool {
|
|
for _, set := range rs.sets {
|
|
if len(set.rows) > 0 {
|
|
return false
|
|
}
|
|
}
|
|
return true
|
|
}
|
|
|
|
func rawBytes(col driver.Value) (_ []byte, ok bool) {
|
|
val, ok := col.([]byte)
|
|
if !ok || len(val) == 0 {
|
|
return nil, false
|
|
}
|
|
// Copy the bytes from the mocked row into a shared raw buffer, which we'll replace the content of later
|
|
// This allows scanning into sql.RawBytes to correctly become invalid on subsequent calls to Next(), Scan() or Close()
|
|
b := make([]byte, len(val))
|
|
copy(b, val)
|
|
return b, true
|
|
}
|
|
|
|
// Bytes that could have been scanned as sql.RawBytes are only valid until the next call to Next, Scan or Close.
|
|
// If those occur, we must replace their content to simulate the shared memory to expose misuse of sql.RawBytes
|
|
func (rs *rowSets) invalidateRaw() {
|
|
// Replace the content of slices previously returned
|
|
b := []byte(invalidate)
|
|
for _, r := range rs.raw {
|
|
copy(r, bytes.Repeat(b, len(r)/len(b)+1))
|
|
}
|
|
// Start with new slices for the next scan
|
|
rs.raw = nil
|
|
}
|
|
|
|
// Rows is a mocked collection of rows to
|
|
// return for Query result
|
|
type Rows struct {
|
|
converter driver.ValueConverter
|
|
cols []string
|
|
def []*Column
|
|
rows [][]driver.Value
|
|
pos int
|
|
nextErr map[int]error
|
|
closeErr error
|
|
}
|
|
|
|
// NewRows allows Rows to be created from a
|
|
// sql driver.Value slice or from the CSV string and
|
|
// to be used as sql driver.Rows.
|
|
// Use Sqlmock.NewRows instead if using a custom converter
|
|
func NewRows(columns []string) *Rows {
|
|
return &Rows{
|
|
cols: columns,
|
|
nextErr: make(map[int]error),
|
|
converter: driver.DefaultParameterConverter,
|
|
}
|
|
}
|
|
|
|
// CloseError allows to set an error
|
|
// which will be returned by rows.Close
|
|
// function.
|
|
//
|
|
// The close error will be triggered only in cases
|
|
// when rows.Next() EOF was not yet reached, that is
|
|
// a default sql library behavior
|
|
func (r *Rows) CloseError(err error) *Rows {
|
|
r.closeErr = err
|
|
return r
|
|
}
|
|
|
|
// RowError allows to set an error
|
|
// which will be returned when a given
|
|
// row number is read
|
|
func (r *Rows) RowError(row int, err error) *Rows {
|
|
r.nextErr[row] = err
|
|
return r
|
|
}
|
|
|
|
// AddRow composed from database driver.Value slice
|
|
// return the same instance to perform subsequent actions.
|
|
// Note that the number of values must match the number
|
|
// of columns
|
|
func (r *Rows) AddRow(values ...driver.Value) *Rows {
|
|
if len(values) != len(r.cols) {
|
|
panic("Expected number of values to match number of columns")
|
|
}
|
|
|
|
row := make([]driver.Value, len(r.cols))
|
|
for i, v := range values {
|
|
// Convert user-friendly values (such as int or driver.Valuer)
|
|
// to database/sql native value (driver.Value such as int64)
|
|
var err error
|
|
v, err = r.converter.ConvertValue(v)
|
|
if err != nil {
|
|
panic(fmt.Errorf(
|
|
"row #%d, column #%d (%q) type %T: %s",
|
|
len(r.rows)+1, i, r.cols[i], values[i], err,
|
|
))
|
|
}
|
|
|
|
row[i] = v
|
|
}
|
|
|
|
r.rows = append(r.rows, row)
|
|
return r
|
|
}
|
|
|
|
// FromCSVString build rows from csv string.
|
|
// return the same instance to perform subsequent actions.
|
|
// Note that the number of values must match the number
|
|
// of columns
|
|
func (r *Rows) FromCSVString(s string) *Rows {
|
|
res := strings.NewReader(strings.TrimSpace(s))
|
|
csvReader := csv.NewReader(res)
|
|
|
|
for {
|
|
res, err := csvReader.Read()
|
|
if err != nil || res == nil {
|
|
break
|
|
}
|
|
|
|
row := make([]driver.Value, len(r.cols))
|
|
for i, v := range res {
|
|
row[i] = CSVColumnParser(strings.TrimSpace(v))
|
|
}
|
|
r.rows = append(r.rows, row)
|
|
}
|
|
return r
|
|
}
|