mirror of
https://github.com/volatiletech/authboss.git
synced 2024-11-24 08:42:17 +02:00
Fully re-implement recover
- Add back the feature to log in after password recovery - Add new storer functionality to mocks - Add RecoveringServerStorer - Add RecoverableUser - Add RecoverStartValuer, RecoverMiddleValuer, RecoverEndValuer - Change storers to differentiate between tokens (recover vs confirm) - Change BCryptCost to be a generic module configuration (doesn't belong to register)
This commit is contained in:
parent
9ef2a06dcb
commit
0eff53792f
15
config.go
15
config.go
@ -37,6 +37,9 @@ type Config struct {
|
||||
}
|
||||
|
||||
Modules struct {
|
||||
// BCryptCost is the cost of the bcrypt password hashing function.
|
||||
BCryptCost int
|
||||
|
||||
// AuthLogoutMethod is the method the logout route should use (default should be DELETE)
|
||||
AuthLogoutMethod string
|
||||
|
||||
@ -51,8 +54,6 @@ type Config struct {
|
||||
// LockDuration is how long an account is locked for.
|
||||
LockDuration time.Duration
|
||||
|
||||
// RegBCryptCost is the cost of the bcrypt password hashing function.
|
||||
RegisterBCryptCost int
|
||||
// RegisterPreserveFields are fields used with registration that are to be rendered when
|
||||
// post fails in a normal way (for example validation errors), they will be passed
|
||||
// back in the data of the response under the key DataPreserve which will be a map[string]string.
|
||||
@ -67,6 +68,10 @@ type Config struct {
|
||||
// RecoverTokenDuration controls how long a token sent via email for password
|
||||
// recovery is valid for.
|
||||
RecoverTokenDuration time.Duration
|
||||
// RecoverLoginAfterRecovery says for the recovery module after a user has successfully
|
||||
// recovered the password, are they simply logged in, or are they redirected to
|
||||
// the login page with an "updated password" message.
|
||||
RecoverLoginAfterRecovery bool
|
||||
|
||||
// OAuth2Providers lists all providers that can be used. See
|
||||
// OAuthProvider documentation for more details.
|
||||
@ -134,18 +139,20 @@ type Config struct {
|
||||
|
||||
// Defaults sets the configuration's default values.
|
||||
func (c *Config) Defaults() {
|
||||
c.Paths.Mount = "/"
|
||||
c.Paths.Mount = "/auth"
|
||||
c.Paths.RootURL = "http://localhost:8080"
|
||||
c.Paths.AuthLoginOK = "/"
|
||||
c.Paths.AuthLogoutOK = "/"
|
||||
c.Paths.ConfirmOK = "/"
|
||||
c.Paths.ConfirmNotOK = "/"
|
||||
c.Paths.RecoverOK = "/"
|
||||
c.Paths.RegisterOK = "/"
|
||||
|
||||
c.Modules.BCryptCost = bcrypt.DefaultCost
|
||||
c.Modules.AuthLogoutMethod = "DELETE"
|
||||
c.Modules.ExpireAfter = 60 * time.Minute
|
||||
c.Modules.LockAfter = 3
|
||||
c.Modules.LockWindow = 5 * time.Minute
|
||||
c.Modules.LockDuration = 5 * time.Hour
|
||||
c.Modules.RecoverTokenDuration = time.Duration(24) * time.Hour
|
||||
c.Modules.RegisterBCryptCost = bcrypt.DefaultCost
|
||||
}
|
||||
|
@ -194,7 +194,7 @@ func (c *Confirm) Get(w http.ResponseWriter, r *http.Request) error {
|
||||
token := base64.StdEncoding.EncodeToString(sum[:])
|
||||
|
||||
storer := authboss.EnsureCanConfirm(c.Authboss.Config.Storage.Server)
|
||||
user, err := storer.LoadByToken(r.Context(), token)
|
||||
user, err := storer.LoadByConfirmToken(r.Context(), token)
|
||||
if err == authboss.ErrUserNotFound {
|
||||
logger.Infof("confirm token was not found in database: %s", token)
|
||||
ro := authboss.RedirectOptions{
|
||||
|
@ -139,13 +139,13 @@ func TestPreventDisallow(t *testing.T) {
|
||||
func TestStartConfirmationWeb(t *testing.T) {
|
||||
// no t.Parallel(), global var mangling
|
||||
|
||||
oldConfirm := goConfirmEmail
|
||||
oldConfirmEmail := goConfirmEmail
|
||||
goConfirmEmail = func(c *Confirm, ctx context.Context, to, token string) {
|
||||
c.SendConfirmEmail(ctx, to, token)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
goConfirmEmail = oldConfirm
|
||||
goConfirmEmail = oldConfirmEmail
|
||||
}()
|
||||
|
||||
harness := testSetup()
|
||||
|
@ -32,28 +32,28 @@ type User struct {
|
||||
Arbitrary map[string]string
|
||||
}
|
||||
|
||||
func (m User) GetPID() string { return m.Email }
|
||||
func (m User) GetEmail() string { return m.Email }
|
||||
func (m User) GetUsername() string { return m.Username }
|
||||
func (m User) GetPassword() string { return m.Password }
|
||||
func (m User) GetRecoverToken() string { return m.RecoverToken }
|
||||
func (m User) GetRecoverTokenExpiry() time.Time { return m.RecoverTokenExpiry }
|
||||
func (m User) GetConfirmToken() string { return m.ConfirmToken }
|
||||
func (m User) GetConfirmed() bool { return m.Confirmed }
|
||||
func (m User) GetAttemptCount() int { return m.AttemptCount }
|
||||
func (m User) GetLastAttempt() time.Time { return m.LastAttempt }
|
||||
func (m User) GetLocked() time.Time { return m.Locked }
|
||||
func (m User) GetOAuthToken() string { return m.OAuthToken }
|
||||
func (m User) GetOAuthRefresh() string { return m.OAuthRefresh }
|
||||
func (m User) GetOAuthExpiry() time.Time { return m.OAuthExpiry }
|
||||
func (m User) GetArbitrary() map[string]string { return m.Arbitrary }
|
||||
func (m User) GetPID() string { return m.Email }
|
||||
func (m User) GetEmail() string { return m.Email }
|
||||
func (m User) GetUsername() string { return m.Username }
|
||||
func (m User) GetPassword() string { return m.Password }
|
||||
func (m User) GetRecoverToken() string { return m.RecoverToken }
|
||||
func (m User) GetRecoverExpiry() time.Time { return m.RecoverTokenExpiry }
|
||||
func (m User) GetConfirmToken() string { return m.ConfirmToken }
|
||||
func (m User) GetConfirmed() bool { return m.Confirmed }
|
||||
func (m User) GetAttemptCount() int { return m.AttemptCount }
|
||||
func (m User) GetLastAttempt() time.Time { return m.LastAttempt }
|
||||
func (m User) GetLocked() time.Time { return m.Locked }
|
||||
func (m User) GetOAuthToken() string { return m.OAuthToken }
|
||||
func (m User) GetOAuthRefresh() string { return m.OAuthRefresh }
|
||||
func (m User) GetOAuthExpiry() time.Time { return m.OAuthExpiry }
|
||||
func (m User) GetArbitrary() map[string]string { return m.Arbitrary }
|
||||
|
||||
func (m *User) PutPID(email string) { m.Email = email }
|
||||
func (m *User) PutUsername(username string) { m.Username = username }
|
||||
func (m *User) PutEmail(email string) { m.Email = email }
|
||||
func (m *User) PutPassword(password string) { m.Password = password }
|
||||
func (m *User) PutRecoverToken(recoverToken string) { m.RecoverToken = recoverToken }
|
||||
func (m *User) PutRecoverTokenExpiry(recoverTokenExpiry time.Time) {
|
||||
func (m *User) PutRecoverExpiry(recoverTokenExpiry time.Time) {
|
||||
m.RecoverTokenExpiry = recoverTokenExpiry
|
||||
}
|
||||
func (m *User) PutConfirmToken(confirmToken string) { m.ConfirmToken = confirmToken }
|
||||
@ -115,8 +115,8 @@ func (s *ServerStorer) Save(ctx context.Context, user authboss.User) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// LoadByToken finds a user by his token
|
||||
func (s *ServerStorer) LoadByToken(ctx context.Context, token string) (authboss.ConfirmableUser, error) {
|
||||
// LoadByConfirmToken finds a user by his confirm token
|
||||
func (s *ServerStorer) LoadByConfirmToken(ctx context.Context, token string) (authboss.ConfirmableUser, error) {
|
||||
for _, v := range s.Users {
|
||||
if v.ConfirmToken == token {
|
||||
return v, nil
|
||||
@ -126,6 +126,17 @@ func (s *ServerStorer) LoadByToken(ctx context.Context, token string) (authboss.
|
||||
return nil, authboss.ErrUserNotFound
|
||||
}
|
||||
|
||||
// LoadByRecoverToken finds a user by his recover token
|
||||
func (s *ServerStorer) LoadByRecoverToken(ctx context.Context, token string) (authboss.RecoverableUser, error) {
|
||||
for _, v := range s.Users {
|
||||
if v.RecoverToken == token {
|
||||
return v, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, authboss.ErrUserNotFound
|
||||
}
|
||||
|
||||
/*
|
||||
// TODO(aarondl): What is this?
|
||||
// AddToken for remember me
|
||||
|
@ -2,59 +2,41 @@
|
||||
package recover
|
||||
|
||||
import (
|
||||
"crypto/md5"
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/sha512"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"path"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
|
||||
"github.com/volatiletech/authboss"
|
||||
"github.com/volatiletech/authboss/internal/response"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
// Storage constants
|
||||
// Constants for templates etc.
|
||||
const (
|
||||
StoreRecoverToken = "recover_token"
|
||||
StoreRecoverTokenExpiry = "recover_token_expiry"
|
||||
)
|
||||
DataRecoverToken = "recover_token"
|
||||
DataRecoverURL = "recover_url"
|
||||
|
||||
const (
|
||||
formValueToken = "token"
|
||||
)
|
||||
FormValueToken = "token"
|
||||
|
||||
const (
|
||||
methodGET = "GET"
|
||||
methodPOST = "POST"
|
||||
EmailRecoverHTML = "recover_html"
|
||||
EmailRecoverTxt = "recover_txt"
|
||||
|
||||
tplLogin = "login.html.tpl"
|
||||
tplRecover = "recover.html.tpl"
|
||||
tplRecoverComplete = "recover_complete.html.tpl"
|
||||
tplInitHTMLEmail = "recover_email.html.tpl"
|
||||
tplInitTextEmail = "recover_email.txt.tpl"
|
||||
PageRecoverStart = "recover_start"
|
||||
PageRecoverMiddle = "recover_middle"
|
||||
PageRecoverEnd = "recover_end"
|
||||
|
||||
recoverInitiateSuccessFlash = "An email has been sent with further instructions on how to reset your password"
|
||||
recoverInitiateSuccessFlash = "An email has been sent to you with further instructions on how to reset your password."
|
||||
recoverTokenExpiredFlash = "Account recovery request has expired. Please try again."
|
||||
recoverFailedErrorFlash = "Account recovery has failed. Please contact tech support."
|
||||
)
|
||||
|
||||
var errRecoveryTokenExpired = errors.New("recovery token expired")
|
||||
|
||||
// RecoverStorer must be implemented in order to satisfy the recover module's
|
||||
// storage requirements.
|
||||
type RecoverStorer interface {
|
||||
authboss.Storer
|
||||
// RecoverUser looks a user up by a recover token. See recover module for
|
||||
// attribute names. If the key is not found in the data store,
|
||||
// simply return nil, ErrUserNotFound.
|
||||
RecoverUser(recoverToken string) (interface{}, error)
|
||||
}
|
||||
|
||||
func init() {
|
||||
m := &Recover{}
|
||||
authboss.RegisterModule("recover", m)
|
||||
@ -63,255 +45,208 @@ func init() {
|
||||
// Recover module
|
||||
type Recover struct {
|
||||
*authboss.Authboss
|
||||
templates response.Templates
|
||||
emailHTMLTemplates response.Templates
|
||||
emailTextTemplates response.Templates
|
||||
}
|
||||
|
||||
// Initialize module
|
||||
func (r *Recover) Initialize(ab *authboss.Authboss) (err error) {
|
||||
// Init module
|
||||
func (r *Recover) Init(ab *authboss.Authboss) (err error) {
|
||||
r.Authboss = ab
|
||||
|
||||
if r.Storer != nil {
|
||||
if _, ok := r.Storer.(RecoverStorer); !ok {
|
||||
return errors.New("recoverStorer required for recover functionality")
|
||||
}
|
||||
} else if r.StoreMaker == nil {
|
||||
return errors.New("need a RecoverStorer")
|
||||
}
|
||||
|
||||
if len(r.XSRFName) == 0 {
|
||||
return errors.New("xsrfName must be set")
|
||||
}
|
||||
|
||||
if r.XSRFMaker == nil {
|
||||
return errors.New("xsrfMaker must be defined")
|
||||
}
|
||||
|
||||
r.templates, err = response.LoadTemplates(r.Authboss, r.Layout, r.ViewsPath, tplRecover, tplRecoverComplete)
|
||||
if err != nil {
|
||||
if err := r.Authboss.Config.Core.ViewRenderer.Load(PageRecoverStart, PageRecoverEnd); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
r.emailHTMLTemplates, err = response.LoadTemplates(r.Authboss, r.LayoutHTMLEmail, r.ViewsPath, tplInitHTMLEmail)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
r.emailTextTemplates, err = response.LoadTemplates(r.Authboss, r.LayoutTextEmail, r.ViewsPath, tplInitTextEmail)
|
||||
if err != nil {
|
||||
if err := r.Authboss.Config.Core.MailRenderer.Load(EmailRecoverHTML, EmailRecoverTxt); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
r.Authboss.Config.Core.Router.Get("/recover", r.Core.ErrorHandler.Wrap(r.StartGet))
|
||||
r.Authboss.Config.Core.Router.Post("/recover", r.Core.ErrorHandler.Wrap(r.StartPost))
|
||||
r.Authboss.Config.Core.Router.Get("/recover/end", r.Core.ErrorHandler.Wrap(r.EndGet))
|
||||
r.Authboss.Config.Core.Router.Post("/recover/end", r.Core.ErrorHandler.Wrap(r.EndPost))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Routes for module
|
||||
func (r *Recover) Routes() authboss.RouteTable {
|
||||
return authboss.RouteTable{
|
||||
"/recover": r.startHandlerFunc,
|
||||
"/recover/complete": r.completeHandlerFunc,
|
||||
}
|
||||
// StartGet starts the recover procedure by rendering a form for the user.
|
||||
func (r *Recover) StartGet(w http.ResponseWriter, req *http.Request) error {
|
||||
return r.Authboss.Config.Core.Responder.Respond(w, req, http.StatusOK, PageRecoverStart, nil)
|
||||
}
|
||||
|
||||
// Storage requirements
|
||||
func (r *Recover) Storage() authboss.StorageOptions {
|
||||
return authboss.StorageOptions{
|
||||
r.PrimaryID: authboss.String,
|
||||
authboss.StoreEmail: authboss.String,
|
||||
authboss.StorePassword: authboss.String,
|
||||
StoreRecoverToken: authboss.String,
|
||||
StoreRecoverTokenExpiry: authboss.String,
|
||||
}
|
||||
}
|
||||
// StartPost starts the recover procedure using values provided from the user
|
||||
// usually from the StartGet's form.
|
||||
func (r *Recover) StartPost(w http.ResponseWriter, req *http.Request) error {
|
||||
logger := r.RequestLogger(req)
|
||||
|
||||
func (rec *Recover) startHandlerFunc(ctx *authboss.Context, w http.ResponseWriter, r *http.Request) error {
|
||||
switch r.Method {
|
||||
case methodGET:
|
||||
data := authboss.NewHTMLData(
|
||||
"primaryID", rec.PrimaryID,
|
||||
"primaryIDValue", "",
|
||||
"confirmPrimaryIDValue", "",
|
||||
)
|
||||
|
||||
return rec.templates.Render(ctx, w, r, tplRecover, data)
|
||||
case methodPOST:
|
||||
primaryID := r.FormValue(rec.PrimaryID)
|
||||
confirmPrimaryID := r.FormValue(fmt.Sprintf("confirm_%s", rec.PrimaryID))
|
||||
|
||||
errData := authboss.NewHTMLData(
|
||||
"primaryID", rec.PrimaryID,
|
||||
"primaryIDValue", primaryID,
|
||||
"confirmPrimaryIDValue", confirmPrimaryID,
|
||||
)
|
||||
|
||||
policies := authboss.FilterValidators(rec.Policies, rec.PrimaryID)
|
||||
if validationErrs := authboss.Validate(r, policies, rec.PrimaryID, authboss.ConfirmPrefix+rec.PrimaryID).Map(); len(validationErrs) > 0 {
|
||||
errData.MergeKV("errs", validationErrs)
|
||||
return rec.templates.Render(ctx, w, r, tplRecover, errData)
|
||||
}
|
||||
|
||||
// redirect to login when user not found to prevent username sniffing
|
||||
if err := ctx.LoadUser(primaryID); err == authboss.ErrUserNotFound {
|
||||
return authboss.ErrAndRedirect{Err: err, Location: rec.RecoverOKPath, FlashSuccess: recoverInitiateSuccessFlash}
|
||||
} else if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
email, err := ctx.User.StringErr(authboss.StoreEmail)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
encodedToken, encodedChecksum, err := newToken()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ctx.User[StoreRecoverToken] = encodedChecksum
|
||||
ctx.User[StoreRecoverTokenExpiry] = time.Now().Add(rec.RecoverTokenDuration)
|
||||
|
||||
if err := ctx.SaveUser(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
goRecoverEmail(rec, ctx, email, encodedToken)
|
||||
|
||||
ctx.SessionStorer.Put(authboss.FlashSuccessKey, recoverInitiateSuccessFlash)
|
||||
response.Redirect(ctx, w, r, rec.RecoverOKPath, "", "", true)
|
||||
default:
|
||||
w.WriteHeader(http.StatusMethodNotAllowed)
|
||||
validatable, err := r.Authboss.Core.BodyReader.Read(PageRecoverStart, req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
recoverVals := authboss.MustHaveRecoverStartValues(validatable)
|
||||
|
||||
func newToken() (encodedToken, encodedChecksum string, err error) {
|
||||
token := make([]byte, 32)
|
||||
if _, err = rand.Read(token); err != nil {
|
||||
return "", "", err
|
||||
user, err := r.Authboss.Storage.Server.Load(req.Context(), recoverVals.GetPID())
|
||||
if err == authboss.ErrUserNotFound {
|
||||
logger.Infof("user %s was attempted to be recovered, user does not exist, faking successful response", recoverVals.GetPID())
|
||||
ro := authboss.RedirectOptions{
|
||||
Code: http.StatusTemporaryRedirect,
|
||||
RedirectPath: r.Authboss.Config.Paths.RecoverOK,
|
||||
Success: recoverInitiateSuccessFlash,
|
||||
}
|
||||
return r.Authboss.Core.Redirector.Redirect(w, req, ro)
|
||||
}
|
||||
sum := md5.Sum(token)
|
||||
|
||||
return base64.URLEncoding.EncodeToString(token), base64.StdEncoding.EncodeToString(sum[:]), nil
|
||||
}
|
||||
ru := authboss.MustBeRecoverable(user)
|
||||
|
||||
var goRecoverEmail = func(r *Recover, ctx *authboss.Context, to, encodedToken string) {
|
||||
if ctx.MailMaker != nil {
|
||||
r.sendRecoverEmail(ctx, to, encodedToken)
|
||||
} else {
|
||||
go r.sendRecoverEmail(ctx, to, encodedToken)
|
||||
hash, token, err := GenerateToken()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ru.PutRecoverToken(hash)
|
||||
ru.PutRecoverExpiry(time.Now().UTC().Add(r.Config.Modules.RecoverTokenDuration))
|
||||
|
||||
if err := r.Authboss.Storage.Server.Save(req.Context(), ru); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
goRecoverEmail(r, req.Context(), ru.GetEmail(), token)
|
||||
|
||||
logger.Infof("user %s password recovery initiated", ru.GetPID())
|
||||
ro := authboss.RedirectOptions{
|
||||
Code: http.StatusTemporaryRedirect,
|
||||
RedirectPath: r.Authboss.Config.Paths.RecoverOK,
|
||||
Success: recoverInitiateSuccessFlash,
|
||||
}
|
||||
return r.Authboss.Core.Redirector.Redirect(w, req, ro)
|
||||
}
|
||||
|
||||
func (r *Recover) sendRecoverEmail(ctx *authboss.Context, to, encodedToken string) {
|
||||
p := path.Join(r.MountPath, "recover/complete")
|
||||
query := url.Values{formValueToken: []string{encodedToken}}
|
||||
url := fmt.Sprintf("%s%s?%s", r.RootURL, p, query.Encode())
|
||||
var goRecoverEmail = func(r *Recover, ctx context.Context, to, encodedToken string) {
|
||||
r.SendRecoverEmail(ctx, to, encodedToken)
|
||||
}
|
||||
|
||||
// SendRecoverEmail to a specific e-mail address passing along the encodedToken
|
||||
// in an escaped URL to the templates.
|
||||
func (r *Recover) SendRecoverEmail(ctx context.Context, to, encodedToken string) {
|
||||
logger := r.Authboss.Logger(ctx)
|
||||
p := path.Join(r.Authboss.Config.Paths.Mount, "recover/end")
|
||||
query := url.Values{FormValueToken: []string{encodedToken}}
|
||||
url := fmt.Sprintf("%s%s?%s", r.Authboss.Config.Paths.RootURL, p, query.Encode())
|
||||
|
||||
email := authboss.Email{
|
||||
To: []string{to},
|
||||
From: r.EmailFrom,
|
||||
Subject: r.EmailSubjectPrefix + "Password Reset",
|
||||
From: r.Authboss.Config.Mail.From,
|
||||
Subject: r.Authboss.Config.Mail.SubjectPrefix + "Password Reset",
|
||||
}
|
||||
|
||||
if err := response.Email(ctx.Mailer, email, r.emailHTMLTemplates, tplInitHTMLEmail, r.emailTextTemplates, tplInitTextEmail, url); err != nil {
|
||||
fmt.Fprintln(ctx.LogWriter, "recover: failed to send recover email:", err)
|
||||
ro := authboss.EmailResponseOptions{
|
||||
HTMLTemplate: EmailRecoverHTML,
|
||||
TextTemplate: EmailRecoverTxt,
|
||||
Data: authboss.HTMLData{
|
||||
DataRecoverURL: url,
|
||||
},
|
||||
}
|
||||
|
||||
logger.Infof("sending recover e-mail to: %s", to)
|
||||
if err := r.Authboss.Email(ctx, email, ro); err != nil {
|
||||
logger.Errorf("failed to recover send e-mail to %s: %+v", to, err)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Recover) completeHandlerFunc(ctx *authboss.Context, w http.ResponseWriter, req *http.Request) (err error) {
|
||||
switch req.Method {
|
||||
case methodGET:
|
||||
_, err = verifyToken(ctx, req)
|
||||
if err == errRecoveryTokenExpired {
|
||||
return authboss.ErrAndRedirect{Err: err, Location: "/recover", FlashError: recoverTokenExpiredFlash}
|
||||
} else if err != nil {
|
||||
return authboss.ErrAndRedirect{Err: err, Location: "/"}
|
||||
}
|
||||
|
||||
token := req.FormValue(formValueToken)
|
||||
data := authboss.NewHTMLData(formValueToken, token)
|
||||
return r.templates.Render(ctx, w, req, tplRecoverComplete, data)
|
||||
case methodPOST:
|
||||
token := req.FormValue(formValueToken)
|
||||
if len(token) == 0 {
|
||||
return authboss.ClientDataErr{Name: formValueToken}
|
||||
}
|
||||
|
||||
password := req.FormValue(authboss.StorePassword)
|
||||
//confirmPassword, _ := ctx.FirstPostFormValue("confirmPassword")
|
||||
|
||||
policies := authboss.FilterValidators(r.Policies, authboss.StorePassword)
|
||||
if validationErrs := authboss.Validate(req, policies, authboss.StorePassword, authboss.ConfirmPrefix+authboss.StorePassword).Map(); len(validationErrs) > 0 {
|
||||
data := authboss.NewHTMLData(
|
||||
formValueToken, token,
|
||||
"errs", validationErrs,
|
||||
)
|
||||
return r.templates.Render(ctx, w, req, tplRecoverComplete, data)
|
||||
}
|
||||
|
||||
if ctx.User, err = verifyToken(ctx, req); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
encryptedPassword, err := bcrypt.GenerateFromPassword([]byte(password), r.BCryptCost)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ctx.User[authboss.StorePassword] = string(encryptedPassword)
|
||||
ctx.User[StoreRecoverToken] = ""
|
||||
var nullTime time.Time
|
||||
ctx.User[StoreRecoverTokenExpiry] = nullTime
|
||||
|
||||
primaryID, err := ctx.User.StringErr(r.PrimaryID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := ctx.SaveUser(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := r.Events.FireAfter(authboss.EventPasswordReset, ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ctx.SessionStorer.Put(authboss.SessionKey, primaryID)
|
||||
response.Redirect(ctx, w, req, r.AuthLoginOKPath, "", "", true)
|
||||
default:
|
||||
w.WriteHeader(http.StatusMethodNotAllowed)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// verifyToken expects a base64.URLEncoded token.
|
||||
func verifyToken(ctx *authboss.Context, r *http.Request) (attrs authboss.Attributes, err error) {
|
||||
token := r.FormValue(formValueToken)
|
||||
if len(token) == 0 {
|
||||
return nil, authboss.ClientDataErr{Name: token}
|
||||
}
|
||||
|
||||
decoded, err := base64.URLEncoding.DecodeString(token)
|
||||
// EndGet shows a password recovery form, and it should have the token that the user
|
||||
// brought in the query parameters in it on submission.
|
||||
func (r *Recover) EndGet(w http.ResponseWriter, req *http.Request) error {
|
||||
validatable, err := r.Authboss.Core.BodyReader.Read(PageRecoverMiddle, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
|
||||
sum := md5.Sum(decoded)
|
||||
storer := ctx.Storer.(RecoverStorer)
|
||||
values := authboss.MustHaveRecoverMiddleValues(validatable)
|
||||
token := values.GetToken()
|
||||
|
||||
userInter, err := storer.RecoverUser(base64.StdEncoding.EncodeToString(sum[:]))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
data := authboss.HTMLData{
|
||||
DataRecoverToken: token,
|
||||
}
|
||||
|
||||
attrs = authboss.Unbind(userInter)
|
||||
|
||||
expiry, ok := attrs.DateTime(StoreRecoverTokenExpiry)
|
||||
if !ok || time.Now().After(expiry) {
|
||||
return nil, errRecoveryTokenExpired
|
||||
}
|
||||
|
||||
return attrs, nil
|
||||
return r.Authboss.Config.Core.Responder.Respond(w, req, http.StatusOK, PageRecoverEnd, data)
|
||||
}
|
||||
|
||||
// EndPost retrieves the token
|
||||
func (r *Recover) EndPost(w http.ResponseWriter, req *http.Request) error {
|
||||
logger := r.RequestLogger(req)
|
||||
|
||||
validatable, err := r.Authboss.Core.BodyReader.Read(PageRecoverEnd, req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
values := authboss.MustHaveRecoverEndValues(validatable)
|
||||
password := values.GetPassword()
|
||||
token := values.GetToken()
|
||||
|
||||
rawToken, err := base64.URLEncoding.DecodeString(token)
|
||||
if err != nil {
|
||||
logger.Infof("invalid recover token submitted, base64 decode failed: %+v", err)
|
||||
return r.invalidToken(PageRecoverEnd, w, req)
|
||||
}
|
||||
|
||||
hash := sha512.Sum512(rawToken)
|
||||
dbToken := base64.StdEncoding.EncodeToString(hash[:])
|
||||
|
||||
storer := authboss.EnsureCanRecover(r.Authboss.Config.Storage.Server)
|
||||
user, err := storer.LoadByRecoverToken(req.Context(), dbToken)
|
||||
if err == authboss.ErrUserNotFound {
|
||||
logger.Info("invalid recover token submitted, user not found")
|
||||
return r.invalidToken(PageRecoverEnd, w, req)
|
||||
} else if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if time.Now().UTC().After(user.GetRecoverExpiry()) {
|
||||
logger.Infof("invalid recover token submitted, already expired: %+v", err)
|
||||
return r.invalidToken(PageRecoverEnd, w, req)
|
||||
}
|
||||
|
||||
pass, err := bcrypt.GenerateFromPassword([]byte(password), r.Authboss.Config.Modules.BCryptCost)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
user.PutPassword(string(pass))
|
||||
user.PutRecoverToken("") // Don't allow another recovery
|
||||
user.PutRecoverExpiry(time.Now().UTC()) // Put current time for those DBs that can't handle 0 time
|
||||
|
||||
if err := storer.Save(req.Context(), user); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
successMsg := "Successfully recovered password"
|
||||
if r.Authboss.Config.Modules.RecoverLoginAfterRecovery {
|
||||
authboss.PutSession(w, authboss.SessionKey, user.GetPID())
|
||||
successMsg += " and logged in"
|
||||
}
|
||||
|
||||
ro := authboss.RedirectOptions{
|
||||
Code: http.StatusTemporaryRedirect,
|
||||
RedirectPath: r.Authboss.Config.Paths.RecoverOK,
|
||||
Success: successMsg,
|
||||
}
|
||||
return r.Authboss.Config.Core.Redirector.Redirect(w, req, ro)
|
||||
}
|
||||
|
||||
func (r *Recover) invalidToken(page string, w http.ResponseWriter, req *http.Request) error {
|
||||
data := authboss.HTMLData{authboss.DataValidation: authboss.ErrorList{
|
||||
errors.New("recovery token is invalid"),
|
||||
}}
|
||||
return r.Authboss.Core.Responder.Respond(w, req, http.StatusOK, PageRecoverEnd, data)
|
||||
}
|
||||
|
||||
// GenerateToken appropriate for user recovery
|
||||
func GenerateToken() (hash, token string, err error) {
|
||||
rawToken := make([]byte, 32)
|
||||
if _, err = rand.Read(rawToken); err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
sum := sha512.Sum512(rawToken)
|
||||
|
||||
return base64.StdEncoding.EncodeToString(sum[:]), base64.URLEncoding.EncodeToString(rawToken), nil
|
||||
}
|
||||
|
@ -2,12 +2,11 @@ package recover
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"html/template"
|
||||
"log"
|
||||
"context"
|
||||
"crypto/sha512"
|
||||
"encoding/base64"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"regexp"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
@ -17,516 +16,399 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
testURLBase64Token = "MTIzNA=="
|
||||
testStdBase64Token = "gdyb21LQTcIANtvYMT7QVQ=="
|
||||
testURLBase64Token = "glL8qvO1YKmLxoyEQwVQPpUMM13f6_e4R-2hUQDzP2g="
|
||||
testStdBase64Token = "cn0uhfu5Ar2A2JsSs/zdj93zhC1lHJDyIhUYdSgyp71XL/nRb3be/I6AeMz4DACwTRqRAJ6loJedJyOcOtU1Jg=="
|
||||
)
|
||||
|
||||
func testSetup() (r *Recover, s *mocks.MockStorer, l *bytes.Buffer) {
|
||||
s = mocks.NewMockStorer()
|
||||
l = &bytes.Buffer{}
|
||||
func TestInit(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ab := authboss.New()
|
||||
ab.Layout = template.Must(template.New("").Parse(`{{template "authboss" .}}`))
|
||||
ab.LayoutHTMLEmail = template.Must(template.New("").Parse(`<strong>{{template "authboss" .}}</strong>`))
|
||||
ab.LayoutTextEmail = template.Must(template.New("").Parse(`{{template "authboss" .}}`))
|
||||
ab.Storage.Server = s
|
||||
ab.XSRFName = "xsrf"
|
||||
ab.XSRFMaker = func(_ http.ResponseWriter, _ *http.Request) string {
|
||||
return "xsrfvalue"
|
||||
}
|
||||
ab.PrimaryID = authboss.StoreUsername
|
||||
ab.LogWriter = l
|
||||
|
||||
ab.Policies = []authboss.Validator{
|
||||
authboss.Rules{
|
||||
FieldName: "username",
|
||||
Required: true,
|
||||
MinLength: 2,
|
||||
MaxLength: 4,
|
||||
AllowWhitespace: false,
|
||||
},
|
||||
authboss.Rules{
|
||||
FieldName: "password",
|
||||
Required: true,
|
||||
MinLength: 4,
|
||||
MaxLength: 8,
|
||||
AllowWhitespace: false,
|
||||
},
|
||||
router := &mocks.Router{}
|
||||
renderer := &mocks.Renderer{}
|
||||
mailRenderer := &mocks.Renderer{}
|
||||
errHandler := &mocks.ErrorHandler{}
|
||||
ab.Config.Core.Router = router
|
||||
ab.Config.Core.ViewRenderer = renderer
|
||||
ab.Config.Core.MailRenderer = mailRenderer
|
||||
ab.Config.Core.ErrorHandler = errHandler
|
||||
|
||||
r := &Recover{}
|
||||
if err := r.Init(ab); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
r = &Recover{}
|
||||
if err := r.Initialize(ab); err != nil {
|
||||
panic(err)
|
||||
if err := renderer.HasLoadedViews(PageRecoverStart, PageRecoverEnd); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if err := mailRenderer.HasLoadedViews(EmailRecoverHTML, EmailRecoverTxt); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
return r, s, l
|
||||
if err := router.HasGets("/recover", "/recover/end"); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if err := router.HasPosts("/recover", "/recover/end"); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
func testRequest(ab *authboss.Authboss, method string, postFormValues ...string) (*authboss.Context, *httptest.ResponseRecorder, *http.Request, authboss.ClientStorerErr) {
|
||||
sessionStorer := mocks.NewMockClientStorer()
|
||||
ctx := ab.NewContext()
|
||||
r := mocks.MockRequest(method, postFormValues...)
|
||||
ctx.SessionStorer = sessionStorer
|
||||
type testHarness struct {
|
||||
recover *Recover
|
||||
ab *authboss.Authboss
|
||||
|
||||
return ctx, httptest.NewRecorder(), r, sessionStorer
|
||||
bodyReader *mocks.BodyReader
|
||||
mailer *mocks.Emailer
|
||||
redirector *mocks.Redirector
|
||||
renderer *mocks.Renderer
|
||||
responder *mocks.Responder
|
||||
session *mocks.ClientStateRW
|
||||
storer *mocks.ServerStorer
|
||||
}
|
||||
|
||||
func TestRecover(t *testing.T) {
|
||||
func testSetup() *testHarness {
|
||||
harness := &testHarness{}
|
||||
|
||||
harness.ab = authboss.New()
|
||||
harness.bodyReader = &mocks.BodyReader{}
|
||||
harness.mailer = &mocks.Emailer{}
|
||||
harness.redirector = &mocks.Redirector{}
|
||||
harness.renderer = &mocks.Renderer{}
|
||||
harness.responder = &mocks.Responder{}
|
||||
harness.session = mocks.NewClientRW()
|
||||
harness.storer = mocks.NewServerStorer()
|
||||
|
||||
harness.ab.Paths.RecoverOK = "/recover/ok"
|
||||
|
||||
harness.ab.Config.Core.BodyReader = harness.bodyReader
|
||||
harness.ab.Config.Core.Logger = mocks.Logger{}
|
||||
harness.ab.Config.Core.Mailer = harness.mailer
|
||||
harness.ab.Config.Core.Redirector = harness.redirector
|
||||
harness.ab.Config.Core.MailRenderer = harness.renderer
|
||||
harness.ab.Config.Core.Responder = harness.responder
|
||||
harness.ab.Config.Storage.SessionState = harness.session
|
||||
harness.ab.Config.Storage.Server = harness.storer
|
||||
|
||||
harness.recover = &Recover{harness.ab}
|
||||
|
||||
return harness
|
||||
}
|
||||
|
||||
func TestStartGet(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
r, _, _ := testSetup()
|
||||
h := testSetup()
|
||||
|
||||
storage := r.Storage()
|
||||
if storage[r.PrimaryID] != authboss.String {
|
||||
t.Error("Expected storage KV:", r.PrimaryID, authboss.String)
|
||||
}
|
||||
if storage[authboss.StoreEmail] != authboss.String {
|
||||
t.Error("Expected storage KV:", authboss.StoreEmail, authboss.String)
|
||||
}
|
||||
if storage[authboss.StorePassword] != authboss.String {
|
||||
t.Error("Expected storage KV:", authboss.StorePassword, authboss.String)
|
||||
}
|
||||
if storage[StoreRecoverToken] != authboss.String {
|
||||
t.Error("Expected storage KV:", StoreRecoverToken, authboss.String)
|
||||
}
|
||||
if storage[StoreRecoverTokenExpiry] != authboss.String {
|
||||
t.Error("Expected storage KV:", StoreRecoverTokenExpiry, authboss.String)
|
||||
}
|
||||
r := mocks.Request("GET")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
routes := r.Routes()
|
||||
if routes["/recover"] == nil {
|
||||
t.Error("Expected route '/recover' with handleFunc")
|
||||
}
|
||||
if routes["/recover/complete"] == nil {
|
||||
t.Error("Expected route '/recover/complete' with handleFunc")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecover_startHandlerFunc_GET(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
rec, _, _ := testSetup()
|
||||
ctx, w, r, _ := testRequest(rec.Authboss, "GET")
|
||||
|
||||
if err := rec.startHandlerFunc(ctx, w, r); err != nil {
|
||||
t.Error("Unexpected error:", err)
|
||||
if err := h.recover.StartGet(w, r); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Error("Unexpected status:", w.Code)
|
||||
t.Error("code was wrong:", w.Code)
|
||||
}
|
||||
|
||||
body := w.Body.String()
|
||||
if !strings.Contains(body, `<form action="recover"`) {
|
||||
t.Error("Should have rendered a form")
|
||||
if h.responder.Page != PageRecoverStart {
|
||||
t.Error("page was wrong:", h.responder.Page)
|
||||
}
|
||||
if !strings.Contains(body, `name="`+rec.PrimaryID) {
|
||||
t.Error("Form should contain the primary ID field")
|
||||
}
|
||||
if !strings.Contains(body, `name="confirm_`+rec.PrimaryID) {
|
||||
t.Error("Form should contain the confirm primary ID field")
|
||||
if h.responder.Data != nil {
|
||||
t.Error("expected no data:", h.responder.Data)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecover_startHandlerFunc_POST_ValidationFails(t *testing.T) {
|
||||
func TestStartPostSuccess(t *testing.T) {
|
||||
// no t.Parallel(), global var mangling
|
||||
|
||||
oldRecoverEmail := goRecoverEmail
|
||||
goRecoverEmail = func(r *Recover, ctx context.Context, to, token string) {
|
||||
r.SendRecoverEmail(ctx, to, token)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
goRecoverEmail = oldRecoverEmail
|
||||
}()
|
||||
|
||||
h := testSetup()
|
||||
|
||||
h.bodyReader.Return = &mocks.Values{
|
||||
PID: "test@test.com",
|
||||
}
|
||||
h.storer.Users["test@test.com"] = &mocks.User{
|
||||
Email: "test@test.com",
|
||||
Password: "i can't recall, doesn't seem like something bcrypted though",
|
||||
}
|
||||
|
||||
r := mocks.Request("GET")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
if err := h.recover.StartPost(w, r); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if w.Code != http.StatusTemporaryRedirect {
|
||||
t.Error("code was wrong:", w.Code)
|
||||
}
|
||||
if h.redirector.Options.RedirectPath != h.ab.Config.Paths.RecoverOK {
|
||||
t.Error("page was wrong:", h.responder.Page)
|
||||
}
|
||||
if len(h.redirector.Options.Success) == 0 {
|
||||
t.Error("expected a nice success message")
|
||||
}
|
||||
|
||||
if h.mailer.Email.To[0] != "test@test.com" {
|
||||
t.Error("e-mail to address is wrong:", h.mailer.Email.To)
|
||||
}
|
||||
if !strings.HasSuffix(h.mailer.Email.Subject, "Password Reset") {
|
||||
t.Error("e-mail subject line is wrong:", h.mailer.Email.Subject)
|
||||
}
|
||||
if len(h.renderer.Data[DataRecoverURL].(string)) == 0 {
|
||||
t.Errorf("the renderer's url in data was missing: %#v", h.renderer.Data)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStartPostFailure(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
rec, _, _ := testSetup()
|
||||
ctx, w, r, _ := testRequest(rec.Authboss, "POST")
|
||||
h := testSetup()
|
||||
|
||||
if err := rec.startHandlerFunc(ctx, w, r); err != nil {
|
||||
t.Error("Unexpected error:", err)
|
||||
h.bodyReader.Return = &mocks.Values{
|
||||
PID: "test@test.com",
|
||||
}
|
||||
|
||||
r := mocks.Request("GET")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
if err := h.recover.StartPost(w, r); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if w.Code != http.StatusTemporaryRedirect {
|
||||
t.Error("code was wrong:", w.Code)
|
||||
}
|
||||
if h.redirector.Options.RedirectPath != h.ab.Config.Paths.RecoverOK {
|
||||
t.Error("page was wrong:", h.responder.Page)
|
||||
}
|
||||
if len(h.redirector.Options.Success) == 0 {
|
||||
t.Error("expected a nice success message")
|
||||
}
|
||||
|
||||
if len(h.mailer.Email.To) != 0 {
|
||||
t.Error("should not have sent an e-mail out!")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEndGet(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
h := testSetup()
|
||||
|
||||
h.bodyReader.Return = &mocks.Values{
|
||||
Token: "abcd",
|
||||
}
|
||||
|
||||
r := mocks.Request("GET")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
if err := h.recover.EndGet(w, r); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Error("Unexpected status:", w.Code)
|
||||
t.Error("code was wrong:", w.Code)
|
||||
}
|
||||
|
||||
if !strings.Contains(w.Body.String(), "Cannot be blank") {
|
||||
t.Error("Expected error about email being blank")
|
||||
if h.responder.Page != PageRecoverEnd {
|
||||
t.Error("page was wrong:", h.responder.Page)
|
||||
}
|
||||
if h.responder.Data[DataRecoverToken].(string) != "abcd" {
|
||||
t.Errorf("recovery token is wrong: %#v", h.responder.Data)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecover_startHandlerFunc_POST_UserNotFound(t *testing.T) {
|
||||
func TestEndPostSuccess(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
rec, _, _ := testSetup()
|
||||
ctx, w, r, _ := testRequest(rec.Authboss, "POST", "username", "john", "confirm_username", "john")
|
||||
h := testSetup()
|
||||
|
||||
err := rec.startHandlerFunc(ctx, w, r)
|
||||
if err == nil {
|
||||
t.Error("Expected error:", err)
|
||||
h.bodyReader.Return = &mocks.Values{
|
||||
Token: testURLBase64Token,
|
||||
}
|
||||
rerr, ok := err.(authboss.ErrAndRedirect)
|
||||
if !ok {
|
||||
t.Error("Expected ErrAndRedirect error")
|
||||
h.storer.Users["test@test.com"] = &mocks.User{
|
||||
Email: "test@test.com",
|
||||
Password: "to-overwrite",
|
||||
RecoverToken: testStdBase64Token,
|
||||
RecoverTokenExpiry: time.Now().UTC().AddDate(0, 0, 1),
|
||||
}
|
||||
|
||||
if rerr.Location != rec.RecoverOKPath {
|
||||
t.Error("Unexpected location:", rerr.Location)
|
||||
r := mocks.Request("GET")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
if err := h.recover.EndPost(w, r); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if rerr.FlashSuccess != recoverInitiateSuccessFlash {
|
||||
t.Error("Unexpected success flash", rerr.FlashSuccess)
|
||||
if w.Code != http.StatusTemporaryRedirect {
|
||||
t.Error("code was wrong:", w.Code)
|
||||
}
|
||||
if p := h.redirector.Options.RedirectPath; p != h.ab.Paths.RecoverOK {
|
||||
t.Error("path was wrong:", p)
|
||||
}
|
||||
if len(h.session.ClientValues[authboss.SessionKey]) != 0 {
|
||||
t.Error("should not have logged in the user")
|
||||
}
|
||||
if !strings.Contains(h.redirector.Options.Success, "recovered password") {
|
||||
t.Error("should not talk about logging in")
|
||||
}
|
||||
if strings.Contains(h.redirector.Options.Success, "logged in") {
|
||||
t.Error("should not talk about logging in")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecover_startHandlerFunc_POST(t *testing.T) {
|
||||
func TestEndPostSuccessLogin(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
rec, storer, _ := testSetup()
|
||||
h := testSetup()
|
||||
|
||||
storer.Users["john"] = authboss.Attributes{authboss.StoreUsername: "john", authboss.StoreEmail: "a@b.c"}
|
||||
|
||||
sentEmail := false
|
||||
goRecoverEmail = func(_ *Recover, _ *authboss.Context, _, _ string) {
|
||||
sentEmail = true
|
||||
h.ab.Config.Modules.RecoverLoginAfterRecovery = true
|
||||
h.bodyReader.Return = &mocks.Values{
|
||||
Token: testURLBase64Token,
|
||||
}
|
||||
h.storer.Users["test@test.com"] = &mocks.User{
|
||||
Email: "test@test.com",
|
||||
Password: "to-overwrite",
|
||||
RecoverToken: testStdBase64Token,
|
||||
RecoverTokenExpiry: time.Now().UTC().AddDate(0, 0, 1),
|
||||
}
|
||||
|
||||
ctx, w, r, sessionStorer := testRequest(rec.Authboss, "POST", "username", "john", "confirm_username", "john")
|
||||
r := mocks.Request("GET")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
if err := rec.startHandlerFunc(ctx, w, r); err != nil {
|
||||
t.Error("Unexpected error:", err)
|
||||
if err := h.recover.EndPost(h.ab.NewResponse(w, r), r); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if !sentEmail {
|
||||
t.Error("Expected email to have been sent")
|
||||
if w.Code != http.StatusTemporaryRedirect {
|
||||
t.Error("code was wrong:", w.Code)
|
||||
}
|
||||
|
||||
if val, err := storer.Users["john"].StringErr(StoreRecoverToken); err != nil {
|
||||
t.Error("Unexpected error:", err)
|
||||
} else if len(val) <= 0 {
|
||||
t.Error("Unexpected Recover Token to be set")
|
||||
if p := h.redirector.Options.RedirectPath; p != h.ab.Paths.RecoverOK {
|
||||
t.Error("path was wrong:", p)
|
||||
}
|
||||
|
||||
if val, err := storer.Users["john"].DateTimeErr(StoreRecoverTokenExpiry); err != nil {
|
||||
t.Error("Unexpected error:", err)
|
||||
} else if !val.After(time.Now()) {
|
||||
t.Error("Expected recovery token expiry to be greater than now")
|
||||
if len(h.session.ClientValues[authboss.SessionKey]) == 0 {
|
||||
t.Error("it should have logged in the user")
|
||||
}
|
||||
|
||||
if w.Code != http.StatusFound {
|
||||
t.Error("Unexpected status:", w.Code)
|
||||
}
|
||||
|
||||
loc := w.Header().Get("Location")
|
||||
if loc != rec.RecoverOKPath {
|
||||
t.Error("Unexpected location:", loc)
|
||||
}
|
||||
|
||||
if value, ok := sessionStorer.Get(authboss.FlashSuccessKey); !ok {
|
||||
t.Error("Expected success flash message")
|
||||
} else if value != recoverInitiateSuccessFlash {
|
||||
t.Error("Unexpected success flash message")
|
||||
if !strings.Contains(h.redirector.Options.Success, "logged in") {
|
||||
t.Error("should talk about logging in")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecover_startHandlerFunc_OtherMethods(t *testing.T) {
|
||||
func TestEndPostInvalidBase64(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
rec, _, _ := testSetup()
|
||||
h := testSetup()
|
||||
|
||||
methods := []string{"HEAD", "PUT", "DELETE", "TRACE", "CONNECT"}
|
||||
|
||||
for i, method := range methods {
|
||||
_, w, r, _ := testRequest(rec.Authboss, method)
|
||||
|
||||
if err := rec.startHandlerFunc(nil, w, r); err != nil {
|
||||
t.Errorf("%d> Unexpected error: %s", i, err)
|
||||
}
|
||||
|
||||
if http.StatusMethodNotAllowed != w.Code {
|
||||
t.Errorf("%d> Expected status code %d, got %d", i, http.StatusMethodNotAllowed, w.Code)
|
||||
continue
|
||||
}
|
||||
h.bodyReader.Return = &mocks.Values{
|
||||
Token: "a",
|
||||
}
|
||||
|
||||
r := mocks.Request("GET")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
if err := h.recover.EndPost(w, r); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
invalidCheck(t, h, w)
|
||||
}
|
||||
|
||||
func TestRecover_newToken(t *testing.T) {
|
||||
func TestEndPostExpiredToken(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
regexURL := regexp.MustCompile(`^(?:[A-Za-z0-9-_]{4})*(?:[A-Za-z0-9-_]{2}==|[A-Za-z0-9-_]{3}=)?$`)
|
||||
regexSTD := regexp.MustCompile(`^(?:[A-Za-z0-9+/]{4})*(?:[A-Za-z0-9+/]{2}==|[A-Za-z0-9+/]{3}=)?$`)
|
||||
h := testSetup()
|
||||
|
||||
encodedToken, encodedSum, _ := newToken()
|
||||
|
||||
if !regexURL.MatchString(encodedToken) {
|
||||
t.Error("Expected encodedToken to be base64 encoded")
|
||||
h.bodyReader.Return = &mocks.Values{
|
||||
Token: testURLBase64Token,
|
||||
}
|
||||
h.storer.Users["test@test.com"] = &mocks.User{
|
||||
Email: "test@test.com",
|
||||
Password: "to-overwrite",
|
||||
RecoverToken: testStdBase64Token,
|
||||
RecoverTokenExpiry: time.Now().UTC().AddDate(0, 0, -1),
|
||||
}
|
||||
|
||||
if !regexSTD.MatchString(encodedSum) {
|
||||
t.Error("Expected encodedSum to be base64 encoded")
|
||||
r := mocks.Request("GET")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
if err := h.recover.EndPost(w, r); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
invalidCheck(t, h, w)
|
||||
}
|
||||
|
||||
func TestRecover_sendRecoverMail_FailToSend(t *testing.T) {
|
||||
func TestEndPostUserNotExist(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
r, _, logger := testSetup()
|
||||
h := testSetup()
|
||||
|
||||
mailer := mocks.NewMockMailer()
|
||||
mailer.SendErr = "failed to send"
|
||||
r.Mailer = mailer
|
||||
|
||||
r.sendRecoverEmail(r.NewContext(), "", "")
|
||||
|
||||
if !strings.Contains(logger.String(), "failed to send") {
|
||||
t.Error("Expected logged to have msg:", "failed to send")
|
||||
h.bodyReader.Return = &mocks.Values{
|
||||
Token: testURLBase64Token,
|
||||
}
|
||||
|
||||
r := mocks.Request("GET")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
if err := h.recover.EndPost(w, r); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
invalidCheck(t, h, w)
|
||||
}
|
||||
|
||||
func TestRecover_sendRecoverEmail(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
r, _, _ := testSetup()
|
||||
|
||||
mailer := mocks.NewMockMailer()
|
||||
r.EmailSubjectPrefix = "foo "
|
||||
r.RootURL = "bar"
|
||||
r.Mailer = mailer
|
||||
|
||||
r.sendRecoverEmail(r.NewContext(), "a@b.c", "abc=")
|
||||
if len(mailer.Last.To) != 1 {
|
||||
t.Error("Expected 1 to email")
|
||||
}
|
||||
if mailer.Last.To[0] != "a@b.c" {
|
||||
t.Error("Unexpected to email:", mailer.Last.To[0])
|
||||
}
|
||||
if mailer.Last.Subject != "foo Password Reset" {
|
||||
t.Error("Unexpected subject:", mailer.Last.Subject)
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s/recover/complete?token=abc%%3D", r.RootURL)
|
||||
if !strings.Contains(mailer.Last.HTMLBody, url) {
|
||||
t.Error("Expected HTMLBody to contain url:", url)
|
||||
}
|
||||
if !strings.Contains(mailer.Last.TextBody, url) {
|
||||
t.Error("Expected TextBody to contain url:", url)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecover_completeHandlerFunc_GET_VerifyFails(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
rec, storer, _ := testSetup()
|
||||
|
||||
ctx, w, r, _ := testRequest(rec.Authboss, "GET", "token", testURLBase64Token)
|
||||
|
||||
err := rec.completeHandlerFunc(ctx, w, r)
|
||||
rerr, ok := err.(authboss.ErrAndRedirect)
|
||||
if !ok {
|
||||
t.Error("Expected ErrAndRedirect:", err)
|
||||
}
|
||||
if rerr.Location != "/" {
|
||||
t.Error("Unexpected location:", rerr.Location)
|
||||
}
|
||||
|
||||
var zeroTime time.Time
|
||||
storer.Users["john"] = authboss.Attributes{StoreRecoverToken: testStdBase64Token, StoreRecoverTokenExpiry: zeroTime}
|
||||
|
||||
ctx, w, r, _ = testRequest(rec.Authboss, "GET", "token", testURLBase64Token)
|
||||
|
||||
err = rec.completeHandlerFunc(ctx, w, r)
|
||||
rerr, ok = err.(authboss.ErrAndRedirect)
|
||||
if !ok {
|
||||
t.Error("Expected ErrAndRedirect")
|
||||
}
|
||||
if rerr.Location != "/recover" {
|
||||
t.Error("Unexpected location:", rerr.Location)
|
||||
}
|
||||
if rerr.FlashError != recoverTokenExpiredFlash {
|
||||
t.Error("Unexpcted flash error:", rerr.FlashError)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecover_completeHandlerFunc_GET(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
rec, storer, _ := testSetup()
|
||||
|
||||
storer.Users["john"] = authboss.Attributes{StoreRecoverToken: testStdBase64Token, StoreRecoverTokenExpiry: time.Now().Add(1 * time.Hour)}
|
||||
|
||||
ctx, w, r, _ := testRequest(rec.Authboss, "GET", "token", testURLBase64Token)
|
||||
|
||||
if err := rec.completeHandlerFunc(ctx, w, r); err != nil {
|
||||
t.Error("Unexpected error:", err)
|
||||
}
|
||||
func invalidCheck(t *testing.T, h *testHarness, w *httptest.ResponseRecorder) {
|
||||
t.Helper()
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Error("Unexpected status:", w.Code)
|
||||
t.Error("code was wrong:", w.Code)
|
||||
}
|
||||
|
||||
body := w.Body.String()
|
||||
if !strings.Contains(body, `<form action="recover/complete"`) {
|
||||
t.Error("Should have rendered a form")
|
||||
if h.responder.Page != PageRecoverEnd {
|
||||
t.Error("page was wrong:", h.responder.Page)
|
||||
}
|
||||
if !strings.Contains(body, `name="password"`) {
|
||||
t.Error("Form should contain the password field")
|
||||
}
|
||||
if !strings.Contains(body, `name="confirm_password"`) {
|
||||
t.Error("Form should contain the confirm password field")
|
||||
}
|
||||
if !strings.Contains(body, `name="token"`) {
|
||||
t.Error("Form should contain the token field")
|
||||
if h.responder.Data[authboss.DataValidation].(authboss.ErrorList)[0].Error() != "recovery token is invalid" {
|
||||
t.Error("expected a vague error to mislead")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecover_completeHandlerFunc_POST_TokenMissing(t *testing.T) {
|
||||
func TestGenerateToken(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
rec, _, _ := testSetup()
|
||||
ctx, w, r, _ := testRequest(rec.Authboss, "POST")
|
||||
|
||||
err := rec.completeHandlerFunc(ctx, w, r)
|
||||
if err == nil || err.Error() != "Failed to retrieve client attribute: token" {
|
||||
t.Error("Unexpected error:", err)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestRecover_completeHandlerFunc_POST_ValidationFails(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
rec, _, _ := testSetup()
|
||||
ctx, w, r, _ := testRequest(rec.Authboss, "POST", "token", testURLBase64Token)
|
||||
|
||||
if err := rec.completeHandlerFunc(ctx, w, r); err != nil {
|
||||
t.Error("Unexpected error:", err)
|
||||
}
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Error("Unexpected status:", w.Code)
|
||||
}
|
||||
|
||||
if !strings.Contains(w.Body.String(), "Cannot be blank") {
|
||||
t.Error("Expected error about password being blank")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecover_completeHandlerFunc_POST_VerificationFails(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
rec, _, _ := testSetup()
|
||||
ctx, w, r, _ := testRequest(rec.Authboss, "POST", "token", testURLBase64Token, authboss.StorePassword, "abcd", "confirm_"+authboss.StorePassword, "abcd")
|
||||
|
||||
if err := rec.completeHandlerFunc(ctx, w, r); err == nil {
|
||||
log.Println(w.Body.String())
|
||||
t.Error("Expected error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecover_completeHandlerFunc_POST(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
rec, storer, _ := testSetup()
|
||||
|
||||
storer.Users["john"] = authboss.Attributes{rec.PrimaryID: "john", StoreRecoverToken: testStdBase64Token, StoreRecoverTokenExpiry: time.Now().Add(1 * time.Hour), authboss.StorePassword: "asdf"}
|
||||
|
||||
cbCalled := false
|
||||
|
||||
rec.Events = authboss.NewCallbacks()
|
||||
rec.Events.After(authboss.EventPasswordReset, func(_ *authboss.Context) error {
|
||||
cbCalled = true
|
||||
return nil
|
||||
})
|
||||
|
||||
ctx, w, r, sessionStorer := testRequest(rec.Authboss, "POST", "token", testURLBase64Token, authboss.StorePassword, "abcd", "confirm_"+authboss.StorePassword, "abcd")
|
||||
|
||||
if err := rec.completeHandlerFunc(ctx, w, r); err != nil {
|
||||
t.Error("Unexpected error:", err)
|
||||
}
|
||||
|
||||
var zeroTime time.Time
|
||||
|
||||
u := storer.Users["john"]
|
||||
if password, ok := u.String(authboss.StorePassword); !ok || password == "asdf" {
|
||||
t.Error("Expected password to have been reset")
|
||||
}
|
||||
|
||||
if recToken, ok := u.String(StoreRecoverToken); !ok || recToken != "" {
|
||||
t.Error("Expected recovery token to have been zeroed")
|
||||
}
|
||||
|
||||
if reCExpiry, ok := u.DateTime(StoreRecoverTokenExpiry); !ok || !reCExpiry.Equal(zeroTime) {
|
||||
t.Error("Expected recovery token expiry to have been zeroed")
|
||||
}
|
||||
|
||||
if !cbCalled {
|
||||
t.Error("Expected EventPasswordReset callback to have been fired")
|
||||
}
|
||||
|
||||
if val, ok := sessionStorer.Get(authboss.SessionKey); !ok || val != "john" {
|
||||
t.Error("Expected SessionKey to be:", "john")
|
||||
}
|
||||
|
||||
if w.Code != http.StatusFound {
|
||||
t.Error("Unexpected status:", w.Code)
|
||||
}
|
||||
|
||||
loc := w.Header().Get("Location")
|
||||
if loc != rec.AuthLogoutOKPath {
|
||||
t.Error("Unexpected location:", loc)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_verifyToken_MissingToken(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testSetup()
|
||||
r := mocks.MockRequest("GET")
|
||||
|
||||
if _, err := verifyToken(nil, r); err == nil {
|
||||
t.Error("Expected error about missing token")
|
||||
}
|
||||
}
|
||||
|
||||
func Test_verifyToken_InvalidToken(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
rec, storer, _ := testSetup()
|
||||
storer.Users["a"] = authboss.Attributes{
|
||||
StoreRecoverToken: testStdBase64Token,
|
||||
}
|
||||
|
||||
ctx := rec.Authboss.NewContext()
|
||||
req, _ := http.NewRequest("GET", "/?token=asdf", nil)
|
||||
if _, err := verifyToken(ctx, req); err != authboss.ErrUserNotFound {
|
||||
t.Error("Unexpected error:", err)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_verifyToken_ExpiredToken(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
rec, storer, _ := testSetup()
|
||||
storer.Users["a"] = authboss.Attributes{
|
||||
StoreRecoverToken: testStdBase64Token,
|
||||
StoreRecoverTokenExpiry: time.Now().Add(time.Duration(-24) * time.Hour),
|
||||
}
|
||||
|
||||
ctx := rec.Authboss.NewContext()
|
||||
req, _ := http.NewRequest("GET", "/?token="+testURLBase64Token, nil)
|
||||
if _, err := verifyToken(ctx, req); err != errRecoveryTokenExpired {
|
||||
t.Error("Unexpected error:", err)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_verifyToken(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
rec, storer, _ := testSetup()
|
||||
storer.Users["a"] = authboss.Attributes{
|
||||
StoreRecoverToken: testStdBase64Token,
|
||||
StoreRecoverTokenExpiry: time.Now().Add(time.Duration(24) * time.Hour),
|
||||
}
|
||||
|
||||
ctx := rec.Authboss.NewContext()
|
||||
req, _ := http.NewRequest("GET", "/?token="+testURLBase64Token, nil)
|
||||
attrs, err := verifyToken(ctx, req)
|
||||
hash, token, err := GenerateToken()
|
||||
if err != nil {
|
||||
t.Error("Unexpected error:", err)
|
||||
t.Error(err)
|
||||
}
|
||||
if attrs == nil {
|
||||
t.Error("Unexpected nil attrs")
|
||||
|
||||
// base64 length: n = 64; 4*(64/3) = 85.3; round to nearest 4: 88
|
||||
if len(hash) != 88 {
|
||||
t.Errorf("string length was wrong (%d): %s", len(hash), hash)
|
||||
}
|
||||
|
||||
// base64 length: n = 32; 4*(32/3) = 42.6; round to nearest 4: 44
|
||||
if len(token) != 44 {
|
||||
t.Errorf("string length was wrong (%d): %s", len(token), token)
|
||||
}
|
||||
|
||||
rawToken, err := base64.URLEncoding.DecodeString(token)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
rawHash, err := base64.StdEncoding.DecodeString(hash)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
checkHash := sha512.Sum512(rawToken)
|
||||
if 0 != bytes.Compare(checkHash[:], rawHash) {
|
||||
t.Error("expected hashes to match")
|
||||
}
|
||||
}
|
||||
|
@ -56,7 +56,7 @@ type EmailResponseOptions struct {
|
||||
TextTemplate string
|
||||
}
|
||||
|
||||
// Email renders the e-mail templates and sends it using the mailer.
|
||||
// Email renders the e-mail templates for the given email and sends it using the mailer.
|
||||
func (a *Authboss) Email(ctx context.Context, email Email, ro EmailResponseOptions) error {
|
||||
if len(ro.HTMLTemplate) != 0 {
|
||||
htmlBody, _, err := a.Core.MailRenderer.Render(ctx, ro.HTMLTemplate, ro.Data)
|
||||
|
27
storage.go
27
storage.go
@ -54,6 +54,8 @@ type ServerStorer interface {
|
||||
// CreatingServerStorer is used for creating new users
|
||||
// like when Registration is being done.
|
||||
type CreatingServerStorer interface {
|
||||
ServerStorer
|
||||
|
||||
// New creates a blank user, it is not yet persisted in the database
|
||||
// but is just for storing data
|
||||
New(ctx context.Context) User
|
||||
@ -64,7 +66,20 @@ type CreatingServerStorer interface {
|
||||
|
||||
// ConfirmingServerStorer can find a user by a confirm token
|
||||
type ConfirmingServerStorer interface {
|
||||
LoadByToken(ctx context.Context, token string) (ConfirmableUser, error)
|
||||
ServerStorer
|
||||
|
||||
// LoadByConfirmToken finds a user by his confirm token field
|
||||
// and should return ErrUserNotFound if that user cannot be found.
|
||||
LoadByConfirmToken(ctx context.Context, token string) (ConfirmableUser, error)
|
||||
}
|
||||
|
||||
// RecoveringServerStorer allows users to be recovered by a token
|
||||
type RecoveringServerStorer interface {
|
||||
ServerStorer
|
||||
|
||||
// LoadByRecoverToken finds a user by his recover token field
|
||||
// and should return ErrUserNotFound if that user cannot be found.
|
||||
LoadByRecoverToken(ctx context.Context, token string) (RecoverableUser, error)
|
||||
}
|
||||
|
||||
// EnsureCanCreate makes sure the server storer supports create operations
|
||||
@ -86,3 +101,13 @@ func EnsureCanConfirm(storer ServerStorer) ConfirmingServerStorer {
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
// EnsureCanRecover makes sure the server storer supports confirm-lookup operations
|
||||
func EnsureCanRecover(storer ServerStorer) RecoveringServerStorer {
|
||||
s, ok := storer.(RecoveringServerStorer)
|
||||
if !ok {
|
||||
panic("could not upgrade serverstorer to recoveringserverstorer, check your struct")
|
||||
}
|
||||
|
||||
return s
|
||||
}
|
||||
|
38
user.go
38
user.go
@ -1,6 +1,9 @@
|
||||
package authboss
|
||||
|
||||
import "time"
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
// User has functions for each piece of data it requires.
|
||||
// Data should not be persisted on each function call.
|
||||
@ -50,6 +53,19 @@ type LockableUser interface {
|
||||
PutLocked(locked time.Time)
|
||||
}
|
||||
|
||||
// RecoverableUser is a user that can be recovered via e-mail
|
||||
type RecoverableUser interface {
|
||||
AuthableUser
|
||||
|
||||
GetEmail() (email string)
|
||||
GetRecoverToken() (token string)
|
||||
GetRecoverExpiry() (expiry time.Time)
|
||||
|
||||
PutEmail(email string)
|
||||
PutRecoverToken(token string)
|
||||
PutRecoverExpiry(expiry time.Time)
|
||||
}
|
||||
|
||||
// ArbitraryUser allows arbitrary data from the web form through. You should
|
||||
// definitely only pull the keys you want from the map, since this is unfiltered
|
||||
// input from a web request and is an attack vector.
|
||||
@ -85,26 +101,34 @@ type OAuth2User interface {
|
||||
PutExpiry(expiry time.Duration)
|
||||
}
|
||||
|
||||
// MustBeAuthable forces an upgrade to an Authable user or panic.
|
||||
// MustBeAuthable forces an upgrade to an AuthableUser or panic.
|
||||
func MustBeAuthable(u User) AuthableUser {
|
||||
if au, ok := u.(AuthableUser); ok {
|
||||
return au
|
||||
}
|
||||
panic("could not upgrade user to an authable user, check your user struct")
|
||||
panic(fmt.Sprintf("could not upgrade user to an authable user, type: %T", u))
|
||||
}
|
||||
|
||||
// MustBeConfirmable forces an upgrade to a Confirmable user or panic.
|
||||
// MustBeConfirmable forces an upgrade to a ConfirmableUser or panic.
|
||||
func MustBeConfirmable(u User) ConfirmableUser {
|
||||
if cu, ok := u.(ConfirmableUser); ok {
|
||||
return cu
|
||||
}
|
||||
panic("could not upgrade user to a confirmable user, check your user struct")
|
||||
panic(fmt.Sprintf("could not upgrade user to a confirmable user, type: %T", u))
|
||||
}
|
||||
|
||||
// MustBeLockable forces an upgrade to a Lockable user or panic.
|
||||
// MustBeLockable forces an upgrade to a LockableUser or panic.
|
||||
func MustBeLockable(u User) LockableUser {
|
||||
if lu, ok := u.(LockableUser); ok {
|
||||
return lu
|
||||
}
|
||||
panic("could not upgrade user to a lockable user, check your user struct")
|
||||
panic(fmt.Sprintf("could not upgrade user to a lockable user, given type: %T", u))
|
||||
}
|
||||
|
||||
// MustBeRecoverable forces an upgrade to a RecoverableUser or panic.
|
||||
func MustBeRecoverable(u User) RecoverableUser {
|
||||
if lu, ok := u.(RecoverableUser); ok {
|
||||
return lu
|
||||
}
|
||||
panic(fmt.Sprintf("could not upgrade user to a recoverable user, given type: %T", u))
|
||||
}
|
||||
|
71
values.go
71
values.go
@ -31,6 +31,40 @@ type UserValuer interface {
|
||||
GetPassword() string
|
||||
}
|
||||
|
||||
// ConfirmValuer allows us to pull out the token from the request
|
||||
type ConfirmValuer interface {
|
||||
Validator
|
||||
|
||||
GetToken() string
|
||||
}
|
||||
|
||||
// RecoverStartValuer provides the PID entered by the user.
|
||||
type RecoverStartValuer interface {
|
||||
Validator
|
||||
|
||||
GetPID() string
|
||||
}
|
||||
|
||||
// RecoverMiddleValuer provides the token that the user submitted
|
||||
// via their link.
|
||||
type RecoverMiddleValuer interface {
|
||||
Validator
|
||||
|
||||
GetToken() string
|
||||
}
|
||||
|
||||
// RecoverEndValuer is used to get data back from the final
|
||||
// page of password recovery, the user will provide a password
|
||||
// and it must be accompanied by the token to authorize the changing
|
||||
// of that password. Contrary to the RecoverValuer, this should
|
||||
// have validation errors for bad tokens.
|
||||
type RecoverEndValuer interface {
|
||||
Validator
|
||||
|
||||
GetPassword() string
|
||||
GetToken() string
|
||||
}
|
||||
|
||||
// ArbitraryValuer provides the "rest" of the fields
|
||||
// that aren't strictly needed for anything in particular,
|
||||
// address, secondary e-mail, etc.
|
||||
@ -49,13 +83,6 @@ type ArbitraryValuer interface {
|
||||
GetValues() map[string]string
|
||||
}
|
||||
|
||||
// ConfirmValuer allows us to pull out the token from the request
|
||||
type ConfirmValuer interface {
|
||||
Validator
|
||||
|
||||
GetToken() string
|
||||
}
|
||||
|
||||
// MustHaveUserValues upgrades a validatable set of values
|
||||
// to ones specific to an authenticating user.
|
||||
func MustHaveUserValues(v Validator) UserValuer {
|
||||
@ -75,3 +102,33 @@ func MustHaveConfirmValues(v Validator) ConfirmValuer {
|
||||
|
||||
panic(fmt.Sprintf("bodyreader returned a type that could not be upgraded to ConfirmValuer: %T", v))
|
||||
}
|
||||
|
||||
// MustHaveRecoverStartValues upgrades a validatable set of values
|
||||
// to ones specific to a user that needs to be recovered.
|
||||
func MustHaveRecoverStartValues(v Validator) RecoverStartValuer {
|
||||
if u, ok := v.(RecoverStartValuer); ok {
|
||||
return u
|
||||
}
|
||||
|
||||
panic(fmt.Sprintf("bodyreader returned a type that could not be upgraded to RecoverStartValuer: %T", v))
|
||||
}
|
||||
|
||||
// MustHaveRecoverMiddleValues upgrades a validatable set of values
|
||||
// to ones specific to a user that's attempting to recover.
|
||||
func MustHaveRecoverMiddleValues(v Validator) RecoverMiddleValuer {
|
||||
if u, ok := v.(RecoverMiddleValuer); ok {
|
||||
return u
|
||||
}
|
||||
|
||||
panic(fmt.Sprintf("bodyreader returned a type that could not be upgraded to RecoverMiddleValuer: %T", v))
|
||||
}
|
||||
|
||||
// MustHaveRecoverEndValues upgrades a validatable set of values
|
||||
// to ones specific to a user that needs to be recovered.
|
||||
func MustHaveRecoverEndValues(v Validator) RecoverEndValuer {
|
||||
if u, ok := v.(RecoverEndValuer); ok {
|
||||
return u
|
||||
}
|
||||
|
||||
panic(fmt.Sprintf("bodyreader returned a type that could not be upgraded to RecoverEndValuer: %T", v))
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user