mirror of
https://github.com/DATA-DOG/go-sqlmock.git
synced 2024-11-21 17:17:08 +02:00
added tests
This commit is contained in:
parent
e062dfc202
commit
5a7ddb9845
@ -161,6 +161,7 @@ func (c *sqlmock) ExpectPing() *ExpectedPing {
|
||||
}
|
||||
|
||||
// Query meets http://golang.org/pkg/database/sql/driver/#Queryer
|
||||
// Deprecated: Drivers should implement QueryerContext instead.
|
||||
func (c *sqlmock) Query(query string, args []driver.Value) (driver.Rows, error) {
|
||||
namedArgs := make([]driver.NamedValue, len(args))
|
||||
for i, v := range args {
|
||||
@ -243,6 +244,7 @@ func (c *sqlmock) query(query string, args []driver.NamedValue) (*ExpectedQuery,
|
||||
}
|
||||
|
||||
// Exec meets http://golang.org/pkg/database/sql/driver/#Execer
|
||||
// Deprecated: Drivers should implement ExecerContext instead.
|
||||
func (c *sqlmock) Exec(query string, args []driver.Value) (driver.Result, error) {
|
||||
namedArgs := make([]driver.NamedValue, len(args))
|
||||
for i, v := range args {
|
||||
|
@ -10,8 +10,6 @@ import (
|
||||
// CheckNamedValue meets https://golang.org/pkg/database/sql/driver/#NamedValueChecker
|
||||
func (c *sqlmock) CheckNamedValue(nv *driver.NamedValue) (err error) {
|
||||
switch nv.Value.(type) {
|
||||
case sql.NamedArg:
|
||||
return nil
|
||||
case sql.Out:
|
||||
return nil
|
||||
default:
|
||||
|
@ -3,6 +3,8 @@
|
||||
package sqlmock
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"errors"
|
||||
"testing"
|
||||
)
|
||||
@ -37,3 +39,32 @@ func TestStatementTX(t *testing.T) {
|
||||
t.Fatalf("unexpected result: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_sqlmock_CheckNamedValue(t *testing.T) {
|
||||
db, mock, err := New()
|
||||
if err != nil {
|
||||
t.Errorf("an error '%s' was not expected when opening a stub database connection", err)
|
||||
}
|
||||
defer db.Close()
|
||||
tests := []struct {
|
||||
name string
|
||||
arg *driver.NamedValue
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
arg: &driver.NamedValue{Name: "test", Value: "test"},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
arg: &driver.NamedValue{Name: "test", Value: sql.Out{}},
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if err := mock.(*sqlmock).CheckNamedValue(tt.arg); (err != nil) != tt.wantErr {
|
||||
t.Errorf("CheckNamedValue() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -2,8 +2,10 @@ package sqlmock
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"sync"
|
||||
"testing"
|
||||
@ -1217,3 +1219,44 @@ func queryWithTimeout(t time.Duration, db *sql.DB, query string, args ...interfa
|
||||
return nil, fmt.Errorf("query timed out after %v", t)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_sqlmock_Prepare_and_Exec(t *testing.T) {
|
||||
db, mock, err := New()
|
||||
if err != nil {
|
||||
t.Errorf("an error '%s' was not expected when opening a stub database connection", err)
|
||||
}
|
||||
defer db.Close()
|
||||
query := "SELECT name, email FROM users WHERE name = ?"
|
||||
|
||||
mock.ExpectPrepare("SELECT (.+) FROM users WHERE (.+)")
|
||||
expected := NewResult(1, 1)
|
||||
mock.ExpectExec("SELECT (.+) FROM users WHERE (.+)").
|
||||
WillReturnResult(expected)
|
||||
expectedRows := mock.NewRows([]string{"id", "name", "email"}).AddRow(1, "test", "test@example.com")
|
||||
mock.ExpectQuery("SELECT (.+) FROM users WHERE (.+)").WillReturnRows(expectedRows)
|
||||
|
||||
got, err := mock.(*sqlmock).Prepare(query)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
if got == nil {
|
||||
t.Error("Prepare () stmt must not be nil")
|
||||
return
|
||||
}
|
||||
result, err := got.Exec([]driver.Value{"test"})
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(result, expected) {
|
||||
t.Errorf("Eesults not equal. Expected: %v, Actual: %v", expected, result)
|
||||
return
|
||||
}
|
||||
rows, err := got.Query([]driver.Value{"test"})
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
}
|
||||
|
12
statement.go
12
statement.go
@ -1,9 +1,5 @@
|
||||
package sqlmock
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
)
|
||||
|
||||
type statement struct {
|
||||
conn *sqlmock
|
||||
ex *ExpectedPrepare
|
||||
@ -18,11 +14,3 @@ func (stmt *statement) Close() error {
|
||||
func (stmt *statement) NumInput() int {
|
||||
return -1
|
||||
}
|
||||
|
||||
func (stmt *statement) Exec(args []driver.Value) (driver.Result, error) {
|
||||
return stmt.conn.Exec(stmt.query, args)
|
||||
}
|
||||
|
||||
func (stmt *statement) Query(args []driver.Value) (driver.Rows, error) {
|
||||
return stmt.conn.Query(stmt.query, args)
|
||||
}
|
||||
|
17
statement_before_go18.go
Normal file
17
statement_before_go18.go
Normal file
@ -0,0 +1,17 @@
|
||||
// +build !go1.8
|
||||
|
||||
package sqlmock
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
)
|
||||
|
||||
// Deprecated: Drivers should implement ExecerContext instead.
|
||||
func (stmt *statement) Exec(args []driver.Value) (driver.Result, error) {
|
||||
return stmt.conn.Exec(stmt.query, args)
|
||||
}
|
||||
|
||||
// Deprecated: Drivers should implement StmtQueryContext instead (or additionally).
|
||||
func (stmt *statement) Query(args []driver.Value) (driver.Rows, error) {
|
||||
return stmt.conn.Query(stmt.query, args)
|
||||
}
|
26
statement_go18.go
Normal file
26
statement_go18.go
Normal file
@ -0,0 +1,26 @@
|
||||
// +build go1.8
|
||||
|
||||
package sqlmock
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql/driver"
|
||||
)
|
||||
|
||||
// Deprecated: Drivers should implement ExecerContext instead.
|
||||
func (stmt *statement) Exec(args []driver.Value) (driver.Result, error) {
|
||||
return stmt.conn.ExecContext(context.Background(), stmt.query, convertValueToNamedValue(args))
|
||||
}
|
||||
|
||||
// Deprecated: Drivers should implement StmtQueryContext instead (or additionally).
|
||||
func (stmt *statement) Query(args []driver.Value) (driver.Rows, error) {
|
||||
return stmt.conn.QueryContext(context.Background(), stmt.query, convertValueToNamedValue(args))
|
||||
}
|
||||
|
||||
func convertValueToNamedValue(args []driver.Value) []driver.NamedValue {
|
||||
namedArgs := make([]driver.NamedValue, len(args))
|
||||
for i, v := range args {
|
||||
namedArgs[i] = driver.NamedValue{Ordinal: i + 1, Value: v}
|
||||
}
|
||||
return namedArgs
|
||||
}
|
Loading…
Reference in New Issue
Block a user