diff --git a/cmd/web-api/tests/user_test.go b/cmd/web-api/tests/user_test.go index a3f1198..19c3902 100644 --- a/cmd/web-api/tests/user_test.go +++ b/cmd/web-api/tests/user_test.go @@ -4,7 +4,6 @@ import ( "context" "encoding/json" "fmt" - "geeks-accelerator/oss/saas-starter-kit/internal/user_auth" "net/http" "net/http/httptest" "testing" @@ -17,6 +16,7 @@ import ( "geeks-accelerator/oss/saas-starter-kit/internal/platform/web/weberror" "geeks-accelerator/oss/saas-starter-kit/internal/user" "geeks-accelerator/oss/saas-starter-kit/internal/user_account" + "geeks-accelerator/oss/saas-starter-kit/internal/user_auth" "github.com/pborman/uuid" ) diff --git a/cmd/web-app/handlers/account.go b/cmd/web-app/handlers/account.go index 0aa935f..f43a11e 100644 --- a/cmd/web-app/handlers/account.go +++ b/cmd/web-app/handlers/account.go @@ -19,8 +19,9 @@ import ( // Account represents the Account API method handler set. type Account struct { - MasterDB *sqlx.DB - Renderer web.Renderer + MasterDB *sqlx.DB + Renderer web.Renderer + Authenticator *auth.Authenticator } // View handles displaying the current account profile. @@ -34,7 +35,7 @@ func (h *Account) View(ctx context.Context, w http.ResponseWriter, r *http.Reque return err } - acc, err := account.Read(ctx, claims, h.MasterDB, claims.Audience, false) + acc, err := account.ReadByID(ctx, claims, h.MasterDB, claims.Audience) if err != nil { return err } @@ -127,7 +128,11 @@ func (h *Account) Update(ctx context.Context, w http.ResponseWriter, r *http.Req } } - sess := webcontext.ContextSession(ctx) + var updateClaims bool + if req.Timezone != nil && claims.Preferences.Timezone != *req.Timezone { + claims.Preferences.Timezone = *req.Timezone + updateClaims = true + } if preferenceDatetimeFormat != req.PreferenceDatetimeFormat { err = account_preference.Set(ctx, claims, h.MasterDB, account_preference.AccountPreferenceSetRequest{ @@ -144,7 +149,10 @@ func (h *Account) Update(ctx context.Context, w http.ResponseWriter, r *http.Req } } - sess.Values[webcontext.SessionKeyPreferenceDatetimeFormat] = req.PreferenceDatetimeFormat + if claims.Preferences.DatetimeFormat != req.PreferenceDatetimeFormat { + claims.Preferences.DatetimeFormat = req.PreferenceDatetimeFormat + updateClaims = true + } } if preferenceDateFormat != req.PreferenceDateFormat { @@ -162,7 +170,10 @@ func (h *Account) Update(ctx context.Context, w http.ResponseWriter, r *http.Req } } - sess.Values[webcontext.SessionKeyPreferenceDateFormat] = req.PreferenceDateFormat + if claims.Preferences.DateFormat != req.PreferenceDateFormat { + claims.Preferences.DateFormat = req.PreferenceDateFormat + updateClaims = true + } } if preferenceTimeFormat != req.PreferenceTimeFormat { @@ -180,7 +191,18 @@ func (h *Account) Update(ctx context.Context, w http.ResponseWriter, r *http.Req } } - sess.Values[webcontext.SessionKeyPreferenceTimeFormat] = req.PreferenceTimeFormat + if claims.Preferences.TimeFormat != req.PreferenceTimeFormat { + claims.Preferences.TimeFormat = req.PreferenceTimeFormat + updateClaims = true + } + } + + // Update the access token to include the updated claims. + if updateClaims { + ctx, err = updateContextClaims(ctx, h.Authenticator, claims) + if err != nil { + return false, err + } } // Display a success message to the user. @@ -196,7 +218,7 @@ func (h *Account) Update(ctx context.Context, w http.ResponseWriter, r *http.Req return true, nil } - acc, err := account.Read(ctx, claims, h.MasterDB, claims.Audience, false) + acc, err := account.ReadByID(ctx, claims, h.MasterDB, claims.Audience) if err != nil { return false, err } diff --git a/cmd/web-app/handlers/routes.go b/cmd/web-app/handlers/routes.go index 0f86f7e..79013a5 100644 --- a/cmd/web-app/handlers/routes.go +++ b/cmd/web-app/handlers/routes.go @@ -66,13 +66,21 @@ func APP(shutdown chan os.Signal, log *log.Logger, env webcontext.Env, staticDir app.Handle("POST", "/user/update", u.Update, mid.AuthenticateSessionRequired(authenticator), mid.HasAuth()) app.Handle("GET", "/user/update", u.Update, mid.AuthenticateSessionRequired(authenticator), mid.HasAuth()) app.Handle("GET", "/user/account", u.Account, mid.AuthenticateSessionRequired(authenticator), mid.HasAuth()) + app.Handle("GET", "/user/virtual-login/:user_id", u.VirtualLogin, mid.AuthenticateSessionRequired(authenticator), mid.HasRole(auth.RoleAdmin)) + app.Handle("POST", "/user/virtual-login", u.VirtualLogin, mid.AuthenticateSessionRequired(authenticator), mid.HasRole(auth.RoleAdmin)) + app.Handle("GET", "/user/virtual-login", u.VirtualLogin, mid.AuthenticateSessionRequired(authenticator), mid.HasRole(auth.RoleAdmin)) + app.Handle("GET", "/user/virtual-logout", u.VirtualLogout, mid.AuthenticateSessionRequired(authenticator), mid.HasAuth()) + app.Handle("GET", "/user/switch-account/:account_id", u.SwitchAccount, mid.AuthenticateSessionRequired(authenticator), mid.HasAuth()) + app.Handle("POST", "/user/switch-account", u.SwitchAccount, mid.AuthenticateSessionRequired(authenticator), mid.HasAuth()) + app.Handle("GET", "/user/switch-account", u.SwitchAccount, mid.AuthenticateSessionRequired(authenticator), mid.HasAuth()) app.Handle("POST", "/user", u.View, mid.AuthenticateSessionRequired(authenticator), mid.HasAuth()) app.Handle("GET", "/user", u.View, mid.AuthenticateSessionRequired(authenticator), mid.HasAuth()) // Register account management endpoints. acc := Account{ - MasterDB: masterDB, - Renderer: renderer, + MasterDB: masterDB, + Renderer: renderer, + Authenticator: authenticator, } app.Handle("POST", "/account/update", acc.Update, mid.AuthenticateSessionRequired(authenticator), mid.HasRole(auth.RoleAdmin)) app.Handle("GET", "/account/update", acc.Update, mid.AuthenticateSessionRequired(authenticator), mid.HasRole(auth.RoleAdmin)) diff --git a/cmd/web-app/handlers/signup.go b/cmd/web-app/handlers/signup.go index 0790b4f..9ebeabf 100644 --- a/cmd/web-app/handlers/signup.go +++ b/cmd/web-app/handlers/signup.go @@ -12,7 +12,7 @@ import ( "geeks-accelerator/oss/saas-starter-kit/internal/platform/web/webcontext" "geeks-accelerator/oss/saas-starter-kit/internal/platform/web/weberror" "geeks-accelerator/oss/saas-starter-kit/internal/signup" - "geeks-accelerator/oss/saas-starter-kit/internal/user" + "geeks-accelerator/oss/saas-starter-kit/internal/user_auth" "github.com/gorilla/schema" "github.com/jmoiron/sqlx" "github.com/pkg/errors" @@ -68,7 +68,7 @@ func (h *Signup) Step1(ctx context.Context, w http.ResponseWriter, r *http.Reque } // Authenticated the new user. - token, err := user.Authenticate(ctx, h.MasterDB, h.Authenticator, req.User.Email, req.User.Password, time.Hour, ctxValues.Now) + token, err := user_auth.Authenticate(ctx, h.MasterDB, h.Authenticator, req.User.Email, req.User.Password, time.Hour, ctxValues.Now) if err != nil { return err } @@ -93,13 +93,6 @@ func (h *Signup) Step1(ctx context.Context, w http.ResponseWriter, r *http.Reque return nil } - data["geonameCountries"] = geonames.ValidGeonameCountries - - data["countries"], err = geonames.FindCountries(ctx, h.MasterDB, "name", "") - if err != nil { - return err - } - return nil } @@ -107,6 +100,13 @@ func (h *Signup) Step1(ctx context.Context, w http.ResponseWriter, r *http.Reque return web.RenderError(ctx, w, r, err, h.Renderer, TmplLayoutBase, TmplContentErrorGeneric, web.MIMETextHTMLCharsetUTF8) } + data["geonameCountries"] = geonames.ValidGeonameCountries + + data["countries"], err = geonames.FindCountries(ctx, h.MasterDB, "name", "") + if err != nil { + return err + } + data["form"] = req if verr, ok := weberror.NewValidationError(ctx, webcontext.Validator().Struct(signup.SignupRequest{})); ok { diff --git a/cmd/web-app/handlers/user.go b/cmd/web-app/handlers/user.go index 8bb03a5..a139866 100644 --- a/cmd/web-app/handlers/user.go +++ b/cmd/web-app/handlers/user.go @@ -4,6 +4,8 @@ import ( "context" "fmt" "net/http" + "net/url" + "strings" "time" "geeks-accelerator/oss/saas-starter-kit/internal/account" @@ -16,6 +18,7 @@ import ( project_routes "geeks-accelerator/oss/saas-starter-kit/internal/project-routes" "geeks-accelerator/oss/saas-starter-kit/internal/user" "geeks-accelerator/oss/saas-starter-kit/internal/user_account" + "geeks-accelerator/oss/saas-starter-kit/internal/user_auth" "github.com/gorilla/schema" "github.com/gorilla/sessions" "github.com/jmoiron/sqlx" @@ -34,7 +37,7 @@ type User struct { } type UserLoginRequest struct { - user.AuthenticateRequest + user_auth.AuthenticateRequest RememberMe bool } @@ -68,7 +71,7 @@ func (h *User) Login(ctx context.Context, w http.ResponseWriter, r *http.Request } // Authenticated the user. - token, err := user.Authenticate(ctx, h.MasterDB, h.Authenticator, req.Email, req.Password, sessionTTL, ctxValues.Now) + token, err := user_auth.Authenticate(ctx, h.MasterDB, h.Authenticator, req.Email, req.Password, sessionTTL, ctxValues.Now) if err != nil { switch errors.Cause(err) { case user.ErrForbidden: @@ -89,8 +92,16 @@ func (h *User) Login(ctx context.Context, w http.ResponseWriter, r *http.Request return err } + redirectUri := "/" + if qv := r.URL.Query().Get("redirect"); qv != "" { + redirectUri, err = url.QueryUnescape(qv) + if err != nil { + return err + } + } + // Redirect the user to the dashboard. - http.Redirect(w, r, "/", http.StatusFound) + http.Redirect(w, r, redirectUri, http.StatusFound) } return nil @@ -110,7 +121,7 @@ func (h *User) Login(ctx context.Context, w http.ResponseWriter, r *http.Request } // handleSessionToken persists the access token to the session for request authentication. -func handleSessionToken(ctx context.Context, db *sqlx.DB, w http.ResponseWriter, r *http.Request, token user.Token) error { +func handleSessionToken(ctx context.Context, db *sqlx.DB, w http.ResponseWriter, r *http.Request, token user_auth.Token) error { if token.AccessToken == "" { return errors.New("accessToken is required.") } @@ -252,7 +263,7 @@ func (h *User) ResetConfirm(ctx context.Context, w http.ResponseWriter, r *http. } // Authenticated the user. Probably should use the default session TTL from UserLogin. - token, err := user.Authenticate(ctx, h.MasterDB, h.Authenticator, u.Email, req.Password, time.Hour, ctxValues.Now) + token, err := user_auth.Authenticate(ctx, h.MasterDB, h.Authenticator, u.Email, req.Password, time.Hour, ctxValues.Now) if err != nil { switch errors.Cause(err) { case account.ErrForbidden: @@ -306,7 +317,7 @@ func (h *User) View(ctx context.Context, w http.ResponseWriter, r *http.Request, return err } - usr, err := user.Read(ctx, claims, h.MasterDB, claims.Subject, false) + usr, err := user.ReadByID(ctx, claims, h.MasterDB, claims.Subject) if err != nil { return err } @@ -343,16 +354,15 @@ func (h *User) Update(ctx context.Context, w http.ResponseWriter, r *http.Reques return err } + claims, err := auth.ClaimsFromContext(ctx) + if err != nil { + return err + } + // req := new(user.UserUpdateRequest) data := make(map[string]interface{}) f := func() (bool, error) { - - claims, err := auth.ClaimsFromContext(ctx) - if err != nil { - return false, err - } - if r.Method == http.MethodPost { err := r.ParseForm() if err != nil { @@ -415,25 +425,6 @@ func (h *User) Update(ctx context.Context, w http.ResponseWriter, r *http.Reques return true, nil } - usr, err := user.Read(ctx, claims, h.MasterDB, claims.Subject, false) - if err != nil { - return false, err - } - - if req.ID == "" { - req.FirstName = &usr.FirstName - req.LastName = &usr.LastName - req.Email = &usr.Email - req.Timezone = &usr.Timezone - } - - data["user"] = usr.Response(ctx) - - data["timezones"], err = geonames.ListTimezones(ctx, h.MasterDB) - if err != nil { - return false, err - } - return false, nil } @@ -444,6 +435,25 @@ func (h *User) Update(ctx context.Context, w http.ResponseWriter, r *http.Reques return nil } + usr, err := user.ReadByID(ctx, claims, h.MasterDB, claims.Subject) + if err != nil { + return err + } + + if req.ID == "" { + req.FirstName = &usr.FirstName + req.LastName = &usr.LastName + req.Email = &usr.Email + req.Timezone = &usr.Timezone + } + + data["user"] = usr.Response(ctx) + + data["timezones"], err = geonames.ListTimezones(ctx, h.MasterDB) + if err != nil { + return err + } + data["form"] = req if verr, ok := weberror.NewValidationError(ctx, webcontext.Validator().Struct(user.UserUpdateRequest{})); ok { @@ -468,7 +478,7 @@ func (h *User) Account(ctx context.Context, w http.ResponseWriter, r *http.Reque return err } - acc, err := account.Read(ctx, claims, h.MasterDB, claims.Audience, false) + acc, err := account.ReadByID(ctx, claims, h.MasterDB, claims.Audience) if err != nil { return err } @@ -483,3 +493,352 @@ func (h *User) Account(ctx context.Context, w http.ResponseWriter, r *http.Reque return h.Renderer.Render(ctx, w, r, TmplLayoutBase, "user-account.gohtml", web.MIMETextHTMLCharsetUTF8, http.StatusOK, data) } + +// VirtualLogin handles switching the scope of the context to another user. +func (h *User) VirtualLogin(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { + + ctxValues, err := webcontext.ContextValues(ctx) + if err != nil { + return err + } + + claims, err := auth.ClaimsFromContext(ctx) + if err != nil { + return err + } + + // + req := new(user_auth.VirtualLoginRequest) + data := make(map[string]interface{}) + f := func() (bool, error) { + if r.Method == http.MethodPost { + err := r.ParseForm() + if err != nil { + return false, err + } + + decoder := schema.NewDecoder() + decoder.IgnoreUnknownKeys(true) + + if err := decoder.Decode(req, r.PostForm); err != nil { + return false, err + } + } else { + if pv, ok := params["user_id"]; ok && pv != "" { + req.UserID = pv + } + } + + if qv := r.URL.Query().Get("account_id"); qv != "" { + req.AccountID = qv + } else { + req.AccountID = claims.Audience + } + + if req.UserID != "" { + sess := webcontext.ContextSession(ctx) + var expires time.Duration + if sess != nil && sess.Options != nil { + expires = time.Second * time.Duration(sess.Options.MaxAge) + } else { + expires = time.Hour + } + + // Perform the account switch. + tkn, err := user_auth.VirtualLogin(ctx, h.MasterDB, h.Authenticator, claims, *req, expires, ctxValues.Now) + if err != nil { + if verr, ok := weberror.NewValidationError(ctx, err); ok { + data["validationErrors"] = verr.(*weberror.Error) + return false, nil + } else { + return false, err + } + } + + // Update the access token in the session. + sess = webcontext.SessionUpdateAccessToken(sess, tkn.AccessToken) + + // Read the account for a flash message. + usr, err := user.ReadByID(ctx, claims, h.MasterDB, tkn.UserID) + if err != nil { + return false, err + } + webcontext.SessionFlashSuccess(ctx, + "User Switched", + fmt.Sprintf("You are now virtually logged into user %s.", + usr.Response(ctx).Name)) + + // Write the session to the client. + err = webcontext.ContextSession(ctx).Save(r, w) + if err != nil { + return false, err + } + + // Redirect the user to the dashboard with the new credentials. + http.Redirect(w, r, "/", http.StatusFound) + + return true, nil + } + + return false, nil + } + + end, err := f() + if err != nil { + return web.RenderError(ctx, w, r, err, h.Renderer, TmplLayoutBase, TmplContentErrorGeneric, web.MIMETextHTMLCharsetUTF8) + } else if end { + return nil + } + + usrAccFilter := "account_id = ?" + usrAccs, err := user_account.Find(ctx, claims, h.MasterDB, user_account.UserAccountFindRequest{ + Where: &usrAccFilter, + Args: []interface{}{claims.Audience}, + }) + if err != nil { + return err + } + + var userIDs []interface{} + var userPhs []string + for _, usrAcc := range usrAccs { + if usrAcc.UserID == claims.Subject { + // Skip the current authenticated user. + continue + } + userIDs = append(userIDs, usrAcc.UserID) + userPhs = append(userPhs, "?") + } + + if len(userIDs) == 0 { + userIDs = append(userIDs, "") + userPhs = append(userPhs, "?") + } + + usrFilter := fmt.Sprintf("id IN (%s)", strings.Join(userPhs, ", ")) + users, err := user.Find(ctx, claims, h.MasterDB, user.UserFindRequest{ + Where: &usrFilter, + Args: userIDs, + }) + if err != nil { + return err + } + data["users"] = users.Response(ctx) + + if req.AccountID == "" { + req.AccountID = claims.Audience + } + + data["form"] = req + + if verr, ok := weberror.NewValidationError(ctx, webcontext.Validator().Struct(user_auth.VirtualLoginRequest{})); ok { + data["validationDefaults"] = verr.(*weberror.Error) + } + + return h.Renderer.Render(ctx, w, r, TmplLayoutBase, "user-virtual-login.gohtml", web.MIMETextHTMLCharsetUTF8, http.StatusOK, data) +} + +// VirtualLogout handles switching the scope back to the user who initiated the virtual login. +func (h *User) VirtualLogout(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { + + ctxValues, err := webcontext.ContextValues(ctx) + if err != nil { + return err + } + + claims, err := auth.ClaimsFromContext(ctx) + if err != nil { + return err + } + + sess := webcontext.ContextSession(ctx) + + var expires time.Duration + if sess != nil && sess.Options != nil { + expires = time.Second * time.Duration(sess.Options.MaxAge) + } else { + expires = time.Hour + } + + tkn, err := user_auth.VirtualLogout(ctx, h.MasterDB, h.Authenticator, claims, expires, ctxValues.Now) + if err != nil { + return err + } + + // Update the access token in the session. + sess = webcontext.SessionUpdateAccessToken(sess, tkn.AccessToken) + + // Display a success message to verify the user has switched contexts. + if claims.Subject != tkn.UserID && claims.Audience != tkn.AccountID { + usr, err := user.ReadByID(ctx, claims, h.MasterDB, tkn.UserID) + if err != nil { + return err + } + acc, err := account.ReadByID(ctx, claims, h.MasterDB, tkn.AccountID) + if err != nil { + return err + } + webcontext.SessionFlashSuccess(ctx, + "Context Switched", + fmt.Sprintf("You are now virtually logged back into account %s user %s.", + acc.Response(ctx).Name, usr.Response(ctx).Name)) + } else if claims.Audience != tkn.AccountID { + acc, err := account.ReadByID(ctx, claims, h.MasterDB, tkn.AccountID) + if err != nil { + return err + } + webcontext.SessionFlashSuccess(ctx, + "Context Switched", + fmt.Sprintf("You are now virtually logged back into account %s.", + acc.Response(ctx).Name)) + } else { + usr, err := user.ReadByID(ctx, claims, h.MasterDB, tkn.UserID) + if err != nil { + return err + } + webcontext.SessionFlashSuccess(ctx, + "Context Switched", + fmt.Sprintf("You are now virtually logged back into user %s.", + usr.Response(ctx).Name)) + } + + // Write the session to the client. + err = webcontext.ContextSession(ctx).Save(r, w) + if err != nil { + return err + } + + // Redirect the user to the dashboard with the new credentials. + http.Redirect(w, r, "/", http.StatusFound) + + return nil +} + +// VirtualLogin handles switching the scope of the context to another user. +func (h *User) SwitchAccount(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { + + ctxValues, err := webcontext.ContextValues(ctx) + if err != nil { + return err + } + + claims, err := auth.ClaimsFromContext(ctx) + if err != nil { + return err + } + + // + req := new(user_auth.SwitchAccountRequest) + data := make(map[string]interface{}) + f := func() (bool, error) { + + if r.Method == http.MethodPost { + err := r.ParseForm() + if err != nil { + return false, err + } + + decoder := schema.NewDecoder() + decoder.IgnoreUnknownKeys(true) + + if err := decoder.Decode(req, r.PostForm); err != nil { + return false, err + } + } else { + if pv, ok := params["account_id"]; ok && pv != "" { + req.AccountID = pv + } else if qv := r.URL.Query().Get("account_id"); qv != "" { + req.AccountID = qv + } + } + + if req.AccountID != "" { + sess := webcontext.ContextSession(ctx) + var expires time.Duration + if sess != nil && sess.Options != nil { + expires = time.Second * time.Duration(sess.Options.MaxAge) + } else { + expires = time.Hour + } + + // Perform the account switch. + tkn, err := user_auth.SwitchAccount(ctx, h.MasterDB, h.Authenticator, claims, *req, expires, ctxValues.Now) + if err != nil { + if verr, ok := weberror.NewValidationError(ctx, err); ok { + data["validationErrors"] = verr.(*weberror.Error) + return false, nil + } else { + return false, err + } + } + + // Update the access token in the session. + sess = webcontext.SessionUpdateAccessToken(sess, tkn.AccessToken) + + // Read the account for a flash message. + acc, err := account.ReadByID(ctx, claims, h.MasterDB, tkn.AccountID) + if err != nil { + return false, err + } + webcontext.SessionFlashSuccess(ctx, + "Account Switched", + fmt.Sprintf("You are now logged into account %s.", + acc.Response(ctx).Name)) + + // Write the session to the client. + err = webcontext.ContextSession(ctx).Save(r, w) + if err != nil { + return false, err + } + + // Redirect the user to the dashboard with the new credentials. + http.Redirect(w, r, "/", http.StatusFound) + + return true, nil + } + + return false, nil + } + + end, err := f() + if err != nil { + return web.RenderError(ctx, w, r, err, h.Renderer, TmplLayoutBase, TmplContentErrorGeneric, web.MIMETextHTMLCharsetUTF8) + } else if end { + return nil + } + + accounts, err := account.Find(ctx, claims, h.MasterDB, account.AccountFindRequest{ + Order: []string{"name"}, + }) + if err != nil { + return err + } + data["accounts"] = accounts.Response(ctx) + + if req.AccountID == "" { + req.AccountID = claims.Audience + } + + data["form"] = req + + if verr, ok := weberror.NewValidationError(ctx, webcontext.Validator().Struct(user_auth.SwitchAccountRequest{})); ok { + data["validationDefaults"] = verr.(*weberror.Error) + } + + return h.Renderer.Render(ctx, w, r, TmplLayoutBase, "user-switch-account.gohtml", web.MIMETextHTMLCharsetUTF8, http.StatusOK, data) +} + +// updateContextClaims updates the claims in the context. +func updateContextClaims(ctx context.Context, authenticator *auth.Authenticator, claims auth.Claims) (context.Context, error) { + tkn, err := authenticator.GenerateToken(claims) + if err != nil { + return ctx, err + } + + sess := webcontext.ContextSession(ctx) + sess = webcontext.SessionUpdateAccessToken(sess, tkn) + + ctx = context.WithValue(ctx, auth.Key, claims) + + return ctx, nil +} diff --git a/cmd/web-app/main.go b/cmd/web-app/main.go index d1a9d35..c456d82 100644 --- a/cmd/web-app/main.go +++ b/cmd/web-app/main.go @@ -698,7 +698,7 @@ func main() { return nil } - usr, err := user.Read(ctx, auth.Claims{}, masterDb, claims.Subject, false) + usr, err := user.ReadByID(ctx, auth.Claims{}, masterDb, claims.Subject) if err != nil { return nil } @@ -733,7 +733,7 @@ func main() { return nil } - acc, err := account.Read(ctx, auth.Claims{}, masterDb, claims.Audience, false) + acc, err := account.ReadByID(ctx, auth.Claims{}, masterDb, claims.Audience) if err != nil { return nil } @@ -746,6 +746,26 @@ func main() { return a }, + "ContextCanSwitchAccount": func(ctx context.Context) bool { + claims, err := auth.ClaimsFromContext(ctx) + if err != nil || len(claims.AccountIDs) < 2 { + return false + } + return true + }, + "ContextIsVirtualSession": func(ctx context.Context) bool { + claims, err := auth.ClaimsFromContext(ctx) + if err != nil { + return false + } + if claims.RootUserID != "" && claims.RootUserID != claims.Subject { + return true + } + if claims.RootAccountID != "" && claims.RootAccountID != claims.Audience { + return true + } + return false + }, } imgUrlFormatter := staticUrlFormatter @@ -827,11 +847,8 @@ func main() { switch statusCode { case http.StatusUnauthorized: - // Handle expired sessions that are returned from the auth middleware. - if strings.Contains(errors.Cause(er).Error(), "token is expired") { - http.Redirect(w, r, "/user/login", http.StatusFound) - return nil - } + http.Redirect(w, r, "/user/login?redirect="+url.QueryEscape(r.RequestURI), http.StatusFound) + return nil } return web.RenderError(ctx, w, r, er, renderer, handlers.TmplLayoutBase, handlers.TmplContentErrorGeneric, web.MIMETextHTMLCharsetUTF8) diff --git a/cmd/web-app/static/assets/images/user-default.jpg b/cmd/web-app/static/assets/images/user-default.jpg new file mode 100644 index 0000000..d70c406 Binary files /dev/null and b/cmd/web-app/static/assets/images/user-default.jpg differ diff --git a/cmd/web-app/templates/content/user-switch-account.gohtml b/cmd/web-app/templates/content/user-switch-account.gohtml new file mode 100644 index 0000000..460ab05 --- /dev/null +++ b/cmd/web-app/templates/content/user-switch-account.gohtml @@ -0,0 +1,60 @@ +{{define "title"}}Switch Account{{end}} +{{define "style"}} + +{{end}} +{{ define "partials/page-wrapper" }} +
+ + +
+ +
+ +
+
+ +
+ +
+
+ {{ template "app-flashes" . }} + +
+

Switch Account

+
+ + {{ template "validation-error" . }} + +
+
+ + {{template "invalid-feedback" dict "validationDefaults" $.validationDefaults "validationErrors" $.validationErrors "fieldName" "AccountID" }} +
+ +
+
+
+
+
+
+
+ +
+ +
+ +
+{{end}} +{{define "js"}} + +{{end}} diff --git a/cmd/web-app/templates/content/user-virtual-login.gohtml b/cmd/web-app/templates/content/user-virtual-login.gohtml new file mode 100644 index 0000000..db24a4a --- /dev/null +++ b/cmd/web-app/templates/content/user-virtual-login.gohtml @@ -0,0 +1,60 @@ +{{define "title"}}Switch User{{end}} +{{define "style"}} + +{{end}} +{{ define "partials/page-wrapper" }} +
+ + +
+ +
+ +
+
+ +
+ +
+
+ {{ template "app-flashes" . }} + +
+

Switch User

+
+ + {{ template "validation-error" . }} + +
+
+ + {{template "invalid-feedback" dict "validationDefaults" $.validationDefaults "validationErrors" $.validationErrors "fieldName" "UserID" }} +
+ +
+
+
+
+
+
+
+ +
+ +
+ +
+{{end}} +{{define "js"}} + +{{end}} diff --git a/cmd/web-app/templates/partials/topbar.tmpl b/cmd/web-app/templates/partials/topbar.tmpl index cbdbac6..e27cc26 100644 --- a/cmd/web-app/templates/partials/topbar.tmpl +++ b/cmd/web-app/templates/partials/topbar.tmpl @@ -177,7 +177,6 @@ Account Settings - Manage Users @@ -197,8 +196,22 @@ Support - + + {{ if ContextCanSwitchAccount $._Ctx }} + + + Switch Account + + {{ end }} + + {{ if ContextIsVirtualSession $._Ctx }} + + + Switch Back + + {{ end }} + Logout diff --git a/internal/account/account.go b/internal/account/account.go index 549ad00..9e21ac0 100644 --- a/internal/account/account.go +++ b/internal/account/account.go @@ -150,7 +150,7 @@ func selectQuery() *sqlbuilder.SelectBuilder { // Find gets all the accounts from the database based on the request params. // TODO: Need to figure out why can't parse the args when appending the where // to the query. -func Find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req AccountFindRequest) ([]*Account, error) { +func Find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req AccountFindRequest) (Accounts, error) { query := selectQuery() if req.Where != nil { @@ -170,7 +170,7 @@ func Find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req AccountF } // find internal method for getting all the accounts from the database using a select query. -func find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, query *sqlbuilder.SelectBuilder, args []interface{}, includedArchived bool) ([]*Account, error) { +func find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, query *sqlbuilder.SelectBuilder, args []interface{}, includedArchived bool) (Accounts, error) { span, ctx := tracer.StartSpanFromContext(ctx, "internal.account.Find") defer span.Finish() diff --git a/internal/account/models.go b/internal/account/models.go index f0d1b1b..bbb6724 100644 --- a/internal/account/models.go +++ b/internal/account/models.go @@ -99,6 +99,21 @@ func (m *AccountResponse) MarshalBinary() ([]byte, error) { return json.Marshal(m) } +// Accounts a list of Accounts. +type Accounts []*Account + +// Response transforms a list of Accounts to a list of AccountResponses. +func (m *Accounts) Response(ctx context.Context) []*AccountResponse { + var l []*AccountResponse + if m != nil && len(*m) > 0 { + for _, n := range *m { + l = append(l, n.Response(ctx)) + } + } + + return l +} + // AccountCreateRequest contains information needed to create a new Account. type AccountCreateRequest struct { Name string `json:"name" validate:"required,unique" example:"Company Name"` diff --git a/internal/platform/auth/claims.go b/internal/platform/auth/claims.go index a77d0ad..6cd92e9 100644 --- a/internal/platform/auth/claims.go +++ b/internal/platform/auth/claims.go @@ -23,9 +23,11 @@ const Key ctxKey = 1 // Claims represents the authorization claims transmitted via a JWT. type Claims struct { - AccountIds []string `json:"accounts"` - Roles []string `json:"roles"` - Preferences ClaimPreferences `json:"prefs"` + RootUserID string `json:"root_user_id"` + RootAccountID string `json:"root_account_id"` + AccountIDs []string `json:"accounts"` + Roles []string `json:"roles"` + Preferences ClaimPreferences `json:"prefs"` jwt.StandardClaims } @@ -41,14 +43,16 @@ type ClaimPreferences struct { // NewClaims constructs a Claims value for the identified user. The Claims // expire within a specified duration of the provided time. Additional fields // of the Claims can be set after calling NewClaims is desired. -func NewClaims(userId, accountId string, accountIds []string, roles []string, prefs ClaimPreferences, now time.Time, expires time.Duration) Claims { +func NewClaims(userID, accountID string, accountIDs []string, roles []string, prefs ClaimPreferences, now time.Time, expires time.Duration) Claims { c := Claims{ - AccountIds: accountIds, - Roles: roles, - Preferences: prefs, + AccountIDs: accountIDs, + RootAccountID: accountID, + RootUserID: userID, + Roles: roles, + Preferences: prefs, StandardClaims: jwt.StandardClaims{ - Subject: userId, - Audience: accountId, + Subject: userID, + Audience: accountID, IssuedAt: now.Unix(), ExpiresAt: now.Add(expires).Unix(), }, diff --git a/internal/platform/web/webcontext/session.go b/internal/platform/web/webcontext/session.go index ec0d194..7c4e498 100644 --- a/internal/platform/web/webcontext/session.go +++ b/internal/platform/web/webcontext/session.go @@ -14,39 +14,23 @@ const KeySession ctxKeySession = 1 // Session keys used to store values. const ( SessionKeyAccessToken = iota - //SessionKeyPreferenceDatetimeFormat - //SessionKeyPreferenceDateFormat - //SessionKeyPreferenceTimeFormat - //SessionKeyTimezone ) -func init() { - //gob.Register(&Session{}) -} - -// Session represents a user with authentication. -type Session struct { - *sessions.Session -} - // ContextWithSession appends a universal translator to a context. func ContextWithSession(ctx context.Context, session *sessions.Session) context.Context { return context.WithValue(ctx, KeySession, session) } // ContextSession returns the session from a context. -func ContextSession(ctx context.Context) *Session { - if s, ok := ctx.Value(KeySession).(*Session); ok { +func ContextSession(ctx context.Context) *sessions.Session { + if s, ok := ctx.Value(KeySession).(*sessions.Session); ok { return s } return nil } func ContextAccessToken(ctx context.Context) (string, bool) { - return ContextSession(ctx).AccessToken() -} - -func (sess *Session) AccessToken() (string, bool) { + sess := ContextSession(ctx) if sess == nil { return "", false } @@ -56,60 +40,19 @@ func (sess *Session) AccessToken() (string, bool) { return "", false } -/* -func(sess *Session) PreferenceDatetimeFormat() (string, bool) { - if sess == nil { - return "", false - } - if sv, ok := sess.Values[SessionKeyPreferenceDatetimeFormat].(string); ok { - return sv, true - } - return "", false -} - -func(sess *Session) PreferenceDateFormat() (string, bool) { - if sess == nil { - return "", false - } - if sv, ok := sess.Values[SessionKeyPreferenceDateFormat].(string); ok { - return sv, true - } - return "", false -} - -func(sess *Session) PreferenceTimeFormat() (string, bool) { - if sess == nil { - return "", false - } - if sv, ok := sess.Values[SessionKeyPreferenceTimeFormat].(string); ok { - return sv, true - } - return "", false -} - -func(sess *Session) Timezone() (*time.Location, bool) { - if sess != nil { - if sv, ok := sess.Values[SessionKeyTimezone].(*time.Location); ok { - return sv, true - } - } - - return nil, false -} -*/ - -func SessionInit(session *Session, accessToken string) *Session { +func SessionInit(session *sessions.Session, accessToken string) *sessions.Session { session.Values[SessionKeyAccessToken] = accessToken - //session.Values[SessionKeyPreferenceDatetimeFormat] = datetimeFormat - //session.Values[SessionKeyPreferenceDateFormat] = dateFormat - //session.Values[SessionKeyPreferenceTimeFormat] = timeFormat - //session.Values[SessionKeyTimezone] = timezone return session } -func SessionDestroy(session *Session) *Session { +func SessionUpdateAccessToken(session *sessions.Session, accessToken string) *sessions.Session { + session.Values[SessionKeyAccessToken] = accessToken + return session +} + +func SessionDestroy(session *sessions.Session) *sessions.Session { delete(session.Values, SessionKeyAccessToken) diff --git a/internal/project/models.go b/internal/project/models.go index 48d146e..14ef059 100644 --- a/internal/project/models.go +++ b/internal/project/models.go @@ -56,6 +56,21 @@ func (m *Project) Response(ctx context.Context) *ProjectResponse { return r } +// Projects a list of Projects. +type Projects []*Project + +// Response transforms a list of Projects to a list of ProjectResponses. +func (m *Projects) Response(ctx context.Context) []*ProjectResponse { + var l []*ProjectResponse + if m != nil && len(*m) > 0 { + for _, n := range *m { + l = append(l, n.Response(ctx)) + } + } + + return l +} + // ProjectCreateRequest contains information needed to create a new Project. type ProjectCreateRequest struct { AccountID string `json:"account_id" validate:"required,uuid" example:"c4653bf9-5978-48b7-89c5-95704aebb7e2"` diff --git a/internal/project/project.go b/internal/project/project.go index 28237c7..7e1c53d 100644 --- a/internal/project/project.go +++ b/internal/project/project.go @@ -124,13 +124,13 @@ func findRequestQuery(req ProjectFindRequest) (*sqlbuilder.SelectBuilder, []inte } // Find gets all the projects from the database based on the request params. -func Find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req ProjectFindRequest) ([]*Project, error) { +func Find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req ProjectFindRequest) (Projects, error) { query, args := findRequestQuery(req) return find(ctx, claims, dbConn, query, args, req.IncludeArchived) } // find internal method for getting all the projects from the database using a select query. -func find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, query *sqlbuilder.SelectBuilder, args []interface{}, includedArchived bool) ([]*Project, error) { +func find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, query *sqlbuilder.SelectBuilder, args []interface{}, includedArchived bool) (Projects, error) { span, ctx := tracer.StartSpanFromContext(ctx, "internal.project.Find") defer span.Finish() diff --git a/internal/user/models.go b/internal/user/models.go index 4840e3a..7c00b2d 100644 --- a/internal/user/models.go +++ b/internal/user/models.go @@ -78,6 +78,21 @@ func (m *UserResponse) MarshalBinary() ([]byte, error) { return json.Marshal(m) } +// Users a list of Users. +type Users []*User + +// Response transforms a list of Users to a list of UserResponses. +func (m *Users) Response(ctx context.Context) []*UserResponse { + var l []*UserResponse + if m != nil && len(*m) > 0 { + for _, n := range *m { + l = append(l, n.Response(ctx)) + } + } + + return l +} + // UserCreateRequest contains information needed to create a new User. type UserCreateRequest struct { FirstName string `json:"first_name" validate:"required" example:"Gabi"` diff --git a/internal/user/user.go b/internal/user/user.go index ded3bcc..effddef 100644 --- a/internal/user/user.go +++ b/internal/user/user.go @@ -202,13 +202,13 @@ func findRequestQuery(req UserFindRequest) (*sqlbuilder.SelectBuilder, []interfa } // Find gets all the users from the database based on the request params. -func Find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserFindRequest) ([]*User, error) { +func Find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserFindRequest) (Users, error) { query, args := findRequestQuery(req) return find(ctx, claims, dbConn, query, args, req.IncludeArchived) } // find internal method for getting all the users from the database using a select query. -func find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, query *sqlbuilder.SelectBuilder, args []interface{}, includedArchived bool) ([]*User, error) { +func find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, query *sqlbuilder.SelectBuilder, args []interface{}, includedArchived bool) (Users, error) { span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.Find") defer span.Finish() diff --git a/internal/user_account/models.go b/internal/user_account/models.go index ba47e16..7c63f29 100644 --- a/internal/user_account/models.go +++ b/internal/user_account/models.go @@ -66,6 +66,34 @@ func (m *UserAccount) Response(ctx context.Context) *UserAccountResponse { return r } +// HasRole checks if the entry has a role. +func (m *UserAccount) HasRole(role UserAccountRole) bool { + if m == nil { + return false + } + for _, r := range m.Roles { + if r == role { + return true + } + } + return false +} + +// UserAccounts a list of UserAccounts. +type UserAccounts []*UserAccount + +// Response transforms a list of UserAccounts to a list of UserAccountResponses. +func (m *UserAccounts) Response(ctx context.Context) []*UserAccountResponse { + var l []*UserAccountResponse + if m != nil && len(*m) > 0 { + for _, n := range *m { + l = append(l, n.Response(ctx)) + } + } + + return l +} + // UserAccountCreateRequest defines the information is needed to associate a user to an // account. Users are global to the application and each users access can be managed // on an account level. If a current entry exists in the database but is archived, diff --git a/internal/user_account/user_account.go b/internal/user_account/user_account.go index beb404d..01e0b28 100644 --- a/internal/user_account/user_account.go +++ b/internal/user_account/user_account.go @@ -131,13 +131,13 @@ func findRequestQuery(req UserAccountFindRequest) (*sqlbuilder.SelectBuilder, [] } // Find gets all the user accounts from the database based on the request params. -func Find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserAccountFindRequest) ([]*UserAccount, error) { +func Find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserAccountFindRequest) (UserAccounts, error) { query, args := findRequestQuery(req) return find(ctx, claims, dbConn, query, args, req.IncludeArchived) } // Find gets all the user accounts from the database based on the select query -func find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, query *sqlbuilder.SelectBuilder, args []interface{}, includedArchived bool) ([]*UserAccount, error) { +func find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, query *sqlbuilder.SelectBuilder, args []interface{}, includedArchived bool) (UserAccounts, error) { span, ctx := tracer.StartSpanFromContext(ctx, "internal.user_account.Find") defer span.Finish() @@ -180,7 +180,7 @@ func find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, query *sqlbu } // Retrieve gets the specified user from the database. -func FindByUserID(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, userID string, includedArchived bool) ([]*UserAccount, error) { +func FindByUserID(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, userID string, includedArchived bool) (UserAccounts, error) { span, ctx := tracer.StartSpanFromContext(ctx, "internal.user_account.FindByUserID") defer span.Finish() diff --git a/internal/user_auth/auth.go b/internal/user_auth/auth.go index d593541..65f440d 100644 --- a/internal/user_auth/auth.go +++ b/internal/user_auth/auth.go @@ -3,6 +3,7 @@ package user_auth import ( "context" "database/sql" + "geeks-accelerator/oss/saas-starter-kit/internal/user_account" "strings" "time" @@ -22,6 +23,9 @@ var ( // ErrAuthenticationFailure occurs when a user attempts to authenticate but // anything goes wrong. ErrAuthenticationFailure = errors.New("Authentication failed") + + // ErrForbidden occurs when a user tries to do something that is forbidden to them according to our access control policies. + ErrForbidden = errors.New("Attempted action is not allowed") ) const ( @@ -65,22 +69,11 @@ func Authenticate(ctx context.Context, dbConn *sqlx.DB, tknGen TokenGenerator, e return generateToken(ctx, dbConn, tknGen, auth.Claims{}, u.ID, "", expires, now, scopes...) } -// Authenticate finds a user by their email and verifies their password. On success -// it returns a Token that can be used to authenticate access to the application in -// the future. -func SwitchAccount(ctx context.Context, dbConn *sqlx.DB, tknGen TokenGenerator, claims auth.Claims, accountID string, expires time.Duration, now time.Time, scopes ...string) (Token, error) { +// SwitchAccount allows users to switch between multiple accounts, this changes the claim audience. +func SwitchAccount(ctx context.Context, dbConn *sqlx.DB, tknGen TokenGenerator, claims auth.Claims, req SwitchAccountRequest, expires time.Duration, now time.Time, scopes ...string) (Token, error) { span, ctx := tracer.StartSpanFromContext(ctx, "internal.user_auth.SwitchAccount") defer span.Finish() - // Defines struct to apply validation for the supplied claims and account ID. - req := struct { - UserID string `json:"user_id" validate:"required,uuid"` - AccountID string `json:"account_id" validate:"required,uuid"` - }{ - UserID: claims.Subject, - AccountID: accountID, - } - // Validate the request. v := webcontext.Validator() err := v.Struct(req) @@ -88,12 +81,74 @@ func SwitchAccount(ctx context.Context, dbConn *sqlx.DB, tknGen TokenGenerator, return Token{}, err } + claims.RootAccountID = req.AccountID + + if claims.RootUserID == "" { + claims.RootUserID = claims.Subject + } + + // Generate a token for the user ID in supplied in claims as the Subject. Pass + // in the supplied claims as well to enforce ACLs when finding the current + // list of accounts for the user. + return generateToken(ctx, dbConn, tknGen, claims, claims.Subject, req.AccountID, expires, now, scopes...) +} + +// VirtualLogin allows users to mock being logged in as other users. +func VirtualLogin(ctx context.Context, dbConn *sqlx.DB, tknGen TokenGenerator, claims auth.Claims, req VirtualLoginRequest, expires time.Duration, now time.Time, scopes ...string) (Token, error) { + span, ctx := tracer.StartSpanFromContext(ctx, "internal.user_auth.VirtualLogin") + defer span.Finish() + + // Validate the request. + v := webcontext.Validator() + err := v.Struct(req) + if err != nil { + return Token{}, err + } + + // Find all the accounts that the current user has access to. + usrAccs, err := user_account.FindByUserID(ctx, claims, dbConn, claims.Subject, false) + if err != nil { + return Token{}, err + } + + // The user must have the role of admin to login any other user. + var hasAccountAdminRole bool + for _, usrAcc := range usrAccs { + if usrAcc.HasRole(user_account.UserAccountRole_Admin) { + if usrAcc.AccountID == req.AccountID { + hasAccountAdminRole = true + break + } + } + } + if !hasAccountAdminRole { + return Token{}, errors.WithMessagef(ErrForbidden, "User %s does not have correct access to account %s ", claims.Subject, req.AccountID) + } + + if claims.RootAccountID == "" { + claims.RootAccountID = claims.Audience + } + if claims.RootUserID == "" { + claims.RootUserID = claims.Subject + } + // Generate a token for the user ID in supplied in claims as the Subject. Pass // in the supplied claims as well to enforce ACLs when finding the current // list of accounts for the user. return generateToken(ctx, dbConn, tknGen, claims, req.UserID, req.AccountID, expires, now, scopes...) } +// VirtualLogout allows switch back to their root user/account. +func VirtualLogout(ctx context.Context, dbConn *sqlx.DB, tknGen TokenGenerator, claims auth.Claims, expires time.Duration, now time.Time, scopes ...string) (Token, error) { + span, ctx := tracer.StartSpanFromContext(ctx, "internal.user_auth.VirtualLogout") + defer span.Finish() + + // Generate a token for the user ID in supplied in claims as the Subject. Pass + // in the supplied claims as well to enforce ACLs when finding the current + // list of accounts for the user. + return generateToken(ctx, dbConn, tknGen, claims, claims.RootUserID, claims.RootAccountID, expires, now, scopes...) +} + // generateToken generates claims for the supplied user ID and account ID and then // returns the token for the generated claims used for authentication. func generateToken(ctx context.Context, dbConn *sqlx.DB, tknGen TokenGenerator, claims auth.Claims, userID, accountID string, expires time.Duration, now time.Time, scopes ...string) (Token, error) { @@ -250,7 +305,7 @@ func generateToken(ctx context.Context, dbConn *sqlx.DB, tknGen TokenGenerator, if scopeValid { roles = append(roles, s) } else { - err := errors.Errorf("invalid scope '%s'", s) + err := errors.Wrapf(ErrForbidden, "invalid scope '%s'", s) return Token{}, err } } @@ -259,7 +314,7 @@ func generateToken(ctx context.Context, dbConn *sqlx.DB, tknGen TokenGenerator, } if len(roles) == 0 { - err := errors.New("no roles defined for user") + err := errors.Wrapf(ErrForbidden, "no roles defined for user") return Token{}, err } } @@ -314,14 +369,24 @@ func generateToken(ctx context.Context, dbConn *sqlx.DB, tknGen TokenGenerator, claimPref = auth.NewClaimPreferences(tz, preferenceDatetimeFormat, preferenceDateFormat, preferenceTimeFormat) } + // Ensure the current claims has the root values set. + if (claims.RootAccountID == "" && claims.Audience != "") || (claims.RootUserID == "" && claims.Subject != "") { + claims.RootAccountID = claims.Audience + claims.RootUserID = claims.Subject + } + // JWT claims requires both an audience and a subject. For this application: // Subject: The ID of the user authenticated. // Audience: The ID of the account the user is accessing. A list of account IDs // will also be included to support the user switching between them. - claims = auth.NewClaims(userID, accountID, accountIds, roles, claimPref, now, expires) + newClaims := auth.NewClaims(userID, accountID, accountIds, roles, claimPref, now, expires) + + // Copy the original root account/user ID. + newClaims.RootAccountID = claims.RootAccountID + newClaims.RootUserID = claims.RootUserID // Generate a token for the user with the defined claims. - tknStr, err := tknGen.GenerateToken(claims) + tknStr, err := tknGen.GenerateToken(newClaims) if err != nil { return Token{}, errors.Wrap(err, "generating token") } @@ -329,9 +394,9 @@ func generateToken(ctx context.Context, dbConn *sqlx.DB, tknGen TokenGenerator, tkn := Token{ AccessToken: tknStr, TokenType: "Bearer", - claims: claims, - UserID: claims.Subject, - AccountID: claims.Audience, + claims: newClaims, + UserID: newClaims.Subject, + AccountID: newClaims.Audience, } if expires.Seconds() > 0 { diff --git a/internal/user_auth/auth_test.go b/internal/user_auth/auth_test.go index 34fbbb5..13e901b 100644 --- a/internal/user_auth/auth_test.go +++ b/internal/user_auth/auth_test.go @@ -2,6 +2,7 @@ package user_auth import ( "encoding/json" + "fmt" "os" "testing" "time" @@ -63,6 +64,9 @@ func TestAuthenticate(t *testing.T) { } t.Logf("\t%s\tCreate user account ok.", tests.Success) + // Add 30 minutes to now to simulate time passing. + now = now.Add(time.Minute * 30) + acc2, err := account.MockAccount(ctx, test.MasterDB, now) if err != nil { t.Log("\t\tGot :", err) @@ -81,7 +85,7 @@ func TestAuthenticate(t *testing.T) { }, now) // Add 30 minutes to now to simulate time passing. - now = now.Add(time.Minute * 30) + now = now.Add(time.Minute * 5) // Try to authenticate valid user with invalid password. _, err = Authenticate(ctx, test.MasterDB, tknGen, usrAcc.User.Email, "xy7", time.Hour, now) @@ -106,17 +110,20 @@ func TestAuthenticate(t *testing.T) { t.Log("\t\tGot :", err) t.Fatalf("\t%s\tParse claims from token failed.", tests.Failed) } + expectClaims := tkn1.claims + expectClaims.RootUserID = "" + expectClaims.RootAccountID = "" + expectClaims.Subject = usrAcc.UserID + expectClaims.Audience = usrAcc.AccountID - // Hack for Unhandled Exception in go-cmp@v0.3.0/cmp/options.go:229 - resClaims, _ := json.Marshal(claims1) - expectClaims, _ := json.Marshal(tkn1.claims) - if diff := cmp.Diff(string(resClaims), string(expectClaims)); diff != "" { + if diff := cmpClaims(claims1, expectClaims); diff != "" { t.Fatalf("\t%s\tExpected parsed claims to match from token. Diff:\n%s", tests.Failed, diff) } t.Logf("\t%s\tAuthenticate parse claims from token ok.", tests.Success) // Try switching to a second account using the first set of claims. - tkn2, err := SwitchAccount(ctx, test.MasterDB, tknGen, claims1, acc2.ID, time.Hour, now) + tkn2, err := SwitchAccount(ctx, test.MasterDB, tknGen, claims1, + SwitchAccountRequest{AccountID: acc2.ID}, time.Hour, now) if err != nil { t.Log("\t\tGot :", err) t.Fatalf("\t%s\tSwitchAccount user failed.", tests.Failed) @@ -129,11 +136,13 @@ func TestAuthenticate(t *testing.T) { t.Log("\t\tGot :", err) t.Fatalf("\t%s\tParse claims from token failed.", tests.Failed) } + expectClaims = tkn2.claims + expectClaims.RootUserID = usrAcc.UserID + expectClaims.RootAccountID = acc2.ID + expectClaims.Subject = usrAcc.UserID + expectClaims.Audience = acc2.ID - // Hack for Unhandled Exception in go-cmp@v0.3.0/cmp/options.go:229 - resClaims, _ = json.Marshal(claims2) - expectClaims, _ = json.Marshal(tkn2.claims) - if diff := cmp.Diff(string(resClaims), string(expectClaims)); diff != "" { + if diff := cmpClaims(claims2, expectClaims); diff != "" { t.Fatalf("\t%s\tExpected parsed claims to match from token. Diff:\n%s", tests.Failed, diff) } t.Logf("\t%s\tSwitchAccount parse claims from token ok.", tests.Success) @@ -256,3 +265,724 @@ func TestUserResetPassword(t *testing.T) { t.Logf("\t%s\tAuthenticate ok.", tests.Success) } } + +// TestSwitchAccount validates the behavior around allowing users to switch between their accounts. +func TestSwitchAccount(t *testing.T) { + defer tests.Recover(t) + + // Auth tokens are valid for an our and is verified against current time. + // Issue the token one hour ago. + now := time.Now().Add(time.Hour * -1) + + ctx := tests.Context() + + type authTest struct { + name string + root *user_account.MockUserAccountResponse + switch1Req SwitchAccountRequest + switch1Roles []user_account.UserAccountRole + switch1Scopes []string + switch1Err error + switch2Req SwitchAccountRequest + switch2Roles []user_account.UserAccountRole + switch2Scopes []string + switch2Err error + } + + var authTests []authTest + + // Test all the combinations there the user has access to all three accounts. + if true { + for _, roles := range [][]user_account.UserAccountRole{ + []user_account.UserAccountRole{user_account.UserAccountRole_Admin, user_account.UserAccountRole_Admin, user_account.UserAccountRole_Admin}, + []user_account.UserAccountRole{user_account.UserAccountRole_User, user_account.UserAccountRole_User, user_account.UserAccountRole_User}, + []user_account.UserAccountRole{user_account.UserAccountRole_Admin, user_account.UserAccountRole_User, user_account.UserAccountRole_Admin}, + []user_account.UserAccountRole{user_account.UserAccountRole_User, user_account.UserAccountRole_Admin, user_account.UserAccountRole_User}, + } { + // Create a new user for testing. + usrAcc, err := user_account.MockUserAccount(ctx, test.MasterDB, now, roles[0]) + if err != nil { + t.Log("\t\tGot :", err) + t.Fatalf("\t%s\tCreate user account failed.", tests.Failed) + } + + // Create the second account. + now = now.Add(time.Minute) + acc2, err := account.MockAccount(ctx, test.MasterDB, now) + if err != nil { + t.Log("\t\tGot :", err) + t.Fatalf("\t%s\tCreate second account failed.", tests.Failed) + } + + // Associate the second account with root user. + usrAcc2, err := user_account.Create(ctx, auth.Claims{}, test.MasterDB, user_account.UserAccountCreateRequest{ + UserID: usrAcc.UserID, + AccountID: acc2.ID, + Roles: []user_account.UserAccountRole{user_account.UserAccountRole(roles[1])}, + }, now) + if err != nil { + t.Log("\t\tGot :", err) + t.Fatalf("\t%s\tLinking second account to user failed.", tests.Failed) + } + + // Create the third account. + now = now.Add(time.Minute) + acc3, err := account.MockAccount(ctx, test.MasterDB, now) + if err != nil { + t.Log("\t\tGot :", err) + t.Fatalf("\t%s\tCreate third account failed.", tests.Failed) + } + + // Associate the third account with root user. + usrAcc3, err := user_account.Create(ctx, auth.Claims{}, test.MasterDB, user_account.UserAccountCreateRequest{ + UserID: usrAcc.UserID, + AccountID: acc3.ID, + Roles: []user_account.UserAccountRole{user_account.UserAccountRole(roles[2])}, + }, now) + if err != nil { + t.Log("\t\tGot :", err) + t.Fatalf("\t%s\tLinking third account to user failed.", tests.Failed) + } + + authTests = append(authTests, authTest{ + name: fmt.Sprintf("Root account role %s -> role %s account 2 -> role %s account 3.", + roles[0], roles[1], roles[2]), + root: usrAcc, + switch1Req: SwitchAccountRequest{AccountID: acc2.ID}, + switch1Roles: usrAcc2.Roles, + switch1Err: nil, + switch2Req: SwitchAccountRequest{AccountID: acc3.ID}, + switch2Err: nil, + switch2Roles: usrAcc3.Roles, + }) + } + } + + // Root account 1 -> invalid account 2 + if true { + // Create a new user for testing. + usrAcc, err := user_account.MockUserAccount(ctx, test.MasterDB, now, user_account.UserAccountRole_Admin) + if err != nil { + t.Log("\t\tGot :", err) + t.Fatalf("\t%s\tCreate user account failed.", tests.Failed) + } + + // Create the second account and don't associate it with the root user. + now = now.Add(time.Minute) + acc2, err := account.MockAccount(ctx, test.MasterDB, now) + if err != nil { + t.Log("\t\tGot :", err) + t.Fatalf("\t%s\tCreate second account failed.", tests.Failed) + } + + authTests = append(authTests, authTest{ + name: "Root account 1 -> invalid account 2.", + root: usrAcc, + switch1Req: SwitchAccountRequest{AccountID: acc2.ID}, + switch1Err: ErrAuthenticationFailure, + }) + } + + // Root account 1 -> valid account 2 with scopes -> valid account 3 with invalid scope. + if true { + // Create a new user for testing. + usrAcc, err := user_account.MockUserAccount(ctx, test.MasterDB, now, user_account.UserAccountRole_Admin) + if err != nil { + t.Log("\t\tGot :", err) + t.Fatalf("\t%s\tCreate user account failed.", tests.Failed) + } + + // Create the second account. + now = now.Add(time.Minute) + acc2, err := account.MockAccount(ctx, test.MasterDB, now) + if err != nil { + t.Log("\t\tGot :", err) + t.Fatalf("\t%s\tCreate second account failed.", tests.Failed) + } + + // Associate the second account with root user. + usrAcc2, err := user_account.Create(ctx, auth.Claims{}, test.MasterDB, user_account.UserAccountCreateRequest{ + UserID: usrAcc.UserID, + AccountID: acc2.ID, + Roles: []user_account.UserAccountRole{user_account.UserAccountRole_Admin}, + }, now) + if err != nil { + t.Log("\t\tGot :", err) + t.Fatalf("\t%s\tLinking second account to user failed.", tests.Failed) + } + + // Create the third account. + now = now.Add(time.Minute) + acc3, err := account.MockAccount(ctx, test.MasterDB, now) + if err != nil { + t.Log("\t\tGot :", err) + t.Fatalf("\t%s\tCreate third account failed.", tests.Failed) + } + + // Associate the third account with root user. + usrAcc3, err := user_account.Create(ctx, auth.Claims{}, test.MasterDB, user_account.UserAccountCreateRequest{ + UserID: usrAcc.UserID, + AccountID: acc3.ID, + Roles: []user_account.UserAccountRole{user_account.UserAccountRole_User}, + }, now) + if err != nil { + t.Log("\t\tGot :", err) + t.Fatalf("\t%s\tLinking third account to user failed.", tests.Failed) + } + + authTests = append(authTests, authTest{ + name: "Root account 1 -> valid account 2 with scopes -> valid account 3 with invalid scope.", + root: usrAcc, + switch1Req: SwitchAccountRequest{AccountID: acc2.ID}, + switch1Roles: usrAcc2.Roles, + switch1Scopes: []string{user_account.UserAccountRole_User.String()}, + switch1Err: nil, + switch2Req: SwitchAccountRequest{AccountID: acc3.ID}, + switch2Roles: usrAcc3.Roles, + switch2Scopes: []string{user_account.UserAccountRole_Admin.String()}, + switch2Err: ErrForbidden, + }) + } + + // Add 30 minutes to now to simulate time passing. + now = now.Add(time.Minute * 5) + + tknGen := &auth.MockTokenGenerator{} + + t.Log("Given the need to switch accounts.") + { + for i, authTest := range authTests { + t.Logf("\tTest: %d\tWhen running test: %s", i, authTest.name) + { + // Verify that the user can be authenticated with the created user. + var claims1 auth.Claims + tkn1, err := Authenticate(ctx, test.MasterDB, tknGen, authTest.root.User.Email, authTest.root.User.Password, time.Hour, now) + if err != nil { + t.Log("\t\tGot :", err) + t.Fatalf("\t%s\tAuthenticate user failed.", tests.Failed) + } else { + // Ensure the token string was correctly generated. + claims1, err = tknGen.ParseClaims(tkn1.AccessToken) + if err != nil { + t.Log("\t\tGot :", err) + t.Fatalf("\t%s\tParse claims from token failed.", tests.Failed) + } + expectClaims := tkn1.claims + expectClaims.RootUserID = "" + expectClaims.RootAccountID = "" + expectClaims.Subject = authTest.root.UserID + expectClaims.Audience = authTest.root.AccountID + expectClaims.Roles = rolesStringSlice(authTest.root.Roles) + + if diff := cmpClaims(claims1, expectClaims); diff != "" { + t.Fatalf("\t%s\tExpected parsed claims to match from token. Diff:\n%s", tests.Failed, diff) + } + } + t.Logf("\t%s\tAuthenticate root user with role %v ok.", tests.Success, authTest.root.Roles) + + // Try to switch to account 2. + var claims2 auth.Claims + tkn2, err := SwitchAccount(ctx, test.MasterDB, tknGen, claims1, authTest.switch1Req, time.Hour, now, authTest.switch1Scopes...) + if err != authTest.switch1Err { + if errors.Cause(err) != authTest.switch1Err { + t.Log("\t\tExpected :", authTest.switch1Err) + t.Log("\t\tGot :", err) + t.Fatalf("\t%s\tSwitchAccount account 1 with role %v failed.", tests.Failed, authTest.switch1Roles) + } + } else { + // Ensure the token string was correctly generated. + claims2, err = tknGen.ParseClaims(tkn2.AccessToken) + if err != nil { + t.Log("\t\tGot :", err) + t.Fatalf("\t%s\tParse claims from token failed.", tests.Failed) + } + expectClaims := tkn2.claims + expectClaims.RootUserID = authTest.root.UserID + expectClaims.RootAccountID = authTest.switch1Req.AccountID + expectClaims.Subject = authTest.root.UserID + expectClaims.Audience = authTest.switch1Req.AccountID + + if len(authTest.switch1Scopes) > 0 { + expectClaims.Roles = authTest.switch1Scopes + } else { + expectClaims.Roles = rolesStringSlice(authTest.switch1Roles) + } + + if diff := cmpClaims(claims2, expectClaims); diff != "" { + t.Fatalf("\t%s\tExpected parsed claims to match from token. Diff:\n%s", tests.Failed, diff) + } + } + t.Logf("\t%s\tSwitchAccount account 1 with role %v ok.", tests.Success, authTest.switch1Roles) + + // If the user can't login, don't need to test any further. + if authTest.switch1Err != nil || authTest.switch2Req.AccountID == "" { + continue + } + + // Try to switch to account 3. + tkn3, err := SwitchAccount(ctx, test.MasterDB, tknGen, claims2, authTest.switch2Req, time.Hour, now, authTest.switch2Scopes...) + if err != authTest.switch2Err { + if errors.Cause(err) != authTest.switch2Err { + t.Log("\t\tExpected :", authTest.switch2Err) + t.Log("\t\tGot :", err) + t.Fatalf("\t%s\tSwitchAccount account 2 with role %v failed.", tests.Failed, authTest.switch2Roles) + } + } else { + // Ensure the token string was correctly generated. + claims3, err := tknGen.ParseClaims(tkn3.AccessToken) + if err != nil { + t.Log("\t\tGot :", err) + t.Fatalf("\t%s\tParse claims from token failed.", tests.Failed) + } + expectClaims := tkn3.claims + expectClaims.RootUserID = authTest.root.UserID + expectClaims.RootAccountID = authTest.switch2Req.AccountID + expectClaims.Subject = authTest.root.UserID + expectClaims.Audience = authTest.switch2Req.AccountID + + if len(authTest.switch2Scopes) > 0 { + expectClaims.Roles = authTest.switch2Scopes + } else { + expectClaims.Roles = rolesStringSlice(authTest.switch2Roles) + } + + if diff := cmpClaims(claims3, expectClaims); diff != "" { + t.Fatalf("\t%s\tExpected parsed claims to match from token. Diff:\n%s", tests.Failed, diff) + } + } + t.Logf("\t%s\tSwitchAccount account 2 with role %v ok.", tests.Success, authTest.switch2Roles) + } + } + } +} + +// TestVirtualLogin validates the behavior around allowing users to virtual login users. +func TestVirtualLogin(t *testing.T) { + defer tests.Recover(t) + + // Auth tokens are valid for an our and is verified against current time. + // Issue the token one hour ago. + now := time.Now().Add(time.Hour * -1) + + ctx := tests.Context() + + type authTest struct { + name string + root *user_account.MockUserAccountResponse + login1Req VirtualLoginRequest + login1Err error + login1Role user_account.UserAccountRole + login2Req VirtualLoginRequest + login2Err error + login2Role user_account.UserAccountRole + login2Logout bool + } + + var authTests []authTest + + // Root admin -> role admin -> role admin + if true { + // Create a new user for testing. + usrAcc, err := user_account.MockUserAccount(ctx, test.MasterDB, now, user_account.UserAccountRole_Admin) + if err != nil { + t.Log("\t\tGot :", err) + t.Fatalf("\t%s\tCreate user account failed.", tests.Failed) + } + + usr2, err := user.MockUser(ctx, test.MasterDB, now) + if err != nil { + t.Log("\t\tGot :", err) + t.Fatalf("\t%s\tCreate second account failed.", tests.Failed) + } + + // Associate second user with basic role associated with the same account. + usrAcc2, err := user_account.Create(ctx, auth.Claims{}, test.MasterDB, user_account.UserAccountCreateRequest{ + UserID: usr2.ID, + AccountID: usrAcc.AccountID, + Roles: []user_account.UserAccountRole{user_account.UserAccountRole(user_account.UserAccountRole_Admin)}, + }, now) + if err != nil { + t.Log("\t\tGot :", err) + t.Fatalf("\t%s\tLinking second user to account failed.", tests.Failed) + } + + usr3, err := user.MockUser(ctx, test.MasterDB, now) + if err != nil { + t.Log("\t\tGot :", err) + t.Fatalf("\t%s\tCreate second account failed.", tests.Failed) + } + + // Associate second user with basic role associated with the same account. + usrAcc3, err := user_account.Create(ctx, auth.Claims{}, test.MasterDB, user_account.UserAccountCreateRequest{ + UserID: usr3.ID, + AccountID: usrAcc.AccountID, + Roles: []user_account.UserAccountRole{user_account.UserAccountRole(user_account.UserAccountRole_Admin)}, + }, now) + if err != nil { + t.Log("\t\tGot :", err) + t.Fatalf("\t%s\tLinking third user to account failed.", tests.Failed) + } + + authTests = append(authTests, authTest{ + name: "Root admin -> role admin -> role admin", + root: usrAcc, + login1Req: VirtualLoginRequest{ + UserID: usr2.ID, + AccountID: usrAcc.AccountID, + }, + login1Role: usrAcc2.Roles[0], + login1Err: nil, + login2Req: VirtualLoginRequest{ + UserID: usr3.ID, + AccountID: usrAcc.AccountID, + }, + login2Err: nil, + login2Role: usrAcc3.Roles[0], + login2Logout: true, + }) + } + + // Root admin -> role admin -> role user + if true { + // Create a new user for testing. + usrAcc, err := user_account.MockUserAccount(ctx, test.MasterDB, now, user_account.UserAccountRole_Admin) + if err != nil { + t.Log("\t\tGot :", err) + t.Fatalf("\t%s\tCreate user account failed.", tests.Failed) + } + + usr2, err := user.MockUser(ctx, test.MasterDB, now) + if err != nil { + t.Log("\t\tGot :", err) + t.Fatalf("\t%s\tCreate second account failed.", tests.Failed) + } + + // Associate second user with basic role associated with the same account. + usrAcc2, err := user_account.Create(ctx, auth.Claims{}, test.MasterDB, user_account.UserAccountCreateRequest{ + UserID: usr2.ID, + AccountID: usrAcc.AccountID, + Roles: []user_account.UserAccountRole{user_account.UserAccountRole(user_account.UserAccountRole_Admin)}, + }, now) + if err != nil { + t.Log("\t\tGot :", err) + t.Fatalf("\t%s\tLinking second user to account failed.", tests.Failed) + } + + usr3, err := user.MockUser(ctx, test.MasterDB, now) + if err != nil { + t.Log("\t\tGot :", err) + t.Fatalf("\t%s\tCreate second account failed.", tests.Failed) + } + + // Associate second user with basic role associated with the same account. + usrAcc3, err := user_account.Create(ctx, auth.Claims{}, test.MasterDB, user_account.UserAccountCreateRequest{ + UserID: usr3.ID, + AccountID: usrAcc.AccountID, + Roles: []user_account.UserAccountRole{user_account.UserAccountRole(user_account.UserAccountRole_User)}, + }, now) + if err != nil { + t.Log("\t\tGot :", err) + t.Fatalf("\t%s\tLinking third user to account failed.", tests.Failed) + } + + authTests = append(authTests, authTest{ + name: "Root admin -> role admin -> role user", + root: usrAcc, + login1Req: VirtualLoginRequest{ + UserID: usr2.ID, + AccountID: usrAcc.AccountID, + }, + login1Err: nil, + login1Role: usrAcc2.Roles[0], + login2Req: VirtualLoginRequest{ + UserID: usr3.ID, + AccountID: usrAcc.AccountID, + }, + login2Err: nil, + login2Role: usrAcc3.Roles[0], + login2Logout: true, + }) + } + + // Root admin -> role user -> role admin + if true { + // Create a new user for testing. + usrAcc, err := user_account.MockUserAccount(ctx, test.MasterDB, now, user_account.UserAccountRole_Admin) + if err != nil { + t.Log("\t\tGot :", err) + t.Fatalf("\t%s\tCreate user account failed.", tests.Failed) + } + + usr2, err := user.MockUser(ctx, test.MasterDB, now) + if err != nil { + t.Log("\t\tGot :", err) + t.Fatalf("\t%s\tCreate second account failed.", tests.Failed) + } + + // Associate second user with basic role associated with the same account. + usrAcc2, err := user_account.Create(ctx, auth.Claims{}, test.MasterDB, user_account.UserAccountCreateRequest{ + UserID: usr2.ID, + AccountID: usrAcc.AccountID, + Roles: []user_account.UserAccountRole{user_account.UserAccountRole(user_account.UserAccountRole_User)}, + }, now) + if err != nil { + t.Log("\t\tGot :", err) + t.Fatalf("\t%s\tLinking second user to account failed.", tests.Failed) + } + + usr3, err := user.MockUser(ctx, test.MasterDB, now) + if err != nil { + t.Log("\t\tGot :", err) + t.Fatalf("\t%s\tCreate second account failed.", tests.Failed) + } + + // Associate second user with basic role associated with the same account. + usrAcc3, err := user_account.Create(ctx, auth.Claims{}, test.MasterDB, user_account.UserAccountCreateRequest{ + UserID: usr3.ID, + AccountID: usrAcc.AccountID, + Roles: []user_account.UserAccountRole{user_account.UserAccountRole(user_account.UserAccountRole_Admin)}, + }, now) + if err != nil { + t.Log("\t\tGot :", err) + t.Fatalf("\t%s\tLinking third user to account failed.", tests.Failed) + } + + authTests = append(authTests, authTest{ + name: "Root admin -> role user -> role admin", + root: usrAcc, + login1Req: VirtualLoginRequest{ + UserID: usr2.ID, + AccountID: usrAcc.AccountID, + }, + login1Err: nil, + login1Role: usrAcc2.Roles[0], + login2Req: VirtualLoginRequest{ + UserID: usr3.ID, + AccountID: usrAcc.AccountID, + }, + login2Err: ErrForbidden, + login2Role: usrAcc3.Roles[0], + login2Logout: true, + }) + } + + // Root user -> role admin + if true { + // Create a new user for testing. + usrAcc, err := user_account.MockUserAccount(ctx, test.MasterDB, now, user_account.UserAccountRole_User) + if err != nil { + t.Log("\t\tGot :", err) + t.Fatalf("\t%s\tCreate user account failed.", tests.Failed) + } + + usr2, err := user.MockUser(ctx, test.MasterDB, now) + if err != nil { + t.Log("\t\tGot :", err) + t.Fatalf("\t%s\tCreate second account failed.", tests.Failed) + } + + // Associate second user with basic role associated with the same account. + usrAcc2, err := user_account.Create(ctx, auth.Claims{}, test.MasterDB, user_account.UserAccountCreateRequest{ + UserID: usr2.ID, + AccountID: usrAcc.AccountID, + Roles: []user_account.UserAccountRole{user_account.UserAccountRole(user_account.UserAccountRole_Admin)}, + }, now) + if err != nil { + t.Log("\t\tGot :", err) + t.Fatalf("\t%s\tLinking second user to account failed.", tests.Failed) + } + + authTests = append(authTests, authTest{ + name: "Root user -> role admin", + root: usrAcc, + login1Req: VirtualLoginRequest{ + UserID: usr2.ID, + AccountID: usrAcc.AccountID, + }, + login1Err: ErrForbidden, + login1Role: usrAcc2.Roles[0], + login2Logout: true, + }) + } + + // Root user -> role user + if true { + // Create a new user for testing. + usrAcc, err := user_account.MockUserAccount(ctx, test.MasterDB, now, user_account.UserAccountRole_User) + if err != nil { + t.Log("\t\tGot :", err) + t.Fatalf("\t%s\tCreate user account failed.", tests.Failed) + } + + usr2, err := user.MockUser(ctx, test.MasterDB, now) + if err != nil { + t.Log("\t\tGot :", err) + t.Fatalf("\t%s\tCreate second account failed.", tests.Failed) + } + + // Associate second user with basic role associated with the same account. + usrAcc2, err := user_account.Create(ctx, auth.Claims{}, test.MasterDB, user_account.UserAccountCreateRequest{ + UserID: usr2.ID, + AccountID: usrAcc.AccountID, + Roles: []user_account.UserAccountRole{user_account.UserAccountRole(user_account.UserAccountRole_User)}, + }, now) + if err != nil { + t.Log("\t\tGot :", err) + t.Fatalf("\t%s\tLinking second user to account failed.", tests.Failed) + } + + authTests = append(authTests, authTest{ + name: "Root user -> role admin", + root: usrAcc, + login1Req: VirtualLoginRequest{ + UserID: usr2.ID, + AccountID: usrAcc.AccountID, + }, + login1Err: ErrForbidden, + login1Role: usrAcc2.Roles[0], + login2Logout: true, + }) + } + + // Add 30 minutes to now to simulate time passing. + now = now.Add(time.Minute * 5) + + tknGen := &auth.MockTokenGenerator{} + + t.Log("Given the need to virtual login.") + { + for i, authTest := range authTests { + t.Logf("\tTest: %d\tWhen running test: %s", i, authTest.name) + { + // Verify that the user can be authenticated with the created user. + var claims1 auth.Claims + tkn1, err := Authenticate(ctx, test.MasterDB, tknGen, authTest.root.User.Email, authTest.root.User.Password, time.Hour, now) + if err != nil { + t.Log("\t\tGot :", err) + t.Fatalf("\t%s\tAuthenticate user failed.", tests.Failed) + } else { + // Ensure the token string was correctly generated. + claims1, err = tknGen.ParseClaims(tkn1.AccessToken) + if err != nil { + t.Log("\t\tGot :", err) + t.Fatalf("\t%s\tParse claims from token failed.", tests.Failed) + } + expectClaims := tkn1.claims + expectClaims.RootUserID = "" + expectClaims.RootAccountID = "" + expectClaims.Subject = authTest.root.UserID + expectClaims.Audience = authTest.root.AccountID + + // Hack for Unhandled Exception in go-cmp@v0.3.0/cmp/options.go:229 + if diff := cmpClaims(claims1, expectClaims); diff != "" { + t.Fatalf("\t%s\tExpected parsed claims to match from token. Diff:\n%s", tests.Failed, diff) + } + } + t.Logf("\t%s\tAuthenticate root user with role %s ok.", tests.Success, authTest.root.Roles[0]) + + // Try virtual login to user 2. + var claims2 auth.Claims + tkn2, err := VirtualLogin(ctx, test.MasterDB, tknGen, claims1, authTest.login1Req, time.Hour, now) + if err != authTest.login1Err { + if errors.Cause(err) != authTest.login1Err { + t.Log("\t\tExpected :", authTest.login1Err) + t.Log("\t\tGot :", err) + t.Fatalf("\t%s\tVirtualLogin user 1 with role %s failed.", tests.Failed, authTest.login1Role) + } + } else { + // Ensure the token string was correctly generated. + claims2, err = tknGen.ParseClaims(tkn2.AccessToken) + if err != nil { + t.Log("\t\tGot :", err) + t.Fatalf("\t%s\tParse claims from token failed.", tests.Failed) + } + expectClaims := tkn2.claims + expectClaims.RootUserID = authTest.root.UserID + expectClaims.RootAccountID = authTest.root.AccountID + expectClaims.Subject = authTest.login1Req.UserID + expectClaims.Audience = authTest.login1Req.AccountID + + // Hack for Unhandled Exception in go-cmp@v0.3.0/cmp/options.go:229 + if diff := cmpClaims(claims2, expectClaims); diff != "" { + t.Fatalf("\t%s\tExpected parsed claims to match from token. Diff:\n%s", tests.Failed, diff) + } + } + t.Logf("\t%s\tVirtualLogin user 1 with role %s ok.", tests.Success, authTest.login1Role) + + // If the user can't login, don't need to test any further. + if authTest.login1Err != nil { + continue + } + + // Try virtual login to user 3. + tkn3, err := VirtualLogin(ctx, test.MasterDB, tknGen, claims2, authTest.login2Req, time.Hour, now) + if err != authTest.login2Err { + if errors.Cause(err) != authTest.login2Err { + t.Log("\t\tExpected :", authTest.login2Err) + t.Log("\t\tGot :", err) + t.Fatalf("\t%s\tVirtualLogin user 2 with role %s failed.", tests.Failed, authTest.login2Role) + } + } else { + // Ensure the token string was correctly generated. + claims3, err := tknGen.ParseClaims(tkn3.AccessToken) + if err != nil { + t.Log("\t\tGot :", err) + t.Fatalf("\t%s\tParse claims from token failed.", tests.Failed) + } + expectClaims := tkn3.claims + expectClaims.RootUserID = authTest.root.UserID + expectClaims.RootAccountID = authTest.root.AccountID + expectClaims.Subject = authTest.login2Req.UserID + expectClaims.Audience = authTest.login2Req.AccountID + + // Hack for Unhandled Exception in go-cmp@v0.3.0/cmp/options.go:229 + if diff := cmpClaims(claims3, expectClaims); diff != "" { + t.Fatalf("\t%s\tExpected parsed claims to match from token. Diff:\n%s", tests.Failed, diff) + } + } + t.Logf("\t%s\tVirtualLogin user 2 with role %s ok.", tests.Success, authTest.login2Role) + + if authTest.login2Logout { + tknOut, err := VirtualLogout(ctx, test.MasterDB, tknGen, claims2, time.Hour, now) + if err != nil { + t.Log("\t\tGot :", err) + t.Fatalf("\t%s\tVirtualLogout user 2 failed.", tests.Failed) + } + + // Ensure the token string was correctly generated. + claimsOut, err := tknGen.ParseClaims(tknOut.AccessToken) + if err != nil { + t.Log("\t\tGot :", err) + t.Fatalf("\t%s\tParse claims from token failed.", tests.Failed) + } + expectClaims := tknOut.claims + expectClaims.RootUserID = authTest.root.UserID + expectClaims.RootAccountID = authTest.root.AccountID + expectClaims.Subject = authTest.root.UserID + expectClaims.Audience = authTest.root.AccountID + + if diff := cmpClaims(claimsOut, expectClaims); diff != "" { + t.Fatalf("\t%s\tExpected parsed claims to match from token. Diff:\n%s", tests.Failed, diff) + } + t.Logf("\t%s\tVirtualLogout user 2 with role %s ok.", tests.Success, authTest.login2Role) + } + } + } + } +} + +// rolesStringSlice converts a list of roles to a string slice. +func rolesStringSlice(roles []user_account.UserAccountRole) []string { + var l []string + for _, r := range roles { + l = append(l, string(r)) + } + return l +} + +// cmpClaims is a hack for Unhandled Exception in go-cmp@v0.3.0/cmp/options.go:229 +func cmpClaims(actualClaims, expectedclaims auth.Claims) string { + dat1, _ := json.Marshal(actualClaims) + dat2, _ := json.Marshal(expectedclaims) + return cmp.Diff(string(dat1), string(dat2)) +} diff --git a/internal/user_auth/models.go b/internal/user_auth/models.go index 09411d3..31bd1ef 100644 --- a/internal/user_auth/models.go +++ b/internal/user_auth/models.go @@ -35,6 +35,17 @@ type Token struct { AccountID string `json:"account_id"example:"c4653bf9-5978-48b7-89c5-95704aebb7e2"` } +// SwitchAccountRequest defines the information for the current user to switch between their accounts +type SwitchAccountRequest struct { + AccountID string `json:"account_id" validate:"required,uuid" example:"c4653bf9-5978-48b7-89c5-95704aebb7e2"` +} + +// VirtualLoginRequest defines the information virtual login to a user / account. +type VirtualLoginRequest struct { + UserID string `json:"user_id" validate:"required,uuid" example:"d69bdef7-173f-4d29-b52c-3edc60baf6a2"` + AccountID string `json:"account_id" validate:"required,uuid" example:"c4653bf9-5978-48b7-89c5-95704aebb7e2"` +} + // AuthorizationHeader returns the header authorization value. func (t Token) AuthorizationHeader() string { return "Bearer " + t.AccessToken