diff --git a/stubs_test.go b/stubs_test.go index 71a8ce3..ebbb749 100644 --- a/stubs_test.go +++ b/stubs_test.go @@ -19,25 +19,18 @@ type NullInt struct { } // Satisfy sql.Scanner interface -func (ni *NullInt) Scan(value interface{}) (err error) { - if value == nil { - ni.Integer, ni.Valid = 0, false - return - } - +func (ni *NullInt) Scan(value interface{}) error { switch v := value.(type) { + case nil: + ni.Integer, ni.Valid = 0, false case int: ni.Integer, ni.Valid = v, true - return case int8: ni.Integer, ni.Valid = int(v), true - return case int16: ni.Integer, ni.Valid = int(v), true - return case int32: ni.Integer, ni.Valid = int(v), true - return case int64: const maxUint = ^uint(0) const minUint = 0 @@ -48,25 +41,23 @@ func (ni *NullInt) Scan(value interface{}) (err error) { return errors.New("value out of int range") } ni.Integer, ni.Valid = int(v), true - return case []byte: n, err := strconv.Atoi(string(v)) if err != nil { return err } ni.Integer, ni.Valid = n, true - return nil case string: n, err := strconv.Atoi(v) if err != nil { return err } ni.Integer, ni.Valid = n, true - return nil + default: + ni.Valid = false + return fmt.Errorf("Can't convert %T to integer", value) } - - ni.Valid = false - return fmt.Errorf("Can't convert %T to integer", value) + return nil } // Satisfy sql.Valuer interface. @@ -78,20 +69,17 @@ func (ni NullInt) Value() (driver.Value, error) { } // Satisfy sql.Scanner interface -func (nt *NullTime) Scan(value interface{}) (err error) { - if value == nil { - nt.Time, nt.Valid = time.Time{}, false - return - } - +func (nt *NullTime) Scan(value interface{}) error { switch v := value.(type) { + case nil: + nt.Time, nt.Valid = time.Time{}, false case time.Time: nt.Time, nt.Valid = v, true - return + default: + nt.Valid = false + return fmt.Errorf("Can't convert %T to time.Time", value) } - - nt.Valid = false - return fmt.Errorf("Can't convert %T to time.Time", value) + return nil } // Satisfy sql.Valuer interface.