package sqlmock

import (
	"database/sql/driver"
	"encoding/csv"
	"io"
	"strings"
)

// CSVColumnParser is a function which converts trimmed csv
// column string to a []byte representation. currently
// transforms NULL to nil
var CSVColumnParser = func(s string) []byte {
	switch {
	case strings.ToLower(s) == "null":
		return nil
	}
	return []byte(s)
}

// Rows interface allows to construct rows
// which also satisfies database/sql/driver.Rows interface
type Rows interface {
	// composed interface, supports sql driver.Rows
	driver.Rows

	// AddRow composed from database driver.Value slice
	// return the same instance to perform subsequent actions.
	// Note that the number of values must match the number
	// of columns
	AddRow(columns ...driver.Value) Rows

	// FromCSVString build rows from csv string.
	// return the same instance to perform subsequent actions.
	// Note that the number of values must match the number
	// of columns
	FromCSVString(s string) Rows

	// RowError allows to set an error
	// which will be returned when a given
	// row number is read
	RowError(row int, err error) Rows

	// CloseError allows to set an error
	// which will be returned by rows.Close
	// function.
	//
	// The close error will be triggered only in cases
	// when rows.Next() EOF was not yet reached, that is
	// a default sql library behavior
	CloseError(err error) Rows
}

type rows struct {
	cols     []string
	rows     [][]driver.Value
	pos      int
	nextErr  map[int]error
	closeErr error
}

func (r *rows) Columns() []string {
	return r.cols
}

func (r *rows) Close() error {
	return r.closeErr
}

// advances to next row
func (r *rows) Next(dest []driver.Value) error {
	r.pos++
	if r.pos > len(r.rows) {
		return io.EOF // per interface spec
	}

	for i, col := range r.rows[r.pos-1] {
		dest[i] = col
	}

	return r.nextErr[r.pos-1]
}

// NewRows allows Rows to be created from a
// sql driver.Value slice or from the CSV string and
// to be used as sql driver.Rows
func NewRows(columns []string) Rows {
	return &rows{cols: columns, nextErr: make(map[int]error)}
}

func (r *rows) CloseError(err error) Rows {
	r.closeErr = err
	return r
}

func (r *rows) RowError(row int, err error) Rows {
	r.nextErr[row] = err
	return r
}

func (r *rows) AddRow(values ...driver.Value) Rows {
	if len(values) != len(r.cols) {
		panic("Expected number of values to match number of columns")
	}

	row := make([]driver.Value, len(r.cols))
	for i, v := range values {
		row[i] = v
	}

	r.rows = append(r.rows, row)
	return r
}

func (r *rows) FromCSVString(s string) Rows {
	res := strings.NewReader(strings.TrimSpace(s))
	csvReader := csv.NewReader(res)

	for {
		res, err := csvReader.Read()
		if err != nil || res == nil {
			break
		}

		row := make([]driver.Value, len(r.cols))
		for i, v := range res {
			row[i] = CSVColumnParser(strings.TrimSpace(v))
		}
		r.rows = append(r.rows, row)
	}
	return r
}