From 900cfcf7135d8eb93f980be94bf686ef942b5704 Mon Sep 17 00:00:00 2001 From: Lee Brown Date: Mon, 5 Aug 2019 18:47:42 -0800 Subject: [PATCH] fixed redirects to handle saving the session when set --- cmd/web-app/handlers/account.go | 7 +- cmd/web-app/handlers/projects.go | 21 +--- cmd/web-app/handlers/root.go | 11 +-- cmd/web-app/handlers/signup.go | 11 +-- cmd/web-app/handlers/user.go | 41 ++------ cmd/web-app/handlers/users.go | 95 ++++++++++--------- .../templates/content/user-view.gohtml | 10 +- .../templates/content/users-invite.gohtml | 2 +- .../templates/content/users-view.gohtml | 2 +- internal/platform/web/models.go | 12 ++- internal/platform/web/response.go | 13 +++ internal/user_account/invite/invite.go | 91 ++++++++++++++---- internal/user_account/invite/invite_test.go | 22 ++--- internal/user_account/invite/models.go | 9 +- 14 files changed, 191 insertions(+), 156 deletions(-) diff --git a/cmd/web-app/handlers/account.go b/cmd/web-app/handlers/account.go index f43a11e..03a32d7 100644 --- a/cmd/web-app/handlers/account.go +++ b/cmd/web-app/handlers/account.go @@ -209,13 +209,8 @@ func (h *Account) Update(ctx context.Context, w http.ResponseWriter, r *http.Req webcontext.SessionFlashSuccess(ctx, "Account Updated", "Account profile successfully updated.") - err = webcontext.ContextSession(ctx).Save(r, w) - if err != nil { - return false, err - } - http.Redirect(w, r, "/account", http.StatusFound) - return true, nil + return true, web.Redirect(ctx, w, r, "/account", http.StatusFound) } acc, err := account.ReadByID(ctx, claims, h.MasterDB, claims.Audience) diff --git a/cmd/web-app/handlers/projects.go b/cmd/web-app/handlers/projects.go index 8aac2ba..1025e60 100644 --- a/cmd/web-app/handlers/projects.go +++ b/cmd/web-app/handlers/projects.go @@ -203,13 +203,8 @@ func (h *Projects) Create(ctx context.Context, w http.ResponseWriter, r *http.Re webcontext.SessionFlashSuccess(ctx, "Project Created", "Project successfully created.") - err = webcontext.ContextSession(ctx).Save(r, w) - if err != nil { - return false, err - } - http.Redirect(w, r, urlProjectsView(usr.ID), http.StatusFound) - return true, nil + return true, web.Redirect(ctx, w, r, urlProjectsView(usr.ID), http.StatusFound) } return false, nil @@ -266,13 +261,8 @@ func (h *Projects) View(ctx context.Context, w http.ResponseWriter, r *http.Requ webcontext.SessionFlashSuccess(ctx, "Project Archive", "Project successfully archive.") - err = webcontext.ContextSession(ctx).Save(r, w) - if err != nil { - return false, err - } - http.Redirect(w, r, urlProjectsIndex(), http.StatusFound) - return true, nil + return true, web.Redirect(ctx, w, r, urlProjectsIndex(), http.StatusFound) } } @@ -347,13 +337,8 @@ func (h *Projects) Update(ctx context.Context, w http.ResponseWriter, r *http.Re webcontext.SessionFlashSuccess(ctx, "Project Updated", "Project successfully updated.") - err = webcontext.ContextSession(ctx).Save(r, w) - if err != nil { - return false, err - } - http.Redirect(w, r, urlProjectsView(req.ID), http.StatusFound) - return true, nil + return true, web.Redirect(ctx, w, r, urlProjectsView(req.ID), http.StatusFound) } return false, nil diff --git a/cmd/web-app/handlers/root.go b/cmd/web-app/handlers/root.go index 93c1726..a327e42 100644 --- a/cmd/web-app/handlers/root.go +++ b/cmd/web-app/handlers/root.go @@ -40,12 +40,10 @@ func (h *Root) indexDashboard(ctx context.Context, w http.ResponseWriter, r *htt // indexDefault loads the root index page when a user has no authentication. func (u *Root) indexDefault(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { - return u.Renderer.Render(ctx, w, r, tmplLayoutSite, "site-index.gohtml", web.MIMETextHTMLCharsetUTF8, http.StatusOK, nil) - } -// indexDefault loads the root index page when a user has no authentication. +// SitePage loads the page with the layout for site instead of the app base. func (u *Root) SitePage(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { var tmpName string @@ -63,18 +61,15 @@ func (u *Root) SitePage(ctx context.Context, w http.ResponseWriter, r *http.Requ case "/legal/terms": tmpName = "legal-terms.gohtml" default: - http.Redirect(w, r, "/", http.StatusFound) - return nil + return web.Redirect(ctx, w, r, "/", http.StatusFound) } return u.Renderer.Render(ctx, w, r, tmplLayoutSite, tmpName, web.MIMETextHTMLCharsetUTF8, http.StatusOK, nil) - } // IndexHtml redirects /index.html to the website root page. func (u *Root) IndexHtml(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { - http.Redirect(w, r, "/", http.StatusMovedPermanently) - return nil + return web.Redirect(ctx, w, r, "/", http.StatusMovedPermanently) } // RobotHandler returns a robots.txt response. diff --git a/cmd/web-app/handlers/signup.go b/cmd/web-app/handlers/signup.go index a25bd90..4f3ece4 100644 --- a/cmd/web-app/handlers/signup.go +++ b/cmd/web-app/handlers/signup.go @@ -86,14 +86,9 @@ func (h *Signup) Step1(ctx context.Context, w http.ResponseWriter, r *http.Reque webcontext.SessionFlashSuccess(ctx, "Thank you for Joining", "You workflow will be a breeze starting today.") - err = webcontext.ContextSession(ctx).Save(r, w) - if err != nil { - return false, err - } // Redirect the user to the dashboard. - http.Redirect(w, r, "/", http.StatusFound) - return true, nil + return true, web.Redirect(ctx, w, r, "/", http.StatusFound) } return false, nil @@ -103,6 +98,10 @@ func (h *Signup) Step1(ctx context.Context, w http.ResponseWriter, r *http.Reque if err != nil { return web.RenderError(ctx, w, r, err, h.Renderer, TmplLayoutBase, TmplContentErrorGeneric, web.MIMETextHTMLCharsetUTF8) } else if end { + err = webcontext.ContextSession(ctx).Save(r, w) + if err != nil { + return err + } return nil } diff --git a/cmd/web-app/handlers/user.go b/cmd/web-app/handlers/user.go index 119135a..9817ac1 100644 --- a/cmd/web-app/handlers/user.go +++ b/cmd/web-app/handlers/user.go @@ -112,8 +112,7 @@ func (h *User) Login(ctx context.Context, w http.ResponseWriter, r *http.Request } // Redirect the user to the dashboard. - http.Redirect(w, r, redirectUri, http.StatusFound) - return true, nil + return true, web.Redirect(ctx, w, r, redirectUri, http.StatusFound) } return false, nil @@ -148,9 +147,7 @@ func (h *User) Logout(ctx context.Context, w http.ResponseWriter, r *http.Reques } // Redirect the user to the root page. - http.Redirect(w, r, "/", http.StatusFound) - - return nil + return web.Redirect(ctx, w, r, "/", http.StatusFound) } // ResetPassword allows a user to perform forgot password. @@ -281,8 +278,7 @@ func (h *User) ResetConfirm(ctx context.Context, w http.ResponseWriter, r *http. } // Redirect the user to the dashboard. - http.Redirect(w, r, "/", http.StatusFound) - return true, nil + return true, web.Redirect(ctx, w, r, "/", http.StatusFound) } _, err = user.ParseResetHash(ctx, h.SecretKey, resetHash, ctxValues.Now) @@ -432,13 +428,8 @@ func (h *User) Update(ctx context.Context, w http.ResponseWriter, r *http.Reques webcontext.SessionFlashSuccess(ctx, "Profile Updated", "User profile successfully updated.") - err = webcontext.ContextSession(ctx).Save(r, w) - if err != nil { - return false, err - } - http.Redirect(w, r, "/user", http.StatusFound) - return true, nil + return true, web.Redirect(ctx, w, r, "/user", http.StatusFound) } return false, nil @@ -584,16 +575,8 @@ func (h *User) VirtualLogin(ctx context.Context, w http.ResponseWriter, r *http. fmt.Sprintf("You are now virtually logged into user %s.", usr.Response(ctx).Name)) - // Write the session to the client. - err = webcontext.ContextSession(ctx).Save(r, w) - if err != nil { - return false, err - } - // Redirect the user to the dashboard with the new credentials. - http.Redirect(w, r, "/", http.StatusFound) - - return true, nil + return true, web.Redirect(ctx, w, r, "/", http.StatusFound) } return false, nil @@ -724,9 +707,7 @@ func (h *User) VirtualLogout(ctx context.Context, w http.ResponseWriter, r *http } // Redirect the user to the dashboard with the new credentials. - http.Redirect(w, r, "/", http.StatusFound) - - return nil + return web.Redirect(ctx, w, r, "/", http.StatusFound) } // VirtualLogin handles switching the scope of the context to another user. @@ -800,16 +781,8 @@ func (h *User) SwitchAccount(ctx context.Context, w http.ResponseWriter, r *http fmt.Sprintf("You are now logged into account %s.", acc.Response(ctx).Name)) - // Write the session to the client. - err = webcontext.ContextSession(ctx).Save(r, w) - if err != nil { - return false, err - } - // Redirect the user to the dashboard with the new credentials. - http.Redirect(w, r, "/", http.StatusFound) - - return true, nil + return true, web.Redirect(ctx, w, r, "/", http.StatusFound) } return false, nil diff --git a/cmd/web-app/handlers/users.go b/cmd/web-app/handlers/users.go index be36709..329311c 100644 --- a/cmd/web-app/handlers/users.go +++ b/cmd/web-app/handlers/users.go @@ -259,13 +259,8 @@ func (h *Users) Create(ctx context.Context, w http.ResponseWriter, r *http.Reque webcontext.SessionFlashSuccess(ctx, "User Created", "User successfully created.") - err = webcontext.ContextSession(ctx).Save(r, w) - if err != nil { - return false, err - } - http.Redirect(w, r, urlUsersView(usr.ID), http.StatusFound) - return true, nil + return true, web.Redirect(ctx, w, r, urlUsersView(usr.ID), http.StatusFound) } return false, nil @@ -333,13 +328,8 @@ func (h *Users) View(ctx context.Context, w http.ResponseWriter, r *http.Request webcontext.SessionFlashSuccess(ctx, "User Archive", "User successfully archive.") - err = webcontext.ContextSession(ctx).Save(r, w) - if err != nil { - return false, err - } - http.Redirect(w, r, urlUsersIndex(), http.StatusFound) - return true, nil + return true, web.Redirect(ctx, w, r, urlUsersIndex(), http.StatusFound) } } @@ -483,13 +473,8 @@ func (h *Users) Update(ctx context.Context, w http.ResponseWriter, r *http.Reque webcontext.SessionFlashSuccess(ctx, "User Updated", "User successfully updated.") - err = webcontext.ContextSession(ctx).Save(r, w) - if err != nil { - return false, err - } - http.Redirect(w, r, urlUsersView(req.ID), http.StatusFound) - return true, nil + return true, web.Redirect(ctx, w, r, urlUsersView(req.ID), http.StatusFound) } return false, nil @@ -607,13 +592,7 @@ func (h *Users) Invite(ctx context.Context, w http.ResponseWriter, r *http.Reque "No users were invited.") } - err = webcontext.ContextSession(ctx).Save(r, w) - if err != nil { - return false, err - } - - http.Redirect(w, r, urlUsersIndex(), http.StatusFound) - return true, nil + return true, web.Redirect(ctx, w, r, urlUsersIndex(), http.StatusFound) } return false, nil @@ -652,7 +631,7 @@ func (h *Users) InviteAccept(ctx context.Context, w http.ResponseWriter, r *http } // - req := new(invite.AcceptInviteRequest) + req := new(invite.AcceptInviteUserRequest) data := make(map[string]interface{}) f := func() (bool, error) { @@ -670,30 +649,33 @@ func (h *Users) InviteAccept(ctx context.Context, w http.ResponseWriter, r *http // Append the query param value to the request. req.InviteHash = inviteHash - hash, err := invite.AcceptInvite(ctx, h.MasterDB, *req, h.SecretKey, ctxValues.Now) + hash, err := invite.AcceptInviteUser(ctx, h.MasterDB, *req, h.SecretKey, ctxValues.Now) if err != nil { switch errors.Cause(err) { case invite.ErrInviteExpired: webcontext.SessionFlashError(ctx, "Invite Expired", "The invite has expired.") + return false, nil + case invite.ErrUserAccountActive: webcontext.SessionFlashError(ctx, "User already Active", - "The user already is already active for the account. Try to login or use forgot password.") - http.Redirect(w, r, "/user/login", http.StatusFound) - return true, nil - case invite.ErrInviteUserPasswordSet: + "The user is already is already active for the account. Try to login or use forgot password.") + + return true, web.Redirect(ctx, w, r, "/user/login", http.StatusFound) + + case invite.ErrNoPendingInvite: webcontext.SessionFlashError(ctx, - "Invite already Accepted", + "Invite Accepted", "The invite has already been accepted. Try to login or use forgot password.") - http.Redirect(w, r, "/user/login", http.StatusFound) - return true, nil + + return true, web.Redirect(ctx, w, r, "/user/login", http.StatusFound) + case user_account.ErrNotFound: return false, err - case invite.ErrNoPendingInvite: - return false, err + default: if verr, ok := weberror.NewValidationError(ctx, err); ok { data["validationErrors"] = verr.(*weberror.Error) @@ -732,36 +714,57 @@ func (h *Users) InviteAccept(ctx context.Context, w http.ResponseWriter, r *http } // Redirect the user to the dashboard. - http.Redirect(w, r, "/", http.StatusFound) - return true, nil + return true, web.Redirect(ctx, w, r, "/", http.StatusFound) } - hash, err := invite.ParseInviteHash(ctx, h.SecretKey, inviteHash, ctxValues.Now) + usrAcc, err := invite.AcceptInvite(ctx, h.MasterDB, invite.AcceptInviteRequest{ + InviteHash: inviteHash, + }, h.SecretKey, ctxValues.Now) if err != nil { + switch errors.Cause(err) { case invite.ErrInviteExpired: webcontext.SessionFlashError(ctx, "Invite Expired", "The invite has expired.") - return false, nil - case invite.ErrInviteUserPasswordSet: + + return true, web.Redirect(ctx, w, r, "/user/login", http.StatusFound) + + case invite.ErrUserAccountActive: webcontext.SessionFlashError(ctx, - "Invite already Accepted", + "User already Active", + "The user is already is already active for the account. Try to login or use forgot password.") + + return true, web.Redirect(ctx, w, r, "/user/login", http.StatusFound) + + case invite.ErrNoPendingInvite: + webcontext.SessionFlashError(ctx, + "Invite Accepted", "The invite has already been accepted. Try to login or use forgot password.") - http.Redirect(w, r, "/user/login", http.StatusFound) - return true, nil + + return true, web.Redirect(ctx, w, r, "/user/login", http.StatusFound) + + case user_account.ErrNotFound: + return false, err default: if verr, ok := weberror.NewValidationError(ctx, err); ok { data["validationErrors"] = verr.(*weberror.Error) + return false, nil } else { return false, err } } + } else if usrAcc.Status == user_account.UserAccountStatus_Active { + webcontext.SessionFlashError(ctx, + "Invite Accepted", + "The invite has been accepted. Login to continue.") + + return true, web.Redirect(ctx, w, r, "/user/login", http.StatusFound) } // Read user by ID with no claims. - usr, err := user.ReadByID(ctx, auth.Claims{}, h.MasterDB, hash.UserID) + usr, err := user.ReadByID(ctx, auth.Claims{}, h.MasterDB, usrAcc.UserID) if err != nil { return false, err } @@ -791,7 +794,7 @@ func (h *Users) InviteAccept(ctx context.Context, w http.ResponseWriter, r *http data["form"] = req - if verr, ok := weberror.NewValidationError(ctx, webcontext.Validator().Struct(invite.AcceptInviteRequest{})); ok { + if verr, ok := weberror.NewValidationError(ctx, webcontext.Validator().Struct(invite.AcceptInviteUserRequest{})); ok { data["validationDefaults"] = verr.(*weberror.Error) } diff --git a/cmd/web-app/templates/content/user-view.gohtml b/cmd/web-app/templates/content/user-view.gohtml index 60584fa..389d987 100644 --- a/cmd/web-app/templates/content/user-view.gohtml +++ b/cmd/web-app/templates/content/user-view.gohtml @@ -55,13 +55,13 @@ Role
{{ if .userAccount }} - {{ range $r := .userAccount.Roles }} - {{ if eq $r "admin" }} - {{ $r }} + {{ range $r := .userAccount.Roles.Options }}{{ if $r.Selected }} + {{ if eq $r.Value "admin" }} + {{ $r.Title }} {{else}} - {{ $r }} + {{ $r.Title }} {{end}} - {{ end }} + {{ end }}{{ end }} {{ end }}

diff --git a/cmd/web-app/templates/content/users-invite.gohtml b/cmd/web-app/templates/content/users-invite.gohtml index 7b18d55..d0b9a06 100644 --- a/cmd/web-app/templates/content/users-invite.gohtml +++ b/cmd/web-app/templates/content/users-invite.gohtml @@ -21,7 +21,7 @@ diff --git a/cmd/web-app/templates/content/users-view.gohtml b/cmd/web-app/templates/content/users-view.gohtml index f7afc30..6d5f05a 100644 --- a/cmd/web-app/templates/content/users-view.gohtml +++ b/cmd/web-app/templates/content/users-view.gohtml @@ -67,7 +67,7 @@ Role
{{ if .userAccount }} - {{ range $r := .userAccount.Roles }}{{ if $r.Selected }} + {{ range $r := .userAccount.Roles.Options }}{{ if $r.Selected }} {{ if eq $r.Value "admin" }} {{ $r.Title }} {{else}} diff --git a/internal/platform/web/models.go b/internal/platform/web/models.go index b2fd123..5eee47a 100644 --- a/internal/platform/web/models.go +++ b/internal/platform/web/models.go @@ -117,12 +117,20 @@ func NewEnumResponse(ctx context.Context, value interface{}, options ...interfac } // EnumResponse is a response friendly format for displaying a multi select enum. -type EnumMultiResponse []EnumOption +type EnumMultiResponse struct { + Values []string `json:"values" example:"active_etc"` + Options []EnumOption `json:"options,omitempty"` +} // NewEnumMultiResponse returns a display friendly format for a multi enum field. func NewEnumMultiResponse(ctx context.Context, selected []interface{}, options ...interface{}) EnumMultiResponse { var er EnumMultiResponse + for _, s := range selected { + selStr := fmt.Sprintf("%s", s) + er.Values = append(er.Values, selStr) + } + for _, opt := range options { optStr := fmt.Sprintf("%s", opt) opt := EnumOption{ @@ -137,7 +145,7 @@ func NewEnumMultiResponse(ctx context.Context, selected []interface{}, options . } } - er = append(er, opt) + er.Options = append(er.Options, opt) } return er diff --git a/internal/platform/web/response.go b/internal/platform/web/response.go index 0db5ef2..37dc558 100644 --- a/internal/platform/web/response.go +++ b/internal/platform/web/response.go @@ -256,3 +256,16 @@ func StaticHandler(ctx context.Context, w http.ResponseWriter, r *http.Request, return nil } + +// Redirect ensures the session is flushed to the browser before the redirect is issued. +func Redirect(ctx context.Context, w http.ResponseWriter, r *http.Request, url string, code int) error { + if sess := webcontext.ContextSession(ctx); sess != nil { + if err := sess.Save(r, w); err != nil { + return err + } + } + + http.Redirect(w, r, url, code) + + return nil +} diff --git a/internal/user_account/invite/invite.go b/internal/user_account/invite/invite.go index 43f1f14..a1fe4b1 100644 --- a/internal/user_account/invite/invite.go +++ b/internal/user_account/invite/invite.go @@ -26,9 +26,6 @@ var ( // ErrUserAccountActive occurs when the user already has an active user_account entry. ErrUserAccountActive = errors.New("User already active.") - - // ErrInviteUserPasswordSet occurs when the the reset hash exceeds the expiration. - ErrInviteUserPasswordSet = errors.New("User password set") ) // SendUserInvites sends emails to the users inviting them to join an account. @@ -181,7 +178,7 @@ func SendUserInvites(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, r } // AcceptInvite updates the user using the provided invite hash. -func AcceptInvite(ctx context.Context, dbConn *sqlx.DB, req AcceptInviteRequest, secretKey string, now time.Time) (*InviteHash, error) { +func AcceptInvite(ctx context.Context, dbConn *sqlx.DB, req AcceptInviteRequest, secretKey string, now time.Time) (*user_account.UserAccount, error) { span, ctx := tracer.StartSpanFromContext(ctx, "internal.user_account.invite.AcceptInvite") defer span.Finish() @@ -193,7 +190,7 @@ func AcceptInvite(ctx context.Context, dbConn *sqlx.DB, req AcceptInviteRequest, return nil, err } - hash, err := ParseInviteHash(ctx, secretKey, req.InviteHash, now) + hash, err := ParseInviteHash(ctx, req.InviteHash, secretKey, now) if err != nil { return nil, err } @@ -216,24 +213,86 @@ func AcceptInvite(ctx context.Context, dbConn *sqlx.DB, req AcceptInviteRequest, AccountID: hash.AccountID, }) if err != nil { - return nil, nil + return nil, err } // Ensure the entry has the status of invited. if usrAcc.Status != user_account.UserAccountStatus_Invited { // If the entry is already active if usrAcc.Status == user_account.UserAccountStatus_Active { - return hash, errors.WithStack(ErrUserAccountActive) + return usrAcc, errors.WithStack(ErrUserAccountActive) + } + return usrAcc, errors.WithStack(ErrNoPendingInvite) + } + + // If the user already has a password set, then just update the user_account entry to status of active. + // The user will need to login and should not be auto-authenticated. + if len(u.PasswordHash) > 0 { + usrAcc.Status = user_account.UserAccountStatus_Active + + err = user_account.Update(ctx, auth.Claims{}, dbConn, user_account.UserAccountUpdateRequest{ + UserID: usrAcc.UserID, + AccountID: usrAcc.AccountID, + Status: &usrAcc.Status, + }, now) + if err != nil { + return nil, err + } + } + + return usrAcc, nil +} + +// AcceptInviteUser updates the user using the provided invite hash. +func AcceptInviteUser(ctx context.Context, dbConn *sqlx.DB, req AcceptInviteUserRequest, secretKey string, now time.Time) (*user_account.UserAccount, error) { + span, ctx := tracer.StartSpanFromContext(ctx, "internal.user_account.invite.AcceptInviteUser") + defer span.Finish() + + v := webcontext.Validator() + + // Validate the request. + err := v.StructCtx(ctx, req) + if err != nil { + return nil, err + } + + hash, err := ParseInviteHash(ctx, req.InviteHash, secretKey, now) + if err != nil { + return nil, err + } + + u, err := user.Read(ctx, auth.Claims{}, dbConn, + user.UserReadRequest{ID: hash.UserID, IncludeArchived: true}) + if err != nil { + return nil, err + } + + if u.ArchivedAt != nil && !u.ArchivedAt.Time.IsZero() { + err = user.Restore(ctx, auth.Claims{}, dbConn, user.UserRestoreRequest{ID: hash.UserID}, now) + if err != nil { + return nil, err + } + } + + usrAcc, err := user_account.Read(ctx, auth.Claims{}, dbConn, user_account.UserAccountReadRequest{ + UserID: hash.UserID, + AccountID: hash.AccountID, + }) + if err != nil { + return nil, err + } + + // Ensure the entry has the status of invited. + if usrAcc.Status != user_account.UserAccountStatus_Invited { + // If the entry is already active + if usrAcc.Status == user_account.UserAccountStatus_Active { + return usrAcc, errors.WithStack(ErrUserAccountActive) } return nil, errors.WithStack(ErrNoPendingInvite) } - if len(u.PasswordHash) > 0 { - // Do not update the password for a user that already has a password set. - return nil, errors.WithStack(ErrInviteUserPasswordSet) - } - - // These two calls, user.Update and user.UpdatePassword should probably be in a transaction! + // These three calls, user.Update, user.UpdatePassword, and user_account.Update + // should probably be in a transaction! err = user.Update(ctx, auth.Claims{}, dbConn, user.UserUpdateRequest{ ID: hash.UserID, Email: &req.Email, @@ -254,15 +313,15 @@ func AcceptInvite(ctx context.Context, dbConn *sqlx.DB, req AcceptInviteRequest, return nil, err } - activeStatus := user_account.UserAccountStatus_Active + usrAcc.Status = user_account.UserAccountStatus_Active err = user_account.Update(ctx, auth.Claims{}, dbConn, user_account.UserAccountUpdateRequest{ UserID: usrAcc.UserID, AccountID: usrAcc.AccountID, - Status: &activeStatus, + Status: &usrAcc.Status, }, now) if err != nil { return nil, err } - return hash, nil + return usrAcc, nil } diff --git a/internal/user_account/invite/invite_test.go b/internal/user_account/invite/invite_test.go index 049190c..032c6e7 100644 --- a/internal/user_account/invite/invite_test.go +++ b/internal/user_account/invite/invite_test.go @@ -148,13 +148,13 @@ func TestSendUserInvites(t *testing.T) { // Ensure validation is working by trying ResetConfirm with an empty request. { - expectedErr := errors.New("Key: 'AcceptInviteRequest.invite_hash' Error:Field validation for 'invite_hash' failed on the 'required' tag\n" + - "Key: 'AcceptInviteRequest.email' Error:Field validation for 'email' failed on the 'required' tag\n" + - "Key: 'AcceptInviteRequest.first_name' Error:Field validation for 'first_name' failed on the 'required' tag\n" + - "Key: 'AcceptInviteRequest.last_name' Error:Field validation for 'last_name' failed on the 'required' tag\n" + - "Key: 'AcceptInviteRequest.password' Error:Field validation for 'password' failed on the 'required' tag\n" + - "Key: 'AcceptInviteRequest.password_confirm' Error:Field validation for 'password_confirm' failed on the 'required' tag") - _, err = AcceptInvite(ctx, test.MasterDB, AcceptInviteRequest{}, secretKey, now) + expectedErr := errors.New("Key: 'AcceptInviteUserRequest.invite_hash' Error:Field validation for 'invite_hash' failed on the 'required' tag\n" + + "Key: 'AcceptInviteUserRequest.email' Error:Field validation for 'email' failed on the 'required' tag\n" + + "Key: 'AcceptInviteUserRequest.first_name' Error:Field validation for 'first_name' failed on the 'required' tag\n" + + "Key: 'AcceptInviteUserRequest.last_name' Error:Field validation for 'last_name' failed on the 'required' tag\n" + + "Key: 'AcceptInviteUserRequest.password' Error:Field validation for 'password' failed on the 'required' tag\n" + + "Key: 'AcceptInviteUserRequest.password_confirm' Error:Field validation for 'password_confirm' failed on the 'required' tag") + _, err = AcceptInviteUser(ctx, test.MasterDB, AcceptInviteUserRequest{}, secretKey, now) if err == nil { t.Logf("\t\tWant: %+v", expectedErr) t.Fatalf("\t%s\tResetConfirm failed.", tests.Failed) @@ -174,7 +174,7 @@ func TestSendUserInvites(t *testing.T) { // Ensure the TTL is enforced. { newPass := uuid.NewRandom().String() - _, err = AcceptInvite(ctx, test.MasterDB, AcceptInviteRequest{ + _, err = AcceptInviteUser(ctx, test.MasterDB, AcceptInviteUserRequest{ InviteHash: inviteHashes[0], Email: inviteEmails[0], FirstName: "Foo", @@ -194,7 +194,7 @@ func TestSendUserInvites(t *testing.T) { for idx, inviteHash := range inviteHashes { newPass := uuid.NewRandom().String() - hash, err := AcceptInvite(ctx, test.MasterDB, AcceptInviteRequest{ + hash, err := AcceptInviteUser(ctx, test.MasterDB, AcceptInviteUserRequest{ InviteHash: inviteHash, Email: inviteEmails[idx], FirstName: "Foo", @@ -227,7 +227,7 @@ func TestSendUserInvites(t *testing.T) { // Ensure the reset hash does not work after its used. { newPass := uuid.NewRandom().String() - _, err = AcceptInvite(ctx, test.MasterDB, AcceptInviteRequest{ + _, err = AcceptInviteUser(ctx, test.MasterDB, AcceptInviteUserRequest{ InviteHash: inviteHashes[0], Email: inviteEmails[0], FirstName: "Foo", @@ -237,7 +237,7 @@ func TestSendUserInvites(t *testing.T) { }, secretKey, now) if errors.Cause(err) != ErrUserAccountActive { t.Logf("\t\tGot : %+v", errors.Cause(err)) - t.Logf("\t\tWant: %+v", ErrInviteUserPasswordSet) + t.Logf("\t\tWant: %+v", ErrUserAccountActive) t.Fatalf("\t%s\tInviteAccept verify reuse failed.", tests.Failed) } t.Logf("\t%s\tInviteAccept verify reuse disabled ok.", tests.Success) diff --git a/internal/user_account/invite/models.go b/internal/user_account/invite/models.go index 759fc57..ca87007 100644 --- a/internal/user_account/invite/models.go +++ b/internal/user_account/invite/models.go @@ -32,6 +32,11 @@ type InviteHash struct { // AcceptInviteRequest defines the fields need to complete an invite request. type AcceptInviteRequest struct { + InviteHash string `json:"invite_hash" validate:"required" example:"d69bdef7-173f-4d29-b52c-3edc60baf6a2"` +} + +// AcceptInviteUserRequest defines the fields need to complete an invite request. +type AcceptInviteUserRequest struct { InviteHash string `json:"invite_hash" validate:"required" example:"d69bdef7-173f-4d29-b52c-3edc60baf6a2"` Email string `json:"email" validate:"required,email" example:"gabi@geeksinthewoods.com"` FirstName string `json:"first_name" validate:"required" example:"Gabi"` @@ -67,12 +72,12 @@ func NewInviteHash(ctx context.Context, secretKey, userID, accountID, requestIp } // ParseInviteHash extracts the details encrypted in the hash string. -func ParseInviteHash(ctx context.Context, secretKey string, str string, now time.Time) (*InviteHash, error) { +func ParseInviteHash(ctx context.Context, encrypted, secretKey string, now time.Time) (*InviteHash, error) { crypto, err := symcrypto.New(secretKey) if err != nil { return nil, errors.WithStack(err) } - hashStr, err := crypto.Decrypt(str) + hashStr, err := crypto.Decrypt(encrypted) if err != nil { return nil, errors.WithStack(err) }