package handlers import ( "context" "fmt" "net/http" "net/url" "strings" "time" "geeks-accelerator/oss/saas-starter-kit/internal/account" "geeks-accelerator/oss/saas-starter-kit/internal/geonames" "geeks-accelerator/oss/saas-starter-kit/internal/platform/auth" "geeks-accelerator/oss/saas-starter-kit/internal/platform/notify" "geeks-accelerator/oss/saas-starter-kit/internal/platform/web" "geeks-accelerator/oss/saas-starter-kit/internal/platform/web/webcontext" "geeks-accelerator/oss/saas-starter-kit/internal/platform/web/weberror" project_routes "geeks-accelerator/oss/saas-starter-kit/internal/project-routes" "geeks-accelerator/oss/saas-starter-kit/internal/user" "geeks-accelerator/oss/saas-starter-kit/internal/user_account" "geeks-accelerator/oss/saas-starter-kit/internal/user_auth" "github.com/gorilla/schema" "github.com/gorilla/sessions" "github.com/jmoiron/sqlx" "github.com/pborman/uuid" "github.com/pkg/errors" ) // User represents the User API method handler set. type User struct { MasterDB *sqlx.DB Renderer web.Renderer Authenticator *auth.Authenticator ProjectRoutes project_routes.ProjectRoutes NotifyEmail notify.Email SecretKey string } func urlUserVirtualLogin(userID string) string { return fmt.Sprintf("/user/virtual-login/%s", userID) } // UserLoginRequest extends the AuthenicateRequest with the RememberMe flag. type UserLoginRequest struct { user_auth.AuthenticateRequest RememberMe bool } // Login handles authenticating a user into the system. func (h *User) Login(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { ctxValues, err := webcontext.ContextValues(ctx) if err != nil { return err } // req := new(UserLoginRequest) data := make(map[string]interface{}) f := func() (bool, error) { if r.Method == http.MethodPost { err := r.ParseForm() if err != nil { return false, err } decoder := schema.NewDecoder() if err := decoder.Decode(req, r.PostForm); err != nil { return false, err } sessionTTL := time.Hour if req.RememberMe { sessionTTL = time.Hour * 36 } // Authenticated the user. token, err := user_auth.Authenticate(ctx, h.MasterDB, h.Authenticator, user_auth.AuthenticateRequest{ Email: req.Email, Password: req.Password, }, sessionTTL, ctxValues.Now) if err != nil { switch errors.Cause(err) { case user.ErrForbidden: return false, web.RespondError(ctx, w, weberror.NewError(ctx, err, http.StatusForbidden)) case user_auth.ErrAuthenticationFailure: data["error"] = weberror.NewErrorMessage(ctx, err, http.StatusUnauthorized, "Authentication failure. Try again.") return false, nil default: if verr, ok := weberror.NewValidationError(ctx, err); ok { data["validationErrors"] = verr.(*weberror.Error) return false, nil } else { return false, err } } } // Add the token to the users session. err = handleSessionToken(ctx, h.MasterDB, w, r, token) if err != nil { return false, err } redirectUri := "/" if qv := r.URL.Query().Get("redirect"); qv != "" { redirectUri, err = url.QueryUnescape(qv) if err != nil { return false, err } } // Redirect the user to the dashboard. return true, web.Redirect(ctx, w, r, redirectUri, http.StatusFound) } return false, nil } end, err := f() if err != nil { return web.RenderError(ctx, w, r, err, h.Renderer, TmplLayoutBase, TmplContentErrorGeneric, web.MIMETextHTMLCharsetUTF8) } else if end { return nil } data["form"] = req if verr, ok := weberror.NewValidationError(ctx, webcontext.Validator().Struct(UserLoginRequest{})); ok { data["validationDefaults"] = verr.(*weberror.Error) } return h.Renderer.Render(ctx, w, r, TmplLayoutBase, "user-login.gohtml", web.MIMETextHTMLCharsetUTF8, http.StatusOK, data) } // Logout handles removing authentication for the user. func (h *User) Logout(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { sess := webcontext.ContextSession(ctx) // Set the access token to empty to logout the user. sess = webcontext.SessionDestroy(sess) if err := sess.Save(r, w); err != nil { return err } // Redirect the user to the root page. return web.Redirect(ctx, w, r, "/", http.StatusFound) } // ResetPassword allows a user to perform forgot password. func (h *User) ResetPassword(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { ctxValues, err := webcontext.ContextValues(ctx) if err != nil { return err } // req := new(user.UserResetPasswordRequest) data := make(map[string]interface{}) f := func() error { if r.Method == http.MethodPost { err := r.ParseForm() if err != nil { return err } decoder := schema.NewDecoder() if err := decoder.Decode(req, r.PostForm); err != nil { return err } _, err = user.ResetPassword(ctx, h.MasterDB, h.ProjectRoutes.UserResetPassword, h.NotifyEmail, *req, h.SecretKey, ctxValues.Now) if err != nil { switch errors.Cause(err) { default: if verr, ok := weberror.NewValidationError(ctx, err); ok { data["validationErrors"] = verr.(*weberror.Error) return nil } else { return err } } } // Display a success message to the user to check their email. webcontext.SessionFlashSuccess(ctx, "Check your email", fmt.Sprintf("An email was sent to '%s'. Click on the link in the email to finish resetting your password.", req.Email)) } return nil } if err := f(); err != nil { return web.RenderError(ctx, w, r, err, h.Renderer, TmplLayoutBase, TmplContentErrorGeneric, web.MIMETextHTMLCharsetUTF8) } data["form"] = req if verr, ok := weberror.NewValidationError(ctx, webcontext.Validator().Struct(user.UserResetPasswordRequest{})); ok { data["validationDefaults"] = verr.(*weberror.Error) } return h.Renderer.Render(ctx, w, r, TmplLayoutBase, "user-reset-password.gohtml", web.MIMETextHTMLCharsetUTF8, http.StatusOK, data) } // ResetConfirm handles changing a users password after they have clicked on the link emailed. func (h *User) ResetConfirm(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { resetHash := params["hash"] ctxValues, err := webcontext.ContextValues(ctx) if err != nil { return err } // req := new(user.UserResetConfirmRequest) data := make(map[string]interface{}) f := func() (bool, error) { if r.Method == http.MethodPost { err := r.ParseForm() if err != nil { return false, err } decoder := schema.NewDecoder() if err := decoder.Decode(req, r.PostForm); err != nil { return false, err } // Append the query param value to the request. req.ResetHash = resetHash u, err := user.ResetConfirm(ctx, h.MasterDB, *req, h.SecretKey, ctxValues.Now) if err != nil { switch errors.Cause(err) { case user.ErrResetExpired: webcontext.SessionFlashError(ctx, "Reset Expired", "The reset has expired.") return false, nil default: if verr, ok := weberror.NewValidationError(ctx, err); ok { data["validationErrors"] = verr.(*weberror.Error) return false, nil } else { return false, err } } } // Authenticated the user. Probably should use the default session TTL from UserLogin. token, err := user_auth.Authenticate(ctx, h.MasterDB, h.Authenticator, user_auth.AuthenticateRequest{ Email: u.Email, Password: req.Password, }, time.Hour, ctxValues.Now) if err != nil { if verr, ok := weberror.NewValidationError(ctx, err); ok { data["validationErrors"] = verr.(*weberror.Error) return false, nil } else { return false, err } } // Add the token to the users session. err = handleSessionToken(ctx, h.MasterDB, w, r, token) if err != nil { return false, err } // Redirect the user to the dashboard. return true, web.Redirect(ctx, w, r, "/", http.StatusFound) } _, err = user.ParseResetHash(ctx, h.SecretKey, resetHash, ctxValues.Now) if err != nil { switch errors.Cause(err) { case user.ErrResetExpired: webcontext.SessionFlashError(ctx, "Reset Expired", "The reset has expired.") return false, nil default: if verr, ok := weberror.NewValidationError(ctx, err); ok { data["validationErrors"] = verr.(*weberror.Error) return false, nil } else { return false, err } } } return false, nil } end, err := f() if err != nil { return web.RenderError(ctx, w, r, err, h.Renderer, TmplLayoutBase, TmplContentErrorGeneric, web.MIMETextHTMLCharsetUTF8) } else if end { return nil } data["form"] = req if verr, ok := weberror.NewValidationError(ctx, webcontext.Validator().Struct(user.UserResetConfirmRequest{})); ok { data["validationDefaults"] = verr.(*weberror.Error) } return h.Renderer.Render(ctx, w, r, TmplLayoutBase, "user-reset-confirm.gohtml", web.MIMETextHTMLCharsetUTF8, http.StatusOK, data) } // View handles displaying the current user profile. func (h *User) View(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { data := make(map[string]interface{}) f := func() error { claims, err := auth.ClaimsFromContext(ctx) if err != nil { return err } usr, err := user.ReadByID(ctx, claims, h.MasterDB, claims.Subject) if err != nil { return err } data["user"] = usr.Response(ctx) usrAccs, err := user_account.FindByUserID(ctx, claims, h.MasterDB, claims.Subject, false) if err != nil { return err } for _, usrAcc := range usrAccs { if usrAcc.AccountID == claims.Audience { data["userAccount"] = usrAcc.Response(ctx) break } } return nil } if err := f(); err != nil { return web.RenderError(ctx, w, r, err, h.Renderer, TmplLayoutBase, TmplContentErrorGeneric, web.MIMETextHTMLCharsetUTF8) } return h.Renderer.Render(ctx, w, r, TmplLayoutBase, "user-view.gohtml", web.MIMETextHTMLCharsetUTF8, http.StatusOK, data) } // Update handles allowing the current user to update their profile. func (h *User) Update(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { ctxValues, err := webcontext.ContextValues(ctx) if err != nil { return err } claims, err := auth.ClaimsFromContext(ctx) if err != nil { return err } // req := new(user.UserUpdateRequest) data := make(map[string]interface{}) f := func() (bool, error) { if r.Method == http.MethodPost { err := r.ParseForm() if err != nil { return false, err } decoder := schema.NewDecoder() decoder.IgnoreUnknownKeys(true) if err := decoder.Decode(req, r.PostForm); err != nil { return false, err } req.ID = claims.Subject err = user.Update(ctx, claims, h.MasterDB, *req, ctxValues.Now) if err != nil { switch errors.Cause(err) { default: if verr, ok := weberror.NewValidationError(ctx, err); ok { data["validationErrors"] = verr.(*weberror.Error) return false, nil } else { return false, err } } } if r.PostForm.Get("Password") != "" { pwdReq := new(user.UserUpdatePasswordRequest) if err := decoder.Decode(pwdReq, r.PostForm); err != nil { return false, err } pwdReq.ID = claims.Subject err = user.UpdatePassword(ctx, claims, h.MasterDB, *pwdReq, ctxValues.Now) if err != nil { switch errors.Cause(err) { default: if verr, ok := weberror.NewValidationError(ctx, err); ok { data["validationErrors"] = verr.(*weberror.Error) return false, nil } else { return false, err } } } } // Display a success message to the user. webcontext.SessionFlashSuccess(ctx, "Profile Updated", "User profile successfully updated.") return true, web.Redirect(ctx, w, r, "/user", http.StatusFound) } return false, nil } end, err := f() if err != nil { return web.RenderError(ctx, w, r, err, h.Renderer, TmplLayoutBase, TmplContentErrorGeneric, web.MIMETextHTMLCharsetUTF8) } else if end { return nil } usr, err := user.ReadByID(ctx, claims, h.MasterDB, claims.Subject) if err != nil { return err } if req.ID == "" { req.FirstName = &usr.FirstName req.LastName = &usr.LastName req.Email = &usr.Email req.Timezone = usr.Timezone } data["user"] = usr.Response(ctx) data["timezones"], err = geonames.ListTimezones(ctx, h.MasterDB) if err != nil { return err } data["form"] = req if verr, ok := weberror.NewValidationError(ctx, webcontext.Validator().Struct(user.UserUpdateRequest{})); ok { data["userValidationDefaults"] = verr.(*weberror.Error) } if verr, ok := weberror.NewValidationError(ctx, webcontext.Validator().Struct(user.UserUpdatePasswordRequest{})); ok { data["passwordValidationDefaults"] = verr.(*weberror.Error) } return h.Renderer.Render(ctx, w, r, TmplLayoutBase, "user-update.gohtml", web.MIMETextHTMLCharsetUTF8, http.StatusOK, data) } // Account handles displaying the Account for the current user. func (h *User) Account(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { data := make(map[string]interface{}) f := func() error { claims, err := auth.ClaimsFromContext(ctx) if err != nil { return err } acc, err := account.ReadByID(ctx, claims, h.MasterDB, claims.Audience) if err != nil { return err } data["account"] = acc.Response(ctx) return nil } if err := f(); err != nil { return web.RenderError(ctx, w, r, err, h.Renderer, TmplLayoutBase, TmplContentErrorGeneric, web.MIMETextHTMLCharsetUTF8) } return h.Renderer.Render(ctx, w, r, TmplLayoutBase, "user-account.gohtml", web.MIMETextHTMLCharsetUTF8, http.StatusOK, data) } // VirtualLogin handles switching the scope of the context to another user. func (h *User) VirtualLogin(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { ctxValues, err := webcontext.ContextValues(ctx) if err != nil { return err } claims, err := auth.ClaimsFromContext(ctx) if err != nil { return err } // req := new(user_auth.VirtualLoginRequest) data := make(map[string]interface{}) f := func() (bool, error) { if r.Method == http.MethodPost { err := r.ParseForm() if err != nil { return false, err } decoder := schema.NewDecoder() decoder.IgnoreUnknownKeys(true) if err := decoder.Decode(req, r.PostForm); err != nil { return false, err } } else { if pv, ok := params["user_id"]; ok && pv != "" { req.UserID = pv } } if qv := r.URL.Query().Get("account_id"); qv != "" { req.AccountID = qv } else { req.AccountID = claims.Audience } if req.UserID != "" { sess := webcontext.ContextSession(ctx) var expires time.Duration if sess != nil && sess.Options != nil { expires = time.Second * time.Duration(sess.Options.MaxAge) } else { expires = time.Hour } // Perform the account switch. tkn, err := user_auth.VirtualLogin(ctx, h.MasterDB, h.Authenticator, claims, *req, expires, ctxValues.Now) if err != nil { if verr, ok := weberror.NewValidationError(ctx, err); ok { data["validationErrors"] = verr.(*weberror.Error) return false, nil } else { return false, err } } // Update the access token in the session. sess = webcontext.SessionUpdateAccessToken(sess, tkn.AccessToken) // Read the account for a flash message. usr, err := user.ReadByID(ctx, claims, h.MasterDB, tkn.UserID) if err != nil { return false, err } webcontext.SessionFlashSuccess(ctx, "User Switched", fmt.Sprintf("You are now virtually logged into user %s.", usr.Response(ctx).Name)) // Redirect the user to the dashboard with the new credentials. return true, web.Redirect(ctx, w, r, "/", http.StatusFound) } return false, nil } end, err := f() if err != nil { return web.RenderError(ctx, w, r, err, h.Renderer, TmplLayoutBase, TmplContentErrorGeneric, web.MIMETextHTMLCharsetUTF8) } else if end { return nil } usrAccs, err := user_account.Find(ctx, claims, h.MasterDB, user_account.UserAccountFindRequest{ Where: "account_id = ?", Args: []interface{}{claims.Audience}, }) if err != nil { return err } var userIDs []interface{} var userPhs []string for _, usrAcc := range usrAccs { if usrAcc.UserID == claims.Subject { // Skip the current authenticated user. continue } userIDs = append(userIDs, usrAcc.UserID) userPhs = append(userPhs, "?") } if len(userIDs) == 0 { userIDs = append(userIDs, "") userPhs = append(userPhs, "?") } users, err := user.Find(ctx, claims, h.MasterDB, user.UserFindRequest{ Where: fmt.Sprintf("id IN (%s)", strings.Join(userPhs, ", ")), Args: userIDs, }) if err != nil { return err } data["users"] = users.Response(ctx) if req.AccountID == "" { req.AccountID = claims.Audience } data["form"] = req if verr, ok := weberror.NewValidationError(ctx, webcontext.Validator().Struct(user_auth.VirtualLoginRequest{})); ok { data["validationDefaults"] = verr.(*weberror.Error) } return h.Renderer.Render(ctx, w, r, TmplLayoutBase, "user-virtual-login.gohtml", web.MIMETextHTMLCharsetUTF8, http.StatusOK, data) } // VirtualLogout handles switching the scope back to the user who initiated the virtual login. func (h *User) VirtualLogout(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { ctxValues, err := webcontext.ContextValues(ctx) if err != nil { return err } claims, err := auth.ClaimsFromContext(ctx) if err != nil { return err } sess := webcontext.ContextSession(ctx) var expires time.Duration if sess != nil && sess.Options != nil { expires = time.Second * time.Duration(sess.Options.MaxAge) } else { expires = time.Hour } tkn, err := user_auth.VirtualLogout(ctx, h.MasterDB, h.Authenticator, claims, expires, ctxValues.Now) if err != nil { return err } // Update the access token in the session. sess = webcontext.SessionUpdateAccessToken(sess, tkn.AccessToken) // Display a success message to verify the user has switched contexts. if claims.Subject != tkn.UserID && claims.Audience != tkn.AccountID { usr, err := user.ReadByID(ctx, claims, h.MasterDB, tkn.UserID) if err != nil { return err } acc, err := account.ReadByID(ctx, claims, h.MasterDB, tkn.AccountID) if err != nil { return err } webcontext.SessionFlashSuccess(ctx, "Context Switched", fmt.Sprintf("You are now virtually logged back into account %s user %s.", acc.Response(ctx).Name, usr.Response(ctx).Name)) } else if claims.Audience != tkn.AccountID { acc, err := account.ReadByID(ctx, claims, h.MasterDB, tkn.AccountID) if err != nil { return err } webcontext.SessionFlashSuccess(ctx, "Context Switched", fmt.Sprintf("You are now virtually logged back into account %s.", acc.Response(ctx).Name)) } else { usr, err := user.ReadByID(ctx, claims, h.MasterDB, tkn.UserID) if err != nil { return err } webcontext.SessionFlashSuccess(ctx, "Context Switched", fmt.Sprintf("You are now virtually logged back into user %s.", usr.Response(ctx).Name)) } // Write the session to the client. err = webcontext.ContextSession(ctx).Save(r, w) if err != nil { return err } // Redirect the user to the dashboard with the new credentials. return web.Redirect(ctx, w, r, "/", http.StatusFound) } // VirtualLogin handles switching the scope of the context to another user. func (h *User) SwitchAccount(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { ctxValues, err := webcontext.ContextValues(ctx) if err != nil { return err } claims, err := auth.ClaimsFromContext(ctx) if err != nil { return err } // req := new(user_auth.SwitchAccountRequest) data := make(map[string]interface{}) f := func() (bool, error) { if r.Method == http.MethodPost { err := r.ParseForm() if err != nil { return false, err } decoder := schema.NewDecoder() decoder.IgnoreUnknownKeys(true) if err := decoder.Decode(req, r.PostForm); err != nil { return false, err } } else { if pv, ok := params["account_id"]; ok && pv != "" { req.AccountID = pv } else if qv := r.URL.Query().Get("account_id"); qv != "" { req.AccountID = qv } } if req.AccountID != "" { sess := webcontext.ContextSession(ctx) var expires time.Duration if sess != nil && sess.Options != nil { expires = time.Second * time.Duration(sess.Options.MaxAge) } else { expires = time.Hour } // Perform the account switch. tkn, err := user_auth.SwitchAccount(ctx, h.MasterDB, h.Authenticator, claims, *req, expires, ctxValues.Now) if err != nil { if verr, ok := weberror.NewValidationError(ctx, err); ok { data["validationErrors"] = verr.(*weberror.Error) return false, nil } else { return false, err } } // Update the access token in the session. sess = webcontext.SessionUpdateAccessToken(sess, tkn.AccessToken) // Read the account for a flash message. acc, err := account.ReadByID(ctx, claims, h.MasterDB, tkn.AccountID) if err != nil { return false, err } webcontext.SessionFlashSuccess(ctx, "Account Switched", fmt.Sprintf("You are now logged into account %s.", acc.Response(ctx).Name)) // Redirect the user to the dashboard with the new credentials. return true, web.Redirect(ctx, w, r, "/", http.StatusFound) } return false, nil } end, err := f() if err != nil { return web.RenderError(ctx, w, r, err, h.Renderer, TmplLayoutBase, TmplContentErrorGeneric, web.MIMETextHTMLCharsetUTF8) } else if end { return nil } accounts, err := account.Find(ctx, claims, h.MasterDB, account.AccountFindRequest{ Order: []string{"name"}, }) if err != nil { return err } data["accounts"] = accounts.Response(ctx) if req.AccountID == "" { req.AccountID = claims.Audience } data["form"] = req if verr, ok := weberror.NewValidationError(ctx, webcontext.Validator().Struct(user_auth.SwitchAccountRequest{})); ok { data["validationDefaults"] = verr.(*weberror.Error) } return h.Renderer.Render(ctx, w, r, TmplLayoutBase, "user-switch-account.gohtml", web.MIMETextHTMLCharsetUTF8, http.StatusOK, data) } // handleSessionToken persists the access token to the session for request authentication. func handleSessionToken(ctx context.Context, db *sqlx.DB, w http.ResponseWriter, r *http.Request, token user_auth.Token) error { if token.AccessToken == "" { return errors.New("accessToken is required.") } sess := webcontext.ContextSession(ctx) if sess.IsNew { sess.ID = uuid.NewRandom().String() } sess.Options = &sessions.Options{ Path: "/", MaxAge: int(token.TTL.Seconds()), HttpOnly: false, } sess = webcontext.SessionInit(sess, token.AccessToken) if err := sess.Save(r, w); err != nil { return err } return nil } // updateContextClaims updates the claims in the context. func updateContextClaims(ctx context.Context, authenticator *auth.Authenticator, claims auth.Claims) (context.Context, error) { tkn, err := authenticator.GenerateToken(claims) if err != nil { return ctx, err } sess := webcontext.ContextSession(ctx) sess = webcontext.SessionUpdateAccessToken(sess, tkn) ctx = context.WithValue(ctx, auth.Key, claims) return ctx, nil }