1
0
mirror of https://github.com/DATA-DOG/go-sqlmock.git synced 2025-03-19 20:57:50 +02:00

Merge pull request #231 from bonitoo-io/pr-152-again

Add Column Metadata
This commit is contained in:
Gediminas Morkevicius 2020-06-28 18:11:42 +03:00 committed by GitHub
commit b8a63d3edf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 422 additions and 3 deletions

View File

@ -222,6 +222,7 @@ It only asserts that argument is of `time.Time` type.
## Change Log ## Change Log
- **2019-04-06** - added functionality to mock a sql MetaData request
- **2019-02-13** - added `go.mod` removed the references and suggestions using `gopkg.in`. - **2019-02-13** - added `go.mod` removed the references and suggestions using `gopkg.in`.
- **2018-12-11** - added expectation of Rows to be closed, while mocking expected query. - **2018-12-11** - added expectation of Rows to be closed, while mocking expected query.
- **2018-12-11** - introduced an option to provide **QueryMatcher** in order to customize SQL query matching. - **2018-12-11** - introduced an option to provide **QueryMatcher** in order to customize SQL query matching.

77
column.go Normal file
View File

@ -0,0 +1,77 @@
package sqlmock
import "reflect"
// Column is a mocked column Metadata for rows.ColumnTypes()
type Column struct {
name string
dbType string
nullable bool
nullableOk bool
length int64
lengthOk bool
precision int64
scale int64
psOk bool
scanType reflect.Type
}
func (c *Column) Name() string {
return c.name
}
func (c *Column) DbType() string {
return c.dbType
}
func (c *Column) IsNullable() (bool, bool) {
return c.nullable, c.nullableOk
}
func (c *Column) Length() (int64, bool) {
return c.length, c.lengthOk
}
func (c *Column) PrecisionScale() (int64, int64, bool) {
return c.precision, c.scale, c.psOk
}
func (c *Column) ScanType() reflect.Type {
return c.scanType
}
// NewColumn returns a Column with specified name
func NewColumn(name string) *Column {
return &Column{
name: name,
}
}
// Nullable returns the column with nullable metadata set
func (c *Column) Nullable(nullable bool) *Column {
c.nullable = nullable
c.nullableOk = true
return c
}
// OfType returns the column with type metadata set
func (c *Column) OfType(dbType string, sampleValue interface{}) *Column {
c.dbType = dbType
c.scanType = reflect.TypeOf(sampleValue)
return c
}
// WithLength returns the column with length metadata set.
func (c *Column) WithLength(length int64) *Column {
c.length = length
c.lengthOk = true
return c
}
// WithPrecisionAndScale returns the column with precision and scale metadata set.
func (c *Column) WithPrecisionAndScale(precision, scale int64) *Column {
c.precision = precision
c.scale = scale
c.psOk = true
return c
}

63
column_test.go Normal file
View File

@ -0,0 +1,63 @@
package sqlmock
import (
"reflect"
"testing"
"time"
)
func TestColumn(t *testing.T) {
now, _ := time.Parse(time.RFC3339, "2020-06-20T22:08:41Z")
column1 := NewColumn("test").OfType("VARCHAR", "").Nullable(true).WithLength(100)
column2 := NewColumn("number").OfType("DECIMAL", float64(0.0)).Nullable(false).WithPrecisionAndScale(10, 4)
column3 := NewColumn("when").OfType("TIMESTAMP", now)
if column1.ScanType().Kind() != reflect.String {
t.Errorf("string scanType mismatch: %v", column1.ScanType())
}
if column2.ScanType().Kind() != reflect.Float64 {
t.Errorf("float scanType mismatch: %v", column2.ScanType())
}
if column3.ScanType() != reflect.TypeOf(time.Time{}) {
t.Errorf("time scanType mismatch: %v", column3.ScanType())
}
nullable, ok := column1.IsNullable()
if !nullable || !ok {
t.Errorf("'test' column should be nullable")
}
nullable, ok = column2.IsNullable()
if nullable || !ok {
t.Errorf("'number' column should not be nullable")
}
nullable, ok = column3.IsNullable()
if ok {
t.Errorf("'when' column nullability should be unknown")
}
length, ok := column1.Length()
if length != 100 || !ok {
t.Errorf("'test' column wrong length")
}
length, ok = column2.Length()
if ok {
t.Errorf("'number' column is not of variable length type")
}
length, ok = column3.Length()
if ok {
t.Errorf("'when' column is not of variable length type")
}
_, _, ok = column1.PrecisionScale()
if ok {
t.Errorf("'test' column not applicable")
}
precision, scale, ok := column2.PrecisionScale()
if precision != 10 || scale != 4 || !ok {
t.Errorf("'number' column not applicable")
}
_, _, ok = column3.PrecisionScale()
if ok {
t.Errorf("'when' column not applicable")
}
}

View File

@ -12,11 +12,19 @@ import (
// WillReturnRows specifies the set of resulting rows that will be returned // WillReturnRows specifies the set of resulting rows that will be returned
// by the triggered query // by the triggered query
func (e *ExpectedQuery) WillReturnRows(rows ...*Rows) *ExpectedQuery { func (e *ExpectedQuery) WillReturnRows(rows ...*Rows) *ExpectedQuery {
defs := 0
sets := make([]*Rows, len(rows)) sets := make([]*Rows, len(rows))
for i, r := range rows { for i, r := range rows {
sets[i] = r sets[i] = r
if r.def != nil {
defs++
} }
}
if defs > 0 && defs == len(sets) {
e.rows = &rowSetsWithDefinition{&rowSets{sets: sets, ex: e}}
} else {
e.rows = &rowSets{sets: sets, ex: e} e.rows = &rowSets{sets: sets, ex: e}
}
return e return e
} }

View File

@ -120,6 +120,7 @@ func (rs *rowSets) invalidateRaw() {
type Rows struct { type Rows struct {
converter driver.ValueConverter converter driver.ValueConverter
cols []string cols []string
def []*Column
rows [][]driver.Value rows [][]driver.Value
pos int pos int
nextErr map[int]error nextErr map[int]error

View File

@ -2,7 +2,11 @@
package sqlmock package sqlmock
import "io" import (
"database/sql/driver"
"io"
"reflect"
)
// Implement the "RowsNextResultSet" interface // Implement the "RowsNextResultSet" interface
func (rs *rowSets) HasNextResultSet() bool { func (rs *rowSets) HasNextResultSet() bool {
@ -18,3 +22,53 @@ func (rs *rowSets) NextResultSet() error {
rs.pos++ rs.pos++
return nil return nil
} }
// type for rows with columns definition created with sqlmock.NewRowsWithColumnDefinition
type rowSetsWithDefinition struct {
*rowSets
}
// Implement the "RowsColumnTypeDatabaseTypeName" interface
func (rs *rowSetsWithDefinition) ColumnTypeDatabaseTypeName(index int) string {
return rs.getDefinition(index).DbType()
}
// Implement the "RowsColumnTypeLength" interface
func (rs *rowSetsWithDefinition) ColumnTypeLength(index int) (length int64, ok bool) {
return rs.getDefinition(index).Length()
}
// Implement the "RowsColumnTypeNullable" interface
func (rs *rowSetsWithDefinition) ColumnTypeNullable(index int) (nullable, ok bool) {
return rs.getDefinition(index).IsNullable()
}
// Implement the "RowsColumnTypePrecisionScale" interface
func (rs *rowSetsWithDefinition) ColumnTypePrecisionScale(index int) (precision, scale int64, ok bool) {
return rs.getDefinition(index).PrecisionScale()
}
// ColumnTypeScanType is defined from driver.RowsColumnTypeScanType
func (rs *rowSetsWithDefinition) ColumnTypeScanType(index int) reflect.Type {
return rs.getDefinition(index).ScanType()
}
// return column definition from current set metadata
func (rs *rowSetsWithDefinition) getDefinition(index int) *Column {
return rs.sets[rs.pos].def[index]
}
// NewRowsWithColumnDefinition return rows with columns metadata
func NewRowsWithColumnDefinition(columns ...*Column) *Rows {
cols := make([]string, len(columns))
for i, column := range columns {
cols[i] = column.Name()
}
return &Rows{
cols: cols,
def: columns,
nextErr: make(map[int]error),
converter: driver.DefaultParameterConverter,
}
}

View File

@ -6,7 +6,9 @@ import (
"database/sql" "database/sql"
"encoding/json" "encoding/json"
"fmt" "fmt"
"reflect"
"testing" "testing"
"time"
) )
func TestQueryMultiRows(t *testing.T) { func TestQueryMultiRows(t *testing.T) {
@ -203,3 +205,183 @@ func TestQueryRowBytesNotInvalidatedByClose_jsonRawMessageIntoCustomBytes(t *tes
} }
queryRowBytesNotInvalidatedByClose(t, rows, scan, []byte(`{"thing": "one", "thing2": "two"}`)) queryRowBytesNotInvalidatedByClose(t, rows, scan, []byte(`{"thing": "one", "thing2": "two"}`))
} }
func TestNewColumnWithDefinition(t *testing.T) {
now, _ := time.Parse(time.RFC3339, "2020-06-20T22:08:41Z")
t.Run("with one ResultSet", func(t *testing.T) {
db, mock, _ := New()
column1 := mock.NewColumn("test").OfType("VARCHAR", "").Nullable(true).WithLength(100)
column2 := mock.NewColumn("number").OfType("DECIMAL", float64(0.0)).Nullable(false).WithPrecisionAndScale(10, 4)
column3 := mock.NewColumn("when").OfType("TIMESTAMP", now)
rows := mock.NewRowsWithColumnDefinition(column1, column2, column3)
rows.AddRow("foo.bar", float64(10.123), now)
mQuery := mock.ExpectQuery("SELECT test, number, when from dummy")
isQuery := mQuery.WillReturnRows(rows)
isQueryClosed := mQuery.RowsWillBeClosed()
isDbClosed := mock.ExpectClose()
query, _ := db.Query("SELECT test, number, when from dummy")
if false == isQuery.fulfilled() {
t.Error("Query is not executed")
}
if query.Next() {
var test string
var number float64
var when time.Time
if queryError := query.Scan(&test, &number, &when); queryError != nil {
t.Error(queryError)
} else if test != "foo.bar" {
t.Error("field test is not 'foo.bar'")
} else if number != float64(10.123) {
t.Error("field number is not '10.123'")
} else if when != now {
t.Errorf("field when is not %v", now)
}
if columnTypes, colTypErr := query.ColumnTypes(); colTypErr != nil {
t.Error(colTypErr)
} else if len(columnTypes) != 3 {
t.Error("number of columnTypes")
} else if name := columnTypes[0].Name(); name != "test" {
t.Errorf("field 'test' has a wrong name '%s'", name)
} else if dbType := columnTypes[0].DatabaseTypeName(); dbType != "VARCHAR" {
t.Errorf("field 'test' has a wrong db type '%s'", dbType)
} else if columnTypes[0].ScanType().Kind() != reflect.String {
t.Error("field 'test' has a wrong scanType")
} else if _, _, ok := columnTypes[0].DecimalSize(); ok {
t.Error("field 'test' should have not precision, scale")
} else if length, ok := columnTypes[0].Length(); length != 100 || !ok {
t.Errorf("field 'test' has a wrong length '%d'", length)
} else if name := columnTypes[1].Name(); name != "number" {
t.Errorf("field 'number' has a wrong name '%s'", name)
} else if dbType := columnTypes[1].DatabaseTypeName(); dbType != "DECIMAL" {
t.Errorf("field 'number' has a wrong db type '%s'", dbType)
} else if columnTypes[1].ScanType().Kind() != reflect.Float64 {
t.Error("field 'number' has a wrong scanType")
} else if precision, scale, ok := columnTypes[1].DecimalSize(); precision != int64(10) || scale != int64(4) || !ok {
t.Error("field 'number' has a wrong precision, scale")
} else if _, ok := columnTypes[1].Length(); ok {
t.Error("field 'number' is not variable length type")
} else if _, ok := columnTypes[2].Nullable(); ok {
t.Error("field 'when' should have nullability unknown")
}
} else {
t.Error("no result set")
}
query.Close()
if false == isQueryClosed.fulfilled() {
t.Error("Query is not executed")
}
db.Close()
if false == isDbClosed.fulfilled() {
t.Error("Db is not closed")
}
})
t.Run("with more then one ResultSet", func(t *testing.T) {
db, mock, _ := New()
column1 := mock.NewColumn("test").OfType("VARCHAR", "").Nullable(true).WithLength(100)
column2 := mock.NewColumn("number").OfType("DECIMAL", float64(0.0)).Nullable(false).WithPrecisionAndScale(10, 4)
column3 := mock.NewColumn("when").OfType("TIMESTAMP", now)
rows1 := mock.NewRowsWithColumnDefinition(column1, column2, column3)
rows1.AddRow("foo.bar", float64(10.123), now)
rows2 := mock.NewRowsWithColumnDefinition(column1, column2, column3)
rows2.AddRow("bar.foo", float64(123.10), now.Add(time.Second*10))
rows3 := mock.NewRowsWithColumnDefinition(column1, column2, column3)
rows3.AddRow("lollipop", float64(10.321), now.Add(time.Second*20))
mQuery := mock.ExpectQuery("SELECT test, number, when from dummy")
isQuery := mQuery.WillReturnRows(rows1, rows2, rows3)
isQueryClosed := mQuery.RowsWillBeClosed()
isDbClosed := mock.ExpectClose()
query, _ := db.Query("SELECT test, number, when from dummy")
if false == isQuery.fulfilled() {
t.Error("Query is not executed")
}
rowsSi := 0
for query.Next() {
var test string
var number float64
var when time.Time
if queryError := query.Scan(&test, &number, &when); queryError != nil {
t.Error(queryError)
} else if rowsSi == 0 && test != "foo.bar" {
t.Error("field test is not 'foo.bar'")
} else if rowsSi == 0 && number != float64(10.123) {
t.Error("field number is not '10.123'")
} else if rowsSi == 0 && when != now {
t.Errorf("field when is not %v", now)
} else if rowsSi == 1 && test != "bar.foo" {
t.Error("field test is not 'bar.bar'")
} else if rowsSi == 1 && number != float64(123.10) {
t.Error("field number is not '123.10'")
} else if rowsSi == 1 && when != now.Add(time.Second*10) {
t.Errorf("field when is not %v", now)
} else if rowsSi == 2 && test != "lollipop" {
t.Error("field test is not 'lollipop'")
} else if rowsSi == 2 && number != float64(10.321) {
t.Error("field number is not '10.321'")
} else if rowsSi == 2 && when != now.Add(time.Second*20) {
t.Errorf("field when is not %v", now)
}
rowsSi++
if columnTypes, colTypErr := query.ColumnTypes(); colTypErr != nil {
t.Error(colTypErr)
} else if len(columnTypes) != 3 {
t.Error("number of columnTypes")
} else if name := columnTypes[0].Name(); name != "test" {
t.Errorf("field 'test' has a wrong name '%s'", name)
} else if dbType := columnTypes[0].DatabaseTypeName(); dbType != "VARCHAR" {
t.Errorf("field 'test' has a wrong db type '%s'", dbType)
} else if columnTypes[0].ScanType().Kind() != reflect.String {
t.Error("field 'test' has a wrong scanType")
} else if _, _, ok := columnTypes[0].DecimalSize(); ok {
t.Error("field 'test' should not have precision, scale")
} else if length, ok := columnTypes[0].Length(); length != 100 || !ok {
t.Errorf("field 'test' has a wrong length '%d'", length)
} else if name := columnTypes[1].Name(); name != "number" {
t.Errorf("field 'number' has a wrong name '%s'", name)
} else if dbType := columnTypes[1].DatabaseTypeName(); dbType != "DECIMAL" {
t.Errorf("field 'number' has a wrong db type '%s'", dbType)
} else if columnTypes[1].ScanType().Kind() != reflect.Float64 {
t.Error("field 'number' has a wrong scanType")
} else if precision, scale, ok := columnTypes[1].DecimalSize(); precision != int64(10) || scale != int64(4) || !ok {
t.Error("field 'number' has a wrong precision, scale")
} else if _, ok := columnTypes[1].Length(); ok {
t.Error("field 'number' is not variable length type")
} else if _, ok := columnTypes[2].Nullable(); ok {
t.Error("field 'when' should have nullability unknown")
}
}
if rowsSi == 0 {
t.Error("no result set")
}
query.Close()
if false == isQueryClosed.fulfilled() {
t.Error("Query is not executed")
}
db.Close()
if false == isDbClosed.fulfilled() {
t.Error("Db is not closed")
}
})
}

View File

@ -20,7 +20,7 @@ import (
// Sqlmock interface serves to create expectations // Sqlmock interface serves to create expectations
// for any kind of database action in order to mock // for any kind of database action in order to mock
// and test real database behavior. // and test real database behavior.
type Sqlmock interface { type SqlmockCommon interface {
// ExpectClose queues an expectation for this database // ExpectClose queues an expectation for this database
// action to be triggered. the *ExpectedClose allows // action to be triggered. the *ExpectedClose allows
// to mock database response // to mock database response

View File

@ -9,6 +9,12 @@ import (
"time" "time"
) )
// Sqlmock interface for Go up to 1.7
type Sqlmock interface {
// Embed common methods
SqlmockCommon
}
type namedValue struct { type namedValue struct {
Name string Name string
Ordinal int Ordinal int

View File

@ -11,6 +11,19 @@ import (
"time" "time"
) )
// Sqlmock interface for Go 1.8+
type Sqlmock interface {
// Embed common methods
SqlmockCommon
// NewRowsWithColumnDefinition allows Rows to be created from a
// sql driver.Value slice with a definition of sql metadata
NewRowsWithColumnDefinition(columns ...*Column) *Rows
// New Column allows to create a Column
NewColumn(name string) *Column
}
// ErrCancelled defines an error value, which can be expected in case of // ErrCancelled defines an error value, which can be expected in case of
// such cancellation error. // such cancellation error.
var ErrCancelled = errors.New("canceling query due to user request") var ErrCancelled = errors.New("canceling query due to user request")
@ -327,3 +340,17 @@ func (c *sqlmock) exec(query string, args []driver.NamedValue) (*ExpectedExec, e
} }
// @TODO maybe add ExpectedBegin.WithOptions(driver.TxOptions) // @TODO maybe add ExpectedBegin.WithOptions(driver.TxOptions)
// NewRowsWithColumnDefinition allows Rows to be created from a
// sql driver.Value slice with a definition of sql metadata
func (c *sqlmock) NewRowsWithColumnDefinition(columns ...*Column) *Rows {
r := NewRowsWithColumnDefinition(columns...)
r.converter = c.converter
return r
}
// NewColumn allows to create a Column that can be enhanced with metadata
// using OfType/Nullable/WithLength/WithPrecisionAndScale methods.
func (c *sqlmock) NewColumn(name string) *Column {
return NewColumn(name)
}