1
0
mirror of https://github.com/volatiletech/authboss.git synced 2024-11-24 08:42:17 +02:00
authboss/mocks/mocks.go

745 lines
18 KiB
Go
Raw Normal View History

2015-03-16 23:42:45 +02:00
// Package mocks defines implemented interfaces for testing modules
package mocks
import (
"context"
2015-08-02 20:51:35 +02:00
"io"
"net/http"
2015-08-02 20:51:35 +02:00
"net/url"
"strings"
2015-02-11 09:03:02 +02:00
"time"
"github.com/pkg/errors"
2017-07-31 04:39:33 +02:00
"github.com/volatiletech/authboss"
)
// User represents all possible fields a authboss User may have
type User struct {
2015-02-11 09:03:02 +02:00
Username string
Email string
Password string
RecoverSelector string
RecoverVerifier string
2015-02-11 09:03:02 +02:00
RecoverTokenExpiry time.Time
ConfirmSelector string
ConfirmVerifier string
Confirmed bool
AttemptCount int
LastAttempt time.Time
Locked time.Time
OAuth2UID string
OAuth2Provider string
OAuth2Token string
OAuth2Refresh string
OAuth2Expiry time.Time
2018-08-27 01:49:16 +02:00
OTPs string
TOTPSecretKey string
SMSPhoneNumber string
RecoveryCodes string
2018-07-18 00:25:25 +02:00
2018-08-31 20:38:53 +02:00
SMSPhoneNumberSeed string
Arbitrary map[string]string
}
2018-05-14 19:55:56 +02:00
// GetPID from user
func (u User) GetPID() string { return u.Email }
2018-05-14 19:55:56 +02:00
// GetEmail from user
func (u User) GetEmail() string { return u.Email }
2018-05-14 19:55:56 +02:00
// GetUsername from user
func (u User) GetUsername() string { return u.Username }
2018-05-14 19:55:56 +02:00
// GetPassword from user
func (u User) GetPassword() string { return u.Password }
2018-05-14 19:55:56 +02:00
// GetRecoverSelector from user
func (u User) GetRecoverSelector() string { return u.RecoverSelector }
// GetRecoverVerifier from user
func (u User) GetRecoverVerifier() string { return u.RecoverVerifier }
2018-05-14 19:55:56 +02:00
// GetRecoverExpiry from user
func (u User) GetRecoverExpiry() time.Time { return u.RecoverTokenExpiry }
2018-05-14 19:55:56 +02:00
// GetConfirmSelector from user
func (u User) GetConfirmSelector() string { return u.ConfirmSelector }
// GetConfirmVerifier from user
func (u User) GetConfirmVerifier() string { return u.ConfirmVerifier }
2018-05-14 19:55:56 +02:00
// GetConfirmed from user
func (u User) GetConfirmed() bool { return u.Confirmed }
2018-05-14 19:55:56 +02:00
// GetAttemptCount from user
func (u User) GetAttemptCount() int { return u.AttemptCount }
2018-05-14 19:55:56 +02:00
// GetLastAttempt from user
func (u User) GetLastAttempt() time.Time { return u.LastAttempt }
2018-05-14 19:55:56 +02:00
// GetLocked from user
func (u User) GetLocked() time.Time { return u.Locked }
2018-05-14 19:55:56 +02:00
// IsOAuth2User returns true if the user is an oauth2 user
func (u User) IsOAuth2User() bool { return len(u.OAuth2Provider) != 0 }
2018-05-14 19:55:56 +02:00
// GetOAuth2UID from user
func (u User) GetOAuth2UID() string { return u.OAuth2UID }
2018-05-14 19:55:56 +02:00
// GetOAuth2Provider from user
func (u User) GetOAuth2Provider() string { return u.OAuth2Provider }
2018-05-14 19:55:56 +02:00
// GetOAuth2AccessToken from user
func (u User) GetOAuth2AccessToken() string { return u.OAuth2Token }
2018-05-14 19:55:56 +02:00
// GetOAuth2RefreshToken from user
func (u User) GetOAuth2RefreshToken() string { return u.OAuth2Refresh }
2018-05-14 19:55:56 +02:00
// GetOAuth2Expiry from user
func (u User) GetOAuth2Expiry() time.Time { return u.OAuth2Expiry }
2018-05-14 19:55:56 +02:00
// GetArbitrary from user
func (u User) GetArbitrary() map[string]string { return u.Arbitrary }
2018-07-18 00:25:25 +02:00
// GetOTPs from user
func (u User) GetOTPs() string { return u.OTPs }
2018-08-27 01:49:16 +02:00
// GetTOTPSecretKey from user
func (u User) GetTOTPSecretKey() string { return u.TOTPSecretKey }
// GetSMSPhoneNumber from user
func (u User) GetSMSPhoneNumber() string { return u.SMSPhoneNumber }
2018-08-31 20:38:53 +02:00
// GetSMSPhoneNumber from user
func (u User) GetSMSPhoneNumberSeed() string { return u.SMSPhoneNumberSeed }
2018-08-27 01:49:16 +02:00
// GetRecoveryCodes from user
func (u User) GetRecoveryCodes() string { return u.RecoveryCodes }
2018-05-14 19:55:56 +02:00
// PutPID into user
func (u *User) PutPID(email string) { u.Email = email }
2018-05-14 19:55:56 +02:00
// PutUsername into user
func (u *User) PutUsername(username string) { u.Username = username }
2018-05-14 19:55:56 +02:00
// PutEmail into user
func (u *User) PutEmail(email string) { u.Email = email }
2018-05-14 19:55:56 +02:00
// PutPassword into user
func (u *User) PutPassword(password string) { u.Password = password }
2018-05-14 19:55:56 +02:00
// PutRecoverSelector into user
func (u *User) PutRecoverSelector(recoverSelector string) { u.RecoverSelector = recoverSelector }
// PutRecoverVerifier into user
func (u *User) PutRecoverVerifier(recoverVerifier string) { u.RecoverVerifier = recoverVerifier }
2018-05-14 19:55:56 +02:00
// PutRecoverExpiry into user
func (u *User) PutRecoverExpiry(recoverTokenExpiry time.Time) {
u.RecoverTokenExpiry = recoverTokenExpiry
}
2018-05-14 19:55:56 +02:00
// PutConfirmSelector into user
func (u *User) PutConfirmSelector(confirmSelector string) { u.ConfirmSelector = confirmSelector }
// PutConfirmVerifier into user
func (u *User) PutConfirmVerifier(confirmVerifier string) { u.ConfirmVerifier = confirmVerifier }
2018-05-14 19:55:56 +02:00
// PutConfirmed into user
func (u *User) PutConfirmed(confirmed bool) { u.Confirmed = confirmed }
2018-05-14 19:55:56 +02:00
// PutAttemptCount into user
func (u *User) PutAttemptCount(attemptCount int) { u.AttemptCount = attemptCount }
2018-05-14 19:55:56 +02:00
// PutLastAttempt into user
func (u *User) PutLastAttempt(attemptTime time.Time) { u.LastAttempt = attemptTime }
2018-05-14 19:55:56 +02:00
// PutLocked into user
func (u *User) PutLocked(locked time.Time) { u.Locked = locked }
2018-05-14 19:55:56 +02:00
// PutOAuth2UID into user
func (u *User) PutOAuth2UID(uid string) { u.OAuth2UID = uid }
2018-05-14 19:55:56 +02:00
// PutOAuth2Provider into user
func (u *User) PutOAuth2Provider(provider string) { u.OAuth2Provider = provider }
2018-05-14 19:55:56 +02:00
// PutOAuth2AccessToken into user
func (u *User) PutOAuth2AccessToken(token string) { u.OAuth2Token = token }
2018-05-14 19:55:56 +02:00
// PutOAuth2RefreshToken into user
func (u *User) PutOAuth2RefreshToken(refresh string) { u.OAuth2Refresh = refresh }
2018-05-14 19:55:56 +02:00
// PutOAuth2Expiry into user
func (u *User) PutOAuth2Expiry(expiry time.Time) { u.OAuth2Expiry = expiry }
2018-05-14 19:55:56 +02:00
// PutArbitrary into user
func (u *User) PutArbitrary(arb map[string]string) { u.Arbitrary = arb }
2018-07-18 00:25:25 +02:00
// PutOTPs into user
func (u *User) PutOTPs(otps string) { u.OTPs = otps }
2018-08-27 01:49:16 +02:00
// PutTOTPSecretKey into user
func (u *User) PutTOTPSecretKey(key string) { u.TOTPSecretKey = key }
// PutSMSPhoneNumber into user
func (u *User) PutSMSPhoneNumber(number string) { u.SMSPhoneNumber = number }
// PutRecoveryCodes into user
func (u *User) PutRecoveryCodes(codes string) { u.RecoveryCodes = codes }
// ServerStorer should be valid for any module storer defined in authboss.
type ServerStorer struct {
Users map[string]*User
RMTokens map[string][]string
}
// NewServerStorer constructor
func NewServerStorer() *ServerStorer {
return &ServerStorer{
Users: make(map[string]*User),
RMTokens: make(map[string][]string),
}
}
// New constructs a blank user to later be created
func (s *ServerStorer) New(context.Context) authboss.User {
return &User{}
}
// Create a user
func (s *ServerStorer) Create(ctx context.Context, user authboss.User) error {
u := user.(*User)
if _, ok := s.Users[u.Email]; ok {
return authboss.ErrUserFound
}
s.Users[u.Email] = u
return nil
}
2018-02-20 18:58:59 +02:00
// Load a user
func (s *ServerStorer) Load(ctx context.Context, key string) (authboss.User, error) {
user, ok := s.Users[key]
if ok {
return user, nil
}
return nil, authboss.ErrUserNotFound
}
// Save a user
func (s *ServerStorer) Save(ctx context.Context, user authboss.User) error {
u := user.(*User)
if _, ok := s.Users[u.Email]; !ok {
return authboss.ErrUserNotFound
}
2018-02-20 18:58:59 +02:00
s.Users[u.Email] = u
return nil
}
// NewFromOAuth2 finds a user with the given details, or returns a new one
func (s *ServerStorer) NewFromOAuth2(ctx context.Context, provider string, details map[string]string) (authboss.OAuth2User, error) {
uid := details["uid"]
email := details["email"]
name := details["name"]
pid := authboss.MakeOAuth2PID(provider, uid)
u, ok := s.Users[pid]
if ok {
u.Username = name
u.Email = email
return u, nil
}
return &User{
OAuth2UID: uid,
OAuth2Provider: provider,
Email: email,
Username: name,
}, nil
}
// SaveOAuth2 creates a user if not found, or updates one that exists.
func (s *ServerStorer) SaveOAuth2(ctx context.Context, user authboss.OAuth2User) error {
u := user.(*User)
pid := authboss.MakeOAuth2PID(u.OAuth2Provider, u.OAuth2UID)
2018-09-16 00:39:26 +02:00
// Since we don't have to differentiate between
// insert/update in a map, we just overwrite
s.Users[pid] = u
return nil
}
// LoadByConfirmSelector finds a user by his confirm selector
func (s *ServerStorer) LoadByConfirmSelector(ctx context.Context, selector string) (authboss.ConfirmableUser, error) {
for _, v := range s.Users {
if v.ConfirmSelector == selector {
return v, nil
}
}
return nil, authboss.ErrUserNotFound
}
// LoadByRecoverSelector finds a user by his recover token
func (s *ServerStorer) LoadByRecoverSelector(ctx context.Context, selector string) (authboss.RecoverableUser, error) {
for _, v := range s.Users {
if v.RecoverSelector == selector {
return v, nil
}
}
return nil, authboss.ErrUserNotFound
}
// AddRememberToken for remember me
func (s *ServerStorer) AddRememberToken(ctx context.Context, key, token string) error {
arr := s.RMTokens[key]
s.RMTokens[key] = append(arr, token)
return nil
}
// DelRememberTokens for a user
func (s *ServerStorer) DelRememberTokens(ctx context.Context, key string) error {
delete(s.RMTokens, key)
return nil
}
// UseRememberToken if it exists, deleting it in the process
func (s *ServerStorer) UseRememberToken(ctx context.Context, givenKey, token string) (err error) {
2018-03-09 23:11:08 +02:00
arr, ok := s.RMTokens[givenKey]
if !ok {
return authboss.ErrTokenNotFound
}
for i, tok := range arr {
if tok == token {
if len(arr) == 1 {
delete(s.RMTokens, givenKey)
return nil
}
2018-03-09 23:11:08 +02:00
arr[i] = arr[len(arr)-1]
s.RMTokens[givenKey] = arr[:len(arr)-2]
return nil
}
}
return authboss.ErrTokenNotFound
}
2018-09-16 00:39:26 +02:00
// FailStorer is used for testing module initialize functions that
// recover more than the base storer
type FailStorer struct {
User
}
2015-03-16 23:42:45 +02:00
// Create fails
func (FailStorer) Create(context.Context) error {
return errors.New("fail storer: create")
}
2015-03-16 23:42:45 +02:00
// Save fails
func (FailStorer) Save(context.Context) error {
return errors.New("fail storer: put")
}
2015-03-16 23:42:45 +02:00
// Load fails
func (FailStorer) Load(context.Context) error {
return errors.New("fail storer: get")
}
// ClientState is used for testing the client stores on context
type ClientState struct {
Values map[string]string
GetShouldFail bool
}
// NewClientState constructs a ClientStorer
func NewClientState(data ...string) *ClientState {
if len(data) != 0 && len(data)%2 != 0 {
2015-02-25 01:01:56 +02:00
panic("It should be a key value list of arguments.")
}
2015-02-25 01:01:56 +02:00
values := make(map[string]string)
for i := 0; i < len(data)-1; i += 2 {
values[data[i]] = data[i+1]
}
return &ClientState{Values: values}
}
2015-03-16 23:42:45 +02:00
// Get a key's value
func (m *ClientState) Get(key string) (string, bool) {
if m.GetShouldFail {
return "", false
}
v, ok := m.Values[key]
return v, ok
}
2015-03-16 23:42:45 +02:00
// Put a value
func (m *ClientState) Put(key, val string) { m.Values[key] = val }
// Del a key/value pair
func (m *ClientState) Del(key string) { delete(m.Values, key) }
// ClientStateRW stores things that would originally
// go in a session, or a map, in memory!
type ClientStateRW struct {
ClientValues map[string]string
}
// NewClientRW takes the data from a client state
// and returns.
func NewClientRW() *ClientStateRW {
return &ClientStateRW{
ClientValues: make(map[string]string),
}
}
2015-03-16 23:42:45 +02:00
// ReadState from memory
2018-03-08 02:41:58 +02:00
func (c *ClientStateRW) ReadState(*http.Request) (authboss.ClientState, error) {
return &ClientState{Values: c.ClientValues}, nil
}
2015-03-16 23:42:45 +02:00
// WriteState to memory
func (c *ClientStateRW) WriteState(w http.ResponseWriter, cstate authboss.ClientState, cse []authboss.ClientStateEvent) error {
for _, e := range cse {
switch e.Kind {
case authboss.ClientStateEventPut:
c.ClientValues[e.Key] = e.Value
case authboss.ClientStateEventDel:
delete(c.ClientValues, e.Key)
}
}
return nil
}
// Request returns a new request with optional key-value body (form-post)
func Request(method string, postKeyValues ...string) *http.Request {
2015-08-02 20:51:35 +02:00
var body io.Reader
location := "http://localhost"
2015-08-02 20:51:35 +02:00
if len(postKeyValues) > 0 {
urlValues := make(url.Values)
for i := 0; i < len(postKeyValues); i += 2 {
urlValues.Set(postKeyValues[i], postKeyValues[i+1])
}
if method == "POST" || method == "PUT" {
body = strings.NewReader(urlValues.Encode())
} else {
location += "?" + urlValues.Encode()
}
}
req, err := http.NewRequest(method, location, body)
if err != nil {
panic(err.Error())
}
2015-08-02 20:51:35 +02:00
if len(postKeyValues) > 0 {
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
}
2015-08-02 20:51:35 +02:00
return req
}
2015-02-09 09:08:33 +02:00
// Mailer helps simplify mailer testing by storing the last sent email
type Mailer struct {
Last authboss.Email
SendErr string
}
// NewMailer constructs a mailer
func NewMailer() *Mailer {
return &Mailer{}
2015-02-09 09:08:33 +02:00
}
2015-03-16 23:42:45 +02:00
// Send an e-mail
func (m *Mailer) Send(ctx context.Context, email authboss.Email) error {
if len(m.SendErr) > 0 {
return errors.New(m.SendErr)
}
2015-02-09 09:08:33 +02:00
m.Last = email
return nil
}
2015-02-25 01:01:56 +02:00
// AfterCallback is a callback that knows if it was called
type AfterCallback struct {
2015-02-25 01:01:56 +02:00
HasBeenCalled bool
Fn authboss.EventHandler
2015-02-25 01:01:56 +02:00
}
// NewAfterCallback constructs a new aftercallback.
func NewAfterCallback() *AfterCallback {
m := AfterCallback{}
2015-02-25 01:01:56 +02:00
m.Fn = func(http.ResponseWriter, *http.Request, bool) (bool, error) {
2015-02-25 01:01:56 +02:00
m.HasBeenCalled = true
return false, nil
2015-02-25 01:01:56 +02:00
}
return &m
}
2018-02-20 18:58:59 +02:00
// Renderer mock
type Renderer struct {
Pages []string
// Render call variables
Context context.Context
Page string
Data authboss.HTMLData
}
// HasLoadedViews ensures the views were loaded
func (r *Renderer) HasLoadedViews(pages ...string) error {
if len(r.Pages) != len(pages) {
return errors.Errorf("want: %d loaded views, got: %d", len(pages), len(r.Pages))
}
for i, want := range pages {
got := r.Pages[i]
if want != got {
return errors.Errorf("want: %s [%d], got: %s", want, i, got)
}
}
return nil
}
// Load nothing but store the pages that were loaded
func (r *Renderer) Load(pages ...string) error {
r.Pages = append(r.Pages, pages...)
return nil
}
// Render nothing, but record the arguments into the renderer
func (r *Renderer) Render(ctx context.Context, page string, data authboss.HTMLData) ([]byte, string, error) {
r.Context = ctx
r.Page = page
r.Data = data
return nil, "text/html", nil
}
// Responder records how a request was responded to
type Responder struct {
Status int
Page string
Data authboss.HTMLData
}
// Respond stores the arguments in the struct
func (r *Responder) Respond(w http.ResponseWriter, req *http.Request, code int, page string, data authboss.HTMLData) error {
r.Status = code
r.Page = page
r.Data = data
return nil
}
// Redirector stores the redirect options passed to it and writes the Code
// to the ResponseWriter.
type Redirector struct {
Options authboss.RedirectOptions
}
// Redirect a request
func (r *Redirector) Redirect(w http.ResponseWriter, req *http.Request, ro authboss.RedirectOptions) error {
r.Options = ro
if len(ro.RedirectPath) == 0 {
panic("no redirect path on redirect call")
}
2018-02-20 18:58:59 +02:00
http.Redirect(w, req, ro.RedirectPath, ro.Code)
return nil
}
// Emailer that holds the options it was given
type Emailer struct {
Email authboss.Email
}
// Send an e-mail
func (e *Emailer) Send(ctx context.Context, email authboss.Email) error {
e.Email = email
return nil
}
2018-02-20 18:58:59 +02:00
// BodyReader reads the body of a request and returns some values
type BodyReader struct {
Return authboss.Validator
2018-02-20 18:58:59 +02:00
}
// Read the return values
func (b BodyReader) Read(page string, r *http.Request) (authboss.Validator, error) {
return b.Return, nil
}
// Values is returned from the BodyReader
type Values struct {
2018-08-31 20:38:53 +02:00
PID string
Password string
Token string
Code string
Recovery string
PhoneNumber string
Remember bool
Errors []error
2018-02-20 18:58:59 +02:00
}
// GetPID from values
func (v Values) GetPID() string {
return v.PID
}
// GetPassword from values
func (v Values) GetPassword() string {
return v.Password
}
// GetToken from values
func (v Values) GetToken() string {
return v.Token
}
2018-08-31 10:15:05 +02:00
// GetCode from values
func (v Values) GetCode() string {
return v.Code
}
2018-08-31 20:38:53 +02:00
// GetPhoneNumber from values
func (v Values) GetPhoneNumber() string {
return v.PhoneNumber
}
2018-08-31 10:15:05 +02:00
// GetRecoveryCode from values
func (v Values) GetRecoveryCode() string {
return v.Recovery
}
// GetShouldRemember gets the value that tells
// the remember module if it should remember the user
func (v Values) GetShouldRemember() bool {
return v.Remember
}
2018-02-20 18:58:59 +02:00
// Validate the values
func (v Values) Validate() []error {
return v.Errors
}
// ArbValues is arbitrary value storage
type ArbValues struct {
Values map[string]string
Errors []error
}
// GetPID gets the pid
func (a ArbValues) GetPID() string {
return a.Values["email"]
}
// GetPassword gets the password
func (a ArbValues) GetPassword() string {
return a.Values["password"]
}
// GetValues returns all values
func (a ArbValues) GetValues() map[string]string {
return a.Values
}
// Validate nothing
func (a ArbValues) Validate() []error {
return a.Errors
2018-02-20 18:58:59 +02:00
}
// Logger logs to the void
type Logger struct {
}
// Info logging
func (l Logger) Info(string) {}
// Error logging
func (l Logger) Error(string) {}
// Router records the routes that were registered
type Router struct {
Gets []string
Posts []string
Deletes []string
}
// ServeHTTP does nothing
func (Router) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
// Get records the path in the router
func (r *Router) Get(path string, _ http.Handler) {
r.Gets = append(r.Gets, path)
}
// Post records the path in the router
func (r *Router) Post(path string, _ http.Handler) {
r.Posts = append(r.Posts, path)
}
// Delete records the path in the router
func (r *Router) Delete(path string, _ http.Handler) {
r.Deletes = append(r.Deletes, path)
}
// HasGets ensures all gets routes are present
func (r *Router) HasGets(gets ...string) error {
return r.hasRoutes(gets, r.Gets)
}
// HasPosts ensures all gets routes are present
func (r *Router) HasPosts(posts ...string) error {
return r.hasRoutes(posts, r.Posts)
}
// HasDeletes ensures all gets routes are present
func (r *Router) HasDeletes(deletes ...string) error {
return r.hasRoutes(deletes, r.Deletes)
}
func (r *Router) hasRoutes(want []string, got []string) error {
if len(got) != len(want) {
return errors.Errorf("want: %d get routes, got: %d", len(want), len(got))
}
for i, w := range want {
g := got[i]
if w != g {
return errors.Errorf("wanted route: %s [%d], but got: %s", w, i, g)
}
}
return nil
}
// ErrorHandler just holds the last error
type ErrorHandler struct {
Error error
}
// Wrap an http method
func (e *ErrorHandler) Wrap(handler func(w http.ResponseWriter, r *http.Request) error) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if err := handler(w, r); err != nil {
e.Error = err
}
})
}