package sqlmock import ( "database/sql/driver" "fmt" "reflect" "regexp" "strings" "sync" "time" ) // an expectation interface type expectation interface { fulfilled() bool Lock() Unlock() String() string } // common expectation struct // satisfies the expectation interface type commonExpectation struct { sync.Mutex triggered bool err error } func (e *commonExpectation) fulfilled() bool { return e.triggered } // ExpectedClose is used to manage *sql.DB.Close expectation // returned by *Sqlmock.ExpectClose. type ExpectedClose struct { commonExpectation } // WillReturnError allows to set an error for *sql.DB.Close action func (e *ExpectedClose) WillReturnError(err error) *ExpectedClose { e.err = err return e } // String returns string representation func (e *ExpectedClose) String() string { msg := "ExpectedClose => expecting database Close" if e.err != nil { msg += fmt.Sprintf(", which should return error: %s", e.err) } return msg } // ExpectedBegin is used to manage *sql.DB.Begin expectation // returned by *Sqlmock.ExpectBegin. type ExpectedBegin struct { commonExpectation delay time.Duration } // WillReturnError allows to set an error for *sql.DB.Begin action func (e *ExpectedBegin) WillReturnError(err error) *ExpectedBegin { e.err = err return e } // String returns string representation func (e *ExpectedBegin) String() string { msg := "ExpectedBegin => expecting database transaction Begin" if e.err != nil { msg += fmt.Sprintf(", which should return error: %s", e.err) } return msg } // WillDelayFor allows to specify duration for which it will delay // result. May be used together with Context func (e *ExpectedBegin) WillDelayFor(duration time.Duration) *ExpectedBegin { e.delay = duration return e } // ExpectedCommit is used to manage *sql.Tx.Commit expectation // returned by *Sqlmock.ExpectCommit. type ExpectedCommit struct { commonExpectation } // WillReturnError allows to set an error for *sql.Tx.Close action func (e *ExpectedCommit) WillReturnError(err error) *ExpectedCommit { e.err = err return e } // String returns string representation func (e *ExpectedCommit) String() string { msg := "ExpectedCommit => expecting transaction Commit" if e.err != nil { msg += fmt.Sprintf(", which should return error: %s", e.err) } return msg } // ExpectedRollback is used to manage *sql.Tx.Rollback expectation // returned by *Sqlmock.ExpectRollback. type ExpectedRollback struct { commonExpectation } // WillReturnError allows to set an error for *sql.Tx.Rollback action func (e *ExpectedRollback) WillReturnError(err error) *ExpectedRollback { e.err = err return e } // String returns string representation func (e *ExpectedRollback) String() string { msg := "ExpectedRollback => expecting transaction Rollback" if e.err != nil { msg += fmt.Sprintf(", which should return error: %s", e.err) } return msg } // ExpectedQuery is used to manage *sql.DB.Query, *dql.DB.QueryRow, *sql.Tx.Query, // *sql.Tx.QueryRow, *sql.Stmt.Query or *sql.Stmt.QueryRow expectations. // Returned by *Sqlmock.ExpectQuery. type ExpectedQuery struct { queryBasedExpectation rows driver.Rows delay time.Duration } // WithArgs will match given expected args to actual database query arguments. // if at least one argument does not match, it will return an error. For specific // arguments an sqlmock.Argument interface can be used to match an argument. func (e *ExpectedQuery) WithArgs(args ...driver.Value) *ExpectedQuery { e.args = args return e } // WillReturnError allows to set an error for expected database query func (e *ExpectedQuery) WillReturnError(err error) *ExpectedQuery { e.err = err return e } // WillReturnRows specifies the set of resulting rows that will be returned // by the triggered query func (e *ExpectedQuery) WillReturnRows(rows driver.Rows) *ExpectedQuery { e.rows = rows return e } // WillDelayFor allows to specify duration for which it will delay // result. May be used together with Context func (e *ExpectedQuery) WillDelayFor(duration time.Duration) *ExpectedQuery { e.delay = duration return e } // String returns string representation func (e *ExpectedQuery) String() string { msg := "ExpectedQuery => expecting Query or QueryRow which:" msg += "\n - matches sql: '" + e.sqlRegex.String() + "'" if len(e.args) == 0 { msg += "\n - is without arguments" } else { msg += "\n - is with arguments:\n" for i, arg := range e.args { msg += fmt.Sprintf(" %d - %+v\n", i, arg) } msg = strings.TrimSpace(msg) } if e.rows != nil { msg += "\n - should return rows:\n" rs, _ := e.rows.(*rows) for i, row := range rs.rows { msg += fmt.Sprintf(" %d - %+v\n", i, row) } msg = strings.TrimSpace(msg) } if e.err != nil { msg += fmt.Sprintf("\n - should return error: %s", e.err) } return msg } // ExpectedExec is used to manage *sql.DB.Exec, *sql.Tx.Exec or *sql.Stmt.Exec expectations. // Returned by *Sqlmock.ExpectExec. type ExpectedExec struct { queryBasedExpectation result driver.Result delay time.Duration } // WithArgs will match given expected args to actual database exec operation arguments. // if at least one argument does not match, it will return an error. For specific // arguments an sqlmock.Argument interface can be used to match an argument. func (e *ExpectedExec) WithArgs(args ...driver.Value) *ExpectedExec { e.args = args return e } // WillReturnError allows to set an error for expected database exec action func (e *ExpectedExec) WillReturnError(err error) *ExpectedExec { e.err = err return e } // WillDelayFor allows to specify duration for which it will delay // result. May be used together with Context func (e *ExpectedExec) WillDelayFor(duration time.Duration) *ExpectedExec { e.delay = duration return e } // String returns string representation func (e *ExpectedExec) String() string { msg := "ExpectedExec => expecting Exec which:" msg += "\n - matches sql: '" + e.sqlRegex.String() + "'" if len(e.args) == 0 { msg += "\n - is without arguments" } else { msg += "\n - is with arguments:\n" var margs []string for i, arg := range e.args { margs = append(margs, fmt.Sprintf(" %d - %+v", i, arg)) } msg += strings.Join(margs, "\n") } if e.result != nil { res, _ := e.result.(*result) msg += "\n - should return Result having:" msg += fmt.Sprintf("\n LastInsertId: %d", res.insertID) msg += fmt.Sprintf("\n RowsAffected: %d", res.rowsAffected) if res.err != nil { msg += fmt.Sprintf("\n Error: %s", res.err) } } if e.err != nil { msg += fmt.Sprintf("\n - should return error: %s", e.err) } return msg } // WillReturnResult arranges for an expected Exec() to return a particular // result, there is sqlmock.NewResult(lastInsertID int64, affectedRows int64) method // to build a corresponding result. Or if actions needs to be tested against errors // sqlmock.NewErrorResult(err error) to return a given error. func (e *ExpectedExec) WillReturnResult(result driver.Result) *ExpectedExec { e.result = result return e } // ExpectedPrepare is used to manage *sql.DB.Prepare or *sql.Tx.Prepare expectations. // Returned by *Sqlmock.ExpectPrepare. type ExpectedPrepare struct { commonExpectation mock *sqlmock sqlRegex *regexp.Regexp statement driver.Stmt closeErr error delay time.Duration } // WillReturnError allows to set an error for the expected *sql.DB.Prepare or *sql.Tx.Prepare action. func (e *ExpectedPrepare) WillReturnError(err error) *ExpectedPrepare { e.err = err return e } // WillReturnCloseError allows to set an error for this prapared statement Close action func (e *ExpectedPrepare) WillReturnCloseError(err error) *ExpectedPrepare { e.closeErr = err return e } // WillDelayFor allows to specify duration for which it will delay // result. May be used together with Context func (e *ExpectedPrepare) WillDelayFor(duration time.Duration) *ExpectedPrepare { e.delay = duration return e } // ExpectQuery allows to expect Query() or QueryRow() on this prepared statement. // this method is convenient in order to prevent duplicating sql query string matching. func (e *ExpectedPrepare) ExpectQuery() *ExpectedQuery { eq := &ExpectedQuery{} eq.sqlRegex = e.sqlRegex e.mock.expected = append(e.mock.expected, eq) return eq } // ExpectExec allows to expect Exec() on this prepared statement. // this method is convenient in order to prevent duplicating sql query string matching. func (e *ExpectedPrepare) ExpectExec() *ExpectedExec { eq := &ExpectedExec{} eq.sqlRegex = e.sqlRegex e.mock.expected = append(e.mock.expected, eq) return eq } // String returns string representation func (e *ExpectedPrepare) String() string { msg := "ExpectedPrepare => expecting Prepare statement which:" msg += "\n - matches sql: '" + e.sqlRegex.String() + "'" if e.err != nil { msg += fmt.Sprintf("\n - should return error: %s", e.err) } if e.closeErr != nil { msg += fmt.Sprintf("\n - should return error on Close: %s", e.closeErr) } return msg } // query based expectation // adds a query matching logic type queryBasedExpectation struct { commonExpectation sqlRegex *regexp.Regexp args []driver.Value } func (e *queryBasedExpectation) attemptMatch(sql string, args []namedValue) (err error) { if !e.queryMatches(sql) { return fmt.Errorf(`could not match sql: "%s" with expected regexp "%s"`, sql, e.sqlRegex.String()) } // catch panic defer func() { if e := recover(); e != nil { _, ok := e.(error) if !ok { err = fmt.Errorf(e.(string)) } } }() err = e.argsMatches(args) return } func (e *queryBasedExpectation) queryMatches(sql string) bool { return e.sqlRegex.MatchString(sql) } func (e *queryBasedExpectation) argsMatches(args []namedValue) error { if nil == e.args { return nil } if len(args) != len(e.args) { return fmt.Errorf("expected %d, but got %d arguments", len(e.args), len(args)) } for k, v := range args { // custom argument matcher matcher, ok := e.args[k].(Argument) if ok { // @TODO: does it make sense to pass value instead of named value? if !matcher.Match(v.Value) { return fmt.Errorf("matcher %T could not match %d argument %T - %+v", matcher, k, args[k], args[k]) } continue } dval := e.args[k] if named, isNamed := dval.(namedValue); isNamed { dval = named.Value if v.Name != named.Name { return fmt.Errorf("named argument %d: name: \"%s\" does not match expected: \"%s\"", k, v.Name, named.Name) } if v.Ordinal != named.Ordinal { return fmt.Errorf("named argument %d: ordinal position: \"%d\" does not match expected: \"%d\"", k, v.Ordinal, named.Ordinal) } } // convert to driver converter darg, err := driver.DefaultParameterConverter.ConvertValue(dval) if err != nil { return fmt.Errorf("could not convert %d argument %T - %+v to driver value: %s", k, e.args[k], e.args[k], err) } if !driver.IsValue(darg) { return fmt.Errorf("argument %d: non-subset type %T returned from Value", k, darg) } if !reflect.DeepEqual(darg, v.Value) { return fmt.Errorf("argument %d expected [%T - %+v] does not match actual [%T - %+v]", k, darg, darg, v.Value, v.Value) } } return nil }