diff --git a/cmd/web-api/tests/user_test.go b/cmd/web-api/tests/user_test.go
index a3f1198..19c3902 100644
--- a/cmd/web-api/tests/user_test.go
+++ b/cmd/web-api/tests/user_test.go
@@ -4,7 +4,6 @@ import (
"context"
"encoding/json"
"fmt"
- "geeks-accelerator/oss/saas-starter-kit/internal/user_auth"
"net/http"
"net/http/httptest"
"testing"
@@ -17,6 +16,7 @@ import (
"geeks-accelerator/oss/saas-starter-kit/internal/platform/web/weberror"
"geeks-accelerator/oss/saas-starter-kit/internal/user"
"geeks-accelerator/oss/saas-starter-kit/internal/user_account"
+ "geeks-accelerator/oss/saas-starter-kit/internal/user_auth"
"github.com/pborman/uuid"
)
diff --git a/cmd/web-app/handlers/account.go b/cmd/web-app/handlers/account.go
index 0aa935f..f43a11e 100644
--- a/cmd/web-app/handlers/account.go
+++ b/cmd/web-app/handlers/account.go
@@ -19,8 +19,9 @@ import (
// Account represents the Account API method handler set.
type Account struct {
- MasterDB *sqlx.DB
- Renderer web.Renderer
+ MasterDB *sqlx.DB
+ Renderer web.Renderer
+ Authenticator *auth.Authenticator
}
// View handles displaying the current account profile.
@@ -34,7 +35,7 @@ func (h *Account) View(ctx context.Context, w http.ResponseWriter, r *http.Reque
return err
}
- acc, err := account.Read(ctx, claims, h.MasterDB, claims.Audience, false)
+ acc, err := account.ReadByID(ctx, claims, h.MasterDB, claims.Audience)
if err != nil {
return err
}
@@ -127,7 +128,11 @@ func (h *Account) Update(ctx context.Context, w http.ResponseWriter, r *http.Req
}
}
- sess := webcontext.ContextSession(ctx)
+ var updateClaims bool
+ if req.Timezone != nil && claims.Preferences.Timezone != *req.Timezone {
+ claims.Preferences.Timezone = *req.Timezone
+ updateClaims = true
+ }
if preferenceDatetimeFormat != req.PreferenceDatetimeFormat {
err = account_preference.Set(ctx, claims, h.MasterDB, account_preference.AccountPreferenceSetRequest{
@@ -144,7 +149,10 @@ func (h *Account) Update(ctx context.Context, w http.ResponseWriter, r *http.Req
}
}
- sess.Values[webcontext.SessionKeyPreferenceDatetimeFormat] = req.PreferenceDatetimeFormat
+ if claims.Preferences.DatetimeFormat != req.PreferenceDatetimeFormat {
+ claims.Preferences.DatetimeFormat = req.PreferenceDatetimeFormat
+ updateClaims = true
+ }
}
if preferenceDateFormat != req.PreferenceDateFormat {
@@ -162,7 +170,10 @@ func (h *Account) Update(ctx context.Context, w http.ResponseWriter, r *http.Req
}
}
- sess.Values[webcontext.SessionKeyPreferenceDateFormat] = req.PreferenceDateFormat
+ if claims.Preferences.DateFormat != req.PreferenceDateFormat {
+ claims.Preferences.DateFormat = req.PreferenceDateFormat
+ updateClaims = true
+ }
}
if preferenceTimeFormat != req.PreferenceTimeFormat {
@@ -180,7 +191,18 @@ func (h *Account) Update(ctx context.Context, w http.ResponseWriter, r *http.Req
}
}
- sess.Values[webcontext.SessionKeyPreferenceTimeFormat] = req.PreferenceTimeFormat
+ if claims.Preferences.TimeFormat != req.PreferenceTimeFormat {
+ claims.Preferences.TimeFormat = req.PreferenceTimeFormat
+ updateClaims = true
+ }
+ }
+
+ // Update the access token to include the updated claims.
+ if updateClaims {
+ ctx, err = updateContextClaims(ctx, h.Authenticator, claims)
+ if err != nil {
+ return false, err
+ }
}
// Display a success message to the user.
@@ -196,7 +218,7 @@ func (h *Account) Update(ctx context.Context, w http.ResponseWriter, r *http.Req
return true, nil
}
- acc, err := account.Read(ctx, claims, h.MasterDB, claims.Audience, false)
+ acc, err := account.ReadByID(ctx, claims, h.MasterDB, claims.Audience)
if err != nil {
return false, err
}
diff --git a/cmd/web-app/handlers/routes.go b/cmd/web-app/handlers/routes.go
index 0f86f7e..79013a5 100644
--- a/cmd/web-app/handlers/routes.go
+++ b/cmd/web-app/handlers/routes.go
@@ -66,13 +66,21 @@ func APP(shutdown chan os.Signal, log *log.Logger, env webcontext.Env, staticDir
app.Handle("POST", "/user/update", u.Update, mid.AuthenticateSessionRequired(authenticator), mid.HasAuth())
app.Handle("GET", "/user/update", u.Update, mid.AuthenticateSessionRequired(authenticator), mid.HasAuth())
app.Handle("GET", "/user/account", u.Account, mid.AuthenticateSessionRequired(authenticator), mid.HasAuth())
+ app.Handle("GET", "/user/virtual-login/:user_id", u.VirtualLogin, mid.AuthenticateSessionRequired(authenticator), mid.HasRole(auth.RoleAdmin))
+ app.Handle("POST", "/user/virtual-login", u.VirtualLogin, mid.AuthenticateSessionRequired(authenticator), mid.HasRole(auth.RoleAdmin))
+ app.Handle("GET", "/user/virtual-login", u.VirtualLogin, mid.AuthenticateSessionRequired(authenticator), mid.HasRole(auth.RoleAdmin))
+ app.Handle("GET", "/user/virtual-logout", u.VirtualLogout, mid.AuthenticateSessionRequired(authenticator), mid.HasAuth())
+ app.Handle("GET", "/user/switch-account/:account_id", u.SwitchAccount, mid.AuthenticateSessionRequired(authenticator), mid.HasAuth())
+ app.Handle("POST", "/user/switch-account", u.SwitchAccount, mid.AuthenticateSessionRequired(authenticator), mid.HasAuth())
+ app.Handle("GET", "/user/switch-account", u.SwitchAccount, mid.AuthenticateSessionRequired(authenticator), mid.HasAuth())
app.Handle("POST", "/user", u.View, mid.AuthenticateSessionRequired(authenticator), mid.HasAuth())
app.Handle("GET", "/user", u.View, mid.AuthenticateSessionRequired(authenticator), mid.HasAuth())
// Register account management endpoints.
acc := Account{
- MasterDB: masterDB,
- Renderer: renderer,
+ MasterDB: masterDB,
+ Renderer: renderer,
+ Authenticator: authenticator,
}
app.Handle("POST", "/account/update", acc.Update, mid.AuthenticateSessionRequired(authenticator), mid.HasRole(auth.RoleAdmin))
app.Handle("GET", "/account/update", acc.Update, mid.AuthenticateSessionRequired(authenticator), mid.HasRole(auth.RoleAdmin))
diff --git a/cmd/web-app/handlers/signup.go b/cmd/web-app/handlers/signup.go
index 0790b4f..9ebeabf 100644
--- a/cmd/web-app/handlers/signup.go
+++ b/cmd/web-app/handlers/signup.go
@@ -12,7 +12,7 @@ import (
"geeks-accelerator/oss/saas-starter-kit/internal/platform/web/webcontext"
"geeks-accelerator/oss/saas-starter-kit/internal/platform/web/weberror"
"geeks-accelerator/oss/saas-starter-kit/internal/signup"
- "geeks-accelerator/oss/saas-starter-kit/internal/user"
+ "geeks-accelerator/oss/saas-starter-kit/internal/user_auth"
"github.com/gorilla/schema"
"github.com/jmoiron/sqlx"
"github.com/pkg/errors"
@@ -68,7 +68,7 @@ func (h *Signup) Step1(ctx context.Context, w http.ResponseWriter, r *http.Reque
}
// Authenticated the new user.
- token, err := user.Authenticate(ctx, h.MasterDB, h.Authenticator, req.User.Email, req.User.Password, time.Hour, ctxValues.Now)
+ token, err := user_auth.Authenticate(ctx, h.MasterDB, h.Authenticator, req.User.Email, req.User.Password, time.Hour, ctxValues.Now)
if err != nil {
return err
}
@@ -93,13 +93,6 @@ func (h *Signup) Step1(ctx context.Context, w http.ResponseWriter, r *http.Reque
return nil
}
- data["geonameCountries"] = geonames.ValidGeonameCountries
-
- data["countries"], err = geonames.FindCountries(ctx, h.MasterDB, "name", "")
- if err != nil {
- return err
- }
-
return nil
}
@@ -107,6 +100,13 @@ func (h *Signup) Step1(ctx context.Context, w http.ResponseWriter, r *http.Reque
return web.RenderError(ctx, w, r, err, h.Renderer, TmplLayoutBase, TmplContentErrorGeneric, web.MIMETextHTMLCharsetUTF8)
}
+ data["geonameCountries"] = geonames.ValidGeonameCountries
+
+ data["countries"], err = geonames.FindCountries(ctx, h.MasterDB, "name", "")
+ if err != nil {
+ return err
+ }
+
data["form"] = req
if verr, ok := weberror.NewValidationError(ctx, webcontext.Validator().Struct(signup.SignupRequest{})); ok {
diff --git a/cmd/web-app/handlers/user.go b/cmd/web-app/handlers/user.go
index 8bb03a5..a139866 100644
--- a/cmd/web-app/handlers/user.go
+++ b/cmd/web-app/handlers/user.go
@@ -4,6 +4,8 @@ import (
"context"
"fmt"
"net/http"
+ "net/url"
+ "strings"
"time"
"geeks-accelerator/oss/saas-starter-kit/internal/account"
@@ -16,6 +18,7 @@ import (
project_routes "geeks-accelerator/oss/saas-starter-kit/internal/project-routes"
"geeks-accelerator/oss/saas-starter-kit/internal/user"
"geeks-accelerator/oss/saas-starter-kit/internal/user_account"
+ "geeks-accelerator/oss/saas-starter-kit/internal/user_auth"
"github.com/gorilla/schema"
"github.com/gorilla/sessions"
"github.com/jmoiron/sqlx"
@@ -34,7 +37,7 @@ type User struct {
}
type UserLoginRequest struct {
- user.AuthenticateRequest
+ user_auth.AuthenticateRequest
RememberMe bool
}
@@ -68,7 +71,7 @@ func (h *User) Login(ctx context.Context, w http.ResponseWriter, r *http.Request
}
// Authenticated the user.
- token, err := user.Authenticate(ctx, h.MasterDB, h.Authenticator, req.Email, req.Password, sessionTTL, ctxValues.Now)
+ token, err := user_auth.Authenticate(ctx, h.MasterDB, h.Authenticator, req.Email, req.Password, sessionTTL, ctxValues.Now)
if err != nil {
switch errors.Cause(err) {
case user.ErrForbidden:
@@ -89,8 +92,16 @@ func (h *User) Login(ctx context.Context, w http.ResponseWriter, r *http.Request
return err
}
+ redirectUri := "/"
+ if qv := r.URL.Query().Get("redirect"); qv != "" {
+ redirectUri, err = url.QueryUnescape(qv)
+ if err != nil {
+ return err
+ }
+ }
+
// Redirect the user to the dashboard.
- http.Redirect(w, r, "/", http.StatusFound)
+ http.Redirect(w, r, redirectUri, http.StatusFound)
}
return nil
@@ -110,7 +121,7 @@ func (h *User) Login(ctx context.Context, w http.ResponseWriter, r *http.Request
}
// handleSessionToken persists the access token to the session for request authentication.
-func handleSessionToken(ctx context.Context, db *sqlx.DB, w http.ResponseWriter, r *http.Request, token user.Token) error {
+func handleSessionToken(ctx context.Context, db *sqlx.DB, w http.ResponseWriter, r *http.Request, token user_auth.Token) error {
if token.AccessToken == "" {
return errors.New("accessToken is required.")
}
@@ -252,7 +263,7 @@ func (h *User) ResetConfirm(ctx context.Context, w http.ResponseWriter, r *http.
}
// Authenticated the user. Probably should use the default session TTL from UserLogin.
- token, err := user.Authenticate(ctx, h.MasterDB, h.Authenticator, u.Email, req.Password, time.Hour, ctxValues.Now)
+ token, err := user_auth.Authenticate(ctx, h.MasterDB, h.Authenticator, u.Email, req.Password, time.Hour, ctxValues.Now)
if err != nil {
switch errors.Cause(err) {
case account.ErrForbidden:
@@ -306,7 +317,7 @@ func (h *User) View(ctx context.Context, w http.ResponseWriter, r *http.Request,
return err
}
- usr, err := user.Read(ctx, claims, h.MasterDB, claims.Subject, false)
+ usr, err := user.ReadByID(ctx, claims, h.MasterDB, claims.Subject)
if err != nil {
return err
}
@@ -343,16 +354,15 @@ func (h *User) Update(ctx context.Context, w http.ResponseWriter, r *http.Reques
return err
}
+ claims, err := auth.ClaimsFromContext(ctx)
+ if err != nil {
+ return err
+ }
+
//
req := new(user.UserUpdateRequest)
data := make(map[string]interface{})
f := func() (bool, error) {
-
- claims, err := auth.ClaimsFromContext(ctx)
- if err != nil {
- return false, err
- }
-
if r.Method == http.MethodPost {
err := r.ParseForm()
if err != nil {
@@ -415,25 +425,6 @@ func (h *User) Update(ctx context.Context, w http.ResponseWriter, r *http.Reques
return true, nil
}
- usr, err := user.Read(ctx, claims, h.MasterDB, claims.Subject, false)
- if err != nil {
- return false, err
- }
-
- if req.ID == "" {
- req.FirstName = &usr.FirstName
- req.LastName = &usr.LastName
- req.Email = &usr.Email
- req.Timezone = &usr.Timezone
- }
-
- data["user"] = usr.Response(ctx)
-
- data["timezones"], err = geonames.ListTimezones(ctx, h.MasterDB)
- if err != nil {
- return false, err
- }
-
return false, nil
}
@@ -444,6 +435,25 @@ func (h *User) Update(ctx context.Context, w http.ResponseWriter, r *http.Reques
return nil
}
+ usr, err := user.ReadByID(ctx, claims, h.MasterDB, claims.Subject)
+ if err != nil {
+ return err
+ }
+
+ if req.ID == "" {
+ req.FirstName = &usr.FirstName
+ req.LastName = &usr.LastName
+ req.Email = &usr.Email
+ req.Timezone = &usr.Timezone
+ }
+
+ data["user"] = usr.Response(ctx)
+
+ data["timezones"], err = geonames.ListTimezones(ctx, h.MasterDB)
+ if err != nil {
+ return err
+ }
+
data["form"] = req
if verr, ok := weberror.NewValidationError(ctx, webcontext.Validator().Struct(user.UserUpdateRequest{})); ok {
@@ -468,7 +478,7 @@ func (h *User) Account(ctx context.Context, w http.ResponseWriter, r *http.Reque
return err
}
- acc, err := account.Read(ctx, claims, h.MasterDB, claims.Audience, false)
+ acc, err := account.ReadByID(ctx, claims, h.MasterDB, claims.Audience)
if err != nil {
return err
}
@@ -483,3 +493,352 @@ func (h *User) Account(ctx context.Context, w http.ResponseWriter, r *http.Reque
return h.Renderer.Render(ctx, w, r, TmplLayoutBase, "user-account.gohtml", web.MIMETextHTMLCharsetUTF8, http.StatusOK, data)
}
+
+// VirtualLogin handles switching the scope of the context to another user.
+func (h *User) VirtualLogin(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error {
+
+ ctxValues, err := webcontext.ContextValues(ctx)
+ if err != nil {
+ return err
+ }
+
+ claims, err := auth.ClaimsFromContext(ctx)
+ if err != nil {
+ return err
+ }
+
+ //
+ req := new(user_auth.VirtualLoginRequest)
+ data := make(map[string]interface{})
+ f := func() (bool, error) {
+ if r.Method == http.MethodPost {
+ err := r.ParseForm()
+ if err != nil {
+ return false, err
+ }
+
+ decoder := schema.NewDecoder()
+ decoder.IgnoreUnknownKeys(true)
+
+ if err := decoder.Decode(req, r.PostForm); err != nil {
+ return false, err
+ }
+ } else {
+ if pv, ok := params["user_id"]; ok && pv != "" {
+ req.UserID = pv
+ }
+ }
+
+ if qv := r.URL.Query().Get("account_id"); qv != "" {
+ req.AccountID = qv
+ } else {
+ req.AccountID = claims.Audience
+ }
+
+ if req.UserID != "" {
+ sess := webcontext.ContextSession(ctx)
+ var expires time.Duration
+ if sess != nil && sess.Options != nil {
+ expires = time.Second * time.Duration(sess.Options.MaxAge)
+ } else {
+ expires = time.Hour
+ }
+
+ // Perform the account switch.
+ tkn, err := user_auth.VirtualLogin(ctx, h.MasterDB, h.Authenticator, claims, *req, expires, ctxValues.Now)
+ if err != nil {
+ if verr, ok := weberror.NewValidationError(ctx, err); ok {
+ data["validationErrors"] = verr.(*weberror.Error)
+ return false, nil
+ } else {
+ return false, err
+ }
+ }
+
+ // Update the access token in the session.
+ sess = webcontext.SessionUpdateAccessToken(sess, tkn.AccessToken)
+
+ // Read the account for a flash message.
+ usr, err := user.ReadByID(ctx, claims, h.MasterDB, tkn.UserID)
+ if err != nil {
+ return false, err
+ }
+ webcontext.SessionFlashSuccess(ctx,
+ "User Switched",
+ fmt.Sprintf("You are now virtually logged into user %s.",
+ usr.Response(ctx).Name))
+
+ // Write the session to the client.
+ err = webcontext.ContextSession(ctx).Save(r, w)
+ if err != nil {
+ return false, err
+ }
+
+ // Redirect the user to the dashboard with the new credentials.
+ http.Redirect(w, r, "/", http.StatusFound)
+
+ return true, nil
+ }
+
+ return false, nil
+ }
+
+ end, err := f()
+ if err != nil {
+ return web.RenderError(ctx, w, r, err, h.Renderer, TmplLayoutBase, TmplContentErrorGeneric, web.MIMETextHTMLCharsetUTF8)
+ } else if end {
+ return nil
+ }
+
+ usrAccFilter := "account_id = ?"
+ usrAccs, err := user_account.Find(ctx, claims, h.MasterDB, user_account.UserAccountFindRequest{
+ Where: &usrAccFilter,
+ Args: []interface{}{claims.Audience},
+ })
+ if err != nil {
+ return err
+ }
+
+ var userIDs []interface{}
+ var userPhs []string
+ for _, usrAcc := range usrAccs {
+ if usrAcc.UserID == claims.Subject {
+ // Skip the current authenticated user.
+ continue
+ }
+ userIDs = append(userIDs, usrAcc.UserID)
+ userPhs = append(userPhs, "?")
+ }
+
+ if len(userIDs) == 0 {
+ userIDs = append(userIDs, "")
+ userPhs = append(userPhs, "?")
+ }
+
+ usrFilter := fmt.Sprintf("id IN (%s)", strings.Join(userPhs, ", "))
+ users, err := user.Find(ctx, claims, h.MasterDB, user.UserFindRequest{
+ Where: &usrFilter,
+ Args: userIDs,
+ })
+ if err != nil {
+ return err
+ }
+ data["users"] = users.Response(ctx)
+
+ if req.AccountID == "" {
+ req.AccountID = claims.Audience
+ }
+
+ data["form"] = req
+
+ if verr, ok := weberror.NewValidationError(ctx, webcontext.Validator().Struct(user_auth.VirtualLoginRequest{})); ok {
+ data["validationDefaults"] = verr.(*weberror.Error)
+ }
+
+ return h.Renderer.Render(ctx, w, r, TmplLayoutBase, "user-virtual-login.gohtml", web.MIMETextHTMLCharsetUTF8, http.StatusOK, data)
+}
+
+// VirtualLogout handles switching the scope back to the user who initiated the virtual login.
+func (h *User) VirtualLogout(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error {
+
+ ctxValues, err := webcontext.ContextValues(ctx)
+ if err != nil {
+ return err
+ }
+
+ claims, err := auth.ClaimsFromContext(ctx)
+ if err != nil {
+ return err
+ }
+
+ sess := webcontext.ContextSession(ctx)
+
+ var expires time.Duration
+ if sess != nil && sess.Options != nil {
+ expires = time.Second * time.Duration(sess.Options.MaxAge)
+ } else {
+ expires = time.Hour
+ }
+
+ tkn, err := user_auth.VirtualLogout(ctx, h.MasterDB, h.Authenticator, claims, expires, ctxValues.Now)
+ if err != nil {
+ return err
+ }
+
+ // Update the access token in the session.
+ sess = webcontext.SessionUpdateAccessToken(sess, tkn.AccessToken)
+
+ // Display a success message to verify the user has switched contexts.
+ if claims.Subject != tkn.UserID && claims.Audience != tkn.AccountID {
+ usr, err := user.ReadByID(ctx, claims, h.MasterDB, tkn.UserID)
+ if err != nil {
+ return err
+ }
+ acc, err := account.ReadByID(ctx, claims, h.MasterDB, tkn.AccountID)
+ if err != nil {
+ return err
+ }
+ webcontext.SessionFlashSuccess(ctx,
+ "Context Switched",
+ fmt.Sprintf("You are now virtually logged back into account %s user %s.",
+ acc.Response(ctx).Name, usr.Response(ctx).Name))
+ } else if claims.Audience != tkn.AccountID {
+ acc, err := account.ReadByID(ctx, claims, h.MasterDB, tkn.AccountID)
+ if err != nil {
+ return err
+ }
+ webcontext.SessionFlashSuccess(ctx,
+ "Context Switched",
+ fmt.Sprintf("You are now virtually logged back into account %s.",
+ acc.Response(ctx).Name))
+ } else {
+ usr, err := user.ReadByID(ctx, claims, h.MasterDB, tkn.UserID)
+ if err != nil {
+ return err
+ }
+ webcontext.SessionFlashSuccess(ctx,
+ "Context Switched",
+ fmt.Sprintf("You are now virtually logged back into user %s.",
+ usr.Response(ctx).Name))
+ }
+
+ // Write the session to the client.
+ err = webcontext.ContextSession(ctx).Save(r, w)
+ if err != nil {
+ return err
+ }
+
+ // Redirect the user to the dashboard with the new credentials.
+ http.Redirect(w, r, "/", http.StatusFound)
+
+ return nil
+}
+
+// VirtualLogin handles switching the scope of the context to another user.
+func (h *User) SwitchAccount(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error {
+
+ ctxValues, err := webcontext.ContextValues(ctx)
+ if err != nil {
+ return err
+ }
+
+ claims, err := auth.ClaimsFromContext(ctx)
+ if err != nil {
+ return err
+ }
+
+ //
+ req := new(user_auth.SwitchAccountRequest)
+ data := make(map[string]interface{})
+ f := func() (bool, error) {
+
+ if r.Method == http.MethodPost {
+ err := r.ParseForm()
+ if err != nil {
+ return false, err
+ }
+
+ decoder := schema.NewDecoder()
+ decoder.IgnoreUnknownKeys(true)
+
+ if err := decoder.Decode(req, r.PostForm); err != nil {
+ return false, err
+ }
+ } else {
+ if pv, ok := params["account_id"]; ok && pv != "" {
+ req.AccountID = pv
+ } else if qv := r.URL.Query().Get("account_id"); qv != "" {
+ req.AccountID = qv
+ }
+ }
+
+ if req.AccountID != "" {
+ sess := webcontext.ContextSession(ctx)
+ var expires time.Duration
+ if sess != nil && sess.Options != nil {
+ expires = time.Second * time.Duration(sess.Options.MaxAge)
+ } else {
+ expires = time.Hour
+ }
+
+ // Perform the account switch.
+ tkn, err := user_auth.SwitchAccount(ctx, h.MasterDB, h.Authenticator, claims, *req, expires, ctxValues.Now)
+ if err != nil {
+ if verr, ok := weberror.NewValidationError(ctx, err); ok {
+ data["validationErrors"] = verr.(*weberror.Error)
+ return false, nil
+ } else {
+ return false, err
+ }
+ }
+
+ // Update the access token in the session.
+ sess = webcontext.SessionUpdateAccessToken(sess, tkn.AccessToken)
+
+ // Read the account for a flash message.
+ acc, err := account.ReadByID(ctx, claims, h.MasterDB, tkn.AccountID)
+ if err != nil {
+ return false, err
+ }
+ webcontext.SessionFlashSuccess(ctx,
+ "Account Switched",
+ fmt.Sprintf("You are now logged into account %s.",
+ acc.Response(ctx).Name))
+
+ // Write the session to the client.
+ err = webcontext.ContextSession(ctx).Save(r, w)
+ if err != nil {
+ return false, err
+ }
+
+ // Redirect the user to the dashboard with the new credentials.
+ http.Redirect(w, r, "/", http.StatusFound)
+
+ return true, nil
+ }
+
+ return false, nil
+ }
+
+ end, err := f()
+ if err != nil {
+ return web.RenderError(ctx, w, r, err, h.Renderer, TmplLayoutBase, TmplContentErrorGeneric, web.MIMETextHTMLCharsetUTF8)
+ } else if end {
+ return nil
+ }
+
+ accounts, err := account.Find(ctx, claims, h.MasterDB, account.AccountFindRequest{
+ Order: []string{"name"},
+ })
+ if err != nil {
+ return err
+ }
+ data["accounts"] = accounts.Response(ctx)
+
+ if req.AccountID == "" {
+ req.AccountID = claims.Audience
+ }
+
+ data["form"] = req
+
+ if verr, ok := weberror.NewValidationError(ctx, webcontext.Validator().Struct(user_auth.SwitchAccountRequest{})); ok {
+ data["validationDefaults"] = verr.(*weberror.Error)
+ }
+
+ return h.Renderer.Render(ctx, w, r, TmplLayoutBase, "user-switch-account.gohtml", web.MIMETextHTMLCharsetUTF8, http.StatusOK, data)
+}
+
+// updateContextClaims updates the claims in the context.
+func updateContextClaims(ctx context.Context, authenticator *auth.Authenticator, claims auth.Claims) (context.Context, error) {
+ tkn, err := authenticator.GenerateToken(claims)
+ if err != nil {
+ return ctx, err
+ }
+
+ sess := webcontext.ContextSession(ctx)
+ sess = webcontext.SessionUpdateAccessToken(sess, tkn)
+
+ ctx = context.WithValue(ctx, auth.Key, claims)
+
+ return ctx, nil
+}
diff --git a/cmd/web-app/main.go b/cmd/web-app/main.go
index d1a9d35..c456d82 100644
--- a/cmd/web-app/main.go
+++ b/cmd/web-app/main.go
@@ -698,7 +698,7 @@ func main() {
return nil
}
- usr, err := user.Read(ctx, auth.Claims{}, masterDb, claims.Subject, false)
+ usr, err := user.ReadByID(ctx, auth.Claims{}, masterDb, claims.Subject)
if err != nil {
return nil
}
@@ -733,7 +733,7 @@ func main() {
return nil
}
- acc, err := account.Read(ctx, auth.Claims{}, masterDb, claims.Audience, false)
+ acc, err := account.ReadByID(ctx, auth.Claims{}, masterDb, claims.Audience)
if err != nil {
return nil
}
@@ -746,6 +746,26 @@ func main() {
return a
},
+ "ContextCanSwitchAccount": func(ctx context.Context) bool {
+ claims, err := auth.ClaimsFromContext(ctx)
+ if err != nil || len(claims.AccountIDs) < 2 {
+ return false
+ }
+ return true
+ },
+ "ContextIsVirtualSession": func(ctx context.Context) bool {
+ claims, err := auth.ClaimsFromContext(ctx)
+ if err != nil {
+ return false
+ }
+ if claims.RootUserID != "" && claims.RootUserID != claims.Subject {
+ return true
+ }
+ if claims.RootAccountID != "" && claims.RootAccountID != claims.Audience {
+ return true
+ }
+ return false
+ },
}
imgUrlFormatter := staticUrlFormatter
@@ -827,11 +847,8 @@ func main() {
switch statusCode {
case http.StatusUnauthorized:
- // Handle expired sessions that are returned from the auth middleware.
- if strings.Contains(errors.Cause(er).Error(), "token is expired") {
- http.Redirect(w, r, "/user/login", http.StatusFound)
- return nil
- }
+ http.Redirect(w, r, "/user/login?redirect="+url.QueryEscape(r.RequestURI), http.StatusFound)
+ return nil
}
return web.RenderError(ctx, w, r, er, renderer, handlers.TmplLayoutBase, handlers.TmplContentErrorGeneric, web.MIMETextHTMLCharsetUTF8)
diff --git a/cmd/web-app/static/assets/images/user-default.jpg b/cmd/web-app/static/assets/images/user-default.jpg
new file mode 100644
index 0000000..d70c406
Binary files /dev/null and b/cmd/web-app/static/assets/images/user-default.jpg differ
diff --git a/cmd/web-app/templates/content/user-switch-account.gohtml b/cmd/web-app/templates/content/user-switch-account.gohtml
new file mode 100644
index 0000000..460ab05
--- /dev/null
+++ b/cmd/web-app/templates/content/user-switch-account.gohtml
@@ -0,0 +1,60 @@
+{{define "title"}}Switch Account{{end}}
+{{define "style"}}
+
+{{end}}
+{{ define "partials/page-wrapper" }}
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ {{ template "app-flashes" . }}
+
+
+
Switch Account
+
+
+ {{ template "validation-error" . }}
+
+
+
+
+
+
+
+
+
+
+
+
+
+{{end}}
+{{define "js"}}
+
+{{end}}
diff --git a/cmd/web-app/templates/content/user-virtual-login.gohtml b/cmd/web-app/templates/content/user-virtual-login.gohtml
new file mode 100644
index 0000000..db24a4a
--- /dev/null
+++ b/cmd/web-app/templates/content/user-virtual-login.gohtml
@@ -0,0 +1,60 @@
+{{define "title"}}Switch User{{end}}
+{{define "style"}}
+
+{{end}}
+{{ define "partials/page-wrapper" }}
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ {{ template "app-flashes" . }}
+
+
+
Switch User
+
+
+ {{ template "validation-error" . }}
+
+
+
+
+
+
+
+
+
+
+
+
+
+{{end}}
+{{define "js"}}
+
+{{end}}
diff --git a/cmd/web-app/templates/partials/topbar.tmpl b/cmd/web-app/templates/partials/topbar.tmpl
index cbdbac6..e27cc26 100644
--- a/cmd/web-app/templates/partials/topbar.tmpl
+++ b/cmd/web-app/templates/partials/topbar.tmpl
@@ -177,7 +177,6 @@
Account Settings
-
Manage Users
@@ -197,8 +196,22 @@
Support
-
+
+ {{ if ContextCanSwitchAccount $._Ctx }}
+
+
+ Switch Account
+
+ {{ end }}
+
+ {{ if ContextIsVirtualSession $._Ctx }}
+
+
+ Switch Back
+
+ {{ end }}
+
Logout
diff --git a/internal/account/account.go b/internal/account/account.go
index 549ad00..9e21ac0 100644
--- a/internal/account/account.go
+++ b/internal/account/account.go
@@ -150,7 +150,7 @@ func selectQuery() *sqlbuilder.SelectBuilder {
// Find gets all the accounts from the database based on the request params.
// TODO: Need to figure out why can't parse the args when appending the where
// to the query.
-func Find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req AccountFindRequest) ([]*Account, error) {
+func Find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req AccountFindRequest) (Accounts, error) {
query := selectQuery()
if req.Where != nil {
@@ -170,7 +170,7 @@ func Find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req AccountF
}
// find internal method for getting all the accounts from the database using a select query.
-func find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, query *sqlbuilder.SelectBuilder, args []interface{}, includedArchived bool) ([]*Account, error) {
+func find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, query *sqlbuilder.SelectBuilder, args []interface{}, includedArchived bool) (Accounts, error) {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.account.Find")
defer span.Finish()
diff --git a/internal/account/models.go b/internal/account/models.go
index f0d1b1b..bbb6724 100644
--- a/internal/account/models.go
+++ b/internal/account/models.go
@@ -99,6 +99,21 @@ func (m *AccountResponse) MarshalBinary() ([]byte, error) {
return json.Marshal(m)
}
+// Accounts a list of Accounts.
+type Accounts []*Account
+
+// Response transforms a list of Accounts to a list of AccountResponses.
+func (m *Accounts) Response(ctx context.Context) []*AccountResponse {
+ var l []*AccountResponse
+ if m != nil && len(*m) > 0 {
+ for _, n := range *m {
+ l = append(l, n.Response(ctx))
+ }
+ }
+
+ return l
+}
+
// AccountCreateRequest contains information needed to create a new Account.
type AccountCreateRequest struct {
Name string `json:"name" validate:"required,unique" example:"Company Name"`
diff --git a/internal/platform/auth/claims.go b/internal/platform/auth/claims.go
index a77d0ad..6cd92e9 100644
--- a/internal/platform/auth/claims.go
+++ b/internal/platform/auth/claims.go
@@ -23,9 +23,11 @@ const Key ctxKey = 1
// Claims represents the authorization claims transmitted via a JWT.
type Claims struct {
- AccountIds []string `json:"accounts"`
- Roles []string `json:"roles"`
- Preferences ClaimPreferences `json:"prefs"`
+ RootUserID string `json:"root_user_id"`
+ RootAccountID string `json:"root_account_id"`
+ AccountIDs []string `json:"accounts"`
+ Roles []string `json:"roles"`
+ Preferences ClaimPreferences `json:"prefs"`
jwt.StandardClaims
}
@@ -41,14 +43,16 @@ type ClaimPreferences struct {
// NewClaims constructs a Claims value for the identified user. The Claims
// expire within a specified duration of the provided time. Additional fields
// of the Claims can be set after calling NewClaims is desired.
-func NewClaims(userId, accountId string, accountIds []string, roles []string, prefs ClaimPreferences, now time.Time, expires time.Duration) Claims {
+func NewClaims(userID, accountID string, accountIDs []string, roles []string, prefs ClaimPreferences, now time.Time, expires time.Duration) Claims {
c := Claims{
- AccountIds: accountIds,
- Roles: roles,
- Preferences: prefs,
+ AccountIDs: accountIDs,
+ RootAccountID: accountID,
+ RootUserID: userID,
+ Roles: roles,
+ Preferences: prefs,
StandardClaims: jwt.StandardClaims{
- Subject: userId,
- Audience: accountId,
+ Subject: userID,
+ Audience: accountID,
IssuedAt: now.Unix(),
ExpiresAt: now.Add(expires).Unix(),
},
diff --git a/internal/platform/web/webcontext/session.go b/internal/platform/web/webcontext/session.go
index ec0d194..7c4e498 100644
--- a/internal/platform/web/webcontext/session.go
+++ b/internal/platform/web/webcontext/session.go
@@ -14,39 +14,23 @@ const KeySession ctxKeySession = 1
// Session keys used to store values.
const (
SessionKeyAccessToken = iota
- //SessionKeyPreferenceDatetimeFormat
- //SessionKeyPreferenceDateFormat
- //SessionKeyPreferenceTimeFormat
- //SessionKeyTimezone
)
-func init() {
- //gob.Register(&Session{})
-}
-
-// Session represents a user with authentication.
-type Session struct {
- *sessions.Session
-}
-
// ContextWithSession appends a universal translator to a context.
func ContextWithSession(ctx context.Context, session *sessions.Session) context.Context {
return context.WithValue(ctx, KeySession, session)
}
// ContextSession returns the session from a context.
-func ContextSession(ctx context.Context) *Session {
- if s, ok := ctx.Value(KeySession).(*Session); ok {
+func ContextSession(ctx context.Context) *sessions.Session {
+ if s, ok := ctx.Value(KeySession).(*sessions.Session); ok {
return s
}
return nil
}
func ContextAccessToken(ctx context.Context) (string, bool) {
- return ContextSession(ctx).AccessToken()
-}
-
-func (sess *Session) AccessToken() (string, bool) {
+ sess := ContextSession(ctx)
if sess == nil {
return "", false
}
@@ -56,60 +40,19 @@ func (sess *Session) AccessToken() (string, bool) {
return "", false
}
-/*
-func(sess *Session) PreferenceDatetimeFormat() (string, bool) {
- if sess == nil {
- return "", false
- }
- if sv, ok := sess.Values[SessionKeyPreferenceDatetimeFormat].(string); ok {
- return sv, true
- }
- return "", false
-}
-
-func(sess *Session) PreferenceDateFormat() (string, bool) {
- if sess == nil {
- return "", false
- }
- if sv, ok := sess.Values[SessionKeyPreferenceDateFormat].(string); ok {
- return sv, true
- }
- return "", false
-}
-
-func(sess *Session) PreferenceTimeFormat() (string, bool) {
- if sess == nil {
- return "", false
- }
- if sv, ok := sess.Values[SessionKeyPreferenceTimeFormat].(string); ok {
- return sv, true
- }
- return "", false
-}
-
-func(sess *Session) Timezone() (*time.Location, bool) {
- if sess != nil {
- if sv, ok := sess.Values[SessionKeyTimezone].(*time.Location); ok {
- return sv, true
- }
- }
-
- return nil, false
-}
-*/
-
-func SessionInit(session *Session, accessToken string) *Session {
+func SessionInit(session *sessions.Session, accessToken string) *sessions.Session {
session.Values[SessionKeyAccessToken] = accessToken
- //session.Values[SessionKeyPreferenceDatetimeFormat] = datetimeFormat
- //session.Values[SessionKeyPreferenceDateFormat] = dateFormat
- //session.Values[SessionKeyPreferenceTimeFormat] = timeFormat
- //session.Values[SessionKeyTimezone] = timezone
return session
}
-func SessionDestroy(session *Session) *Session {
+func SessionUpdateAccessToken(session *sessions.Session, accessToken string) *sessions.Session {
+ session.Values[SessionKeyAccessToken] = accessToken
+ return session
+}
+
+func SessionDestroy(session *sessions.Session) *sessions.Session {
delete(session.Values, SessionKeyAccessToken)
diff --git a/internal/project/models.go b/internal/project/models.go
index 48d146e..14ef059 100644
--- a/internal/project/models.go
+++ b/internal/project/models.go
@@ -56,6 +56,21 @@ func (m *Project) Response(ctx context.Context) *ProjectResponse {
return r
}
+// Projects a list of Projects.
+type Projects []*Project
+
+// Response transforms a list of Projects to a list of ProjectResponses.
+func (m *Projects) Response(ctx context.Context) []*ProjectResponse {
+ var l []*ProjectResponse
+ if m != nil && len(*m) > 0 {
+ for _, n := range *m {
+ l = append(l, n.Response(ctx))
+ }
+ }
+
+ return l
+}
+
// ProjectCreateRequest contains information needed to create a new Project.
type ProjectCreateRequest struct {
AccountID string `json:"account_id" validate:"required,uuid" example:"c4653bf9-5978-48b7-89c5-95704aebb7e2"`
diff --git a/internal/project/project.go b/internal/project/project.go
index 28237c7..7e1c53d 100644
--- a/internal/project/project.go
+++ b/internal/project/project.go
@@ -124,13 +124,13 @@ func findRequestQuery(req ProjectFindRequest) (*sqlbuilder.SelectBuilder, []inte
}
// Find gets all the projects from the database based on the request params.
-func Find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req ProjectFindRequest) ([]*Project, error) {
+func Find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req ProjectFindRequest) (Projects, error) {
query, args := findRequestQuery(req)
return find(ctx, claims, dbConn, query, args, req.IncludeArchived)
}
// find internal method for getting all the projects from the database using a select query.
-func find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, query *sqlbuilder.SelectBuilder, args []interface{}, includedArchived bool) ([]*Project, error) {
+func find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, query *sqlbuilder.SelectBuilder, args []interface{}, includedArchived bool) (Projects, error) {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.project.Find")
defer span.Finish()
diff --git a/internal/user/models.go b/internal/user/models.go
index 4840e3a..7c00b2d 100644
--- a/internal/user/models.go
+++ b/internal/user/models.go
@@ -78,6 +78,21 @@ func (m *UserResponse) MarshalBinary() ([]byte, error) {
return json.Marshal(m)
}
+// Users a list of Users.
+type Users []*User
+
+// Response transforms a list of Users to a list of UserResponses.
+func (m *Users) Response(ctx context.Context) []*UserResponse {
+ var l []*UserResponse
+ if m != nil && len(*m) > 0 {
+ for _, n := range *m {
+ l = append(l, n.Response(ctx))
+ }
+ }
+
+ return l
+}
+
// UserCreateRequest contains information needed to create a new User.
type UserCreateRequest struct {
FirstName string `json:"first_name" validate:"required" example:"Gabi"`
diff --git a/internal/user/user.go b/internal/user/user.go
index ded3bcc..effddef 100644
--- a/internal/user/user.go
+++ b/internal/user/user.go
@@ -202,13 +202,13 @@ func findRequestQuery(req UserFindRequest) (*sqlbuilder.SelectBuilder, []interfa
}
// Find gets all the users from the database based on the request params.
-func Find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserFindRequest) ([]*User, error) {
+func Find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserFindRequest) (Users, error) {
query, args := findRequestQuery(req)
return find(ctx, claims, dbConn, query, args, req.IncludeArchived)
}
// find internal method for getting all the users from the database using a select query.
-func find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, query *sqlbuilder.SelectBuilder, args []interface{}, includedArchived bool) ([]*User, error) {
+func find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, query *sqlbuilder.SelectBuilder, args []interface{}, includedArchived bool) (Users, error) {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.Find")
defer span.Finish()
diff --git a/internal/user_account/models.go b/internal/user_account/models.go
index ba47e16..7c63f29 100644
--- a/internal/user_account/models.go
+++ b/internal/user_account/models.go
@@ -66,6 +66,34 @@ func (m *UserAccount) Response(ctx context.Context) *UserAccountResponse {
return r
}
+// HasRole checks if the entry has a role.
+func (m *UserAccount) HasRole(role UserAccountRole) bool {
+ if m == nil {
+ return false
+ }
+ for _, r := range m.Roles {
+ if r == role {
+ return true
+ }
+ }
+ return false
+}
+
+// UserAccounts a list of UserAccounts.
+type UserAccounts []*UserAccount
+
+// Response transforms a list of UserAccounts to a list of UserAccountResponses.
+func (m *UserAccounts) Response(ctx context.Context) []*UserAccountResponse {
+ var l []*UserAccountResponse
+ if m != nil && len(*m) > 0 {
+ for _, n := range *m {
+ l = append(l, n.Response(ctx))
+ }
+ }
+
+ return l
+}
+
// UserAccountCreateRequest defines the information is needed to associate a user to an
// account. Users are global to the application and each users access can be managed
// on an account level. If a current entry exists in the database but is archived,
diff --git a/internal/user_account/user_account.go b/internal/user_account/user_account.go
index beb404d..01e0b28 100644
--- a/internal/user_account/user_account.go
+++ b/internal/user_account/user_account.go
@@ -131,13 +131,13 @@ func findRequestQuery(req UserAccountFindRequest) (*sqlbuilder.SelectBuilder, []
}
// Find gets all the user accounts from the database based on the request params.
-func Find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserAccountFindRequest) ([]*UserAccount, error) {
+func Find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserAccountFindRequest) (UserAccounts, error) {
query, args := findRequestQuery(req)
return find(ctx, claims, dbConn, query, args, req.IncludeArchived)
}
// Find gets all the user accounts from the database based on the select query
-func find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, query *sqlbuilder.SelectBuilder, args []interface{}, includedArchived bool) ([]*UserAccount, error) {
+func find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, query *sqlbuilder.SelectBuilder, args []interface{}, includedArchived bool) (UserAccounts, error) {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user_account.Find")
defer span.Finish()
@@ -180,7 +180,7 @@ func find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, query *sqlbu
}
// Retrieve gets the specified user from the database.
-func FindByUserID(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, userID string, includedArchived bool) ([]*UserAccount, error) {
+func FindByUserID(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, userID string, includedArchived bool) (UserAccounts, error) {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user_account.FindByUserID")
defer span.Finish()
diff --git a/internal/user_auth/auth.go b/internal/user_auth/auth.go
index d593541..65f440d 100644
--- a/internal/user_auth/auth.go
+++ b/internal/user_auth/auth.go
@@ -3,6 +3,7 @@ package user_auth
import (
"context"
"database/sql"
+ "geeks-accelerator/oss/saas-starter-kit/internal/user_account"
"strings"
"time"
@@ -22,6 +23,9 @@ var (
// ErrAuthenticationFailure occurs when a user attempts to authenticate but
// anything goes wrong.
ErrAuthenticationFailure = errors.New("Authentication failed")
+
+ // ErrForbidden occurs when a user tries to do something that is forbidden to them according to our access control policies.
+ ErrForbidden = errors.New("Attempted action is not allowed")
)
const (
@@ -65,22 +69,11 @@ func Authenticate(ctx context.Context, dbConn *sqlx.DB, tknGen TokenGenerator, e
return generateToken(ctx, dbConn, tknGen, auth.Claims{}, u.ID, "", expires, now, scopes...)
}
-// Authenticate finds a user by their email and verifies their password. On success
-// it returns a Token that can be used to authenticate access to the application in
-// the future.
-func SwitchAccount(ctx context.Context, dbConn *sqlx.DB, tknGen TokenGenerator, claims auth.Claims, accountID string, expires time.Duration, now time.Time, scopes ...string) (Token, error) {
+// SwitchAccount allows users to switch between multiple accounts, this changes the claim audience.
+func SwitchAccount(ctx context.Context, dbConn *sqlx.DB, tknGen TokenGenerator, claims auth.Claims, req SwitchAccountRequest, expires time.Duration, now time.Time, scopes ...string) (Token, error) {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user_auth.SwitchAccount")
defer span.Finish()
- // Defines struct to apply validation for the supplied claims and account ID.
- req := struct {
- UserID string `json:"user_id" validate:"required,uuid"`
- AccountID string `json:"account_id" validate:"required,uuid"`
- }{
- UserID: claims.Subject,
- AccountID: accountID,
- }
-
// Validate the request.
v := webcontext.Validator()
err := v.Struct(req)
@@ -88,12 +81,74 @@ func SwitchAccount(ctx context.Context, dbConn *sqlx.DB, tknGen TokenGenerator,
return Token{}, err
}
+ claims.RootAccountID = req.AccountID
+
+ if claims.RootUserID == "" {
+ claims.RootUserID = claims.Subject
+ }
+
+ // Generate a token for the user ID in supplied in claims as the Subject. Pass
+ // in the supplied claims as well to enforce ACLs when finding the current
+ // list of accounts for the user.
+ return generateToken(ctx, dbConn, tknGen, claims, claims.Subject, req.AccountID, expires, now, scopes...)
+}
+
+// VirtualLogin allows users to mock being logged in as other users.
+func VirtualLogin(ctx context.Context, dbConn *sqlx.DB, tknGen TokenGenerator, claims auth.Claims, req VirtualLoginRequest, expires time.Duration, now time.Time, scopes ...string) (Token, error) {
+ span, ctx := tracer.StartSpanFromContext(ctx, "internal.user_auth.VirtualLogin")
+ defer span.Finish()
+
+ // Validate the request.
+ v := webcontext.Validator()
+ err := v.Struct(req)
+ if err != nil {
+ return Token{}, err
+ }
+
+ // Find all the accounts that the current user has access to.
+ usrAccs, err := user_account.FindByUserID(ctx, claims, dbConn, claims.Subject, false)
+ if err != nil {
+ return Token{}, err
+ }
+
+ // The user must have the role of admin to login any other user.
+ var hasAccountAdminRole bool
+ for _, usrAcc := range usrAccs {
+ if usrAcc.HasRole(user_account.UserAccountRole_Admin) {
+ if usrAcc.AccountID == req.AccountID {
+ hasAccountAdminRole = true
+ break
+ }
+ }
+ }
+ if !hasAccountAdminRole {
+ return Token{}, errors.WithMessagef(ErrForbidden, "User %s does not have correct access to account %s ", claims.Subject, req.AccountID)
+ }
+
+ if claims.RootAccountID == "" {
+ claims.RootAccountID = claims.Audience
+ }
+ if claims.RootUserID == "" {
+ claims.RootUserID = claims.Subject
+ }
+
// Generate a token for the user ID in supplied in claims as the Subject. Pass
// in the supplied claims as well to enforce ACLs when finding the current
// list of accounts for the user.
return generateToken(ctx, dbConn, tknGen, claims, req.UserID, req.AccountID, expires, now, scopes...)
}
+// VirtualLogout allows switch back to their root user/account.
+func VirtualLogout(ctx context.Context, dbConn *sqlx.DB, tknGen TokenGenerator, claims auth.Claims, expires time.Duration, now time.Time, scopes ...string) (Token, error) {
+ span, ctx := tracer.StartSpanFromContext(ctx, "internal.user_auth.VirtualLogout")
+ defer span.Finish()
+
+ // Generate a token for the user ID in supplied in claims as the Subject. Pass
+ // in the supplied claims as well to enforce ACLs when finding the current
+ // list of accounts for the user.
+ return generateToken(ctx, dbConn, tknGen, claims, claims.RootUserID, claims.RootAccountID, expires, now, scopes...)
+}
+
// generateToken generates claims for the supplied user ID and account ID and then
// returns the token for the generated claims used for authentication.
func generateToken(ctx context.Context, dbConn *sqlx.DB, tknGen TokenGenerator, claims auth.Claims, userID, accountID string, expires time.Duration, now time.Time, scopes ...string) (Token, error) {
@@ -250,7 +305,7 @@ func generateToken(ctx context.Context, dbConn *sqlx.DB, tknGen TokenGenerator,
if scopeValid {
roles = append(roles, s)
} else {
- err := errors.Errorf("invalid scope '%s'", s)
+ err := errors.Wrapf(ErrForbidden, "invalid scope '%s'", s)
return Token{}, err
}
}
@@ -259,7 +314,7 @@ func generateToken(ctx context.Context, dbConn *sqlx.DB, tknGen TokenGenerator,
}
if len(roles) == 0 {
- err := errors.New("no roles defined for user")
+ err := errors.Wrapf(ErrForbidden, "no roles defined for user")
return Token{}, err
}
}
@@ -314,14 +369,24 @@ func generateToken(ctx context.Context, dbConn *sqlx.DB, tknGen TokenGenerator,
claimPref = auth.NewClaimPreferences(tz, preferenceDatetimeFormat, preferenceDateFormat, preferenceTimeFormat)
}
+ // Ensure the current claims has the root values set.
+ if (claims.RootAccountID == "" && claims.Audience != "") || (claims.RootUserID == "" && claims.Subject != "") {
+ claims.RootAccountID = claims.Audience
+ claims.RootUserID = claims.Subject
+ }
+
// JWT claims requires both an audience and a subject. For this application:
// Subject: The ID of the user authenticated.
// Audience: The ID of the account the user is accessing. A list of account IDs
// will also be included to support the user switching between them.
- claims = auth.NewClaims(userID, accountID, accountIds, roles, claimPref, now, expires)
+ newClaims := auth.NewClaims(userID, accountID, accountIds, roles, claimPref, now, expires)
+
+ // Copy the original root account/user ID.
+ newClaims.RootAccountID = claims.RootAccountID
+ newClaims.RootUserID = claims.RootUserID
// Generate a token for the user with the defined claims.
- tknStr, err := tknGen.GenerateToken(claims)
+ tknStr, err := tknGen.GenerateToken(newClaims)
if err != nil {
return Token{}, errors.Wrap(err, "generating token")
}
@@ -329,9 +394,9 @@ func generateToken(ctx context.Context, dbConn *sqlx.DB, tknGen TokenGenerator,
tkn := Token{
AccessToken: tknStr,
TokenType: "Bearer",
- claims: claims,
- UserID: claims.Subject,
- AccountID: claims.Audience,
+ claims: newClaims,
+ UserID: newClaims.Subject,
+ AccountID: newClaims.Audience,
}
if expires.Seconds() > 0 {
diff --git a/internal/user_auth/auth_test.go b/internal/user_auth/auth_test.go
index 34fbbb5..13e901b 100644
--- a/internal/user_auth/auth_test.go
+++ b/internal/user_auth/auth_test.go
@@ -2,6 +2,7 @@ package user_auth
import (
"encoding/json"
+ "fmt"
"os"
"testing"
"time"
@@ -63,6 +64,9 @@ func TestAuthenticate(t *testing.T) {
}
t.Logf("\t%s\tCreate user account ok.", tests.Success)
+ // Add 30 minutes to now to simulate time passing.
+ now = now.Add(time.Minute * 30)
+
acc2, err := account.MockAccount(ctx, test.MasterDB, now)
if err != nil {
t.Log("\t\tGot :", err)
@@ -81,7 +85,7 @@ func TestAuthenticate(t *testing.T) {
}, now)
// Add 30 minutes to now to simulate time passing.
- now = now.Add(time.Minute * 30)
+ now = now.Add(time.Minute * 5)
// Try to authenticate valid user with invalid password.
_, err = Authenticate(ctx, test.MasterDB, tknGen, usrAcc.User.Email, "xy7", time.Hour, now)
@@ -106,17 +110,20 @@ func TestAuthenticate(t *testing.T) {
t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tParse claims from token failed.", tests.Failed)
}
+ expectClaims := tkn1.claims
+ expectClaims.RootUserID = ""
+ expectClaims.RootAccountID = ""
+ expectClaims.Subject = usrAcc.UserID
+ expectClaims.Audience = usrAcc.AccountID
- // Hack for Unhandled Exception in go-cmp@v0.3.0/cmp/options.go:229
- resClaims, _ := json.Marshal(claims1)
- expectClaims, _ := json.Marshal(tkn1.claims)
- if diff := cmp.Diff(string(resClaims), string(expectClaims)); diff != "" {
+ if diff := cmpClaims(claims1, expectClaims); diff != "" {
t.Fatalf("\t%s\tExpected parsed claims to match from token. Diff:\n%s", tests.Failed, diff)
}
t.Logf("\t%s\tAuthenticate parse claims from token ok.", tests.Success)
// Try switching to a second account using the first set of claims.
- tkn2, err := SwitchAccount(ctx, test.MasterDB, tknGen, claims1, acc2.ID, time.Hour, now)
+ tkn2, err := SwitchAccount(ctx, test.MasterDB, tknGen, claims1,
+ SwitchAccountRequest{AccountID: acc2.ID}, time.Hour, now)
if err != nil {
t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tSwitchAccount user failed.", tests.Failed)
@@ -129,11 +136,13 @@ func TestAuthenticate(t *testing.T) {
t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tParse claims from token failed.", tests.Failed)
}
+ expectClaims = tkn2.claims
+ expectClaims.RootUserID = usrAcc.UserID
+ expectClaims.RootAccountID = acc2.ID
+ expectClaims.Subject = usrAcc.UserID
+ expectClaims.Audience = acc2.ID
- // Hack for Unhandled Exception in go-cmp@v0.3.0/cmp/options.go:229
- resClaims, _ = json.Marshal(claims2)
- expectClaims, _ = json.Marshal(tkn2.claims)
- if diff := cmp.Diff(string(resClaims), string(expectClaims)); diff != "" {
+ if diff := cmpClaims(claims2, expectClaims); diff != "" {
t.Fatalf("\t%s\tExpected parsed claims to match from token. Diff:\n%s", tests.Failed, diff)
}
t.Logf("\t%s\tSwitchAccount parse claims from token ok.", tests.Success)
@@ -256,3 +265,724 @@ func TestUserResetPassword(t *testing.T) {
t.Logf("\t%s\tAuthenticate ok.", tests.Success)
}
}
+
+// TestSwitchAccount validates the behavior around allowing users to switch between their accounts.
+func TestSwitchAccount(t *testing.T) {
+ defer tests.Recover(t)
+
+ // Auth tokens are valid for an our and is verified against current time.
+ // Issue the token one hour ago.
+ now := time.Now().Add(time.Hour * -1)
+
+ ctx := tests.Context()
+
+ type authTest struct {
+ name string
+ root *user_account.MockUserAccountResponse
+ switch1Req SwitchAccountRequest
+ switch1Roles []user_account.UserAccountRole
+ switch1Scopes []string
+ switch1Err error
+ switch2Req SwitchAccountRequest
+ switch2Roles []user_account.UserAccountRole
+ switch2Scopes []string
+ switch2Err error
+ }
+
+ var authTests []authTest
+
+ // Test all the combinations there the user has access to all three accounts.
+ if true {
+ for _, roles := range [][]user_account.UserAccountRole{
+ []user_account.UserAccountRole{user_account.UserAccountRole_Admin, user_account.UserAccountRole_Admin, user_account.UserAccountRole_Admin},
+ []user_account.UserAccountRole{user_account.UserAccountRole_User, user_account.UserAccountRole_User, user_account.UserAccountRole_User},
+ []user_account.UserAccountRole{user_account.UserAccountRole_Admin, user_account.UserAccountRole_User, user_account.UserAccountRole_Admin},
+ []user_account.UserAccountRole{user_account.UserAccountRole_User, user_account.UserAccountRole_Admin, user_account.UserAccountRole_User},
+ } {
+ // Create a new user for testing.
+ usrAcc, err := user_account.MockUserAccount(ctx, test.MasterDB, now, roles[0])
+ if err != nil {
+ t.Log("\t\tGot :", err)
+ t.Fatalf("\t%s\tCreate user account failed.", tests.Failed)
+ }
+
+ // Create the second account.
+ now = now.Add(time.Minute)
+ acc2, err := account.MockAccount(ctx, test.MasterDB, now)
+ if err != nil {
+ t.Log("\t\tGot :", err)
+ t.Fatalf("\t%s\tCreate second account failed.", tests.Failed)
+ }
+
+ // Associate the second account with root user.
+ usrAcc2, err := user_account.Create(ctx, auth.Claims{}, test.MasterDB, user_account.UserAccountCreateRequest{
+ UserID: usrAcc.UserID,
+ AccountID: acc2.ID,
+ Roles: []user_account.UserAccountRole{user_account.UserAccountRole(roles[1])},
+ }, now)
+ if err != nil {
+ t.Log("\t\tGot :", err)
+ t.Fatalf("\t%s\tLinking second account to user failed.", tests.Failed)
+ }
+
+ // Create the third account.
+ now = now.Add(time.Minute)
+ acc3, err := account.MockAccount(ctx, test.MasterDB, now)
+ if err != nil {
+ t.Log("\t\tGot :", err)
+ t.Fatalf("\t%s\tCreate third account failed.", tests.Failed)
+ }
+
+ // Associate the third account with root user.
+ usrAcc3, err := user_account.Create(ctx, auth.Claims{}, test.MasterDB, user_account.UserAccountCreateRequest{
+ UserID: usrAcc.UserID,
+ AccountID: acc3.ID,
+ Roles: []user_account.UserAccountRole{user_account.UserAccountRole(roles[2])},
+ }, now)
+ if err != nil {
+ t.Log("\t\tGot :", err)
+ t.Fatalf("\t%s\tLinking third account to user failed.", tests.Failed)
+ }
+
+ authTests = append(authTests, authTest{
+ name: fmt.Sprintf("Root account role %s -> role %s account 2 -> role %s account 3.",
+ roles[0], roles[1], roles[2]),
+ root: usrAcc,
+ switch1Req: SwitchAccountRequest{AccountID: acc2.ID},
+ switch1Roles: usrAcc2.Roles,
+ switch1Err: nil,
+ switch2Req: SwitchAccountRequest{AccountID: acc3.ID},
+ switch2Err: nil,
+ switch2Roles: usrAcc3.Roles,
+ })
+ }
+ }
+
+ // Root account 1 -> invalid account 2
+ if true {
+ // Create a new user for testing.
+ usrAcc, err := user_account.MockUserAccount(ctx, test.MasterDB, now, user_account.UserAccountRole_Admin)
+ if err != nil {
+ t.Log("\t\tGot :", err)
+ t.Fatalf("\t%s\tCreate user account failed.", tests.Failed)
+ }
+
+ // Create the second account and don't associate it with the root user.
+ now = now.Add(time.Minute)
+ acc2, err := account.MockAccount(ctx, test.MasterDB, now)
+ if err != nil {
+ t.Log("\t\tGot :", err)
+ t.Fatalf("\t%s\tCreate second account failed.", tests.Failed)
+ }
+
+ authTests = append(authTests, authTest{
+ name: "Root account 1 -> invalid account 2.",
+ root: usrAcc,
+ switch1Req: SwitchAccountRequest{AccountID: acc2.ID},
+ switch1Err: ErrAuthenticationFailure,
+ })
+ }
+
+ // Root account 1 -> valid account 2 with scopes -> valid account 3 with invalid scope.
+ if true {
+ // Create a new user for testing.
+ usrAcc, err := user_account.MockUserAccount(ctx, test.MasterDB, now, user_account.UserAccountRole_Admin)
+ if err != nil {
+ t.Log("\t\tGot :", err)
+ t.Fatalf("\t%s\tCreate user account failed.", tests.Failed)
+ }
+
+ // Create the second account.
+ now = now.Add(time.Minute)
+ acc2, err := account.MockAccount(ctx, test.MasterDB, now)
+ if err != nil {
+ t.Log("\t\tGot :", err)
+ t.Fatalf("\t%s\tCreate second account failed.", tests.Failed)
+ }
+
+ // Associate the second account with root user.
+ usrAcc2, err := user_account.Create(ctx, auth.Claims{}, test.MasterDB, user_account.UserAccountCreateRequest{
+ UserID: usrAcc.UserID,
+ AccountID: acc2.ID,
+ Roles: []user_account.UserAccountRole{user_account.UserAccountRole_Admin},
+ }, now)
+ if err != nil {
+ t.Log("\t\tGot :", err)
+ t.Fatalf("\t%s\tLinking second account to user failed.", tests.Failed)
+ }
+
+ // Create the third account.
+ now = now.Add(time.Minute)
+ acc3, err := account.MockAccount(ctx, test.MasterDB, now)
+ if err != nil {
+ t.Log("\t\tGot :", err)
+ t.Fatalf("\t%s\tCreate third account failed.", tests.Failed)
+ }
+
+ // Associate the third account with root user.
+ usrAcc3, err := user_account.Create(ctx, auth.Claims{}, test.MasterDB, user_account.UserAccountCreateRequest{
+ UserID: usrAcc.UserID,
+ AccountID: acc3.ID,
+ Roles: []user_account.UserAccountRole{user_account.UserAccountRole_User},
+ }, now)
+ if err != nil {
+ t.Log("\t\tGot :", err)
+ t.Fatalf("\t%s\tLinking third account to user failed.", tests.Failed)
+ }
+
+ authTests = append(authTests, authTest{
+ name: "Root account 1 -> valid account 2 with scopes -> valid account 3 with invalid scope.",
+ root: usrAcc,
+ switch1Req: SwitchAccountRequest{AccountID: acc2.ID},
+ switch1Roles: usrAcc2.Roles,
+ switch1Scopes: []string{user_account.UserAccountRole_User.String()},
+ switch1Err: nil,
+ switch2Req: SwitchAccountRequest{AccountID: acc3.ID},
+ switch2Roles: usrAcc3.Roles,
+ switch2Scopes: []string{user_account.UserAccountRole_Admin.String()},
+ switch2Err: ErrForbidden,
+ })
+ }
+
+ // Add 30 minutes to now to simulate time passing.
+ now = now.Add(time.Minute * 5)
+
+ tknGen := &auth.MockTokenGenerator{}
+
+ t.Log("Given the need to switch accounts.")
+ {
+ for i, authTest := range authTests {
+ t.Logf("\tTest: %d\tWhen running test: %s", i, authTest.name)
+ {
+ // Verify that the user can be authenticated with the created user.
+ var claims1 auth.Claims
+ tkn1, err := Authenticate(ctx, test.MasterDB, tknGen, authTest.root.User.Email, authTest.root.User.Password, time.Hour, now)
+ if err != nil {
+ t.Log("\t\tGot :", err)
+ t.Fatalf("\t%s\tAuthenticate user failed.", tests.Failed)
+ } else {
+ // Ensure the token string was correctly generated.
+ claims1, err = tknGen.ParseClaims(tkn1.AccessToken)
+ if err != nil {
+ t.Log("\t\tGot :", err)
+ t.Fatalf("\t%s\tParse claims from token failed.", tests.Failed)
+ }
+ expectClaims := tkn1.claims
+ expectClaims.RootUserID = ""
+ expectClaims.RootAccountID = ""
+ expectClaims.Subject = authTest.root.UserID
+ expectClaims.Audience = authTest.root.AccountID
+ expectClaims.Roles = rolesStringSlice(authTest.root.Roles)
+
+ if diff := cmpClaims(claims1, expectClaims); diff != "" {
+ t.Fatalf("\t%s\tExpected parsed claims to match from token. Diff:\n%s", tests.Failed, diff)
+ }
+ }
+ t.Logf("\t%s\tAuthenticate root user with role %v ok.", tests.Success, authTest.root.Roles)
+
+ // Try to switch to account 2.
+ var claims2 auth.Claims
+ tkn2, err := SwitchAccount(ctx, test.MasterDB, tknGen, claims1, authTest.switch1Req, time.Hour, now, authTest.switch1Scopes...)
+ if err != authTest.switch1Err {
+ if errors.Cause(err) != authTest.switch1Err {
+ t.Log("\t\tExpected :", authTest.switch1Err)
+ t.Log("\t\tGot :", err)
+ t.Fatalf("\t%s\tSwitchAccount account 1 with role %v failed.", tests.Failed, authTest.switch1Roles)
+ }
+ } else {
+ // Ensure the token string was correctly generated.
+ claims2, err = tknGen.ParseClaims(tkn2.AccessToken)
+ if err != nil {
+ t.Log("\t\tGot :", err)
+ t.Fatalf("\t%s\tParse claims from token failed.", tests.Failed)
+ }
+ expectClaims := tkn2.claims
+ expectClaims.RootUserID = authTest.root.UserID
+ expectClaims.RootAccountID = authTest.switch1Req.AccountID
+ expectClaims.Subject = authTest.root.UserID
+ expectClaims.Audience = authTest.switch1Req.AccountID
+
+ if len(authTest.switch1Scopes) > 0 {
+ expectClaims.Roles = authTest.switch1Scopes
+ } else {
+ expectClaims.Roles = rolesStringSlice(authTest.switch1Roles)
+ }
+
+ if diff := cmpClaims(claims2, expectClaims); diff != "" {
+ t.Fatalf("\t%s\tExpected parsed claims to match from token. Diff:\n%s", tests.Failed, diff)
+ }
+ }
+ t.Logf("\t%s\tSwitchAccount account 1 with role %v ok.", tests.Success, authTest.switch1Roles)
+
+ // If the user can't login, don't need to test any further.
+ if authTest.switch1Err != nil || authTest.switch2Req.AccountID == "" {
+ continue
+ }
+
+ // Try to switch to account 3.
+ tkn3, err := SwitchAccount(ctx, test.MasterDB, tknGen, claims2, authTest.switch2Req, time.Hour, now, authTest.switch2Scopes...)
+ if err != authTest.switch2Err {
+ if errors.Cause(err) != authTest.switch2Err {
+ t.Log("\t\tExpected :", authTest.switch2Err)
+ t.Log("\t\tGot :", err)
+ t.Fatalf("\t%s\tSwitchAccount account 2 with role %v failed.", tests.Failed, authTest.switch2Roles)
+ }
+ } else {
+ // Ensure the token string was correctly generated.
+ claims3, err := tknGen.ParseClaims(tkn3.AccessToken)
+ if err != nil {
+ t.Log("\t\tGot :", err)
+ t.Fatalf("\t%s\tParse claims from token failed.", tests.Failed)
+ }
+ expectClaims := tkn3.claims
+ expectClaims.RootUserID = authTest.root.UserID
+ expectClaims.RootAccountID = authTest.switch2Req.AccountID
+ expectClaims.Subject = authTest.root.UserID
+ expectClaims.Audience = authTest.switch2Req.AccountID
+
+ if len(authTest.switch2Scopes) > 0 {
+ expectClaims.Roles = authTest.switch2Scopes
+ } else {
+ expectClaims.Roles = rolesStringSlice(authTest.switch2Roles)
+ }
+
+ if diff := cmpClaims(claims3, expectClaims); diff != "" {
+ t.Fatalf("\t%s\tExpected parsed claims to match from token. Diff:\n%s", tests.Failed, diff)
+ }
+ }
+ t.Logf("\t%s\tSwitchAccount account 2 with role %v ok.", tests.Success, authTest.switch2Roles)
+ }
+ }
+ }
+}
+
+// TestVirtualLogin validates the behavior around allowing users to virtual login users.
+func TestVirtualLogin(t *testing.T) {
+ defer tests.Recover(t)
+
+ // Auth tokens are valid for an our and is verified against current time.
+ // Issue the token one hour ago.
+ now := time.Now().Add(time.Hour * -1)
+
+ ctx := tests.Context()
+
+ type authTest struct {
+ name string
+ root *user_account.MockUserAccountResponse
+ login1Req VirtualLoginRequest
+ login1Err error
+ login1Role user_account.UserAccountRole
+ login2Req VirtualLoginRequest
+ login2Err error
+ login2Role user_account.UserAccountRole
+ login2Logout bool
+ }
+
+ var authTests []authTest
+
+ // Root admin -> role admin -> role admin
+ if true {
+ // Create a new user for testing.
+ usrAcc, err := user_account.MockUserAccount(ctx, test.MasterDB, now, user_account.UserAccountRole_Admin)
+ if err != nil {
+ t.Log("\t\tGot :", err)
+ t.Fatalf("\t%s\tCreate user account failed.", tests.Failed)
+ }
+
+ usr2, err := user.MockUser(ctx, test.MasterDB, now)
+ if err != nil {
+ t.Log("\t\tGot :", err)
+ t.Fatalf("\t%s\tCreate second account failed.", tests.Failed)
+ }
+
+ // Associate second user with basic role associated with the same account.
+ usrAcc2, err := user_account.Create(ctx, auth.Claims{}, test.MasterDB, user_account.UserAccountCreateRequest{
+ UserID: usr2.ID,
+ AccountID: usrAcc.AccountID,
+ Roles: []user_account.UserAccountRole{user_account.UserAccountRole(user_account.UserAccountRole_Admin)},
+ }, now)
+ if err != nil {
+ t.Log("\t\tGot :", err)
+ t.Fatalf("\t%s\tLinking second user to account failed.", tests.Failed)
+ }
+
+ usr3, err := user.MockUser(ctx, test.MasterDB, now)
+ if err != nil {
+ t.Log("\t\tGot :", err)
+ t.Fatalf("\t%s\tCreate second account failed.", tests.Failed)
+ }
+
+ // Associate second user with basic role associated with the same account.
+ usrAcc3, err := user_account.Create(ctx, auth.Claims{}, test.MasterDB, user_account.UserAccountCreateRequest{
+ UserID: usr3.ID,
+ AccountID: usrAcc.AccountID,
+ Roles: []user_account.UserAccountRole{user_account.UserAccountRole(user_account.UserAccountRole_Admin)},
+ }, now)
+ if err != nil {
+ t.Log("\t\tGot :", err)
+ t.Fatalf("\t%s\tLinking third user to account failed.", tests.Failed)
+ }
+
+ authTests = append(authTests, authTest{
+ name: "Root admin -> role admin -> role admin",
+ root: usrAcc,
+ login1Req: VirtualLoginRequest{
+ UserID: usr2.ID,
+ AccountID: usrAcc.AccountID,
+ },
+ login1Role: usrAcc2.Roles[0],
+ login1Err: nil,
+ login2Req: VirtualLoginRequest{
+ UserID: usr3.ID,
+ AccountID: usrAcc.AccountID,
+ },
+ login2Err: nil,
+ login2Role: usrAcc3.Roles[0],
+ login2Logout: true,
+ })
+ }
+
+ // Root admin -> role admin -> role user
+ if true {
+ // Create a new user for testing.
+ usrAcc, err := user_account.MockUserAccount(ctx, test.MasterDB, now, user_account.UserAccountRole_Admin)
+ if err != nil {
+ t.Log("\t\tGot :", err)
+ t.Fatalf("\t%s\tCreate user account failed.", tests.Failed)
+ }
+
+ usr2, err := user.MockUser(ctx, test.MasterDB, now)
+ if err != nil {
+ t.Log("\t\tGot :", err)
+ t.Fatalf("\t%s\tCreate second account failed.", tests.Failed)
+ }
+
+ // Associate second user with basic role associated with the same account.
+ usrAcc2, err := user_account.Create(ctx, auth.Claims{}, test.MasterDB, user_account.UserAccountCreateRequest{
+ UserID: usr2.ID,
+ AccountID: usrAcc.AccountID,
+ Roles: []user_account.UserAccountRole{user_account.UserAccountRole(user_account.UserAccountRole_Admin)},
+ }, now)
+ if err != nil {
+ t.Log("\t\tGot :", err)
+ t.Fatalf("\t%s\tLinking second user to account failed.", tests.Failed)
+ }
+
+ usr3, err := user.MockUser(ctx, test.MasterDB, now)
+ if err != nil {
+ t.Log("\t\tGot :", err)
+ t.Fatalf("\t%s\tCreate second account failed.", tests.Failed)
+ }
+
+ // Associate second user with basic role associated with the same account.
+ usrAcc3, err := user_account.Create(ctx, auth.Claims{}, test.MasterDB, user_account.UserAccountCreateRequest{
+ UserID: usr3.ID,
+ AccountID: usrAcc.AccountID,
+ Roles: []user_account.UserAccountRole{user_account.UserAccountRole(user_account.UserAccountRole_User)},
+ }, now)
+ if err != nil {
+ t.Log("\t\tGot :", err)
+ t.Fatalf("\t%s\tLinking third user to account failed.", tests.Failed)
+ }
+
+ authTests = append(authTests, authTest{
+ name: "Root admin -> role admin -> role user",
+ root: usrAcc,
+ login1Req: VirtualLoginRequest{
+ UserID: usr2.ID,
+ AccountID: usrAcc.AccountID,
+ },
+ login1Err: nil,
+ login1Role: usrAcc2.Roles[0],
+ login2Req: VirtualLoginRequest{
+ UserID: usr3.ID,
+ AccountID: usrAcc.AccountID,
+ },
+ login2Err: nil,
+ login2Role: usrAcc3.Roles[0],
+ login2Logout: true,
+ })
+ }
+
+ // Root admin -> role user -> role admin
+ if true {
+ // Create a new user for testing.
+ usrAcc, err := user_account.MockUserAccount(ctx, test.MasterDB, now, user_account.UserAccountRole_Admin)
+ if err != nil {
+ t.Log("\t\tGot :", err)
+ t.Fatalf("\t%s\tCreate user account failed.", tests.Failed)
+ }
+
+ usr2, err := user.MockUser(ctx, test.MasterDB, now)
+ if err != nil {
+ t.Log("\t\tGot :", err)
+ t.Fatalf("\t%s\tCreate second account failed.", tests.Failed)
+ }
+
+ // Associate second user with basic role associated with the same account.
+ usrAcc2, err := user_account.Create(ctx, auth.Claims{}, test.MasterDB, user_account.UserAccountCreateRequest{
+ UserID: usr2.ID,
+ AccountID: usrAcc.AccountID,
+ Roles: []user_account.UserAccountRole{user_account.UserAccountRole(user_account.UserAccountRole_User)},
+ }, now)
+ if err != nil {
+ t.Log("\t\tGot :", err)
+ t.Fatalf("\t%s\tLinking second user to account failed.", tests.Failed)
+ }
+
+ usr3, err := user.MockUser(ctx, test.MasterDB, now)
+ if err != nil {
+ t.Log("\t\tGot :", err)
+ t.Fatalf("\t%s\tCreate second account failed.", tests.Failed)
+ }
+
+ // Associate second user with basic role associated with the same account.
+ usrAcc3, err := user_account.Create(ctx, auth.Claims{}, test.MasterDB, user_account.UserAccountCreateRequest{
+ UserID: usr3.ID,
+ AccountID: usrAcc.AccountID,
+ Roles: []user_account.UserAccountRole{user_account.UserAccountRole(user_account.UserAccountRole_Admin)},
+ }, now)
+ if err != nil {
+ t.Log("\t\tGot :", err)
+ t.Fatalf("\t%s\tLinking third user to account failed.", tests.Failed)
+ }
+
+ authTests = append(authTests, authTest{
+ name: "Root admin -> role user -> role admin",
+ root: usrAcc,
+ login1Req: VirtualLoginRequest{
+ UserID: usr2.ID,
+ AccountID: usrAcc.AccountID,
+ },
+ login1Err: nil,
+ login1Role: usrAcc2.Roles[0],
+ login2Req: VirtualLoginRequest{
+ UserID: usr3.ID,
+ AccountID: usrAcc.AccountID,
+ },
+ login2Err: ErrForbidden,
+ login2Role: usrAcc3.Roles[0],
+ login2Logout: true,
+ })
+ }
+
+ // Root user -> role admin
+ if true {
+ // Create a new user for testing.
+ usrAcc, err := user_account.MockUserAccount(ctx, test.MasterDB, now, user_account.UserAccountRole_User)
+ if err != nil {
+ t.Log("\t\tGot :", err)
+ t.Fatalf("\t%s\tCreate user account failed.", tests.Failed)
+ }
+
+ usr2, err := user.MockUser(ctx, test.MasterDB, now)
+ if err != nil {
+ t.Log("\t\tGot :", err)
+ t.Fatalf("\t%s\tCreate second account failed.", tests.Failed)
+ }
+
+ // Associate second user with basic role associated with the same account.
+ usrAcc2, err := user_account.Create(ctx, auth.Claims{}, test.MasterDB, user_account.UserAccountCreateRequest{
+ UserID: usr2.ID,
+ AccountID: usrAcc.AccountID,
+ Roles: []user_account.UserAccountRole{user_account.UserAccountRole(user_account.UserAccountRole_Admin)},
+ }, now)
+ if err != nil {
+ t.Log("\t\tGot :", err)
+ t.Fatalf("\t%s\tLinking second user to account failed.", tests.Failed)
+ }
+
+ authTests = append(authTests, authTest{
+ name: "Root user -> role admin",
+ root: usrAcc,
+ login1Req: VirtualLoginRequest{
+ UserID: usr2.ID,
+ AccountID: usrAcc.AccountID,
+ },
+ login1Err: ErrForbidden,
+ login1Role: usrAcc2.Roles[0],
+ login2Logout: true,
+ })
+ }
+
+ // Root user -> role user
+ if true {
+ // Create a new user for testing.
+ usrAcc, err := user_account.MockUserAccount(ctx, test.MasterDB, now, user_account.UserAccountRole_User)
+ if err != nil {
+ t.Log("\t\tGot :", err)
+ t.Fatalf("\t%s\tCreate user account failed.", tests.Failed)
+ }
+
+ usr2, err := user.MockUser(ctx, test.MasterDB, now)
+ if err != nil {
+ t.Log("\t\tGot :", err)
+ t.Fatalf("\t%s\tCreate second account failed.", tests.Failed)
+ }
+
+ // Associate second user with basic role associated with the same account.
+ usrAcc2, err := user_account.Create(ctx, auth.Claims{}, test.MasterDB, user_account.UserAccountCreateRequest{
+ UserID: usr2.ID,
+ AccountID: usrAcc.AccountID,
+ Roles: []user_account.UserAccountRole{user_account.UserAccountRole(user_account.UserAccountRole_User)},
+ }, now)
+ if err != nil {
+ t.Log("\t\tGot :", err)
+ t.Fatalf("\t%s\tLinking second user to account failed.", tests.Failed)
+ }
+
+ authTests = append(authTests, authTest{
+ name: "Root user -> role admin",
+ root: usrAcc,
+ login1Req: VirtualLoginRequest{
+ UserID: usr2.ID,
+ AccountID: usrAcc.AccountID,
+ },
+ login1Err: ErrForbidden,
+ login1Role: usrAcc2.Roles[0],
+ login2Logout: true,
+ })
+ }
+
+ // Add 30 minutes to now to simulate time passing.
+ now = now.Add(time.Minute * 5)
+
+ tknGen := &auth.MockTokenGenerator{}
+
+ t.Log("Given the need to virtual login.")
+ {
+ for i, authTest := range authTests {
+ t.Logf("\tTest: %d\tWhen running test: %s", i, authTest.name)
+ {
+ // Verify that the user can be authenticated with the created user.
+ var claims1 auth.Claims
+ tkn1, err := Authenticate(ctx, test.MasterDB, tknGen, authTest.root.User.Email, authTest.root.User.Password, time.Hour, now)
+ if err != nil {
+ t.Log("\t\tGot :", err)
+ t.Fatalf("\t%s\tAuthenticate user failed.", tests.Failed)
+ } else {
+ // Ensure the token string was correctly generated.
+ claims1, err = tknGen.ParseClaims(tkn1.AccessToken)
+ if err != nil {
+ t.Log("\t\tGot :", err)
+ t.Fatalf("\t%s\tParse claims from token failed.", tests.Failed)
+ }
+ expectClaims := tkn1.claims
+ expectClaims.RootUserID = ""
+ expectClaims.RootAccountID = ""
+ expectClaims.Subject = authTest.root.UserID
+ expectClaims.Audience = authTest.root.AccountID
+
+ // Hack for Unhandled Exception in go-cmp@v0.3.0/cmp/options.go:229
+ if diff := cmpClaims(claims1, expectClaims); diff != "" {
+ t.Fatalf("\t%s\tExpected parsed claims to match from token. Diff:\n%s", tests.Failed, diff)
+ }
+ }
+ t.Logf("\t%s\tAuthenticate root user with role %s ok.", tests.Success, authTest.root.Roles[0])
+
+ // Try virtual login to user 2.
+ var claims2 auth.Claims
+ tkn2, err := VirtualLogin(ctx, test.MasterDB, tknGen, claims1, authTest.login1Req, time.Hour, now)
+ if err != authTest.login1Err {
+ if errors.Cause(err) != authTest.login1Err {
+ t.Log("\t\tExpected :", authTest.login1Err)
+ t.Log("\t\tGot :", err)
+ t.Fatalf("\t%s\tVirtualLogin user 1 with role %s failed.", tests.Failed, authTest.login1Role)
+ }
+ } else {
+ // Ensure the token string was correctly generated.
+ claims2, err = tknGen.ParseClaims(tkn2.AccessToken)
+ if err != nil {
+ t.Log("\t\tGot :", err)
+ t.Fatalf("\t%s\tParse claims from token failed.", tests.Failed)
+ }
+ expectClaims := tkn2.claims
+ expectClaims.RootUserID = authTest.root.UserID
+ expectClaims.RootAccountID = authTest.root.AccountID
+ expectClaims.Subject = authTest.login1Req.UserID
+ expectClaims.Audience = authTest.login1Req.AccountID
+
+ // Hack for Unhandled Exception in go-cmp@v0.3.0/cmp/options.go:229
+ if diff := cmpClaims(claims2, expectClaims); diff != "" {
+ t.Fatalf("\t%s\tExpected parsed claims to match from token. Diff:\n%s", tests.Failed, diff)
+ }
+ }
+ t.Logf("\t%s\tVirtualLogin user 1 with role %s ok.", tests.Success, authTest.login1Role)
+
+ // If the user can't login, don't need to test any further.
+ if authTest.login1Err != nil {
+ continue
+ }
+
+ // Try virtual login to user 3.
+ tkn3, err := VirtualLogin(ctx, test.MasterDB, tknGen, claims2, authTest.login2Req, time.Hour, now)
+ if err != authTest.login2Err {
+ if errors.Cause(err) != authTest.login2Err {
+ t.Log("\t\tExpected :", authTest.login2Err)
+ t.Log("\t\tGot :", err)
+ t.Fatalf("\t%s\tVirtualLogin user 2 with role %s failed.", tests.Failed, authTest.login2Role)
+ }
+ } else {
+ // Ensure the token string was correctly generated.
+ claims3, err := tknGen.ParseClaims(tkn3.AccessToken)
+ if err != nil {
+ t.Log("\t\tGot :", err)
+ t.Fatalf("\t%s\tParse claims from token failed.", tests.Failed)
+ }
+ expectClaims := tkn3.claims
+ expectClaims.RootUserID = authTest.root.UserID
+ expectClaims.RootAccountID = authTest.root.AccountID
+ expectClaims.Subject = authTest.login2Req.UserID
+ expectClaims.Audience = authTest.login2Req.AccountID
+
+ // Hack for Unhandled Exception in go-cmp@v0.3.0/cmp/options.go:229
+ if diff := cmpClaims(claims3, expectClaims); diff != "" {
+ t.Fatalf("\t%s\tExpected parsed claims to match from token. Diff:\n%s", tests.Failed, diff)
+ }
+ }
+ t.Logf("\t%s\tVirtualLogin user 2 with role %s ok.", tests.Success, authTest.login2Role)
+
+ if authTest.login2Logout {
+ tknOut, err := VirtualLogout(ctx, test.MasterDB, tknGen, claims2, time.Hour, now)
+ if err != nil {
+ t.Log("\t\tGot :", err)
+ t.Fatalf("\t%s\tVirtualLogout user 2 failed.", tests.Failed)
+ }
+
+ // Ensure the token string was correctly generated.
+ claimsOut, err := tknGen.ParseClaims(tknOut.AccessToken)
+ if err != nil {
+ t.Log("\t\tGot :", err)
+ t.Fatalf("\t%s\tParse claims from token failed.", tests.Failed)
+ }
+ expectClaims := tknOut.claims
+ expectClaims.RootUserID = authTest.root.UserID
+ expectClaims.RootAccountID = authTest.root.AccountID
+ expectClaims.Subject = authTest.root.UserID
+ expectClaims.Audience = authTest.root.AccountID
+
+ if diff := cmpClaims(claimsOut, expectClaims); diff != "" {
+ t.Fatalf("\t%s\tExpected parsed claims to match from token. Diff:\n%s", tests.Failed, diff)
+ }
+ t.Logf("\t%s\tVirtualLogout user 2 with role %s ok.", tests.Success, authTest.login2Role)
+ }
+ }
+ }
+ }
+}
+
+// rolesStringSlice converts a list of roles to a string slice.
+func rolesStringSlice(roles []user_account.UserAccountRole) []string {
+ var l []string
+ for _, r := range roles {
+ l = append(l, string(r))
+ }
+ return l
+}
+
+// cmpClaims is a hack for Unhandled Exception in go-cmp@v0.3.0/cmp/options.go:229
+func cmpClaims(actualClaims, expectedclaims auth.Claims) string {
+ dat1, _ := json.Marshal(actualClaims)
+ dat2, _ := json.Marshal(expectedclaims)
+ return cmp.Diff(string(dat1), string(dat2))
+}
diff --git a/internal/user_auth/models.go b/internal/user_auth/models.go
index 09411d3..31bd1ef 100644
--- a/internal/user_auth/models.go
+++ b/internal/user_auth/models.go
@@ -35,6 +35,17 @@ type Token struct {
AccountID string `json:"account_id"example:"c4653bf9-5978-48b7-89c5-95704aebb7e2"`
}
+// SwitchAccountRequest defines the information for the current user to switch between their accounts
+type SwitchAccountRequest struct {
+ AccountID string `json:"account_id" validate:"required,uuid" example:"c4653bf9-5978-48b7-89c5-95704aebb7e2"`
+}
+
+// VirtualLoginRequest defines the information virtual login to a user / account.
+type VirtualLoginRequest struct {
+ UserID string `json:"user_id" validate:"required,uuid" example:"d69bdef7-173f-4d29-b52c-3edc60baf6a2"`
+ AccountID string `json:"account_id" validate:"required,uuid" example:"c4653bf9-5978-48b7-89c5-95704aebb7e2"`
+}
+
// AuthorizationHeader returns the header authorization value.
func (t Token) AuthorizationHeader() string {
return "Bearer " + t.AccessToken