// +build go1.8

package sqlmock

import (
	"database/sql"
	"encoding/json"
	"fmt"
	"reflect"
	"testing"
	"time"
)

func TestQueryMultiRows(t *testing.T) {
	t.Parallel()
	db, mock, err := New()
	if err != nil {
		t.Errorf("an error '%s' was not expected when opening a stub database connection", err)
	}
	defer db.Close()

	rs1 := NewRows([]string{"id", "title"}).AddRow(5, "hello world")
	rs2 := NewRows([]string{"name"}).AddRow("gopher").AddRow("john").AddRow("jane").RowError(2, fmt.Errorf("error"))

	mock.ExpectQuery("SELECT (.+) FROM articles WHERE id = \\?;SELECT name FROM users").
		WithArgs(5).
		WillReturnRows(rs1, rs2)

	rows, err := db.Query("SELECT id, title FROM articles WHERE id = ?;SELECT name FROM users", 5)
	if err != nil {
		t.Errorf("error was not expected, but got: %v", err)
	}
	defer rows.Close()

	if !rows.Next() {
		t.Error("expected a row to be available in first result set")
	}

	var id int
	var name string

	err = rows.Scan(&id, &name)
	if err != nil {
		t.Errorf("error was not expected, but got: %v", err)
	}

	if id != 5 || name != "hello world" {
		t.Errorf("unexpected row values id: %v name: %v", id, name)
	}

	if rows.Next() {
		t.Error("was not expecting next row in first result set")
	}

	if !rows.NextResultSet() {
		t.Error("had to have next result set")
	}

	if !rows.Next() {
		t.Error("expected a row to be available in second result set")
	}

	err = rows.Scan(&name)
	if err != nil {
		t.Errorf("error was not expected, but got: %v", err)
	}

	if name != "gopher" {
		t.Errorf("unexpected row name: %v", name)
	}

	if !rows.Next() {
		t.Error("expected a row to be available in second result set")
	}

	err = rows.Scan(&name)
	if err != nil {
		t.Errorf("error was not expected, but got: %v", err)
	}

	if name != "john" {
		t.Errorf("unexpected row name: %v", name)
	}

	if rows.Next() {
		t.Error("expected next row to produce error")
	}

	if rows.Err() == nil {
		t.Error("expected an error, but there was none")
	}

	if err := mock.ExpectationsWereMet(); err != nil {
		t.Errorf("there were unfulfilled expectations: %s", err)
	}
}

func TestQueryRowBytesInvalidatedByNext_jsonRawMessageIntoRawBytes(t *testing.T) {
	t.Parallel()
	replace := []byte(invalid)
	rows := NewRows([]string{"raw"}).
		AddRow(json.RawMessage(`{"thing": "one", "thing2": "two"}`)).
		AddRow(json.RawMessage(`{"that": "foo", "this": "bar"}`))
	scan := func(rs *sql.Rows) ([]byte, error) {
		var raw sql.RawBytes
		return raw, rs.Scan(&raw)
	}
	want := []struct {
		Initial  []byte
		Replaced []byte
	}{
		{Initial: []byte(`{"thing": "one", "thing2": "two"}`), Replaced: replace[:len(replace)-6]},
		{Initial: []byte(`{"that": "foo", "this": "bar"}`), Replaced: replace[:len(replace)-9]},
	}
	queryRowBytesInvalidatedByNext(t, rows, scan, want)
}

func TestQueryRowBytesNotInvalidatedByNext_jsonRawMessageIntoBytes(t *testing.T) {
	t.Parallel()
	rows := NewRows([]string{"raw"}).
		AddRow(json.RawMessage(`{"thing": "one", "thing2": "two"}`)).
		AddRow(json.RawMessage(`{"that": "foo", "this": "bar"}`))
	scan := func(rs *sql.Rows) ([]byte, error) {
		var b []byte
		return b, rs.Scan(&b)
	}
	want := [][]byte{[]byte(`{"thing": "one", "thing2": "two"}`), []byte(`{"that": "foo", "this": "bar"}`)}
	queryRowBytesNotInvalidatedByNext(t, rows, scan, want)
}

func TestQueryRowBytesNotInvalidatedByNext_bytesIntoCustomBytes(t *testing.T) {
	t.Parallel()
	rows := NewRows([]string{"raw"}).
		AddRow([]byte(`one binary value with some text!`)).
		AddRow([]byte(`two binary value with even more text than the first one`))
	scan := func(rs *sql.Rows) ([]byte, error) {
		type customBytes []byte
		var b customBytes
		return b, rs.Scan(&b)
	}
	want := [][]byte{[]byte(`one binary value with some text!`), []byte(`two binary value with even more text than the first one`)}
	queryRowBytesNotInvalidatedByNext(t, rows, scan, want)
}

func TestQueryRowBytesNotInvalidatedByNext_jsonRawMessageIntoCustomBytes(t *testing.T) {
	t.Parallel()
	rows := NewRows([]string{"raw"}).
		AddRow(json.RawMessage(`{"thing": "one", "thing2": "two"}`)).
		AddRow(json.RawMessage(`{"that": "foo", "this": "bar"}`))
	scan := func(rs *sql.Rows) ([]byte, error) {
		type customBytes []byte
		var b customBytes
		return b, rs.Scan(&b)
	}
	want := [][]byte{[]byte(`{"thing": "one", "thing2": "two"}`), []byte(`{"that": "foo", "this": "bar"}`)}
	queryRowBytesNotInvalidatedByNext(t, rows, scan, want)
}

func TestQueryRowBytesNotInvalidatedByClose_bytesIntoCustomBytes(t *testing.T) {
	t.Parallel()
	rows := NewRows([]string{"raw"}).AddRow([]byte(`one binary value with some text!`))
	scan := func(rs *sql.Rows) ([]byte, error) {
		type customBytes []byte
		var b customBytes
		return b, rs.Scan(&b)
	}
	queryRowBytesNotInvalidatedByClose(t, rows, scan, []byte(`one binary value with some text!`))
}

func TestQueryRowBytesInvalidatedByClose_jsonRawMessageIntoRawBytes(t *testing.T) {
	t.Parallel()
	replace := []byte(invalid)
	rows := NewRows([]string{"raw"}).AddRow(json.RawMessage(`{"thing": "one", "thing2": "two"}`))
	scan := func(rs *sql.Rows) ([]byte, error) {
		var raw sql.RawBytes
		return raw, rs.Scan(&raw)
	}
	want := struct {
		Initial  []byte
		Replaced []byte
	}{
		Initial:  []byte(`{"thing": "one", "thing2": "two"}`),
		Replaced: replace[:len(replace)-6],
	}
	queryRowBytesInvalidatedByClose(t, rows, scan, want)
}

func TestQueryRowBytesNotInvalidatedByClose_jsonRawMessageIntoBytes(t *testing.T) {
	t.Parallel()
	rows := NewRows([]string{"raw"}).AddRow(json.RawMessage(`{"thing": "one", "thing2": "two"}`))
	scan := func(rs *sql.Rows) ([]byte, error) {
		var b []byte
		return b, rs.Scan(&b)
	}
	queryRowBytesNotInvalidatedByClose(t, rows, scan, []byte(`{"thing": "one", "thing2": "two"}`))
}

func TestQueryRowBytesNotInvalidatedByClose_jsonRawMessageIntoCustomBytes(t *testing.T) {
	t.Parallel()
	rows := NewRows([]string{"raw"}).AddRow(json.RawMessage(`{"thing": "one", "thing2": "two"}`))
	scan := func(rs *sql.Rows) ([]byte, error) {
		type customBytes []byte
		var b customBytes
		return b, rs.Scan(&b)
	}
	queryRowBytesNotInvalidatedByClose(t, rows, scan, []byte(`{"thing": "one", "thing2": "two"}`))
}

func TestNewColumnWithDefinition(t *testing.T) {
	now, _ := time.Parse(time.RFC3339, "2020-06-20T22:08:41Z")

	t.Run("with one ResultSet", func(t *testing.T) {
		db, mock, _ := New()
		column1 := mock.NewColumn("test").OfType("VARCHAR", "").Nullable(true).WithLength(100)
		column2 := mock.NewColumn("number").OfType("DECIMAL", float64(0.0)).Nullable(false).WithPrecisionAndScale(10, 4)
		column3 := mock.NewColumn("when").OfType("TIMESTAMP", now)
		rows := mock.NewRowsWithColumnDefinition(column1, column2, column3)
		rows.AddRow("foo.bar", float64(10.123), now)

		mQuery := mock.ExpectQuery("SELECT test, number, when from dummy")
		isQuery := mQuery.WillReturnRows(rows)
		isQueryClosed := mQuery.RowsWillBeClosed()
		isDbClosed := mock.ExpectClose()

		query, _ := db.Query("SELECT test, number, when from dummy")

		if false == isQuery.fulfilled() {
			t.Error("Query is not executed")
		}

		if query.Next() {
			var test string
			var number float64
			var when time.Time

			if queryError := query.Scan(&test, &number, &when); queryError != nil {
				t.Error(queryError)
			} else if test != "foo.bar" {
				t.Error("field test is not 'foo.bar'")
			} else if number != float64(10.123) {
				t.Error("field number is not '10.123'")
			} else if when != now {
				t.Errorf("field when is not %v", now)
			}

			if columnTypes, colTypErr := query.ColumnTypes(); colTypErr != nil {
				t.Error(colTypErr)
			} else if len(columnTypes) != 3 {
				t.Error("number of columnTypes")
			} else if name := columnTypes[0].Name(); name != "test" {
				t.Errorf("field 'test' has a wrong name '%s'", name)
			} else if dbType := columnTypes[0].DatabaseTypeName(); dbType != "VARCHAR" {
				t.Errorf("field 'test' has a wrong db type '%s'", dbType)
			} else if columnTypes[0].ScanType().Kind() != reflect.String {
				t.Error("field 'test' has a wrong scanType")
			} else if _, _, ok := columnTypes[0].DecimalSize(); ok {
				t.Error("field 'test' should have not precision, scale")
			} else if length, ok := columnTypes[0].Length(); length != 100 || !ok {
				t.Errorf("field 'test' has a wrong length '%d'", length)
			} else if name := columnTypes[1].Name(); name != "number" {
				t.Errorf("field 'number' has a wrong name '%s'", name)
			} else if dbType := columnTypes[1].DatabaseTypeName(); dbType != "DECIMAL" {
				t.Errorf("field 'number' has a wrong db type '%s'", dbType)
			} else if columnTypes[1].ScanType().Kind() != reflect.Float64 {
				t.Error("field 'number' has a wrong scanType")
			} else if precision, scale, ok := columnTypes[1].DecimalSize(); precision != int64(10) || scale != int64(4) || !ok {
				t.Error("field 'number' has a wrong precision, scale")
			} else if _, ok := columnTypes[1].Length(); ok {
				t.Error("field 'number' is not variable length type")
			} else if _, ok := columnTypes[2].Nullable(); ok {
				t.Error("field 'when' should have nullability unknown")
			}
		} else {
			t.Error("no result set")
		}

		query.Close()
		if false == isQueryClosed.fulfilled() {
			t.Error("Query is not executed")
		}

		db.Close()
		if false == isDbClosed.fulfilled() {
			t.Error("Db is not closed")
		}
	})

	t.Run("with more then one ResultSet", func(t *testing.T) {
		db, mock, _ := New()
		column1 := mock.NewColumn("test").OfType("VARCHAR", "").Nullable(true).WithLength(100)
		column2 := mock.NewColumn("number").OfType("DECIMAL", float64(0.0)).Nullable(false).WithPrecisionAndScale(10, 4)
		column3 := mock.NewColumn("when").OfType("TIMESTAMP", now)
		rows1 := mock.NewRowsWithColumnDefinition(column1, column2, column3)
		rows1.AddRow("foo.bar", float64(10.123), now)
		rows2 := mock.NewRowsWithColumnDefinition(column1, column2, column3)
		rows2.AddRow("bar.foo", float64(123.10), now.Add(time.Second*10))
		rows3 := mock.NewRowsWithColumnDefinition(column1, column2, column3)
		rows3.AddRow("lollipop", float64(10.321), now.Add(time.Second*20))

		mQuery := mock.ExpectQuery("SELECT test, number, when from dummy")
		isQuery := mQuery.WillReturnRows(rows1, rows2, rows3)
		isQueryClosed := mQuery.RowsWillBeClosed()
		isDbClosed := mock.ExpectClose()

		query, _ := db.Query("SELECT test, number, when from dummy")

		if false == isQuery.fulfilled() {
			t.Error("Query is not executed")
		}

		rowsSi := 0

		for query.Next() {
			var test string
			var number float64
			var when time.Time

			if queryError := query.Scan(&test, &number, &when); queryError != nil {
				t.Error(queryError)

			} else if rowsSi == 0 && test != "foo.bar" {
				t.Error("field test is not 'foo.bar'")
			} else if rowsSi == 0 && number != float64(10.123) {
				t.Error("field number is not '10.123'")
			} else if rowsSi == 0 && when != now {
				t.Errorf("field when is not %v", now)

			} else if rowsSi == 1 && test != "bar.foo" {
				t.Error("field test is not 'bar.bar'")
			} else if rowsSi == 1 && number != float64(123.10) {
				t.Error("field number is not '123.10'")
			} else if rowsSi == 1 && when != now.Add(time.Second*10) {
				t.Errorf("field when is not %v", now)

			} else if rowsSi == 2 && test != "lollipop" {
				t.Error("field test is not 'lollipop'")
			} else if rowsSi == 2 && number != float64(10.321) {
				t.Error("field number is not '10.321'")
			} else if rowsSi == 2 && when != now.Add(time.Second*20) {
				t.Errorf("field when is not %v", now)
			}

			rowsSi++

			if columnTypes, colTypErr := query.ColumnTypes(); colTypErr != nil {
				t.Error(colTypErr)
			} else if len(columnTypes) != 3 {
				t.Error("number of columnTypes")
			} else if name := columnTypes[0].Name(); name != "test" {
				t.Errorf("field 'test' has a wrong name '%s'", name)
			} else if dbType := columnTypes[0].DatabaseTypeName(); dbType != "VARCHAR" {
				t.Errorf("field 'test' has a wrong db type '%s'", dbType)
			} else if columnTypes[0].ScanType().Kind() != reflect.String {
				t.Error("field 'test' has a wrong scanType")
			} else if _, _, ok := columnTypes[0].DecimalSize(); ok {
				t.Error("field 'test' should not have precision, scale")
			} else if length, ok := columnTypes[0].Length(); length != 100 || !ok {
				t.Errorf("field 'test' has a wrong length '%d'", length)
			} else if name := columnTypes[1].Name(); name != "number" {
				t.Errorf("field 'number' has a wrong name '%s'", name)
			} else if dbType := columnTypes[1].DatabaseTypeName(); dbType != "DECIMAL" {
				t.Errorf("field 'number' has a wrong db type '%s'", dbType)
			} else if columnTypes[1].ScanType().Kind() != reflect.Float64 {
				t.Error("field 'number' has a wrong scanType")
			} else if precision, scale, ok := columnTypes[1].DecimalSize(); precision != int64(10) || scale != int64(4) || !ok {
				t.Error("field 'number' has a wrong precision, scale")
			} else if _, ok := columnTypes[1].Length(); ok {
				t.Error("field 'number' is not variable length type")
			} else if _, ok := columnTypes[2].Nullable(); ok {
				t.Error("field 'when' should have nullability unknown")
			}
		}
		if rowsSi == 0 {
			t.Error("no result set")
		}

		query.Close()
		if false == isQueryClosed.fulfilled() {
			t.Error("Query is not executed")
		}

		db.Close()
		if false == isDbClosed.fulfilled() {
			t.Error("Db is not closed")
		}
	})
}