1
0
mirror of https://github.com/raseels-repos/golang-saas-starter-kit.git synced 2025-08-08 22:36:41 +02:00

fixed redirects to handle saving the session when set

This commit is contained in:
Lee Brown
2019-08-05 18:47:42 -08:00
parent 7909dc4ca4
commit 900cfcf713
14 changed files with 191 additions and 156 deletions

View File

@ -209,13 +209,8 @@ func (h *Account) Update(ctx context.Context, w http.ResponseWriter, r *http.Req
webcontext.SessionFlashSuccess(ctx, webcontext.SessionFlashSuccess(ctx,
"Account Updated", "Account Updated",
"Account profile successfully updated.") "Account profile successfully updated.")
err = webcontext.ContextSession(ctx).Save(r, w)
if err != nil {
return false, err
}
http.Redirect(w, r, "/account", http.StatusFound) return true, web.Redirect(ctx, w, r, "/account", http.StatusFound)
return true, nil
} }
acc, err := account.ReadByID(ctx, claims, h.MasterDB, claims.Audience) acc, err := account.ReadByID(ctx, claims, h.MasterDB, claims.Audience)

View File

@ -203,13 +203,8 @@ func (h *Projects) Create(ctx context.Context, w http.ResponseWriter, r *http.Re
webcontext.SessionFlashSuccess(ctx, webcontext.SessionFlashSuccess(ctx,
"Project Created", "Project Created",
"Project successfully created.") "Project successfully created.")
err = webcontext.ContextSession(ctx).Save(r, w)
if err != nil {
return false, err
}
http.Redirect(w, r, urlProjectsView(usr.ID), http.StatusFound) return true, web.Redirect(ctx, w, r, urlProjectsView(usr.ID), http.StatusFound)
return true, nil
} }
return false, nil return false, nil
@ -266,13 +261,8 @@ func (h *Projects) View(ctx context.Context, w http.ResponseWriter, r *http.Requ
webcontext.SessionFlashSuccess(ctx, webcontext.SessionFlashSuccess(ctx,
"Project Archive", "Project Archive",
"Project successfully archive.") "Project successfully archive.")
err = webcontext.ContextSession(ctx).Save(r, w)
if err != nil {
return false, err
}
http.Redirect(w, r, urlProjectsIndex(), http.StatusFound) return true, web.Redirect(ctx, w, r, urlProjectsIndex(), http.StatusFound)
return true, nil
} }
} }
@ -347,13 +337,8 @@ func (h *Projects) Update(ctx context.Context, w http.ResponseWriter, r *http.Re
webcontext.SessionFlashSuccess(ctx, webcontext.SessionFlashSuccess(ctx,
"Project Updated", "Project Updated",
"Project successfully updated.") "Project successfully updated.")
err = webcontext.ContextSession(ctx).Save(r, w)
if err != nil {
return false, err
}
http.Redirect(w, r, urlProjectsView(req.ID), http.StatusFound) return true, web.Redirect(ctx, w, r, urlProjectsView(req.ID), http.StatusFound)
return true, nil
} }
return false, nil return false, nil

View File

@ -40,12 +40,10 @@ func (h *Root) indexDashboard(ctx context.Context, w http.ResponseWriter, r *htt
// indexDefault loads the root index page when a user has no authentication. // indexDefault loads the root index page when a user has no authentication.
func (u *Root) indexDefault(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { func (u *Root) indexDefault(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error {
return u.Renderer.Render(ctx, w, r, tmplLayoutSite, "site-index.gohtml", web.MIMETextHTMLCharsetUTF8, http.StatusOK, nil) return u.Renderer.Render(ctx, w, r, tmplLayoutSite, "site-index.gohtml", web.MIMETextHTMLCharsetUTF8, http.StatusOK, nil)
} }
// indexDefault loads the root index page when a user has no authentication. // SitePage loads the page with the layout for site instead of the app base.
func (u *Root) SitePage(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { func (u *Root) SitePage(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error {
var tmpName string var tmpName string
@ -63,18 +61,15 @@ func (u *Root) SitePage(ctx context.Context, w http.ResponseWriter, r *http.Requ
case "/legal/terms": case "/legal/terms":
tmpName = "legal-terms.gohtml" tmpName = "legal-terms.gohtml"
default: default:
http.Redirect(w, r, "/", http.StatusFound) return web.Redirect(ctx, w, r, "/", http.StatusFound)
return nil
} }
return u.Renderer.Render(ctx, w, r, tmplLayoutSite, tmpName, web.MIMETextHTMLCharsetUTF8, http.StatusOK, nil) return u.Renderer.Render(ctx, w, r, tmplLayoutSite, tmpName, web.MIMETextHTMLCharsetUTF8, http.StatusOK, nil)
} }
// IndexHtml redirects /index.html to the website root page. // IndexHtml redirects /index.html to the website root page.
func (u *Root) IndexHtml(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { func (u *Root) IndexHtml(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error {
http.Redirect(w, r, "/", http.StatusMovedPermanently) return web.Redirect(ctx, w, r, "/", http.StatusMovedPermanently)
return nil
} }
// RobotHandler returns a robots.txt response. // RobotHandler returns a robots.txt response.

View File

@ -86,14 +86,9 @@ func (h *Signup) Step1(ctx context.Context, w http.ResponseWriter, r *http.Reque
webcontext.SessionFlashSuccess(ctx, webcontext.SessionFlashSuccess(ctx,
"Thank you for Joining", "Thank you for Joining",
"You workflow will be a breeze starting today.") "You workflow will be a breeze starting today.")
err = webcontext.ContextSession(ctx).Save(r, w)
if err != nil {
return false, err
}
// Redirect the user to the dashboard. // Redirect the user to the dashboard.
http.Redirect(w, r, "/", http.StatusFound) return true, web.Redirect(ctx, w, r, "/", http.StatusFound)
return true, nil
} }
return false, nil return false, nil
@ -103,6 +98,10 @@ func (h *Signup) Step1(ctx context.Context, w http.ResponseWriter, r *http.Reque
if err != nil { if err != nil {
return web.RenderError(ctx, w, r, err, h.Renderer, TmplLayoutBase, TmplContentErrorGeneric, web.MIMETextHTMLCharsetUTF8) return web.RenderError(ctx, w, r, err, h.Renderer, TmplLayoutBase, TmplContentErrorGeneric, web.MIMETextHTMLCharsetUTF8)
} else if end { } else if end {
err = webcontext.ContextSession(ctx).Save(r, w)
if err != nil {
return err
}
return nil return nil
} }

View File

@ -112,8 +112,7 @@ func (h *User) Login(ctx context.Context, w http.ResponseWriter, r *http.Request
} }
// Redirect the user to the dashboard. // Redirect the user to the dashboard.
http.Redirect(w, r, redirectUri, http.StatusFound) return true, web.Redirect(ctx, w, r, redirectUri, http.StatusFound)
return true, nil
} }
return false, nil return false, nil
@ -148,9 +147,7 @@ func (h *User) Logout(ctx context.Context, w http.ResponseWriter, r *http.Reques
} }
// Redirect the user to the root page. // Redirect the user to the root page.
http.Redirect(w, r, "/", http.StatusFound) return web.Redirect(ctx, w, r, "/", http.StatusFound)
return nil
} }
// ResetPassword allows a user to perform forgot password. // ResetPassword allows a user to perform forgot password.
@ -281,8 +278,7 @@ func (h *User) ResetConfirm(ctx context.Context, w http.ResponseWriter, r *http.
} }
// Redirect the user to the dashboard. // Redirect the user to the dashboard.
http.Redirect(w, r, "/", http.StatusFound) return true, web.Redirect(ctx, w, r, "/", http.StatusFound)
return true, nil
} }
_, err = user.ParseResetHash(ctx, h.SecretKey, resetHash, ctxValues.Now) _, err = user.ParseResetHash(ctx, h.SecretKey, resetHash, ctxValues.Now)
@ -432,13 +428,8 @@ func (h *User) Update(ctx context.Context, w http.ResponseWriter, r *http.Reques
webcontext.SessionFlashSuccess(ctx, webcontext.SessionFlashSuccess(ctx,
"Profile Updated", "Profile Updated",
"User profile successfully updated.") "User profile successfully updated.")
err = webcontext.ContextSession(ctx).Save(r, w)
if err != nil {
return false, err
}
http.Redirect(w, r, "/user", http.StatusFound) return true, web.Redirect(ctx, w, r, "/user", http.StatusFound)
return true, nil
} }
return false, nil return false, nil
@ -584,16 +575,8 @@ func (h *User) VirtualLogin(ctx context.Context, w http.ResponseWriter, r *http.
fmt.Sprintf("You are now virtually logged into user %s.", fmt.Sprintf("You are now virtually logged into user %s.",
usr.Response(ctx).Name)) usr.Response(ctx).Name))
// Write the session to the client.
err = webcontext.ContextSession(ctx).Save(r, w)
if err != nil {
return false, err
}
// Redirect the user to the dashboard with the new credentials. // Redirect the user to the dashboard with the new credentials.
http.Redirect(w, r, "/", http.StatusFound) return true, web.Redirect(ctx, w, r, "/", http.StatusFound)
return true, nil
} }
return false, nil return false, nil
@ -724,9 +707,7 @@ func (h *User) VirtualLogout(ctx context.Context, w http.ResponseWriter, r *http
} }
// Redirect the user to the dashboard with the new credentials. // Redirect the user to the dashboard with the new credentials.
http.Redirect(w, r, "/", http.StatusFound) return web.Redirect(ctx, w, r, "/", http.StatusFound)
return nil
} }
// VirtualLogin handles switching the scope of the context to another user. // VirtualLogin handles switching the scope of the context to another user.
@ -800,16 +781,8 @@ func (h *User) SwitchAccount(ctx context.Context, w http.ResponseWriter, r *http
fmt.Sprintf("You are now logged into account %s.", fmt.Sprintf("You are now logged into account %s.",
acc.Response(ctx).Name)) acc.Response(ctx).Name))
// Write the session to the client.
err = webcontext.ContextSession(ctx).Save(r, w)
if err != nil {
return false, err
}
// Redirect the user to the dashboard with the new credentials. // Redirect the user to the dashboard with the new credentials.
http.Redirect(w, r, "/", http.StatusFound) return true, web.Redirect(ctx, w, r, "/", http.StatusFound)
return true, nil
} }
return false, nil return false, nil

View File

@ -259,13 +259,8 @@ func (h *Users) Create(ctx context.Context, w http.ResponseWriter, r *http.Reque
webcontext.SessionFlashSuccess(ctx, webcontext.SessionFlashSuccess(ctx,
"User Created", "User Created",
"User successfully created.") "User successfully created.")
err = webcontext.ContextSession(ctx).Save(r, w)
if err != nil {
return false, err
}
http.Redirect(w, r, urlUsersView(usr.ID), http.StatusFound) return true, web.Redirect(ctx, w, r, urlUsersView(usr.ID), http.StatusFound)
return true, nil
} }
return false, nil return false, nil
@ -333,13 +328,8 @@ func (h *Users) View(ctx context.Context, w http.ResponseWriter, r *http.Request
webcontext.SessionFlashSuccess(ctx, webcontext.SessionFlashSuccess(ctx,
"User Archive", "User Archive",
"User successfully archive.") "User successfully archive.")
err = webcontext.ContextSession(ctx).Save(r, w)
if err != nil {
return false, err
}
http.Redirect(w, r, urlUsersIndex(), http.StatusFound) return true, web.Redirect(ctx, w, r, urlUsersIndex(), http.StatusFound)
return true, nil
} }
} }
@ -483,13 +473,8 @@ func (h *Users) Update(ctx context.Context, w http.ResponseWriter, r *http.Reque
webcontext.SessionFlashSuccess(ctx, webcontext.SessionFlashSuccess(ctx,
"User Updated", "User Updated",
"User successfully updated.") "User successfully updated.")
err = webcontext.ContextSession(ctx).Save(r, w)
if err != nil {
return false, err
}
http.Redirect(w, r, urlUsersView(req.ID), http.StatusFound) return true, web.Redirect(ctx, w, r, urlUsersView(req.ID), http.StatusFound)
return true, nil
} }
return false, nil return false, nil
@ -607,13 +592,7 @@ func (h *Users) Invite(ctx context.Context, w http.ResponseWriter, r *http.Reque
"No users were invited.") "No users were invited.")
} }
err = webcontext.ContextSession(ctx).Save(r, w) return true, web.Redirect(ctx, w, r, urlUsersIndex(), http.StatusFound)
if err != nil {
return false, err
}
http.Redirect(w, r, urlUsersIndex(), http.StatusFound)
return true, nil
} }
return false, nil return false, nil
@ -652,7 +631,7 @@ func (h *Users) InviteAccept(ctx context.Context, w http.ResponseWriter, r *http
} }
// //
req := new(invite.AcceptInviteRequest) req := new(invite.AcceptInviteUserRequest)
data := make(map[string]interface{}) data := make(map[string]interface{})
f := func() (bool, error) { f := func() (bool, error) {
@ -670,30 +649,33 @@ func (h *Users) InviteAccept(ctx context.Context, w http.ResponseWriter, r *http
// Append the query param value to the request. // Append the query param value to the request.
req.InviteHash = inviteHash req.InviteHash = inviteHash
hash, err := invite.AcceptInvite(ctx, h.MasterDB, *req, h.SecretKey, ctxValues.Now) hash, err := invite.AcceptInviteUser(ctx, h.MasterDB, *req, h.SecretKey, ctxValues.Now)
if err != nil { if err != nil {
switch errors.Cause(err) { switch errors.Cause(err) {
case invite.ErrInviteExpired: case invite.ErrInviteExpired:
webcontext.SessionFlashError(ctx, webcontext.SessionFlashError(ctx,
"Invite Expired", "Invite Expired",
"The invite has expired.") "The invite has expired.")
return false, nil return false, nil
case invite.ErrUserAccountActive: case invite.ErrUserAccountActive:
webcontext.SessionFlashError(ctx, webcontext.SessionFlashError(ctx,
"User already Active", "User already Active",
"The user already is already active for the account. Try to login or use forgot password.") "The user is already is already active for the account. Try to login or use forgot password.")
http.Redirect(w, r, "/user/login", http.StatusFound)
return true, nil return true, web.Redirect(ctx, w, r, "/user/login", http.StatusFound)
case invite.ErrInviteUserPasswordSet:
case invite.ErrNoPendingInvite:
webcontext.SessionFlashError(ctx, webcontext.SessionFlashError(ctx,
"Invite already Accepted", "Invite Accepted",
"The invite has already been accepted. Try to login or use forgot password.") "The invite has already been accepted. Try to login or use forgot password.")
http.Redirect(w, r, "/user/login", http.StatusFound)
return true, nil return true, web.Redirect(ctx, w, r, "/user/login", http.StatusFound)
case user_account.ErrNotFound: case user_account.ErrNotFound:
return false, err return false, err
case invite.ErrNoPendingInvite:
return false, err
default: default:
if verr, ok := weberror.NewValidationError(ctx, err); ok { if verr, ok := weberror.NewValidationError(ctx, err); ok {
data["validationErrors"] = verr.(*weberror.Error) data["validationErrors"] = verr.(*weberror.Error)
@ -732,36 +714,57 @@ func (h *Users) InviteAccept(ctx context.Context, w http.ResponseWriter, r *http
} }
// Redirect the user to the dashboard. // Redirect the user to the dashboard.
http.Redirect(w, r, "/", http.StatusFound) return true, web.Redirect(ctx, w, r, "/", http.StatusFound)
return true, nil
} }
hash, err := invite.ParseInviteHash(ctx, h.SecretKey, inviteHash, ctxValues.Now) usrAcc, err := invite.AcceptInvite(ctx, h.MasterDB, invite.AcceptInviteRequest{
InviteHash: inviteHash,
}, h.SecretKey, ctxValues.Now)
if err != nil { if err != nil {
switch errors.Cause(err) { switch errors.Cause(err) {
case invite.ErrInviteExpired: case invite.ErrInviteExpired:
webcontext.SessionFlashError(ctx, webcontext.SessionFlashError(ctx,
"Invite Expired", "Invite Expired",
"The invite has expired.") "The invite has expired.")
return false, nil
case invite.ErrInviteUserPasswordSet: return true, web.Redirect(ctx, w, r, "/user/login", http.StatusFound)
case invite.ErrUserAccountActive:
webcontext.SessionFlashError(ctx, webcontext.SessionFlashError(ctx,
"Invite already Accepted", "User already Active",
"The user is already is already active for the account. Try to login or use forgot password.")
return true, web.Redirect(ctx, w, r, "/user/login", http.StatusFound)
case invite.ErrNoPendingInvite:
webcontext.SessionFlashError(ctx,
"Invite Accepted",
"The invite has already been accepted. Try to login or use forgot password.") "The invite has already been accepted. Try to login or use forgot password.")
http.Redirect(w, r, "/user/login", http.StatusFound)
return true, nil return true, web.Redirect(ctx, w, r, "/user/login", http.StatusFound)
case user_account.ErrNotFound:
return false, err
default: default:
if verr, ok := weberror.NewValidationError(ctx, err); ok { if verr, ok := weberror.NewValidationError(ctx, err); ok {
data["validationErrors"] = verr.(*weberror.Error) data["validationErrors"] = verr.(*weberror.Error)
return false, nil return false, nil
} else { } else {
return false, err return false, err
} }
} }
} else if usrAcc.Status == user_account.UserAccountStatus_Active {
webcontext.SessionFlashError(ctx,
"Invite Accepted",
"The invite has been accepted. Login to continue.")
return true, web.Redirect(ctx, w, r, "/user/login", http.StatusFound)
} }
// Read user by ID with no claims. // Read user by ID with no claims.
usr, err := user.ReadByID(ctx, auth.Claims{}, h.MasterDB, hash.UserID) usr, err := user.ReadByID(ctx, auth.Claims{}, h.MasterDB, usrAcc.UserID)
if err != nil { if err != nil {
return false, err return false, err
} }
@ -791,7 +794,7 @@ func (h *Users) InviteAccept(ctx context.Context, w http.ResponseWriter, r *http
data["form"] = req data["form"] = req
if verr, ok := weberror.NewValidationError(ctx, webcontext.Validator().Struct(invite.AcceptInviteRequest{})); ok { if verr, ok := weberror.NewValidationError(ctx, webcontext.Validator().Struct(invite.AcceptInviteUserRequest{})); ok {
data["validationDefaults"] = verr.(*weberror.Error) data["validationDefaults"] = verr.(*weberror.Error)
} }

View File

@ -55,13 +55,13 @@
<small>Role</small><br/> <small>Role</small><br/>
{{ if .userAccount }} {{ if .userAccount }}
<b> <b>
{{ range $r := .userAccount.Roles }} {{ range $r := .userAccount.Roles.Options }}{{ if $r.Selected }}
{{ if eq $r "admin" }} {{ if eq $r.Value "admin" }}
<span class="text-pink"><i class="far fa-kiss-wink-heart mr-1"></i>{{ $r }}</span> <span class="text-pink"><i class="far fa-kiss-wink-heart mr-1"></i>{{ $r.Title }}</span>
{{else}} {{else}}
<span class="text-purple"><i class="far fa-user-circle mr-1"></i>{{ $r }}</span> <span class="text-purple"><i class="far fa-user-circle mr-1"></i>{{ $r.Title }}</span>
{{end}} {{end}}
{{ end }} {{ end }}{{ end }}
</b> </b>
{{ end }} {{ end }}
</p> </p>

View File

@ -21,7 +21,7 @@
<label for="selectRoles">Roles</label> <label for="selectRoles">Roles</label>
<select class="form-control {{ ValidationFieldClass $.validationErrors "Roles" }}" <select class="form-control {{ ValidationFieldClass $.validationErrors "Roles" }}"
id="selectRoles" name="Roles" multiple="multiple"> id="selectRoles" name="Roles" multiple="multiple">
{{ range $t := .roles }} {{ range $t := .roles.Options }}
<option value="{{ $t.Value }}" {{ if $t.Selected }}selected="selected"{{ end }}>{{ $t.Title }}</option> <option value="{{ $t.Value }}" {{ if $t.Selected }}selected="selected"{{ end }}>{{ $t.Title }}</option>
{{ end }} {{ end }}
</select> </select>

View File

@ -67,7 +67,7 @@
<small>Role</small><br/> <small>Role</small><br/>
{{ if .userAccount }} {{ if .userAccount }}
<b> <b>
{{ range $r := .userAccount.Roles }}{{ if $r.Selected }} {{ range $r := .userAccount.Roles.Options }}{{ if $r.Selected }}
{{ if eq $r.Value "admin" }} {{ if eq $r.Value "admin" }}
<span class="text-pink"><i class="far fa-kiss-wink-heart mr-1"></i>{{ $r.Title }}</span> <span class="text-pink"><i class="far fa-kiss-wink-heart mr-1"></i>{{ $r.Title }}</span>
{{else}} {{else}}

View File

@ -117,12 +117,20 @@ func NewEnumResponse(ctx context.Context, value interface{}, options ...interfac
} }
// EnumResponse is a response friendly format for displaying a multi select enum. // EnumResponse is a response friendly format for displaying a multi select enum.
type EnumMultiResponse []EnumOption type EnumMultiResponse struct {
Values []string `json:"values" example:"active_etc"`
Options []EnumOption `json:"options,omitempty"`
}
// NewEnumMultiResponse returns a display friendly format for a multi enum field. // NewEnumMultiResponse returns a display friendly format for a multi enum field.
func NewEnumMultiResponse(ctx context.Context, selected []interface{}, options ...interface{}) EnumMultiResponse { func NewEnumMultiResponse(ctx context.Context, selected []interface{}, options ...interface{}) EnumMultiResponse {
var er EnumMultiResponse var er EnumMultiResponse
for _, s := range selected {
selStr := fmt.Sprintf("%s", s)
er.Values = append(er.Values, selStr)
}
for _, opt := range options { for _, opt := range options {
optStr := fmt.Sprintf("%s", opt) optStr := fmt.Sprintf("%s", opt)
opt := EnumOption{ opt := EnumOption{
@ -137,7 +145,7 @@ func NewEnumMultiResponse(ctx context.Context, selected []interface{}, options .
} }
} }
er = append(er, opt) er.Options = append(er.Options, opt)
} }
return er return er

View File

@ -256,3 +256,16 @@ func StaticHandler(ctx context.Context, w http.ResponseWriter, r *http.Request,
return nil return nil
} }
// Redirect ensures the session is flushed to the browser before the redirect is issued.
func Redirect(ctx context.Context, w http.ResponseWriter, r *http.Request, url string, code int) error {
if sess := webcontext.ContextSession(ctx); sess != nil {
if err := sess.Save(r, w); err != nil {
return err
}
}
http.Redirect(w, r, url, code)
return nil
}

View File

@ -26,9 +26,6 @@ var (
// ErrUserAccountActive occurs when the user already has an active user_account entry. // ErrUserAccountActive occurs when the user already has an active user_account entry.
ErrUserAccountActive = errors.New("User already active.") ErrUserAccountActive = errors.New("User already active.")
// ErrInviteUserPasswordSet occurs when the the reset hash exceeds the expiration.
ErrInviteUserPasswordSet = errors.New("User password set")
) )
// SendUserInvites sends emails to the users inviting them to join an account. // SendUserInvites sends emails to the users inviting them to join an account.
@ -181,7 +178,7 @@ func SendUserInvites(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, r
} }
// AcceptInvite updates the user using the provided invite hash. // AcceptInvite updates the user using the provided invite hash.
func AcceptInvite(ctx context.Context, dbConn *sqlx.DB, req AcceptInviteRequest, secretKey string, now time.Time) (*InviteHash, error) { func AcceptInvite(ctx context.Context, dbConn *sqlx.DB, req AcceptInviteRequest, secretKey string, now time.Time) (*user_account.UserAccount, error) {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user_account.invite.AcceptInvite") span, ctx := tracer.StartSpanFromContext(ctx, "internal.user_account.invite.AcceptInvite")
defer span.Finish() defer span.Finish()
@ -193,7 +190,7 @@ func AcceptInvite(ctx context.Context, dbConn *sqlx.DB, req AcceptInviteRequest,
return nil, err return nil, err
} }
hash, err := ParseInviteHash(ctx, secretKey, req.InviteHash, now) hash, err := ParseInviteHash(ctx, req.InviteHash, secretKey, now)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -216,24 +213,86 @@ func AcceptInvite(ctx context.Context, dbConn *sqlx.DB, req AcceptInviteRequest,
AccountID: hash.AccountID, AccountID: hash.AccountID,
}) })
if err != nil { if err != nil {
return nil, nil return nil, err
} }
// Ensure the entry has the status of invited. // Ensure the entry has the status of invited.
if usrAcc.Status != user_account.UserAccountStatus_Invited { if usrAcc.Status != user_account.UserAccountStatus_Invited {
// If the entry is already active // If the entry is already active
if usrAcc.Status == user_account.UserAccountStatus_Active { if usrAcc.Status == user_account.UserAccountStatus_Active {
return hash, errors.WithStack(ErrUserAccountActive) return usrAcc, errors.WithStack(ErrUserAccountActive)
}
return usrAcc, errors.WithStack(ErrNoPendingInvite)
}
// If the user already has a password set, then just update the user_account entry to status of active.
// The user will need to login and should not be auto-authenticated.
if len(u.PasswordHash) > 0 {
usrAcc.Status = user_account.UserAccountStatus_Active
err = user_account.Update(ctx, auth.Claims{}, dbConn, user_account.UserAccountUpdateRequest{
UserID: usrAcc.UserID,
AccountID: usrAcc.AccountID,
Status: &usrAcc.Status,
}, now)
if err != nil {
return nil, err
}
}
return usrAcc, nil
}
// AcceptInviteUser updates the user using the provided invite hash.
func AcceptInviteUser(ctx context.Context, dbConn *sqlx.DB, req AcceptInviteUserRequest, secretKey string, now time.Time) (*user_account.UserAccount, error) {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user_account.invite.AcceptInviteUser")
defer span.Finish()
v := webcontext.Validator()
// Validate the request.
err := v.StructCtx(ctx, req)
if err != nil {
return nil, err
}
hash, err := ParseInviteHash(ctx, req.InviteHash, secretKey, now)
if err != nil {
return nil, err
}
u, err := user.Read(ctx, auth.Claims{}, dbConn,
user.UserReadRequest{ID: hash.UserID, IncludeArchived: true})
if err != nil {
return nil, err
}
if u.ArchivedAt != nil && !u.ArchivedAt.Time.IsZero() {
err = user.Restore(ctx, auth.Claims{}, dbConn, user.UserRestoreRequest{ID: hash.UserID}, now)
if err != nil {
return nil, err
}
}
usrAcc, err := user_account.Read(ctx, auth.Claims{}, dbConn, user_account.UserAccountReadRequest{
UserID: hash.UserID,
AccountID: hash.AccountID,
})
if err != nil {
return nil, err
}
// Ensure the entry has the status of invited.
if usrAcc.Status != user_account.UserAccountStatus_Invited {
// If the entry is already active
if usrAcc.Status == user_account.UserAccountStatus_Active {
return usrAcc, errors.WithStack(ErrUserAccountActive)
} }
return nil, errors.WithStack(ErrNoPendingInvite) return nil, errors.WithStack(ErrNoPendingInvite)
} }
if len(u.PasswordHash) > 0 { // These three calls, user.Update, user.UpdatePassword, and user_account.Update
// Do not update the password for a user that already has a password set. // should probably be in a transaction!
return nil, errors.WithStack(ErrInviteUserPasswordSet)
}
// These two calls, user.Update and user.UpdatePassword should probably be in a transaction!
err = user.Update(ctx, auth.Claims{}, dbConn, user.UserUpdateRequest{ err = user.Update(ctx, auth.Claims{}, dbConn, user.UserUpdateRequest{
ID: hash.UserID, ID: hash.UserID,
Email: &req.Email, Email: &req.Email,
@ -254,15 +313,15 @@ func AcceptInvite(ctx context.Context, dbConn *sqlx.DB, req AcceptInviteRequest,
return nil, err return nil, err
} }
activeStatus := user_account.UserAccountStatus_Active usrAcc.Status = user_account.UserAccountStatus_Active
err = user_account.Update(ctx, auth.Claims{}, dbConn, user_account.UserAccountUpdateRequest{ err = user_account.Update(ctx, auth.Claims{}, dbConn, user_account.UserAccountUpdateRequest{
UserID: usrAcc.UserID, UserID: usrAcc.UserID,
AccountID: usrAcc.AccountID, AccountID: usrAcc.AccountID,
Status: &activeStatus, Status: &usrAcc.Status,
}, now) }, now)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return hash, nil return usrAcc, nil
} }

View File

@ -148,13 +148,13 @@ func TestSendUserInvites(t *testing.T) {
// Ensure validation is working by trying ResetConfirm with an empty request. // Ensure validation is working by trying ResetConfirm with an empty request.
{ {
expectedErr := errors.New("Key: 'AcceptInviteRequest.invite_hash' Error:Field validation for 'invite_hash' failed on the 'required' tag\n" + expectedErr := errors.New("Key: 'AcceptInviteUserRequest.invite_hash' Error:Field validation for 'invite_hash' failed on the 'required' tag\n" +
"Key: 'AcceptInviteRequest.email' Error:Field validation for 'email' failed on the 'required' tag\n" + "Key: 'AcceptInviteUserRequest.email' Error:Field validation for 'email' failed on the 'required' tag\n" +
"Key: 'AcceptInviteRequest.first_name' Error:Field validation for 'first_name' failed on the 'required' tag\n" + "Key: 'AcceptInviteUserRequest.first_name' Error:Field validation for 'first_name' failed on the 'required' tag\n" +
"Key: 'AcceptInviteRequest.last_name' Error:Field validation for 'last_name' failed on the 'required' tag\n" + "Key: 'AcceptInviteUserRequest.last_name' Error:Field validation for 'last_name' failed on the 'required' tag\n" +
"Key: 'AcceptInviteRequest.password' Error:Field validation for 'password' failed on the 'required' tag\n" + "Key: 'AcceptInviteUserRequest.password' Error:Field validation for 'password' failed on the 'required' tag\n" +
"Key: 'AcceptInviteRequest.password_confirm' Error:Field validation for 'password_confirm' failed on the 'required' tag") "Key: 'AcceptInviteUserRequest.password_confirm' Error:Field validation for 'password_confirm' failed on the 'required' tag")
_, err = AcceptInvite(ctx, test.MasterDB, AcceptInviteRequest{}, secretKey, now) _, err = AcceptInviteUser(ctx, test.MasterDB, AcceptInviteUserRequest{}, secretKey, now)
if err == nil { if err == nil {
t.Logf("\t\tWant: %+v", expectedErr) t.Logf("\t\tWant: %+v", expectedErr)
t.Fatalf("\t%s\tResetConfirm failed.", tests.Failed) t.Fatalf("\t%s\tResetConfirm failed.", tests.Failed)
@ -174,7 +174,7 @@ func TestSendUserInvites(t *testing.T) {
// Ensure the TTL is enforced. // Ensure the TTL is enforced.
{ {
newPass := uuid.NewRandom().String() newPass := uuid.NewRandom().String()
_, err = AcceptInvite(ctx, test.MasterDB, AcceptInviteRequest{ _, err = AcceptInviteUser(ctx, test.MasterDB, AcceptInviteUserRequest{
InviteHash: inviteHashes[0], InviteHash: inviteHashes[0],
Email: inviteEmails[0], Email: inviteEmails[0],
FirstName: "Foo", FirstName: "Foo",
@ -194,7 +194,7 @@ func TestSendUserInvites(t *testing.T) {
for idx, inviteHash := range inviteHashes { for idx, inviteHash := range inviteHashes {
newPass := uuid.NewRandom().String() newPass := uuid.NewRandom().String()
hash, err := AcceptInvite(ctx, test.MasterDB, AcceptInviteRequest{ hash, err := AcceptInviteUser(ctx, test.MasterDB, AcceptInviteUserRequest{
InviteHash: inviteHash, InviteHash: inviteHash,
Email: inviteEmails[idx], Email: inviteEmails[idx],
FirstName: "Foo", FirstName: "Foo",
@ -227,7 +227,7 @@ func TestSendUserInvites(t *testing.T) {
// Ensure the reset hash does not work after its used. // Ensure the reset hash does not work after its used.
{ {
newPass := uuid.NewRandom().String() newPass := uuid.NewRandom().String()
_, err = AcceptInvite(ctx, test.MasterDB, AcceptInviteRequest{ _, err = AcceptInviteUser(ctx, test.MasterDB, AcceptInviteUserRequest{
InviteHash: inviteHashes[0], InviteHash: inviteHashes[0],
Email: inviteEmails[0], Email: inviteEmails[0],
FirstName: "Foo", FirstName: "Foo",
@ -237,7 +237,7 @@ func TestSendUserInvites(t *testing.T) {
}, secretKey, now) }, secretKey, now)
if errors.Cause(err) != ErrUserAccountActive { if errors.Cause(err) != ErrUserAccountActive {
t.Logf("\t\tGot : %+v", errors.Cause(err)) t.Logf("\t\tGot : %+v", errors.Cause(err))
t.Logf("\t\tWant: %+v", ErrInviteUserPasswordSet) t.Logf("\t\tWant: %+v", ErrUserAccountActive)
t.Fatalf("\t%s\tInviteAccept verify reuse failed.", tests.Failed) t.Fatalf("\t%s\tInviteAccept verify reuse failed.", tests.Failed)
} }
t.Logf("\t%s\tInviteAccept verify reuse disabled ok.", tests.Success) t.Logf("\t%s\tInviteAccept verify reuse disabled ok.", tests.Success)

View File

@ -32,6 +32,11 @@ type InviteHash struct {
// AcceptInviteRequest defines the fields need to complete an invite request. // AcceptInviteRequest defines the fields need to complete an invite request.
type AcceptInviteRequest struct { type AcceptInviteRequest struct {
InviteHash string `json:"invite_hash" validate:"required" example:"d69bdef7-173f-4d29-b52c-3edc60baf6a2"`
}
// AcceptInviteUserRequest defines the fields need to complete an invite request.
type AcceptInviteUserRequest struct {
InviteHash string `json:"invite_hash" validate:"required" example:"d69bdef7-173f-4d29-b52c-3edc60baf6a2"` InviteHash string `json:"invite_hash" validate:"required" example:"d69bdef7-173f-4d29-b52c-3edc60baf6a2"`
Email string `json:"email" validate:"required,email" example:"gabi@geeksinthewoods.com"` Email string `json:"email" validate:"required,email" example:"gabi@geeksinthewoods.com"`
FirstName string `json:"first_name" validate:"required" example:"Gabi"` FirstName string `json:"first_name" validate:"required" example:"Gabi"`
@ -67,12 +72,12 @@ func NewInviteHash(ctx context.Context, secretKey, userID, accountID, requestIp
} }
// ParseInviteHash extracts the details encrypted in the hash string. // ParseInviteHash extracts the details encrypted in the hash string.
func ParseInviteHash(ctx context.Context, secretKey string, str string, now time.Time) (*InviteHash, error) { func ParseInviteHash(ctx context.Context, encrypted, secretKey string, now time.Time) (*InviteHash, error) {
crypto, err := symcrypto.New(secretKey) crypto, err := symcrypto.New(secretKey)
if err != nil { if err != nil {
return nil, errors.WithStack(err) return nil, errors.WithStack(err)
} }
hashStr, err := crypto.Decrypt(str) hashStr, err := crypto.Decrypt(encrypted)
if err != nil { if err != nil {
return nil, errors.WithStack(err) return nil, errors.WithStack(err)
} }