diff --git a/connection.go b/connection.go index 00cf5cb..bacbb8c 100644 --- a/connection.go +++ b/connection.go @@ -1,118 +1,118 @@ package sqlmock import ( - "database/sql/driver" - "fmt" + "database/sql/driver" + "fmt" ) type conn struct { - expectations []expectation - active expectation + expectations []expectation + active expectation } // Close a mock database driver connection. It should // be always called to ensure that all expectations // were met successfully. Returns error if there is any func (c *conn) Close() (err error) { - for _, e := range mock.conn.expectations { - if !e.fulfilled() { - err = fmt.Errorf("there is a remaining expectation %T which was not matched yet", e) - break - } - } - mock.conn.expectations = []expectation{} - mock.conn.active = nil - return err + for _, e := range mock.conn.expectations { + if !e.fulfilled() { + err = fmt.Errorf("there is a remaining expectation %T which was not matched yet", e) + break + } + } + mock.conn.expectations = []expectation{} + mock.conn.active = nil + return err } func (c *conn) Begin() (driver.Tx, error) { - e := c.next() - if e == nil { - return nil, fmt.Errorf("all expectations were already fulfilled, call to begin transaction was not expected") - } + e := c.next() + if e == nil { + return nil, fmt.Errorf("all expectations were already fulfilled, call to begin transaction was not expected") + } - etb, ok := e.(*expectedBegin) - if !ok { - return nil, fmt.Errorf("call to begin transaction, was not expected, next expectation is %T as %+v", e, e) - } - etb.triggered = true - return &transaction{c}, etb.err + etb, ok := e.(*expectedBegin) + if !ok { + return nil, fmt.Errorf("call to begin transaction, was not expected, next expectation is %T as %+v", e, e) + } + etb.triggered = true + return &transaction{c}, etb.err } // get next unfulfilled expectation func (c *conn) next() (e expectation) { - for _, e = range c.expectations { - if !e.fulfilled() { - return - } - } - return nil // all expectations were fulfilled + for _, e = range c.expectations { + if !e.fulfilled() { + return + } + } + return nil // all expectations were fulfilled } func (c *conn) Exec(query string, args []driver.Value) (driver.Result, error) { - e := c.next() - query = stripQuery(query) - if e == nil { - return nil, fmt.Errorf("all expectations were already fulfilled, call to exec '%s' query with args %+v was not expected", query, args) - } + e := c.next() + query = stripQuery(query) + if e == nil { + return nil, fmt.Errorf("all expectations were already fulfilled, call to exec '%s' query with args %+v was not expected", query, args) + } - eq, ok := e.(*expectedExec) - if !ok { - return nil, fmt.Errorf("call to exec query '%s' with args %+v, was not expected, next expectation is %T as %+v", query, args, e, e) - } + eq, ok := e.(*expectedExec) + if !ok { + return nil, fmt.Errorf("call to exec query '%s' with args %+v, was not expected, next expectation is %T as %+v", query, args, e, e) + } - eq.triggered = true - if eq.err != nil { - return nil, eq.err // mocked to return error - } + eq.triggered = true + if eq.err != nil { + return nil, eq.err // mocked to return error + } - if eq.result == nil { - return nil, fmt.Errorf("exec query '%s' with args %+v, must return a database/sql/driver.result, but it was not set for expectation %T as %+v", query, args, eq, eq) - } + if eq.result == nil { + return nil, fmt.Errorf("exec query '%s' with args %+v, must return a database/sql/driver.result, but it was not set for expectation %T as %+v", query, args, eq, eq) + } - if !eq.queryMatches(query) { - return nil, fmt.Errorf("exec query '%s', does not match regex '%s'", query, eq.sqlRegex.String()) - } + if !eq.queryMatches(query) { + return nil, fmt.Errorf("exec query '%s', does not match regex '%s'", query, eq.sqlRegex.String()) + } - if !eq.argsMatches(args) { - return nil, fmt.Errorf("exec query '%s', args %+v does not match expected %+v", query, args, eq.args) - } + if !eq.argsMatches(args) { + return nil, fmt.Errorf("exec query '%s', args %+v does not match expected %+v", query, args, eq.args) + } - return eq.result, nil + return eq.result, nil } func (c *conn) Prepare(query string) (driver.Stmt, error) { - return &statement{mock.conn, stripQuery(query)}, nil + return &statement{mock.conn, stripQuery(query)}, nil } func (c *conn) Query(query string, args []driver.Value) (driver.Rows, error) { - e := c.next() - query = stripQuery(query) - if e == nil { - return nil, fmt.Errorf("all expectations were already fulfilled, call to query '%s' with args %+v was not expected", query, args) - } + e := c.next() + query = stripQuery(query) + if e == nil { + return nil, fmt.Errorf("all expectations were already fulfilled, call to query '%s' with args %+v was not expected", query, args) + } - eq, ok := e.(*expectedQuery) - if !ok { - return nil, fmt.Errorf("call to query '%s' with args %+v, was not expected, next expectation is %T as %+v", query, args, e, e) - } + eq, ok := e.(*expectedQuery) + if !ok { + return nil, fmt.Errorf("call to query '%s' with args %+v, was not expected, next expectation is %T as %+v", query, args, e, e) + } - eq.triggered = true - if eq.err != nil { - return nil, eq.err // mocked to return error - } + eq.triggered = true + if eq.err != nil { + return nil, eq.err // mocked to return error + } - if eq.rows == nil { - return nil, fmt.Errorf("query '%s' with args %+v, must return a database/sql/driver.rows, but it was not set for expectation %T as %+v", query, args, eq, eq) - } + if eq.rows == nil { + return nil, fmt.Errorf("query '%s' with args %+v, must return a database/sql/driver.rows, but it was not set for expectation %T as %+v", query, args, eq, eq) + } - if !eq.queryMatches(query) { - return nil, fmt.Errorf("query '%s', does not match regex [%s]", query, eq.sqlRegex.String()) - } + if !eq.queryMatches(query) { + return nil, fmt.Errorf("query '%s', does not match regex [%s]", query, eq.sqlRegex.String()) + } - if !eq.argsMatches(args) { - return nil, fmt.Errorf("query '%s', args %+v does not match expected %+v", query, args, eq.args) - } + if !eq.argsMatches(args) { + return nil, fmt.Errorf("query '%s', args %+v does not match expected %+v", query, args, eq.args) + } - return eq.rows, nil + return eq.rows, nil } diff --git a/expectations.go b/expectations.go index 269fcba..cc85dfc 100644 --- a/expectations.go +++ b/expectations.go @@ -1,106 +1,119 @@ package sqlmock import ( - "database/sql/driver" - "reflect" - "regexp" + "database/sql/driver" + "reflect" + "regexp" ) +// Argument interface allows to match +// any argument in specific way +type Argument interface { + Match(driver.Value) bool +} + // an expectation interface type expectation interface { - fulfilled() bool - setError(err error) + fulfilled() bool + setError(err error) } // common expectation struct // satisfies the expectation interface type commonExpectation struct { - triggered bool - err error + triggered bool + err error } func (e *commonExpectation) fulfilled() bool { - return e.triggered + return e.triggered } func (e *commonExpectation) setError(err error) { - e.err = err + e.err = err } // query based expectation // adds a query matching logic type queryBasedExpectation struct { - commonExpectation - sqlRegex *regexp.Regexp - args []driver.Value + commonExpectation + sqlRegex *regexp.Regexp + args []driver.Value } func (e *queryBasedExpectation) queryMatches(sql string) bool { - return e.sqlRegex.MatchString(sql) + return e.sqlRegex.MatchString(sql) } func (e *queryBasedExpectation) argsMatches(args []driver.Value) bool { - if nil == e.args { - return true - } - if len(args) != len(e.args) { - return false - } - for k, v := range args { - vi := reflect.ValueOf(v) - ai := reflect.ValueOf(e.args[k]) - switch vi.Kind() { - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - if vi.Int() != ai.Int() { - return false - } - case reflect.Float32, reflect.Float64: - if vi.Float() != ai.Float() { - return false - } - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - if vi.Uint() != ai.Uint() { - return false - } - case reflect.String: - if vi.String() != ai.String() { - return false - } - default: - // compare types like time.Time based on type only - if vi.Kind() != ai.Kind() { - return false - } - } - } - return true + if nil == e.args { + return true + } + if len(args) != len(e.args) { + return false + } + for k, v := range args { + matcher, ok := e.args[k].(Argument) + if ok { + if !matcher.Match(v) { + return false + } + continue + } + vi := reflect.ValueOf(v) + ai := reflect.ValueOf(e.args[k]) + switch vi.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + if vi.Int() != ai.Int() { + return false + } + case reflect.Float32, reflect.Float64: + if vi.Float() != ai.Float() { + return false + } + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + if vi.Uint() != ai.Uint() { + return false + } + case reflect.String: + if vi.String() != ai.String() { + return false + } + default: + // compare types like time.Time based on type only + if vi.Kind() != ai.Kind() { + return false + } + } + } + return true } // begin transaction type expectedBegin struct { - commonExpectation + commonExpectation } // tx commit type expectedCommit struct { - commonExpectation + commonExpectation } // tx rollback type expectedRollback struct { - commonExpectation + commonExpectation } // query expectation type expectedQuery struct { - queryBasedExpectation + queryBasedExpectation - rows driver.Rows + rows driver.Rows } // exec query expectation type expectedExec struct { - queryBasedExpectation + queryBasedExpectation - result driver.Result + result driver.Result } diff --git a/expectations_test.go b/expectations_test.go index 699d9b5..e16b11d 100644 --- a/expectations_test.go +++ b/expectations_test.go @@ -1,47 +1,59 @@ package sqlmock import ( - "database/sql/driver" - "testing" - "time" + "database/sql/driver" + "testing" + "time" ) -func TestQueryExpectationArgComparison(t *testing.T) { - e := &queryBasedExpectation{} - against := []driver.Value{5} - if !e.argsMatches(against) { - t.Error("arguments should match, since the no expectation was set") - } - - e.args = []driver.Value{5, "str"} - - against = []driver.Value{5} - if e.argsMatches(against) { - t.Error("arguments should not match, since the size is not the same") - } - - against = []driver.Value{3, "str"} - if e.argsMatches(against) { - t.Error("arguments should not match, since the first argument (int value) is different") - } - - against = []driver.Value{5, "st"} - if e.argsMatches(against) { - t.Error("arguments should not match, since the second argument (string value) is different") - } - - against = []driver.Value{5, "str"} - if !e.argsMatches(against) { - t.Error("arguments should match, but it did not") - } - - e.args = []driver.Value{5, time.Now()} - - const longForm = "Jan 2, 2006 at 3:04pm (MST)" - tm, _ := time.Parse(longForm, "Feb 3, 2013 at 7:54pm (PST)") - - against = []driver.Value{5, tm} - if !e.argsMatches(against) { - t.Error("arguments should match (time will be compared only by type), but it did not") - } +type matcher struct { +} + +func (m matcher) Match(driver.Value) bool { + return true +} + +func TestQueryExpectationArgComparison(t *testing.T) { + e := &queryBasedExpectation{} + against := []driver.Value{5} + if !e.argsMatches(against) { + t.Error("arguments should match, since the no expectation was set") + } + + e.args = []driver.Value{5, "str"} + + against = []driver.Value{5} + if e.argsMatches(against) { + t.Error("arguments should not match, since the size is not the same") + } + + against = []driver.Value{3, "str"} + if e.argsMatches(against) { + t.Error("arguments should not match, since the first argument (int value) is different") + } + + against = []driver.Value{5, "st"} + if e.argsMatches(against) { + t.Error("arguments should not match, since the second argument (string value) is different") + } + + against = []driver.Value{5, "str"} + if !e.argsMatches(against) { + t.Error("arguments should match, but it did not") + } + + e.args = []driver.Value{5, time.Now()} + + const longForm = "Jan 2, 2006 at 3:04pm (MST)" + tm, _ := time.Parse(longForm, "Feb 3, 2013 at 7:54pm (PST)") + + against = []driver.Value{5, tm} + if !e.argsMatches(against) { + t.Error("arguments should match (time will be compared only by type), but it did not") + } + + against = []driver.Value{5, matcher{}} + if !e.argsMatches(against) { + t.Error("arguments should match, but it did not") + } }