diff --git a/.github/FUNDING.yml b/.github/FUNDING.yml new file mode 100644 index 0000000..e73ed94 --- /dev/null +++ b/.github/FUNDING.yml @@ -0,0 +1 @@ +github: [l3pp4rd, gliptak, dolmen, IvoGoman] diff --git a/driver.go b/driver.go index 802f8fb..0e98358 100644 --- a/driver.go +++ b/driver.go @@ -40,7 +40,7 @@ func (d *mockDriver) Open(dsn string) (driver.Conn, error) { // a specific driver. // Pings db so that all expectations could be // asserted. -func New(options ...func(*sqlmock) error) (*sql.DB, Sqlmock, error) { +func New(options ...SqlMockOption) (*sql.DB, Sqlmock, error) { pool.Lock() dsn := fmt.Sprintf("sqlmock_db_%d", pool.counter) pool.counter++ @@ -67,7 +67,7 @@ func New(options ...func(*sqlmock) error) (*sql.DB, Sqlmock, error) { // // It is not recommended to use this method, unless you // really need it and there is no other way around. -func NewWithDSN(dsn string, options ...func(*sqlmock) error) (*sql.DB, Sqlmock, error) { +func NewWithDSN(dsn string, options ...SqlMockOption) (*sql.DB, Sqlmock, error) { pool.Lock() if _, ok := pool.conns[dsn]; ok { pool.Unlock() diff --git a/expectations.go b/expectations.go index 5adf608..8a6cd44 100644 --- a/expectations.go +++ b/expectations.go @@ -134,11 +134,27 @@ type ExpectedQuery struct { // WithArgs will match given expected args to actual database query arguments. // if at least one argument does not match, it will return an error. For specific // arguments an sqlmock.Argument interface can be used to match an argument. +// Must not be used together with WithoutArgs() func (e *ExpectedQuery) WithArgs(args ...driver.Value) *ExpectedQuery { + if e.noArgs { + panic("WithArgs() and WithoutArgs() must not be used together") + } e.args = args return e } +// WithoutArgs will ensure that no arguments are passed for this query. +// if at least one argument is passed, it will return an error. This allows +// for stricter validation of the query arguments. +// Must no be used together with WithArgs() +func (e *ExpectedQuery) WithoutArgs() *ExpectedQuery { + if len(e.args) > 0 { + panic("WithoutArgs() and WithArgs() must not be used together") + } + e.noArgs = true + return e +} + // RowsWillBeClosed expects this query rows to be closed. func (e *ExpectedQuery) RowsWillBeClosed() *ExpectedQuery { e.rowsMustBeClosed = true @@ -195,11 +211,27 @@ type ExpectedExec struct { // WithArgs will match given expected args to actual database exec operation arguments. // if at least one argument does not match, it will return an error. For specific // arguments an sqlmock.Argument interface can be used to match an argument. +// Must not be used together with WithoutArgs() func (e *ExpectedExec) WithArgs(args ...driver.Value) *ExpectedExec { + if len(e.args) > 0 { + panic("WithArgs() and WithoutArgs() must not be used together") + } e.args = args return e } +// WithoutArgs will ensure that no args are passed for this expected database exec action. +// if at least one argument is passed, it will return an error. This allows for stricter +// validation of the query arguments. +// Must not be used together with WithArgs() +func (e *ExpectedExec) WithoutArgs() *ExpectedExec { + if len(e.args) > 0 { + panic("WithoutArgs() and WithArgs() must not be used together") + } + e.noArgs = true + return e +} + // WillReturnError allows to set an error for expected database exec action func (e *ExpectedExec) WillReturnError(err error) *ExpectedExec { e.err = err @@ -338,6 +370,7 @@ type queryBasedExpectation struct { expectSQL string converter driver.ValueConverter args []driver.Value + noArgs bool // ensure no args are passed } // ExpectedPing is used to manage *sql.DB.Ping expectations. diff --git a/expectations_before_go18.go b/expectations_before_go18.go index 0831863..67c08dc 100644 --- a/expectations_before_go18.go +++ b/expectations_before_go18.go @@ -1,3 +1,4 @@ +//go:build !go1.8 // +build !go1.8 package sqlmock @@ -17,7 +18,7 @@ func (e *ExpectedQuery) WillReturnRows(rows *Rows) *ExpectedQuery { func (e *queryBasedExpectation) argsMatches(args []namedValue) error { if nil == e.args { - if len(args) > 0 { + if e.noArgs && len(args) > 0 { return fmt.Errorf("expected 0, but got %d arguments", len(args)) } return nil diff --git a/expectations_before_go18_test.go b/expectations_before_go18_test.go index 81dc8cf..4234cd6 100644 --- a/expectations_before_go18_test.go +++ b/expectations_before_go18_test.go @@ -1,3 +1,4 @@ +//go:build !go1.8 // +build !go1.8 package sqlmock @@ -9,10 +10,15 @@ import ( ) func TestQueryExpectationArgComparison(t *testing.T) { - e := &queryBasedExpectation{converter: driver.DefaultParameterConverter} + e := &queryBasedExpectation{converter: driver.DefaultParameterConverter, noArgs: true} against := []namedValue{{Value: int64(5), Ordinal: 1}} if err := e.argsMatches(against); err == nil { - t.Error("arguments should not match, since no expectation was set, but argument was passed") + t.Error("arguments should not match, since argument was passed, but noArgs was set") + } + + e.noArgs = false + if err := e.argsMatches(against); err != nil { + t.Error("arguments should match, since argument was passed, but no expected args or noArgs was set") } e.args = []driver.Value{5, "str"} diff --git a/expectations_go18.go b/expectations_go18.go index ccdc2e1..5fade37 100644 --- a/expectations_go18.go +++ b/expectations_go18.go @@ -1,3 +1,4 @@ +//go:build go1.8 // +build go1.8 package sqlmock @@ -30,7 +31,7 @@ func (e *ExpectedQuery) WillReturnRows(rows ...*Rows) *ExpectedQuery { func (e *queryBasedExpectation) argsMatches(args []driver.NamedValue) error { if nil == e.args { - if len(args) > 0 { + if e.noArgs && len(args) > 0 { return fmt.Errorf("expected 0, but got %d arguments", len(args)) } return nil diff --git a/expectations_go18_test.go b/expectations_go18_test.go index d5638bc..cd633b7 100644 --- a/expectations_go18_test.go +++ b/expectations_go18_test.go @@ -1,3 +1,4 @@ +//go:build go1.8 // +build go1.8 package sqlmock @@ -10,10 +11,15 @@ import ( ) func TestQueryExpectationArgComparison(t *testing.T) { - e := &queryBasedExpectation{converter: driver.DefaultParameterConverter} + e := &queryBasedExpectation{converter: driver.DefaultParameterConverter, noArgs: true} against := []driver.NamedValue{{Value: int64(5), Ordinal: 1}} if err := e.argsMatches(against); err == nil { - t.Error("arguments should not match, since no expectation was set, but argument was passed") + t.Error("arguments should not match, since argument was passed, but noArgs was set") + } + + e.noArgs = false + if err := e.argsMatches(against); err != nil { + t.Error("arguments should match, since argument was passed, but no expected args or noArgs was set") } e.args = []driver.Value{5, "str"} @@ -102,10 +108,15 @@ func TestQueryExpectationArgComparisonBool(t *testing.T) { } func TestQueryExpectationNamedArgComparison(t *testing.T) { - e := &queryBasedExpectation{converter: driver.DefaultParameterConverter} + e := &queryBasedExpectation{converter: driver.DefaultParameterConverter, noArgs: true} against := []driver.NamedValue{{Value: int64(5), Name: "id"}} if err := e.argsMatches(against); err == nil { - t.Errorf("arguments should not match, since no expectation was set, but argument was passed") + t.Error("arguments should not match, since argument was passed, but noArgs was set") + } + + e.noArgs = false + if err := e.argsMatches(against); err != nil { + t.Error("arguments should match, since argument was passed, but no expected args or noArgs was set") } e.args = []driver.Value{ diff --git a/expectations_test.go b/expectations_test.go index d99f1ab..0d1d5d1 100644 --- a/expectations_test.go +++ b/expectations_test.go @@ -102,6 +102,29 @@ func TestCustomValueConverterQueryScan(t *testing.T) { } } +func TestQueryWithNoArgsAndWithArgsPanic(t *testing.T) { + defer func() { + if r := recover(); r != nil { + return + } + t.Error("Expected panic for using WithArgs and ExpectNoArgs together") + }() + mock := &sqlmock{} + mock.ExpectQuery("SELECT (.+) FROM user").WithArgs("John").WithoutArgs() +} + +func TestExecWithNoArgsAndWithArgsPanic(t *testing.T) { + defer func() { + if r := recover(); r != nil { + return + } + t.Error("Expected panic for using WithArgs and ExpectNoArgs together") + }() + mock := &sqlmock{} + mock.ExpectExec("^INSERT INTO user").WithArgs("John").WithoutArgs() +} + + func TestQueryWillReturnsNil(t *testing.T) { t.Parallel() @@ -122,5 +145,4 @@ func TestQueryWillReturnsNil(t *testing.T) { _, err = mock.(*sqlmock).Query(query, []driver.Value{"test"}) if err != nil { t.Error(err) - } } diff --git a/options.go b/options.go index 00c9837..a57ae26 100644 --- a/options.go +++ b/options.go @@ -2,9 +2,12 @@ package sqlmock import "database/sql/driver" +// SqlMockOption is the type defining an option used to configure an SqlMock at creation +type SqlMockOption func(*sqlmock) error + // ValueConverterOption allows to create a sqlmock connection // with a custom ValueConverter to support drivers with special data types. -func ValueConverterOption(converter driver.ValueConverter) func(*sqlmock) error { +func ValueConverterOption(converter driver.ValueConverter) SqlMockOption { return func(s *sqlmock) error { s.converter = converter return nil @@ -14,7 +17,7 @@ func ValueConverterOption(converter driver.ValueConverter) func(*sqlmock) error // QueryMatcherOption allows to customize SQL query matcher // and match SQL query strings in more sophisticated ways. // The default QueryMatcher is QueryMatcherRegexp. -func QueryMatcherOption(queryMatcher QueryMatcher) func(*sqlmock) error { +func QueryMatcherOption(queryMatcher QueryMatcher) SqlMockOption { return func(s *sqlmock) error { s.queryMatcher = queryMatcher return nil @@ -30,7 +33,7 @@ func QueryMatcherOption(queryMatcher QueryMatcher) func(*sqlmock) error { // If false is passed or this option is omitted, calls to Ping will not be // considered when determining expectations and calls to ExpectPing will have // no effect. -func MonitorPingsOption(monitorPings bool) func(*sqlmock) error { +func MonitorPingsOption(monitorPings bool) SqlMockOption { return func(s *sqlmock) error { s.monitorPings = monitorPings return nil diff --git a/query.go b/query.go index 47d3796..54341a3 100644 --- a/query.go +++ b/query.go @@ -42,9 +42,12 @@ func (f QueryMatcherFunc) Match(expectedSQL, actualSQL string) error { // QueryMatcherRegexp is the default SQL query matcher // used by sqlmock. It parses expectedSQL to a regular // expression and attempts to match actualSQL. -var QueryMatcherRegexp QueryMatcher = QueryMatcherFunc(func(expectedSQL, actualSQL string) error { +var QueryMatcherRegexp QueryMatcher = QueryMatcherFunc(func(expectedSQL, actualSQL string) error { expect := stripQuery(expectedSQL) actual := stripQuery(actualSQL) + if actual != "" && expect == "" { + return fmt.Errorf("expectedSQL can't be empty") + } re, err := regexp.Compile(expect) if err != nil { return err diff --git a/query_test.go b/query_test.go index 0ba7bdc..514a6ad 100644 --- a/query_test.go +++ b/query_test.go @@ -69,6 +69,7 @@ func TestQueryMatcherRegexp(t *testing.T) { {"SELECT (.+) FROM users", "SELECT name, email FROM users WHERE id = ?", nil}, {"Select (.+) FROM users", "SELECT name, email FROM users WHERE id = ?", fmt.Errorf(`could not match actual sql: "SELECT name, email FROM users WHERE id = ?" with expected regexp "Select (.+) FROM users"`)}, {"SELECT (.+) FROM\nusers", "SELECT name, email\n FROM users\n WHERE id = ?", nil}, + {"","SELECT from table", fmt.Errorf(`expectedSQL can't be empty`)}, } for i, c := range cases { diff --git a/rows.go b/rows.go index 941544b..01ea811 100644 --- a/rows.go +++ b/rows.go @@ -4,6 +4,7 @@ import ( "bytes" "database/sql/driver" "encoding/csv" + "errors" "fmt" "io" "strings" @@ -14,7 +15,7 @@ const invalidate = "☠☠☠ MEMORY OVERWRITTEN ☠☠☠ " // CSVColumnParser is a function which converts trimmed csv // column string to a []byte representation. Currently // transforms NULL to nil -var CSVColumnParser = func(s string) []byte { +var CSVColumnParser = func(s string) interface{} { switch { case strings.ToLower(s) == "null": return nil @@ -165,7 +166,7 @@ func (r *Rows) RowError(row int, err error) *Rows { // of columns func (r *Rows) AddRow(values ...driver.Value) *Rows { if len(values) != len(r.cols) { - panic("Expected number of values to match number of columns") + panic(fmt.Sprintf("Expected number of values to match number of columns: expected %d, actual %d", len(values), len(r.cols))) } row := make([]driver.Value, len(r.cols)) @@ -208,8 +209,11 @@ func (r *Rows) FromCSVString(s string) *Rows { for { res, err := csvReader.Read() - if err != nil || res == nil { - break + if err != nil { + if errors.Is(err, io.EOF) { + break + } + panic(fmt.Sprintf("Parsing CSV string failed: %s", err.Error())) } row := make([]driver.Value, len(r.cols)) diff --git a/rows_test.go b/rows_test.go index ef17521..80f1476 100644 --- a/rows_test.go +++ b/rows_test.go @@ -432,7 +432,7 @@ func TestRowsScanError(t *testing.T) { func TestCSVRowParser(t *testing.T) { t.Parallel() - rs := NewRows([]string{"col1", "col2"}).FromCSVString("a,NULL") + rs := NewRows([]string{"col1", "col2", "col3"}).FromCSVString("a,NULL,NULL") db, mock, err := New() if err != nil { t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) @@ -448,9 +448,10 @@ func TestCSVRowParser(t *testing.T) { defer rw.Close() var col1 string var col2 []byte + var col3 *string rw.Next() - if err = rw.Scan(&col1, &col2); err != nil { + if err = rw.Scan(&col1, &col2, &col3); err != nil { t.Fatalf("unexpected error: %s", err) } if col1 != "a" { @@ -459,6 +460,18 @@ func TestCSVRowParser(t *testing.T) { if col2 != nil { t.Fatalf("expected col2 to be nil, but got [%T]:%+v", col2, col2) } + if col3 != nil { + t.Fatalf("expected col3 to be nil, but got [%T]:%+v", col3, col3) + } +} + +func TestCSVParserInvalidInput(t *testing.T) { + defer func() { + recover() + }() + _ = NewRows([]string{"col1", "col2"}).FromCSVString("a,\"NULL\"\"") + // shouldn't reach here + t.Error("expected panic from parsing invalid CSV") } func TestWrongNumberOfValues(t *testing.T) { @@ -717,6 +730,31 @@ func TestAddRows(t *testing.T) { // scanned id: 4 and title: Emily } +func TestAddRowExpectPanic(t *testing.T) { + t.Parallel() + + const expectedPanic = "Expected number of values to match number of columns: expected 1, actual 2" + values := []driver.Value{ + "John", + "Jane", + } + + defer func() { + if r := recover(); r != nil { + if r != expectedPanic { + t.Fatalf("panic message did not match expected: expected '%s', actual '%s'", r, expectedPanic) + } + + return + } + t.Fatalf("expected panic: %s", expectedPanic) + }() + + rows := NewRows([]string{"id", "name"}) + // Note missing spread "..." + rows.AddRow(values) +} + func ExampleRows_AddRows() { db, mock, err := New() if err != nil { diff --git a/sqlmock.go b/sqlmock.go index d074266..3ee1256 100644 --- a/sqlmock.go +++ b/sqlmock.go @@ -98,7 +98,7 @@ type sqlmock struct { expected []expectation } -func (c *sqlmock) open(options []func(*sqlmock) error) (*sql.DB, Sqlmock, error) { +func (c *sqlmock) open(options []SqlMockOption) (*sql.DB, Sqlmock, error) { db, err := sql.Open("sqlmock", c.dsn) if err != nil { return db, c, err diff --git a/sqlmock_go18_test.go b/sqlmock_go18_test.go index cf56e67..6267f38 100644 --- a/sqlmock_go18_test.go +++ b/sqlmock_go18_test.go @@ -1,3 +1,4 @@ +//go:build go1.8 // +build go1.8 package sqlmock @@ -437,7 +438,6 @@ func TestContextExecErrorDelay(t *testing.T) { // test that return of error is delayed var delay time.Duration = 100 * time.Millisecond mock.ExpectExec("^INSERT INTO articles"). - WithArgs("hello"). WillReturnError(errors.New("slow fail")). WillDelayFor(delay) diff --git a/sqlmock_test.go b/sqlmock_test.go index 982a32a..2129a16 100644 --- a/sqlmock_test.go +++ b/sqlmock_test.go @@ -749,6 +749,16 @@ func TestRunExecsWithExpectedErrorMeetsExpectations(t *testing.T) { } } +func TestRunExecsWithNoArgsExpectedMeetsExpectations(t *testing.T) { + db, dbmock, _ := New() + dbmock.ExpectExec("THE FIRST EXEC").WithoutArgs().WillReturnResult(NewResult(0, 0)) + + _, err := db.Exec("THE FIRST EXEC", "foobar") + if err == nil { + t.Fatalf("expected error, but there wasn't any") + } +} + func TestRunQueryWithExpectedErrorMeetsExpectations(t *testing.T) { db, dbmock, _ := New() dbmock.ExpectQuery("THE FIRST QUERY").WillReturnError(fmt.Errorf("big bad bug")) @@ -959,7 +969,7 @@ func TestPrepareExec(t *testing.T) { mock.ExpectBegin() ep := mock.ExpectPrepare("INSERT INTO ORDERS\\(ID, STATUS\\) VALUES \\(\\?, \\?\\)") for i := 0; i < 3; i++ { - ep.ExpectExec().WithArgs(i, "Hello"+strconv.Itoa(i)).WillReturnResult(NewResult(1, 1)) + ep.ExpectExec().WillReturnResult(NewResult(1, 1)) } mock.ExpectCommit() tx, _ := db.Begin() @@ -1073,7 +1083,7 @@ func TestPreparedStatementCloseExpectation(t *testing.T) { defer db.Close() ep := mock.ExpectPrepare("INSERT INTO ORDERS").WillBeClosed() - ep.ExpectExec().WithArgs(1, "Hello").WillReturnResult(NewResult(1, 1)) + ep.ExpectExec().WillReturnResult(NewResult(1, 1)) stmt, err := db.Prepare("INSERT INTO ORDERS(ID, STATUS) VALUES (?, ?)") if err != nil { @@ -1104,7 +1114,6 @@ func TestExecExpectationErrorDelay(t *testing.T) { // test that return of error is delayed var delay time.Duration = 100 * time.Millisecond mock.ExpectExec("^INSERT INTO articles"). - WithArgs("hello"). WillReturnError(errors.New("slow fail")). WillDelayFor(delay) @@ -1230,10 +1239,10 @@ func Test_sqlmock_Prepare_and_Exec(t *testing.T) { mock.ExpectPrepare("SELECT (.+) FROM users WHERE (.+)") expected := NewResult(1, 1) - mock.ExpectExec("SELECT (.+) FROM users WHERE (.+)").WithArgs("test"). + mock.ExpectExec("SELECT (.+) FROM users WHERE (.+)"). WillReturnResult(expected) expectedRows := mock.NewRows([]string{"id", "name", "email"}).AddRow(1, "test", "test@example.com") - mock.ExpectQuery("SELECT (.+) FROM users WHERE (.+)").WithArgs("test").WillReturnRows(expectedRows) + mock.ExpectQuery("SELECT (.+) FROM users WHERE (.+)").WillReturnRows(expectedRows) got, err := mock.(*sqlmock).Prepare(query) if err != nil { @@ -1326,7 +1335,7 @@ func Test_sqlmock_Query(t *testing.T) { } defer db.Close() expectedRows := mock.NewRows([]string{"id", "name", "email"}).AddRow(1, "test", "test@example.com") - mock.ExpectQuery("SELECT (.+) FROM users WHERE (.+)").WithArgs("test").WillReturnRows(expectedRows) + mock.ExpectQuery("SELECT (.+) FROM users WHERE (.+)").WillReturnRows(expectedRows) query := "SELECT name, email FROM users WHERE name = ?" rows, err := mock.(*sqlmock).Query(query, []driver.Value{"test"}) if err != nil { @@ -1340,3 +1349,19 @@ func Test_sqlmock_Query(t *testing.T) { return } } + +func Test_sqlmock_QueryExpectWithoutArgs(t *testing.T) { + db, mock, err := New() + if err != nil { + t.Errorf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + expectedRows := mock.NewRows([]string{"id", "name", "email"}).AddRow(1, "test", "test@example.com") + mock.ExpectQuery("SELECT (.+) FROM users WHERE (.+)").WillReturnRows(expectedRows).WithoutArgs() + query := "SELECT name, email FROM users WHERE name = ?" + _, err = mock.(*sqlmock).Query(query, []driver.Value{"test"}) + if err == nil { + t.Errorf("error expected") + return + } +}