From 900cfcf7135d8eb93f980be94bf686ef942b5704 Mon Sep 17 00:00:00 2001
From: Lee Brown
Date: Mon, 5 Aug 2019 18:47:42 -0800
Subject: [PATCH] fixed redirects to handle saving the session when set
---
cmd/web-app/handlers/account.go | 7 +-
cmd/web-app/handlers/projects.go | 21 +---
cmd/web-app/handlers/root.go | 11 +--
cmd/web-app/handlers/signup.go | 11 +--
cmd/web-app/handlers/user.go | 41 ++------
cmd/web-app/handlers/users.go | 95 ++++++++++---------
.../templates/content/user-view.gohtml | 10 +-
.../templates/content/users-invite.gohtml | 2 +-
.../templates/content/users-view.gohtml | 2 +-
internal/platform/web/models.go | 12 ++-
internal/platform/web/response.go | 13 +++
internal/user_account/invite/invite.go | 91 ++++++++++++++----
internal/user_account/invite/invite_test.go | 22 ++---
internal/user_account/invite/models.go | 9 +-
14 files changed, 191 insertions(+), 156 deletions(-)
diff --git a/cmd/web-app/handlers/account.go b/cmd/web-app/handlers/account.go
index f43a11e..03a32d7 100644
--- a/cmd/web-app/handlers/account.go
+++ b/cmd/web-app/handlers/account.go
@@ -209,13 +209,8 @@ func (h *Account) Update(ctx context.Context, w http.ResponseWriter, r *http.Req
webcontext.SessionFlashSuccess(ctx,
"Account Updated",
"Account profile successfully updated.")
- err = webcontext.ContextSession(ctx).Save(r, w)
- if err != nil {
- return false, err
- }
- http.Redirect(w, r, "/account", http.StatusFound)
- return true, nil
+ return true, web.Redirect(ctx, w, r, "/account", http.StatusFound)
}
acc, err := account.ReadByID(ctx, claims, h.MasterDB, claims.Audience)
diff --git a/cmd/web-app/handlers/projects.go b/cmd/web-app/handlers/projects.go
index 8aac2ba..1025e60 100644
--- a/cmd/web-app/handlers/projects.go
+++ b/cmd/web-app/handlers/projects.go
@@ -203,13 +203,8 @@ func (h *Projects) Create(ctx context.Context, w http.ResponseWriter, r *http.Re
webcontext.SessionFlashSuccess(ctx,
"Project Created",
"Project successfully created.")
- err = webcontext.ContextSession(ctx).Save(r, w)
- if err != nil {
- return false, err
- }
- http.Redirect(w, r, urlProjectsView(usr.ID), http.StatusFound)
- return true, nil
+ return true, web.Redirect(ctx, w, r, urlProjectsView(usr.ID), http.StatusFound)
}
return false, nil
@@ -266,13 +261,8 @@ func (h *Projects) View(ctx context.Context, w http.ResponseWriter, r *http.Requ
webcontext.SessionFlashSuccess(ctx,
"Project Archive",
"Project successfully archive.")
- err = webcontext.ContextSession(ctx).Save(r, w)
- if err != nil {
- return false, err
- }
- http.Redirect(w, r, urlProjectsIndex(), http.StatusFound)
- return true, nil
+ return true, web.Redirect(ctx, w, r, urlProjectsIndex(), http.StatusFound)
}
}
@@ -347,13 +337,8 @@ func (h *Projects) Update(ctx context.Context, w http.ResponseWriter, r *http.Re
webcontext.SessionFlashSuccess(ctx,
"Project Updated",
"Project successfully updated.")
- err = webcontext.ContextSession(ctx).Save(r, w)
- if err != nil {
- return false, err
- }
- http.Redirect(w, r, urlProjectsView(req.ID), http.StatusFound)
- return true, nil
+ return true, web.Redirect(ctx, w, r, urlProjectsView(req.ID), http.StatusFound)
}
return false, nil
diff --git a/cmd/web-app/handlers/root.go b/cmd/web-app/handlers/root.go
index 93c1726..a327e42 100644
--- a/cmd/web-app/handlers/root.go
+++ b/cmd/web-app/handlers/root.go
@@ -40,12 +40,10 @@ func (h *Root) indexDashboard(ctx context.Context, w http.ResponseWriter, r *htt
// indexDefault loads the root index page when a user has no authentication.
func (u *Root) indexDefault(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error {
-
return u.Renderer.Render(ctx, w, r, tmplLayoutSite, "site-index.gohtml", web.MIMETextHTMLCharsetUTF8, http.StatusOK, nil)
-
}
-// indexDefault loads the root index page when a user has no authentication.
+// SitePage loads the page with the layout for site instead of the app base.
func (u *Root) SitePage(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error {
var tmpName string
@@ -63,18 +61,15 @@ func (u *Root) SitePage(ctx context.Context, w http.ResponseWriter, r *http.Requ
case "/legal/terms":
tmpName = "legal-terms.gohtml"
default:
- http.Redirect(w, r, "/", http.StatusFound)
- return nil
+ return web.Redirect(ctx, w, r, "/", http.StatusFound)
}
return u.Renderer.Render(ctx, w, r, tmplLayoutSite, tmpName, web.MIMETextHTMLCharsetUTF8, http.StatusOK, nil)
-
}
// IndexHtml redirects /index.html to the website root page.
func (u *Root) IndexHtml(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error {
- http.Redirect(w, r, "/", http.StatusMovedPermanently)
- return nil
+ return web.Redirect(ctx, w, r, "/", http.StatusMovedPermanently)
}
// RobotHandler returns a robots.txt response.
diff --git a/cmd/web-app/handlers/signup.go b/cmd/web-app/handlers/signup.go
index a25bd90..4f3ece4 100644
--- a/cmd/web-app/handlers/signup.go
+++ b/cmd/web-app/handlers/signup.go
@@ -86,14 +86,9 @@ func (h *Signup) Step1(ctx context.Context, w http.ResponseWriter, r *http.Reque
webcontext.SessionFlashSuccess(ctx,
"Thank you for Joining",
"You workflow will be a breeze starting today.")
- err = webcontext.ContextSession(ctx).Save(r, w)
- if err != nil {
- return false, err
- }
// Redirect the user to the dashboard.
- http.Redirect(w, r, "/", http.StatusFound)
- return true, nil
+ return true, web.Redirect(ctx, w, r, "/", http.StatusFound)
}
return false, nil
@@ -103,6 +98,10 @@ func (h *Signup) Step1(ctx context.Context, w http.ResponseWriter, r *http.Reque
if err != nil {
return web.RenderError(ctx, w, r, err, h.Renderer, TmplLayoutBase, TmplContentErrorGeneric, web.MIMETextHTMLCharsetUTF8)
} else if end {
+ err = webcontext.ContextSession(ctx).Save(r, w)
+ if err != nil {
+ return err
+ }
return nil
}
diff --git a/cmd/web-app/handlers/user.go b/cmd/web-app/handlers/user.go
index 119135a..9817ac1 100644
--- a/cmd/web-app/handlers/user.go
+++ b/cmd/web-app/handlers/user.go
@@ -112,8 +112,7 @@ func (h *User) Login(ctx context.Context, w http.ResponseWriter, r *http.Request
}
// Redirect the user to the dashboard.
- http.Redirect(w, r, redirectUri, http.StatusFound)
- return true, nil
+ return true, web.Redirect(ctx, w, r, redirectUri, http.StatusFound)
}
return false, nil
@@ -148,9 +147,7 @@ func (h *User) Logout(ctx context.Context, w http.ResponseWriter, r *http.Reques
}
// Redirect the user to the root page.
- http.Redirect(w, r, "/", http.StatusFound)
-
- return nil
+ return web.Redirect(ctx, w, r, "/", http.StatusFound)
}
// ResetPassword allows a user to perform forgot password.
@@ -281,8 +278,7 @@ func (h *User) ResetConfirm(ctx context.Context, w http.ResponseWriter, r *http.
}
// Redirect the user to the dashboard.
- http.Redirect(w, r, "/", http.StatusFound)
- return true, nil
+ return true, web.Redirect(ctx, w, r, "/", http.StatusFound)
}
_, err = user.ParseResetHash(ctx, h.SecretKey, resetHash, ctxValues.Now)
@@ -432,13 +428,8 @@ func (h *User) Update(ctx context.Context, w http.ResponseWriter, r *http.Reques
webcontext.SessionFlashSuccess(ctx,
"Profile Updated",
"User profile successfully updated.")
- err = webcontext.ContextSession(ctx).Save(r, w)
- if err != nil {
- return false, err
- }
- http.Redirect(w, r, "/user", http.StatusFound)
- return true, nil
+ return true, web.Redirect(ctx, w, r, "/user", http.StatusFound)
}
return false, nil
@@ -584,16 +575,8 @@ func (h *User) VirtualLogin(ctx context.Context, w http.ResponseWriter, r *http.
fmt.Sprintf("You are now virtually logged into user %s.",
usr.Response(ctx).Name))
- // Write the session to the client.
- err = webcontext.ContextSession(ctx).Save(r, w)
- if err != nil {
- return false, err
- }
-
// Redirect the user to the dashboard with the new credentials.
- http.Redirect(w, r, "/", http.StatusFound)
-
- return true, nil
+ return true, web.Redirect(ctx, w, r, "/", http.StatusFound)
}
return false, nil
@@ -724,9 +707,7 @@ func (h *User) VirtualLogout(ctx context.Context, w http.ResponseWriter, r *http
}
// Redirect the user to the dashboard with the new credentials.
- http.Redirect(w, r, "/", http.StatusFound)
-
- return nil
+ return web.Redirect(ctx, w, r, "/", http.StatusFound)
}
// VirtualLogin handles switching the scope of the context to another user.
@@ -800,16 +781,8 @@ func (h *User) SwitchAccount(ctx context.Context, w http.ResponseWriter, r *http
fmt.Sprintf("You are now logged into account %s.",
acc.Response(ctx).Name))
- // Write the session to the client.
- err = webcontext.ContextSession(ctx).Save(r, w)
- if err != nil {
- return false, err
- }
-
// Redirect the user to the dashboard with the new credentials.
- http.Redirect(w, r, "/", http.StatusFound)
-
- return true, nil
+ return true, web.Redirect(ctx, w, r, "/", http.StatusFound)
}
return false, nil
diff --git a/cmd/web-app/handlers/users.go b/cmd/web-app/handlers/users.go
index be36709..329311c 100644
--- a/cmd/web-app/handlers/users.go
+++ b/cmd/web-app/handlers/users.go
@@ -259,13 +259,8 @@ func (h *Users) Create(ctx context.Context, w http.ResponseWriter, r *http.Reque
webcontext.SessionFlashSuccess(ctx,
"User Created",
"User successfully created.")
- err = webcontext.ContextSession(ctx).Save(r, w)
- if err != nil {
- return false, err
- }
- http.Redirect(w, r, urlUsersView(usr.ID), http.StatusFound)
- return true, nil
+ return true, web.Redirect(ctx, w, r, urlUsersView(usr.ID), http.StatusFound)
}
return false, nil
@@ -333,13 +328,8 @@ func (h *Users) View(ctx context.Context, w http.ResponseWriter, r *http.Request
webcontext.SessionFlashSuccess(ctx,
"User Archive",
"User successfully archive.")
- err = webcontext.ContextSession(ctx).Save(r, w)
- if err != nil {
- return false, err
- }
- http.Redirect(w, r, urlUsersIndex(), http.StatusFound)
- return true, nil
+ return true, web.Redirect(ctx, w, r, urlUsersIndex(), http.StatusFound)
}
}
@@ -483,13 +473,8 @@ func (h *Users) Update(ctx context.Context, w http.ResponseWriter, r *http.Reque
webcontext.SessionFlashSuccess(ctx,
"User Updated",
"User successfully updated.")
- err = webcontext.ContextSession(ctx).Save(r, w)
- if err != nil {
- return false, err
- }
- http.Redirect(w, r, urlUsersView(req.ID), http.StatusFound)
- return true, nil
+ return true, web.Redirect(ctx, w, r, urlUsersView(req.ID), http.StatusFound)
}
return false, nil
@@ -607,13 +592,7 @@ func (h *Users) Invite(ctx context.Context, w http.ResponseWriter, r *http.Reque
"No users were invited.")
}
- err = webcontext.ContextSession(ctx).Save(r, w)
- if err != nil {
- return false, err
- }
-
- http.Redirect(w, r, urlUsersIndex(), http.StatusFound)
- return true, nil
+ return true, web.Redirect(ctx, w, r, urlUsersIndex(), http.StatusFound)
}
return false, nil
@@ -652,7 +631,7 @@ func (h *Users) InviteAccept(ctx context.Context, w http.ResponseWriter, r *http
}
//
- req := new(invite.AcceptInviteRequest)
+ req := new(invite.AcceptInviteUserRequest)
data := make(map[string]interface{})
f := func() (bool, error) {
@@ -670,30 +649,33 @@ func (h *Users) InviteAccept(ctx context.Context, w http.ResponseWriter, r *http
// Append the query param value to the request.
req.InviteHash = inviteHash
- hash, err := invite.AcceptInvite(ctx, h.MasterDB, *req, h.SecretKey, ctxValues.Now)
+ hash, err := invite.AcceptInviteUser(ctx, h.MasterDB, *req, h.SecretKey, ctxValues.Now)
if err != nil {
switch errors.Cause(err) {
case invite.ErrInviteExpired:
webcontext.SessionFlashError(ctx,
"Invite Expired",
"The invite has expired.")
+
return false, nil
+
case invite.ErrUserAccountActive:
webcontext.SessionFlashError(ctx,
"User already Active",
- "The user already is already active for the account. Try to login or use forgot password.")
- http.Redirect(w, r, "/user/login", http.StatusFound)
- return true, nil
- case invite.ErrInviteUserPasswordSet:
+ "The user is already is already active for the account. Try to login or use forgot password.")
+
+ return true, web.Redirect(ctx, w, r, "/user/login", http.StatusFound)
+
+ case invite.ErrNoPendingInvite:
webcontext.SessionFlashError(ctx,
- "Invite already Accepted",
+ "Invite Accepted",
"The invite has already been accepted. Try to login or use forgot password.")
- http.Redirect(w, r, "/user/login", http.StatusFound)
- return true, nil
+
+ return true, web.Redirect(ctx, w, r, "/user/login", http.StatusFound)
+
case user_account.ErrNotFound:
return false, err
- case invite.ErrNoPendingInvite:
- return false, err
+
default:
if verr, ok := weberror.NewValidationError(ctx, err); ok {
data["validationErrors"] = verr.(*weberror.Error)
@@ -732,36 +714,57 @@ func (h *Users) InviteAccept(ctx context.Context, w http.ResponseWriter, r *http
}
// Redirect the user to the dashboard.
- http.Redirect(w, r, "/", http.StatusFound)
- return true, nil
+ return true, web.Redirect(ctx, w, r, "/", http.StatusFound)
}
- hash, err := invite.ParseInviteHash(ctx, h.SecretKey, inviteHash, ctxValues.Now)
+ usrAcc, err := invite.AcceptInvite(ctx, h.MasterDB, invite.AcceptInviteRequest{
+ InviteHash: inviteHash,
+ }, h.SecretKey, ctxValues.Now)
if err != nil {
+
switch errors.Cause(err) {
case invite.ErrInviteExpired:
webcontext.SessionFlashError(ctx,
"Invite Expired",
"The invite has expired.")
- return false, nil
- case invite.ErrInviteUserPasswordSet:
+
+ return true, web.Redirect(ctx, w, r, "/user/login", http.StatusFound)
+
+ case invite.ErrUserAccountActive:
webcontext.SessionFlashError(ctx,
- "Invite already Accepted",
+ "User already Active",
+ "The user is already is already active for the account. Try to login or use forgot password.")
+
+ return true, web.Redirect(ctx, w, r, "/user/login", http.StatusFound)
+
+ case invite.ErrNoPendingInvite:
+ webcontext.SessionFlashError(ctx,
+ "Invite Accepted",
"The invite has already been accepted. Try to login or use forgot password.")
- http.Redirect(w, r, "/user/login", http.StatusFound)
- return true, nil
+
+ return true, web.Redirect(ctx, w, r, "/user/login", http.StatusFound)
+
+ case user_account.ErrNotFound:
+ return false, err
default:
if verr, ok := weberror.NewValidationError(ctx, err); ok {
data["validationErrors"] = verr.(*weberror.Error)
+
return false, nil
} else {
return false, err
}
}
+ } else if usrAcc.Status == user_account.UserAccountStatus_Active {
+ webcontext.SessionFlashError(ctx,
+ "Invite Accepted",
+ "The invite has been accepted. Login to continue.")
+
+ return true, web.Redirect(ctx, w, r, "/user/login", http.StatusFound)
}
// Read user by ID with no claims.
- usr, err := user.ReadByID(ctx, auth.Claims{}, h.MasterDB, hash.UserID)
+ usr, err := user.ReadByID(ctx, auth.Claims{}, h.MasterDB, usrAcc.UserID)
if err != nil {
return false, err
}
@@ -791,7 +794,7 @@ func (h *Users) InviteAccept(ctx context.Context, w http.ResponseWriter, r *http
data["form"] = req
- if verr, ok := weberror.NewValidationError(ctx, webcontext.Validator().Struct(invite.AcceptInviteRequest{})); ok {
+ if verr, ok := weberror.NewValidationError(ctx, webcontext.Validator().Struct(invite.AcceptInviteUserRequest{})); ok {
data["validationDefaults"] = verr.(*weberror.Error)
}
diff --git a/cmd/web-app/templates/content/user-view.gohtml b/cmd/web-app/templates/content/user-view.gohtml
index 60584fa..389d987 100644
--- a/cmd/web-app/templates/content/user-view.gohtml
+++ b/cmd/web-app/templates/content/user-view.gohtml
@@ -55,13 +55,13 @@
Role
{{ if .userAccount }}
- {{ range $r := .userAccount.Roles }}
- {{ if eq $r "admin" }}
- {{ $r }}
+ {{ range $r := .userAccount.Roles.Options }}{{ if $r.Selected }}
+ {{ if eq $r.Value "admin" }}
+ {{ $r.Title }}
{{else}}
- {{ $r }}
+ {{ $r.Title }}
{{end}}
- {{ end }}
+ {{ end }}{{ end }}
{{ end }}
diff --git a/cmd/web-app/templates/content/users-invite.gohtml b/cmd/web-app/templates/content/users-invite.gohtml
index 7b18d55..d0b9a06 100644
--- a/cmd/web-app/templates/content/users-invite.gohtml
+++ b/cmd/web-app/templates/content/users-invite.gohtml
@@ -21,7 +21,7 @@
diff --git a/cmd/web-app/templates/content/users-view.gohtml b/cmd/web-app/templates/content/users-view.gohtml
index f7afc30..6d5f05a 100644
--- a/cmd/web-app/templates/content/users-view.gohtml
+++ b/cmd/web-app/templates/content/users-view.gohtml
@@ -67,7 +67,7 @@
Role
{{ if .userAccount }}
- {{ range $r := .userAccount.Roles }}{{ if $r.Selected }}
+ {{ range $r := .userAccount.Roles.Options }}{{ if $r.Selected }}
{{ if eq $r.Value "admin" }}
{{ $r.Title }}
{{else}}
diff --git a/internal/platform/web/models.go b/internal/platform/web/models.go
index b2fd123..5eee47a 100644
--- a/internal/platform/web/models.go
+++ b/internal/platform/web/models.go
@@ -117,12 +117,20 @@ func NewEnumResponse(ctx context.Context, value interface{}, options ...interfac
}
// EnumResponse is a response friendly format for displaying a multi select enum.
-type EnumMultiResponse []EnumOption
+type EnumMultiResponse struct {
+ Values []string `json:"values" example:"active_etc"`
+ Options []EnumOption `json:"options,omitempty"`
+}
// NewEnumMultiResponse returns a display friendly format for a multi enum field.
func NewEnumMultiResponse(ctx context.Context, selected []interface{}, options ...interface{}) EnumMultiResponse {
var er EnumMultiResponse
+ for _, s := range selected {
+ selStr := fmt.Sprintf("%s", s)
+ er.Values = append(er.Values, selStr)
+ }
+
for _, opt := range options {
optStr := fmt.Sprintf("%s", opt)
opt := EnumOption{
@@ -137,7 +145,7 @@ func NewEnumMultiResponse(ctx context.Context, selected []interface{}, options .
}
}
- er = append(er, opt)
+ er.Options = append(er.Options, opt)
}
return er
diff --git a/internal/platform/web/response.go b/internal/platform/web/response.go
index 0db5ef2..37dc558 100644
--- a/internal/platform/web/response.go
+++ b/internal/platform/web/response.go
@@ -256,3 +256,16 @@ func StaticHandler(ctx context.Context, w http.ResponseWriter, r *http.Request,
return nil
}
+
+// Redirect ensures the session is flushed to the browser before the redirect is issued.
+func Redirect(ctx context.Context, w http.ResponseWriter, r *http.Request, url string, code int) error {
+ if sess := webcontext.ContextSession(ctx); sess != nil {
+ if err := sess.Save(r, w); err != nil {
+ return err
+ }
+ }
+
+ http.Redirect(w, r, url, code)
+
+ return nil
+}
diff --git a/internal/user_account/invite/invite.go b/internal/user_account/invite/invite.go
index 43f1f14..a1fe4b1 100644
--- a/internal/user_account/invite/invite.go
+++ b/internal/user_account/invite/invite.go
@@ -26,9 +26,6 @@ var (
// ErrUserAccountActive occurs when the user already has an active user_account entry.
ErrUserAccountActive = errors.New("User already active.")
-
- // ErrInviteUserPasswordSet occurs when the the reset hash exceeds the expiration.
- ErrInviteUserPasswordSet = errors.New("User password set")
)
// SendUserInvites sends emails to the users inviting them to join an account.
@@ -181,7 +178,7 @@ func SendUserInvites(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, r
}
// AcceptInvite updates the user using the provided invite hash.
-func AcceptInvite(ctx context.Context, dbConn *sqlx.DB, req AcceptInviteRequest, secretKey string, now time.Time) (*InviteHash, error) {
+func AcceptInvite(ctx context.Context, dbConn *sqlx.DB, req AcceptInviteRequest, secretKey string, now time.Time) (*user_account.UserAccount, error) {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user_account.invite.AcceptInvite")
defer span.Finish()
@@ -193,7 +190,7 @@ func AcceptInvite(ctx context.Context, dbConn *sqlx.DB, req AcceptInviteRequest,
return nil, err
}
- hash, err := ParseInviteHash(ctx, secretKey, req.InviteHash, now)
+ hash, err := ParseInviteHash(ctx, req.InviteHash, secretKey, now)
if err != nil {
return nil, err
}
@@ -216,24 +213,86 @@ func AcceptInvite(ctx context.Context, dbConn *sqlx.DB, req AcceptInviteRequest,
AccountID: hash.AccountID,
})
if err != nil {
- return nil, nil
+ return nil, err
}
// Ensure the entry has the status of invited.
if usrAcc.Status != user_account.UserAccountStatus_Invited {
// If the entry is already active
if usrAcc.Status == user_account.UserAccountStatus_Active {
- return hash, errors.WithStack(ErrUserAccountActive)
+ return usrAcc, errors.WithStack(ErrUserAccountActive)
+ }
+ return usrAcc, errors.WithStack(ErrNoPendingInvite)
+ }
+
+ // If the user already has a password set, then just update the user_account entry to status of active.
+ // The user will need to login and should not be auto-authenticated.
+ if len(u.PasswordHash) > 0 {
+ usrAcc.Status = user_account.UserAccountStatus_Active
+
+ err = user_account.Update(ctx, auth.Claims{}, dbConn, user_account.UserAccountUpdateRequest{
+ UserID: usrAcc.UserID,
+ AccountID: usrAcc.AccountID,
+ Status: &usrAcc.Status,
+ }, now)
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ return usrAcc, nil
+}
+
+// AcceptInviteUser updates the user using the provided invite hash.
+func AcceptInviteUser(ctx context.Context, dbConn *sqlx.DB, req AcceptInviteUserRequest, secretKey string, now time.Time) (*user_account.UserAccount, error) {
+ span, ctx := tracer.StartSpanFromContext(ctx, "internal.user_account.invite.AcceptInviteUser")
+ defer span.Finish()
+
+ v := webcontext.Validator()
+
+ // Validate the request.
+ err := v.StructCtx(ctx, req)
+ if err != nil {
+ return nil, err
+ }
+
+ hash, err := ParseInviteHash(ctx, req.InviteHash, secretKey, now)
+ if err != nil {
+ return nil, err
+ }
+
+ u, err := user.Read(ctx, auth.Claims{}, dbConn,
+ user.UserReadRequest{ID: hash.UserID, IncludeArchived: true})
+ if err != nil {
+ return nil, err
+ }
+
+ if u.ArchivedAt != nil && !u.ArchivedAt.Time.IsZero() {
+ err = user.Restore(ctx, auth.Claims{}, dbConn, user.UserRestoreRequest{ID: hash.UserID}, now)
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ usrAcc, err := user_account.Read(ctx, auth.Claims{}, dbConn, user_account.UserAccountReadRequest{
+ UserID: hash.UserID,
+ AccountID: hash.AccountID,
+ })
+ if err != nil {
+ return nil, err
+ }
+
+ // Ensure the entry has the status of invited.
+ if usrAcc.Status != user_account.UserAccountStatus_Invited {
+ // If the entry is already active
+ if usrAcc.Status == user_account.UserAccountStatus_Active {
+ return usrAcc, errors.WithStack(ErrUserAccountActive)
}
return nil, errors.WithStack(ErrNoPendingInvite)
}
- if len(u.PasswordHash) > 0 {
- // Do not update the password for a user that already has a password set.
- return nil, errors.WithStack(ErrInviteUserPasswordSet)
- }
-
- // These two calls, user.Update and user.UpdatePassword should probably be in a transaction!
+ // These three calls, user.Update, user.UpdatePassword, and user_account.Update
+ // should probably be in a transaction!
err = user.Update(ctx, auth.Claims{}, dbConn, user.UserUpdateRequest{
ID: hash.UserID,
Email: &req.Email,
@@ -254,15 +313,15 @@ func AcceptInvite(ctx context.Context, dbConn *sqlx.DB, req AcceptInviteRequest,
return nil, err
}
- activeStatus := user_account.UserAccountStatus_Active
+ usrAcc.Status = user_account.UserAccountStatus_Active
err = user_account.Update(ctx, auth.Claims{}, dbConn, user_account.UserAccountUpdateRequest{
UserID: usrAcc.UserID,
AccountID: usrAcc.AccountID,
- Status: &activeStatus,
+ Status: &usrAcc.Status,
}, now)
if err != nil {
return nil, err
}
- return hash, nil
+ return usrAcc, nil
}
diff --git a/internal/user_account/invite/invite_test.go b/internal/user_account/invite/invite_test.go
index 049190c..032c6e7 100644
--- a/internal/user_account/invite/invite_test.go
+++ b/internal/user_account/invite/invite_test.go
@@ -148,13 +148,13 @@ func TestSendUserInvites(t *testing.T) {
// Ensure validation is working by trying ResetConfirm with an empty request.
{
- expectedErr := errors.New("Key: 'AcceptInviteRequest.invite_hash' Error:Field validation for 'invite_hash' failed on the 'required' tag\n" +
- "Key: 'AcceptInviteRequest.email' Error:Field validation for 'email' failed on the 'required' tag\n" +
- "Key: 'AcceptInviteRequest.first_name' Error:Field validation for 'first_name' failed on the 'required' tag\n" +
- "Key: 'AcceptInviteRequest.last_name' Error:Field validation for 'last_name' failed on the 'required' tag\n" +
- "Key: 'AcceptInviteRequest.password' Error:Field validation for 'password' failed on the 'required' tag\n" +
- "Key: 'AcceptInviteRequest.password_confirm' Error:Field validation for 'password_confirm' failed on the 'required' tag")
- _, err = AcceptInvite(ctx, test.MasterDB, AcceptInviteRequest{}, secretKey, now)
+ expectedErr := errors.New("Key: 'AcceptInviteUserRequest.invite_hash' Error:Field validation for 'invite_hash' failed on the 'required' tag\n" +
+ "Key: 'AcceptInviteUserRequest.email' Error:Field validation for 'email' failed on the 'required' tag\n" +
+ "Key: 'AcceptInviteUserRequest.first_name' Error:Field validation for 'first_name' failed on the 'required' tag\n" +
+ "Key: 'AcceptInviteUserRequest.last_name' Error:Field validation for 'last_name' failed on the 'required' tag\n" +
+ "Key: 'AcceptInviteUserRequest.password' Error:Field validation for 'password' failed on the 'required' tag\n" +
+ "Key: 'AcceptInviteUserRequest.password_confirm' Error:Field validation for 'password_confirm' failed on the 'required' tag")
+ _, err = AcceptInviteUser(ctx, test.MasterDB, AcceptInviteUserRequest{}, secretKey, now)
if err == nil {
t.Logf("\t\tWant: %+v", expectedErr)
t.Fatalf("\t%s\tResetConfirm failed.", tests.Failed)
@@ -174,7 +174,7 @@ func TestSendUserInvites(t *testing.T) {
// Ensure the TTL is enforced.
{
newPass := uuid.NewRandom().String()
- _, err = AcceptInvite(ctx, test.MasterDB, AcceptInviteRequest{
+ _, err = AcceptInviteUser(ctx, test.MasterDB, AcceptInviteUserRequest{
InviteHash: inviteHashes[0],
Email: inviteEmails[0],
FirstName: "Foo",
@@ -194,7 +194,7 @@ func TestSendUserInvites(t *testing.T) {
for idx, inviteHash := range inviteHashes {
newPass := uuid.NewRandom().String()
- hash, err := AcceptInvite(ctx, test.MasterDB, AcceptInviteRequest{
+ hash, err := AcceptInviteUser(ctx, test.MasterDB, AcceptInviteUserRequest{
InviteHash: inviteHash,
Email: inviteEmails[idx],
FirstName: "Foo",
@@ -227,7 +227,7 @@ func TestSendUserInvites(t *testing.T) {
// Ensure the reset hash does not work after its used.
{
newPass := uuid.NewRandom().String()
- _, err = AcceptInvite(ctx, test.MasterDB, AcceptInviteRequest{
+ _, err = AcceptInviteUser(ctx, test.MasterDB, AcceptInviteUserRequest{
InviteHash: inviteHashes[0],
Email: inviteEmails[0],
FirstName: "Foo",
@@ -237,7 +237,7 @@ func TestSendUserInvites(t *testing.T) {
}, secretKey, now)
if errors.Cause(err) != ErrUserAccountActive {
t.Logf("\t\tGot : %+v", errors.Cause(err))
- t.Logf("\t\tWant: %+v", ErrInviteUserPasswordSet)
+ t.Logf("\t\tWant: %+v", ErrUserAccountActive)
t.Fatalf("\t%s\tInviteAccept verify reuse failed.", tests.Failed)
}
t.Logf("\t%s\tInviteAccept verify reuse disabled ok.", tests.Success)
diff --git a/internal/user_account/invite/models.go b/internal/user_account/invite/models.go
index 759fc57..ca87007 100644
--- a/internal/user_account/invite/models.go
+++ b/internal/user_account/invite/models.go
@@ -32,6 +32,11 @@ type InviteHash struct {
// AcceptInviteRequest defines the fields need to complete an invite request.
type AcceptInviteRequest struct {
+ InviteHash string `json:"invite_hash" validate:"required" example:"d69bdef7-173f-4d29-b52c-3edc60baf6a2"`
+}
+
+// AcceptInviteUserRequest defines the fields need to complete an invite request.
+type AcceptInviteUserRequest struct {
InviteHash string `json:"invite_hash" validate:"required" example:"d69bdef7-173f-4d29-b52c-3edc60baf6a2"`
Email string `json:"email" validate:"required,email" example:"gabi@geeksinthewoods.com"`
FirstName string `json:"first_name" validate:"required" example:"Gabi"`
@@ -67,12 +72,12 @@ func NewInviteHash(ctx context.Context, secretKey, userID, accountID, requestIp
}
// ParseInviteHash extracts the details encrypted in the hash string.
-func ParseInviteHash(ctx context.Context, secretKey string, str string, now time.Time) (*InviteHash, error) {
+func ParseInviteHash(ctx context.Context, encrypted, secretKey string, now time.Time) (*InviteHash, error) {
crypto, err := symcrypto.New(secretKey)
if err != nil {
return nil, errors.WithStack(err)
}
- hashStr, err := crypto.Decrypt(str)
+ hashStr, err := crypto.Decrypt(encrypted)
if err != nil {
return nil, errors.WithStack(err)
}