You've already forked golang-saas-starter-kit
mirror of
https://github.com/raseels-repos/golang-saas-starter-kit.git
synced 2025-06-15 00:15:15 +02:00
Completed virtual user login and switch accounts
This commit is contained in:
@ -4,7 +4,6 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"geeks-accelerator/oss/saas-starter-kit/internal/user_auth"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
@ -17,6 +16,7 @@ import (
|
||||
"geeks-accelerator/oss/saas-starter-kit/internal/platform/web/weberror"
|
||||
"geeks-accelerator/oss/saas-starter-kit/internal/user"
|
||||
"geeks-accelerator/oss/saas-starter-kit/internal/user_account"
|
||||
"geeks-accelerator/oss/saas-starter-kit/internal/user_auth"
|
||||
"github.com/pborman/uuid"
|
||||
)
|
||||
|
||||
|
@ -21,6 +21,7 @@ import (
|
||||
type Account struct {
|
||||
MasterDB *sqlx.DB
|
||||
Renderer web.Renderer
|
||||
Authenticator *auth.Authenticator
|
||||
}
|
||||
|
||||
// View handles displaying the current account profile.
|
||||
@ -34,7 +35,7 @@ func (h *Account) View(ctx context.Context, w http.ResponseWriter, r *http.Reque
|
||||
return err
|
||||
}
|
||||
|
||||
acc, err := account.Read(ctx, claims, h.MasterDB, claims.Audience, false)
|
||||
acc, err := account.ReadByID(ctx, claims, h.MasterDB, claims.Audience)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -127,7 +128,11 @@ func (h *Account) Update(ctx context.Context, w http.ResponseWriter, r *http.Req
|
||||
}
|
||||
}
|
||||
|
||||
sess := webcontext.ContextSession(ctx)
|
||||
var updateClaims bool
|
||||
if req.Timezone != nil && claims.Preferences.Timezone != *req.Timezone {
|
||||
claims.Preferences.Timezone = *req.Timezone
|
||||
updateClaims = true
|
||||
}
|
||||
|
||||
if preferenceDatetimeFormat != req.PreferenceDatetimeFormat {
|
||||
err = account_preference.Set(ctx, claims, h.MasterDB, account_preference.AccountPreferenceSetRequest{
|
||||
@ -144,7 +149,10 @@ func (h *Account) Update(ctx context.Context, w http.ResponseWriter, r *http.Req
|
||||
}
|
||||
}
|
||||
|
||||
sess.Values[webcontext.SessionKeyPreferenceDatetimeFormat] = req.PreferenceDatetimeFormat
|
||||
if claims.Preferences.DatetimeFormat != req.PreferenceDatetimeFormat {
|
||||
claims.Preferences.DatetimeFormat = req.PreferenceDatetimeFormat
|
||||
updateClaims = true
|
||||
}
|
||||
}
|
||||
|
||||
if preferenceDateFormat != req.PreferenceDateFormat {
|
||||
@ -162,7 +170,10 @@ func (h *Account) Update(ctx context.Context, w http.ResponseWriter, r *http.Req
|
||||
}
|
||||
}
|
||||
|
||||
sess.Values[webcontext.SessionKeyPreferenceDateFormat] = req.PreferenceDateFormat
|
||||
if claims.Preferences.DateFormat != req.PreferenceDateFormat {
|
||||
claims.Preferences.DateFormat = req.PreferenceDateFormat
|
||||
updateClaims = true
|
||||
}
|
||||
}
|
||||
|
||||
if preferenceTimeFormat != req.PreferenceTimeFormat {
|
||||
@ -180,7 +191,18 @@ func (h *Account) Update(ctx context.Context, w http.ResponseWriter, r *http.Req
|
||||
}
|
||||
}
|
||||
|
||||
sess.Values[webcontext.SessionKeyPreferenceTimeFormat] = req.PreferenceTimeFormat
|
||||
if claims.Preferences.TimeFormat != req.PreferenceTimeFormat {
|
||||
claims.Preferences.TimeFormat = req.PreferenceTimeFormat
|
||||
updateClaims = true
|
||||
}
|
||||
}
|
||||
|
||||
// Update the access token to include the updated claims.
|
||||
if updateClaims {
|
||||
ctx, err = updateContextClaims(ctx, h.Authenticator, claims)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
}
|
||||
|
||||
// Display a success message to the user.
|
||||
@ -196,7 +218,7 @@ func (h *Account) Update(ctx context.Context, w http.ResponseWriter, r *http.Req
|
||||
return true, nil
|
||||
}
|
||||
|
||||
acc, err := account.Read(ctx, claims, h.MasterDB, claims.Audience, false)
|
||||
acc, err := account.ReadByID(ctx, claims, h.MasterDB, claims.Audience)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
@ -66,6 +66,13 @@ func APP(shutdown chan os.Signal, log *log.Logger, env webcontext.Env, staticDir
|
||||
app.Handle("POST", "/user/update", u.Update, mid.AuthenticateSessionRequired(authenticator), mid.HasAuth())
|
||||
app.Handle("GET", "/user/update", u.Update, mid.AuthenticateSessionRequired(authenticator), mid.HasAuth())
|
||||
app.Handle("GET", "/user/account", u.Account, mid.AuthenticateSessionRequired(authenticator), mid.HasAuth())
|
||||
app.Handle("GET", "/user/virtual-login/:user_id", u.VirtualLogin, mid.AuthenticateSessionRequired(authenticator), mid.HasRole(auth.RoleAdmin))
|
||||
app.Handle("POST", "/user/virtual-login", u.VirtualLogin, mid.AuthenticateSessionRequired(authenticator), mid.HasRole(auth.RoleAdmin))
|
||||
app.Handle("GET", "/user/virtual-login", u.VirtualLogin, mid.AuthenticateSessionRequired(authenticator), mid.HasRole(auth.RoleAdmin))
|
||||
app.Handle("GET", "/user/virtual-logout", u.VirtualLogout, mid.AuthenticateSessionRequired(authenticator), mid.HasAuth())
|
||||
app.Handle("GET", "/user/switch-account/:account_id", u.SwitchAccount, mid.AuthenticateSessionRequired(authenticator), mid.HasAuth())
|
||||
app.Handle("POST", "/user/switch-account", u.SwitchAccount, mid.AuthenticateSessionRequired(authenticator), mid.HasAuth())
|
||||
app.Handle("GET", "/user/switch-account", u.SwitchAccount, mid.AuthenticateSessionRequired(authenticator), mid.HasAuth())
|
||||
app.Handle("POST", "/user", u.View, mid.AuthenticateSessionRequired(authenticator), mid.HasAuth())
|
||||
app.Handle("GET", "/user", u.View, mid.AuthenticateSessionRequired(authenticator), mid.HasAuth())
|
||||
|
||||
@ -73,6 +80,7 @@ func APP(shutdown chan os.Signal, log *log.Logger, env webcontext.Env, staticDir
|
||||
acc := Account{
|
||||
MasterDB: masterDB,
|
||||
Renderer: renderer,
|
||||
Authenticator: authenticator,
|
||||
}
|
||||
app.Handle("POST", "/account/update", acc.Update, mid.AuthenticateSessionRequired(authenticator), mid.HasRole(auth.RoleAdmin))
|
||||
app.Handle("GET", "/account/update", acc.Update, mid.AuthenticateSessionRequired(authenticator), mid.HasRole(auth.RoleAdmin))
|
||||
|
@ -12,7 +12,7 @@ import (
|
||||
"geeks-accelerator/oss/saas-starter-kit/internal/platform/web/webcontext"
|
||||
"geeks-accelerator/oss/saas-starter-kit/internal/platform/web/weberror"
|
||||
"geeks-accelerator/oss/saas-starter-kit/internal/signup"
|
||||
"geeks-accelerator/oss/saas-starter-kit/internal/user"
|
||||
"geeks-accelerator/oss/saas-starter-kit/internal/user_auth"
|
||||
"github.com/gorilla/schema"
|
||||
"github.com/jmoiron/sqlx"
|
||||
"github.com/pkg/errors"
|
||||
@ -68,7 +68,7 @@ func (h *Signup) Step1(ctx context.Context, w http.ResponseWriter, r *http.Reque
|
||||
}
|
||||
|
||||
// Authenticated the new user.
|
||||
token, err := user.Authenticate(ctx, h.MasterDB, h.Authenticator, req.User.Email, req.User.Password, time.Hour, ctxValues.Now)
|
||||
token, err := user_auth.Authenticate(ctx, h.MasterDB, h.Authenticator, req.User.Email, req.User.Password, time.Hour, ctxValues.Now)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -93,13 +93,6 @@ func (h *Signup) Step1(ctx context.Context, w http.ResponseWriter, r *http.Reque
|
||||
return nil
|
||||
}
|
||||
|
||||
data["geonameCountries"] = geonames.ValidGeonameCountries
|
||||
|
||||
data["countries"], err = geonames.FindCountries(ctx, h.MasterDB, "name", "")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -107,6 +100,13 @@ func (h *Signup) Step1(ctx context.Context, w http.ResponseWriter, r *http.Reque
|
||||
return web.RenderError(ctx, w, r, err, h.Renderer, TmplLayoutBase, TmplContentErrorGeneric, web.MIMETextHTMLCharsetUTF8)
|
||||
}
|
||||
|
||||
data["geonameCountries"] = geonames.ValidGeonameCountries
|
||||
|
||||
data["countries"], err = geonames.FindCountries(ctx, h.MasterDB, "name", "")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
data["form"] = req
|
||||
|
||||
if verr, ok := weberror.NewValidationError(ctx, webcontext.Validator().Struct(signup.SignupRequest{})); ok {
|
||||
|
@ -4,6 +4,8 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"geeks-accelerator/oss/saas-starter-kit/internal/account"
|
||||
@ -16,6 +18,7 @@ import (
|
||||
project_routes "geeks-accelerator/oss/saas-starter-kit/internal/project-routes"
|
||||
"geeks-accelerator/oss/saas-starter-kit/internal/user"
|
||||
"geeks-accelerator/oss/saas-starter-kit/internal/user_account"
|
||||
"geeks-accelerator/oss/saas-starter-kit/internal/user_auth"
|
||||
"github.com/gorilla/schema"
|
||||
"github.com/gorilla/sessions"
|
||||
"github.com/jmoiron/sqlx"
|
||||
@ -34,7 +37,7 @@ type User struct {
|
||||
}
|
||||
|
||||
type UserLoginRequest struct {
|
||||
user.AuthenticateRequest
|
||||
user_auth.AuthenticateRequest
|
||||
RememberMe bool
|
||||
}
|
||||
|
||||
@ -68,7 +71,7 @@ func (h *User) Login(ctx context.Context, w http.ResponseWriter, r *http.Request
|
||||
}
|
||||
|
||||
// Authenticated the user.
|
||||
token, err := user.Authenticate(ctx, h.MasterDB, h.Authenticator, req.Email, req.Password, sessionTTL, ctxValues.Now)
|
||||
token, err := user_auth.Authenticate(ctx, h.MasterDB, h.Authenticator, req.Email, req.Password, sessionTTL, ctxValues.Now)
|
||||
if err != nil {
|
||||
switch errors.Cause(err) {
|
||||
case user.ErrForbidden:
|
||||
@ -89,8 +92,16 @@ func (h *User) Login(ctx context.Context, w http.ResponseWriter, r *http.Request
|
||||
return err
|
||||
}
|
||||
|
||||
redirectUri := "/"
|
||||
if qv := r.URL.Query().Get("redirect"); qv != "" {
|
||||
redirectUri, err = url.QueryUnescape(qv)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Redirect the user to the dashboard.
|
||||
http.Redirect(w, r, "/", http.StatusFound)
|
||||
http.Redirect(w, r, redirectUri, http.StatusFound)
|
||||
}
|
||||
|
||||
return nil
|
||||
@ -110,7 +121,7 @@ func (h *User) Login(ctx context.Context, w http.ResponseWriter, r *http.Request
|
||||
}
|
||||
|
||||
// handleSessionToken persists the access token to the session for request authentication.
|
||||
func handleSessionToken(ctx context.Context, db *sqlx.DB, w http.ResponseWriter, r *http.Request, token user.Token) error {
|
||||
func handleSessionToken(ctx context.Context, db *sqlx.DB, w http.ResponseWriter, r *http.Request, token user_auth.Token) error {
|
||||
if token.AccessToken == "" {
|
||||
return errors.New("accessToken is required.")
|
||||
}
|
||||
@ -252,7 +263,7 @@ func (h *User) ResetConfirm(ctx context.Context, w http.ResponseWriter, r *http.
|
||||
}
|
||||
|
||||
// Authenticated the user. Probably should use the default session TTL from UserLogin.
|
||||
token, err := user.Authenticate(ctx, h.MasterDB, h.Authenticator, u.Email, req.Password, time.Hour, ctxValues.Now)
|
||||
token, err := user_auth.Authenticate(ctx, h.MasterDB, h.Authenticator, u.Email, req.Password, time.Hour, ctxValues.Now)
|
||||
if err != nil {
|
||||
switch errors.Cause(err) {
|
||||
case account.ErrForbidden:
|
||||
@ -306,7 +317,7 @@ func (h *User) View(ctx context.Context, w http.ResponseWriter, r *http.Request,
|
||||
return err
|
||||
}
|
||||
|
||||
usr, err := user.Read(ctx, claims, h.MasterDB, claims.Subject, false)
|
||||
usr, err := user.ReadByID(ctx, claims, h.MasterDB, claims.Subject)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -343,16 +354,15 @@ func (h *User) Update(ctx context.Context, w http.ResponseWriter, r *http.Reques
|
||||
return err
|
||||
}
|
||||
|
||||
claims, err := auth.ClaimsFromContext(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
//
|
||||
req := new(user.UserUpdateRequest)
|
||||
data := make(map[string]interface{})
|
||||
f := func() (bool, error) {
|
||||
|
||||
claims, err := auth.ClaimsFromContext(ctx)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if r.Method == http.MethodPost {
|
||||
err := r.ParseForm()
|
||||
if err != nil {
|
||||
@ -415,9 +425,19 @@ func (h *User) Update(ctx context.Context, w http.ResponseWriter, r *http.Reques
|
||||
return true, nil
|
||||
}
|
||||
|
||||
usr, err := user.Read(ctx, claims, h.MasterDB, claims.Subject, false)
|
||||
return false, nil
|
||||
}
|
||||
|
||||
end, err := f()
|
||||
if err != nil {
|
||||
return false, err
|
||||
return web.RenderError(ctx, w, r, err, h.Renderer, TmplLayoutBase, TmplContentErrorGeneric, web.MIMETextHTMLCharsetUTF8)
|
||||
} else if end {
|
||||
return nil
|
||||
}
|
||||
|
||||
usr, err := user.ReadByID(ctx, claims, h.MasterDB, claims.Subject)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if req.ID == "" {
|
||||
@ -431,17 +451,7 @@ func (h *User) Update(ctx context.Context, w http.ResponseWriter, r *http.Reques
|
||||
|
||||
data["timezones"], err = geonames.ListTimezones(ctx, h.MasterDB)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
end, err := f()
|
||||
if err != nil {
|
||||
return web.RenderError(ctx, w, r, err, h.Renderer, TmplLayoutBase, TmplContentErrorGeneric, web.MIMETextHTMLCharsetUTF8)
|
||||
} else if end {
|
||||
return nil
|
||||
return err
|
||||
}
|
||||
|
||||
data["form"] = req
|
||||
@ -468,7 +478,7 @@ func (h *User) Account(ctx context.Context, w http.ResponseWriter, r *http.Reque
|
||||
return err
|
||||
}
|
||||
|
||||
acc, err := account.Read(ctx, claims, h.MasterDB, claims.Audience, false)
|
||||
acc, err := account.ReadByID(ctx, claims, h.MasterDB, claims.Audience)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -483,3 +493,352 @@ func (h *User) Account(ctx context.Context, w http.ResponseWriter, r *http.Reque
|
||||
|
||||
return h.Renderer.Render(ctx, w, r, TmplLayoutBase, "user-account.gohtml", web.MIMETextHTMLCharsetUTF8, http.StatusOK, data)
|
||||
}
|
||||
|
||||
// VirtualLogin handles switching the scope of the context to another user.
|
||||
func (h *User) VirtualLogin(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error {
|
||||
|
||||
ctxValues, err := webcontext.ContextValues(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
claims, err := auth.ClaimsFromContext(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
//
|
||||
req := new(user_auth.VirtualLoginRequest)
|
||||
data := make(map[string]interface{})
|
||||
f := func() (bool, error) {
|
||||
if r.Method == http.MethodPost {
|
||||
err := r.ParseForm()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
decoder := schema.NewDecoder()
|
||||
decoder.IgnoreUnknownKeys(true)
|
||||
|
||||
if err := decoder.Decode(req, r.PostForm); err != nil {
|
||||
return false, err
|
||||
}
|
||||
} else {
|
||||
if pv, ok := params["user_id"]; ok && pv != "" {
|
||||
req.UserID = pv
|
||||
}
|
||||
}
|
||||
|
||||
if qv := r.URL.Query().Get("account_id"); qv != "" {
|
||||
req.AccountID = qv
|
||||
} else {
|
||||
req.AccountID = claims.Audience
|
||||
}
|
||||
|
||||
if req.UserID != "" {
|
||||
sess := webcontext.ContextSession(ctx)
|
||||
var expires time.Duration
|
||||
if sess != nil && sess.Options != nil {
|
||||
expires = time.Second * time.Duration(sess.Options.MaxAge)
|
||||
} else {
|
||||
expires = time.Hour
|
||||
}
|
||||
|
||||
// Perform the account switch.
|
||||
tkn, err := user_auth.VirtualLogin(ctx, h.MasterDB, h.Authenticator, claims, *req, expires, ctxValues.Now)
|
||||
if err != nil {
|
||||
if verr, ok := weberror.NewValidationError(ctx, err); ok {
|
||||
data["validationErrors"] = verr.(*weberror.Error)
|
||||
return false, nil
|
||||
} else {
|
||||
return false, err
|
||||
}
|
||||
}
|
||||
|
||||
// Update the access token in the session.
|
||||
sess = webcontext.SessionUpdateAccessToken(sess, tkn.AccessToken)
|
||||
|
||||
// Read the account for a flash message.
|
||||
usr, err := user.ReadByID(ctx, claims, h.MasterDB, tkn.UserID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
webcontext.SessionFlashSuccess(ctx,
|
||||
"User Switched",
|
||||
fmt.Sprintf("You are now virtually logged into user %s.",
|
||||
usr.Response(ctx).Name))
|
||||
|
||||
// Write the session to the client.
|
||||
err = webcontext.ContextSession(ctx).Save(r, w)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
// Redirect the user to the dashboard with the new credentials.
|
||||
http.Redirect(w, r, "/", http.StatusFound)
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
end, err := f()
|
||||
if err != nil {
|
||||
return web.RenderError(ctx, w, r, err, h.Renderer, TmplLayoutBase, TmplContentErrorGeneric, web.MIMETextHTMLCharsetUTF8)
|
||||
} else if end {
|
||||
return nil
|
||||
}
|
||||
|
||||
usrAccFilter := "account_id = ?"
|
||||
usrAccs, err := user_account.Find(ctx, claims, h.MasterDB, user_account.UserAccountFindRequest{
|
||||
Where: &usrAccFilter,
|
||||
Args: []interface{}{claims.Audience},
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var userIDs []interface{}
|
||||
var userPhs []string
|
||||
for _, usrAcc := range usrAccs {
|
||||
if usrAcc.UserID == claims.Subject {
|
||||
// Skip the current authenticated user.
|
||||
continue
|
||||
}
|
||||
userIDs = append(userIDs, usrAcc.UserID)
|
||||
userPhs = append(userPhs, "?")
|
||||
}
|
||||
|
||||
if len(userIDs) == 0 {
|
||||
userIDs = append(userIDs, "")
|
||||
userPhs = append(userPhs, "?")
|
||||
}
|
||||
|
||||
usrFilter := fmt.Sprintf("id IN (%s)", strings.Join(userPhs, ", "))
|
||||
users, err := user.Find(ctx, claims, h.MasterDB, user.UserFindRequest{
|
||||
Where: &usrFilter,
|
||||
Args: userIDs,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
data["users"] = users.Response(ctx)
|
||||
|
||||
if req.AccountID == "" {
|
||||
req.AccountID = claims.Audience
|
||||
}
|
||||
|
||||
data["form"] = req
|
||||
|
||||
if verr, ok := weberror.NewValidationError(ctx, webcontext.Validator().Struct(user_auth.VirtualLoginRequest{})); ok {
|
||||
data["validationDefaults"] = verr.(*weberror.Error)
|
||||
}
|
||||
|
||||
return h.Renderer.Render(ctx, w, r, TmplLayoutBase, "user-virtual-login.gohtml", web.MIMETextHTMLCharsetUTF8, http.StatusOK, data)
|
||||
}
|
||||
|
||||
// VirtualLogout handles switching the scope back to the user who initiated the virtual login.
|
||||
func (h *User) VirtualLogout(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error {
|
||||
|
||||
ctxValues, err := webcontext.ContextValues(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
claims, err := auth.ClaimsFromContext(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
sess := webcontext.ContextSession(ctx)
|
||||
|
||||
var expires time.Duration
|
||||
if sess != nil && sess.Options != nil {
|
||||
expires = time.Second * time.Duration(sess.Options.MaxAge)
|
||||
} else {
|
||||
expires = time.Hour
|
||||
}
|
||||
|
||||
tkn, err := user_auth.VirtualLogout(ctx, h.MasterDB, h.Authenticator, claims, expires, ctxValues.Now)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Update the access token in the session.
|
||||
sess = webcontext.SessionUpdateAccessToken(sess, tkn.AccessToken)
|
||||
|
||||
// Display a success message to verify the user has switched contexts.
|
||||
if claims.Subject != tkn.UserID && claims.Audience != tkn.AccountID {
|
||||
usr, err := user.ReadByID(ctx, claims, h.MasterDB, tkn.UserID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
acc, err := account.ReadByID(ctx, claims, h.MasterDB, tkn.AccountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
webcontext.SessionFlashSuccess(ctx,
|
||||
"Context Switched",
|
||||
fmt.Sprintf("You are now virtually logged back into account %s user %s.",
|
||||
acc.Response(ctx).Name, usr.Response(ctx).Name))
|
||||
} else if claims.Audience != tkn.AccountID {
|
||||
acc, err := account.ReadByID(ctx, claims, h.MasterDB, tkn.AccountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
webcontext.SessionFlashSuccess(ctx,
|
||||
"Context Switched",
|
||||
fmt.Sprintf("You are now virtually logged back into account %s.",
|
||||
acc.Response(ctx).Name))
|
||||
} else {
|
||||
usr, err := user.ReadByID(ctx, claims, h.MasterDB, tkn.UserID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
webcontext.SessionFlashSuccess(ctx,
|
||||
"Context Switched",
|
||||
fmt.Sprintf("You are now virtually logged back into user %s.",
|
||||
usr.Response(ctx).Name))
|
||||
}
|
||||
|
||||
// Write the session to the client.
|
||||
err = webcontext.ContextSession(ctx).Save(r, w)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Redirect the user to the dashboard with the new credentials.
|
||||
http.Redirect(w, r, "/", http.StatusFound)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// VirtualLogin handles switching the scope of the context to another user.
|
||||
func (h *User) SwitchAccount(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error {
|
||||
|
||||
ctxValues, err := webcontext.ContextValues(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
claims, err := auth.ClaimsFromContext(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
//
|
||||
req := new(user_auth.SwitchAccountRequest)
|
||||
data := make(map[string]interface{})
|
||||
f := func() (bool, error) {
|
||||
|
||||
if r.Method == http.MethodPost {
|
||||
err := r.ParseForm()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
decoder := schema.NewDecoder()
|
||||
decoder.IgnoreUnknownKeys(true)
|
||||
|
||||
if err := decoder.Decode(req, r.PostForm); err != nil {
|
||||
return false, err
|
||||
}
|
||||
} else {
|
||||
if pv, ok := params["account_id"]; ok && pv != "" {
|
||||
req.AccountID = pv
|
||||
} else if qv := r.URL.Query().Get("account_id"); qv != "" {
|
||||
req.AccountID = qv
|
||||
}
|
||||
}
|
||||
|
||||
if req.AccountID != "" {
|
||||
sess := webcontext.ContextSession(ctx)
|
||||
var expires time.Duration
|
||||
if sess != nil && sess.Options != nil {
|
||||
expires = time.Second * time.Duration(sess.Options.MaxAge)
|
||||
} else {
|
||||
expires = time.Hour
|
||||
}
|
||||
|
||||
// Perform the account switch.
|
||||
tkn, err := user_auth.SwitchAccount(ctx, h.MasterDB, h.Authenticator, claims, *req, expires, ctxValues.Now)
|
||||
if err != nil {
|
||||
if verr, ok := weberror.NewValidationError(ctx, err); ok {
|
||||
data["validationErrors"] = verr.(*weberror.Error)
|
||||
return false, nil
|
||||
} else {
|
||||
return false, err
|
||||
}
|
||||
}
|
||||
|
||||
// Update the access token in the session.
|
||||
sess = webcontext.SessionUpdateAccessToken(sess, tkn.AccessToken)
|
||||
|
||||
// Read the account for a flash message.
|
||||
acc, err := account.ReadByID(ctx, claims, h.MasterDB, tkn.AccountID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
webcontext.SessionFlashSuccess(ctx,
|
||||
"Account Switched",
|
||||
fmt.Sprintf("You are now logged into account %s.",
|
||||
acc.Response(ctx).Name))
|
||||
|
||||
// Write the session to the client.
|
||||
err = webcontext.ContextSession(ctx).Save(r, w)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
// Redirect the user to the dashboard with the new credentials.
|
||||
http.Redirect(w, r, "/", http.StatusFound)
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
end, err := f()
|
||||
if err != nil {
|
||||
return web.RenderError(ctx, w, r, err, h.Renderer, TmplLayoutBase, TmplContentErrorGeneric, web.MIMETextHTMLCharsetUTF8)
|
||||
} else if end {
|
||||
return nil
|
||||
}
|
||||
|
||||
accounts, err := account.Find(ctx, claims, h.MasterDB, account.AccountFindRequest{
|
||||
Order: []string{"name"},
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
data["accounts"] = accounts.Response(ctx)
|
||||
|
||||
if req.AccountID == "" {
|
||||
req.AccountID = claims.Audience
|
||||
}
|
||||
|
||||
data["form"] = req
|
||||
|
||||
if verr, ok := weberror.NewValidationError(ctx, webcontext.Validator().Struct(user_auth.SwitchAccountRequest{})); ok {
|
||||
data["validationDefaults"] = verr.(*weberror.Error)
|
||||
}
|
||||
|
||||
return h.Renderer.Render(ctx, w, r, TmplLayoutBase, "user-switch-account.gohtml", web.MIMETextHTMLCharsetUTF8, http.StatusOK, data)
|
||||
}
|
||||
|
||||
// updateContextClaims updates the claims in the context.
|
||||
func updateContextClaims(ctx context.Context, authenticator *auth.Authenticator, claims auth.Claims) (context.Context, error) {
|
||||
tkn, err := authenticator.GenerateToken(claims)
|
||||
if err != nil {
|
||||
return ctx, err
|
||||
}
|
||||
|
||||
sess := webcontext.ContextSession(ctx)
|
||||
sess = webcontext.SessionUpdateAccessToken(sess, tkn)
|
||||
|
||||
ctx = context.WithValue(ctx, auth.Key, claims)
|
||||
|
||||
return ctx, nil
|
||||
}
|
||||
|
@ -698,7 +698,7 @@ func main() {
|
||||
return nil
|
||||
}
|
||||
|
||||
usr, err := user.Read(ctx, auth.Claims{}, masterDb, claims.Subject, false)
|
||||
usr, err := user.ReadByID(ctx, auth.Claims{}, masterDb, claims.Subject)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
@ -733,7 +733,7 @@ func main() {
|
||||
return nil
|
||||
}
|
||||
|
||||
acc, err := account.Read(ctx, auth.Claims{}, masterDb, claims.Audience, false)
|
||||
acc, err := account.ReadByID(ctx, auth.Claims{}, masterDb, claims.Audience)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
@ -746,6 +746,26 @@ func main() {
|
||||
|
||||
return a
|
||||
},
|
||||
"ContextCanSwitchAccount": func(ctx context.Context) bool {
|
||||
claims, err := auth.ClaimsFromContext(ctx)
|
||||
if err != nil || len(claims.AccountIDs) < 2 {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
},
|
||||
"ContextIsVirtualSession": func(ctx context.Context) bool {
|
||||
claims, err := auth.ClaimsFromContext(ctx)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
if claims.RootUserID != "" && claims.RootUserID != claims.Subject {
|
||||
return true
|
||||
}
|
||||
if claims.RootAccountID != "" && claims.RootAccountID != claims.Audience {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
},
|
||||
}
|
||||
|
||||
imgUrlFormatter := staticUrlFormatter
|
||||
@ -827,12 +847,9 @@ func main() {
|
||||
|
||||
switch statusCode {
|
||||
case http.StatusUnauthorized:
|
||||
// Handle expired sessions that are returned from the auth middleware.
|
||||
if strings.Contains(errors.Cause(er).Error(), "token is expired") {
|
||||
http.Redirect(w, r, "/user/login", http.StatusFound)
|
||||
http.Redirect(w, r, "/user/login?redirect="+url.QueryEscape(r.RequestURI), http.StatusFound)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return web.RenderError(ctx, w, r, er, renderer, handlers.TmplLayoutBase, handlers.TmplContentErrorGeneric, web.MIMETextHTMLCharsetUTF8)
|
||||
}
|
||||
|
BIN
cmd/web-app/static/assets/images/user-default.jpg
Normal file
BIN
cmd/web-app/static/assets/images/user-default.jpg
Normal file
Binary file not shown.
After Width: | Height: | Size: 40 KiB |
60
cmd/web-app/templates/content/user-switch-account.gohtml
Normal file
60
cmd/web-app/templates/content/user-switch-account.gohtml
Normal file
@ -0,0 +1,60 @@
|
||||
{{define "title"}}Switch Account{{end}}
|
||||
{{define "style"}}
|
||||
|
||||
{{end}}
|
||||
{{ define "partials/page-wrapper" }}
|
||||
<div class="container" id="page-content">
|
||||
|
||||
<!-- Outer Row -->
|
||||
<div class="row justify-content-center">
|
||||
|
||||
<div class="col-xl-10 col-lg-12 col-md-9">
|
||||
|
||||
<div class="card o-hidden border-0 shadow-lg my-5">
|
||||
<div class="card-body p-0">
|
||||
<!-- Nested Row within Card Body -->
|
||||
<div class="row">
|
||||
<div class="col-lg-6 d-none d-lg-block bg-login-image"></div>
|
||||
<div class="col-lg-6">
|
||||
<div class="p-5">
|
||||
{{ template "app-flashes" . }}
|
||||
|
||||
<div class="text-center">
|
||||
<h1 class="h4 text-gray-900 mb-4">Switch Account</h1>
|
||||
</div>
|
||||
|
||||
{{ template "validation-error" . }}
|
||||
|
||||
<form class="user" method="post" novalidate>
|
||||
<div class="form-group">
|
||||
<select class="form-control form-control-select-box {{ ValidationFieldClass $.validationErrors "AccountID" }}" name="AccountID" placeholder="AccountID" required>
|
||||
{{ range $i := $.accounts }}
|
||||
<option value="{{ $i.ID }}" {{ if eq $.form.AccountID $i.ID }}selected="selected"{{ end }}>{{ $i.Name }}</option>
|
||||
{{ end }}
|
||||
</select>
|
||||
{{template "invalid-feedback" dict "validationDefaults" $.validationDefaults "validationErrors" $.validationErrors "fieldName" "AccountID" }}
|
||||
</div>
|
||||
<button class="btn btn-primary btn-user btn-block">
|
||||
Login
|
||||
</button>
|
||||
<hr>
|
||||
</form>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
|
||||
</div>
|
||||
|
||||
</div>
|
||||
{{end}}
|
||||
{{define "js"}}
|
||||
<script>
|
||||
$(document).ready(function() {
|
||||
$(document).find('body').addClass('bg-gradient-primary');
|
||||
});
|
||||
</script>
|
||||
{{end}}
|
60
cmd/web-app/templates/content/user-virtual-login.gohtml
Normal file
60
cmd/web-app/templates/content/user-virtual-login.gohtml
Normal file
@ -0,0 +1,60 @@
|
||||
{{define "title"}}Switch User{{end}}
|
||||
{{define "style"}}
|
||||
|
||||
{{end}}
|
||||
{{ define "partials/page-wrapper" }}
|
||||
<div class="container" id="page-content">
|
||||
|
||||
<!-- Outer Row -->
|
||||
<div class="row justify-content-center">
|
||||
|
||||
<div class="col-xl-10 col-lg-12 col-md-9">
|
||||
|
||||
<div class="card o-hidden border-0 shadow-lg my-5">
|
||||
<div class="card-body p-0">
|
||||
<!-- Nested Row within Card Body -->
|
||||
<div class="row">
|
||||
<div class="col-lg-6 d-none d-lg-block bg-login-image"></div>
|
||||
<div class="col-lg-6">
|
||||
<div class="p-5">
|
||||
{{ template "app-flashes" . }}
|
||||
|
||||
<div class="text-center">
|
||||
<h1 class="h4 text-gray-900 mb-4">Switch User</h1>
|
||||
</div>
|
||||
|
||||
{{ template "validation-error" . }}
|
||||
|
||||
<form class="user" method="post" novalidate>
|
||||
<div class="form-group">
|
||||
<select class="form-control form-control-select-box {{ ValidationFieldClass $.validationErrors "User" }}" name="UserID" placeholder="UserID" required>
|
||||
{{ range $i := $.users }}
|
||||
<option value="{{ $i.ID }}" {{ if eq $.form.UserID $i.ID }}selected="selected"{{ end }}>{{ $i.Name }}</option>
|
||||
{{ end }}
|
||||
</select>
|
||||
{{template "invalid-feedback" dict "validationDefaults" $.validationDefaults "validationErrors" $.validationErrors "fieldName" "UserID" }}
|
||||
</div>
|
||||
<button class="btn btn-primary btn-user btn-block">
|
||||
Login
|
||||
</button>
|
||||
<hr>
|
||||
</form>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
|
||||
</div>
|
||||
|
||||
</div>
|
||||
{{end}}
|
||||
{{define "js"}}
|
||||
<script>
|
||||
$(document).ready(function() {
|
||||
$(document).find('body').addClass('bg-gradient-primary');
|
||||
});
|
||||
</script>
|
||||
{{end}}
|
@ -177,7 +177,6 @@
|
||||
<i class="fas fa-cogs fa-sm fa-fw mr-2 text-gray-400"></i>
|
||||
Account Settings
|
||||
</a>
|
||||
|
||||
<a class="dropdown-item" href="/users">
|
||||
<i class="fas fa-cogs fa-sm fa-fw mr-2 text-gray-400"></i>
|
||||
Manage Users
|
||||
@ -197,8 +196,22 @@
|
||||
<i class="fas fa-cogs fa-sm fa-fw mr-2 text-gray-400"></i>
|
||||
Support
|
||||
</a>
|
||||
|
||||
<div class="dropdown-divider"></div>
|
||||
|
||||
{{ if ContextCanSwitchAccount $._Ctx }}
|
||||
<a class="dropdown-item" href="/user/switch-account?modal=1">
|
||||
<i class="far fa-sign-in fa-sm fa-fw mr-2 text-gray-400"></i>
|
||||
Switch Account
|
||||
</a>
|
||||
{{ end }}
|
||||
|
||||
{{ if ContextIsVirtualSession $._Ctx }}
|
||||
<a class="dropdown-item" href="/user/virtual-logout">
|
||||
<i class="far fa-sign-out fa-sm fa-fw mr-2 text-gray-400"></i>
|
||||
Switch Back
|
||||
</a>
|
||||
{{ end }}
|
||||
|
||||
<a class="dropdown-item" href="/user/logout" data-toggle="modal" data-target="#logoutModal">
|
||||
<i class="fas fa-sign-out-alt fa-sm fa-fw mr-2 text-gray-400"></i>
|
||||
Logout
|
||||
|
@ -150,7 +150,7 @@ func selectQuery() *sqlbuilder.SelectBuilder {
|
||||
// Find gets all the accounts from the database based on the request params.
|
||||
// TODO: Need to figure out why can't parse the args when appending the where
|
||||
// to the query.
|
||||
func Find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req AccountFindRequest) ([]*Account, error) {
|
||||
func Find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req AccountFindRequest) (Accounts, error) {
|
||||
query := selectQuery()
|
||||
|
||||
if req.Where != nil {
|
||||
@ -170,7 +170,7 @@ func Find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req AccountF
|
||||
}
|
||||
|
||||
// find internal method for getting all the accounts from the database using a select query.
|
||||
func find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, query *sqlbuilder.SelectBuilder, args []interface{}, includedArchived bool) ([]*Account, error) {
|
||||
func find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, query *sqlbuilder.SelectBuilder, args []interface{}, includedArchived bool) (Accounts, error) {
|
||||
span, ctx := tracer.StartSpanFromContext(ctx, "internal.account.Find")
|
||||
defer span.Finish()
|
||||
|
||||
|
@ -99,6 +99,21 @@ func (m *AccountResponse) MarshalBinary() ([]byte, error) {
|
||||
return json.Marshal(m)
|
||||
}
|
||||
|
||||
// Accounts a list of Accounts.
|
||||
type Accounts []*Account
|
||||
|
||||
// Response transforms a list of Accounts to a list of AccountResponses.
|
||||
func (m *Accounts) Response(ctx context.Context) []*AccountResponse {
|
||||
var l []*AccountResponse
|
||||
if m != nil && len(*m) > 0 {
|
||||
for _, n := range *m {
|
||||
l = append(l, n.Response(ctx))
|
||||
}
|
||||
}
|
||||
|
||||
return l
|
||||
}
|
||||
|
||||
// AccountCreateRequest contains information needed to create a new Account.
|
||||
type AccountCreateRequest struct {
|
||||
Name string `json:"name" validate:"required,unique" example:"Company Name"`
|
||||
|
@ -23,7 +23,9 @@ const Key ctxKey = 1
|
||||
|
||||
// Claims represents the authorization claims transmitted via a JWT.
|
||||
type Claims struct {
|
||||
AccountIds []string `json:"accounts"`
|
||||
RootUserID string `json:"root_user_id"`
|
||||
RootAccountID string `json:"root_account_id"`
|
||||
AccountIDs []string `json:"accounts"`
|
||||
Roles []string `json:"roles"`
|
||||
Preferences ClaimPreferences `json:"prefs"`
|
||||
jwt.StandardClaims
|
||||
@ -41,14 +43,16 @@ type ClaimPreferences struct {
|
||||
// NewClaims constructs a Claims value for the identified user. The Claims
|
||||
// expire within a specified duration of the provided time. Additional fields
|
||||
// of the Claims can be set after calling NewClaims is desired.
|
||||
func NewClaims(userId, accountId string, accountIds []string, roles []string, prefs ClaimPreferences, now time.Time, expires time.Duration) Claims {
|
||||
func NewClaims(userID, accountID string, accountIDs []string, roles []string, prefs ClaimPreferences, now time.Time, expires time.Duration) Claims {
|
||||
c := Claims{
|
||||
AccountIds: accountIds,
|
||||
AccountIDs: accountIDs,
|
||||
RootAccountID: accountID,
|
||||
RootUserID: userID,
|
||||
Roles: roles,
|
||||
Preferences: prefs,
|
||||
StandardClaims: jwt.StandardClaims{
|
||||
Subject: userId,
|
||||
Audience: accountId,
|
||||
Subject: userID,
|
||||
Audience: accountID,
|
||||
IssuedAt: now.Unix(),
|
||||
ExpiresAt: now.Add(expires).Unix(),
|
||||
},
|
||||
|
@ -14,39 +14,23 @@ const KeySession ctxKeySession = 1
|
||||
// Session keys used to store values.
|
||||
const (
|
||||
SessionKeyAccessToken = iota
|
||||
//SessionKeyPreferenceDatetimeFormat
|
||||
//SessionKeyPreferenceDateFormat
|
||||
//SessionKeyPreferenceTimeFormat
|
||||
//SessionKeyTimezone
|
||||
)
|
||||
|
||||
func init() {
|
||||
//gob.Register(&Session{})
|
||||
}
|
||||
|
||||
// Session represents a user with authentication.
|
||||
type Session struct {
|
||||
*sessions.Session
|
||||
}
|
||||
|
||||
// ContextWithSession appends a universal translator to a context.
|
||||
func ContextWithSession(ctx context.Context, session *sessions.Session) context.Context {
|
||||
return context.WithValue(ctx, KeySession, session)
|
||||
}
|
||||
|
||||
// ContextSession returns the session from a context.
|
||||
func ContextSession(ctx context.Context) *Session {
|
||||
if s, ok := ctx.Value(KeySession).(*Session); ok {
|
||||
func ContextSession(ctx context.Context) *sessions.Session {
|
||||
if s, ok := ctx.Value(KeySession).(*sessions.Session); ok {
|
||||
return s
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func ContextAccessToken(ctx context.Context) (string, bool) {
|
||||
return ContextSession(ctx).AccessToken()
|
||||
}
|
||||
|
||||
func (sess *Session) AccessToken() (string, bool) {
|
||||
sess := ContextSession(ctx)
|
||||
if sess == nil {
|
||||
return "", false
|
||||
}
|
||||
@ -56,60 +40,19 @@ func (sess *Session) AccessToken() (string, bool) {
|
||||
return "", false
|
||||
}
|
||||
|
||||
/*
|
||||
func(sess *Session) PreferenceDatetimeFormat() (string, bool) {
|
||||
if sess == nil {
|
||||
return "", false
|
||||
}
|
||||
if sv, ok := sess.Values[SessionKeyPreferenceDatetimeFormat].(string); ok {
|
||||
return sv, true
|
||||
}
|
||||
return "", false
|
||||
}
|
||||
|
||||
func(sess *Session) PreferenceDateFormat() (string, bool) {
|
||||
if sess == nil {
|
||||
return "", false
|
||||
}
|
||||
if sv, ok := sess.Values[SessionKeyPreferenceDateFormat].(string); ok {
|
||||
return sv, true
|
||||
}
|
||||
return "", false
|
||||
}
|
||||
|
||||
func(sess *Session) PreferenceTimeFormat() (string, bool) {
|
||||
if sess == nil {
|
||||
return "", false
|
||||
}
|
||||
if sv, ok := sess.Values[SessionKeyPreferenceTimeFormat].(string); ok {
|
||||
return sv, true
|
||||
}
|
||||
return "", false
|
||||
}
|
||||
|
||||
func(sess *Session) Timezone() (*time.Location, bool) {
|
||||
if sess != nil {
|
||||
if sv, ok := sess.Values[SessionKeyTimezone].(*time.Location); ok {
|
||||
return sv, true
|
||||
}
|
||||
}
|
||||
|
||||
return nil, false
|
||||
}
|
||||
*/
|
||||
|
||||
func SessionInit(session *Session, accessToken string) *Session {
|
||||
func SessionInit(session *sessions.Session, accessToken string) *sessions.Session {
|
||||
|
||||
session.Values[SessionKeyAccessToken] = accessToken
|
||||
//session.Values[SessionKeyPreferenceDatetimeFormat] = datetimeFormat
|
||||
//session.Values[SessionKeyPreferenceDateFormat] = dateFormat
|
||||
//session.Values[SessionKeyPreferenceTimeFormat] = timeFormat
|
||||
//session.Values[SessionKeyTimezone] = timezone
|
||||
|
||||
return session
|
||||
}
|
||||
|
||||
func SessionDestroy(session *Session) *Session {
|
||||
func SessionUpdateAccessToken(session *sessions.Session, accessToken string) *sessions.Session {
|
||||
session.Values[SessionKeyAccessToken] = accessToken
|
||||
return session
|
||||
}
|
||||
|
||||
func SessionDestroy(session *sessions.Session) *sessions.Session {
|
||||
|
||||
delete(session.Values, SessionKeyAccessToken)
|
||||
|
||||
|
@ -56,6 +56,21 @@ func (m *Project) Response(ctx context.Context) *ProjectResponse {
|
||||
return r
|
||||
}
|
||||
|
||||
// Projects a list of Projects.
|
||||
type Projects []*Project
|
||||
|
||||
// Response transforms a list of Projects to a list of ProjectResponses.
|
||||
func (m *Projects) Response(ctx context.Context) []*ProjectResponse {
|
||||
var l []*ProjectResponse
|
||||
if m != nil && len(*m) > 0 {
|
||||
for _, n := range *m {
|
||||
l = append(l, n.Response(ctx))
|
||||
}
|
||||
}
|
||||
|
||||
return l
|
||||
}
|
||||
|
||||
// ProjectCreateRequest contains information needed to create a new Project.
|
||||
type ProjectCreateRequest struct {
|
||||
AccountID string `json:"account_id" validate:"required,uuid" example:"c4653bf9-5978-48b7-89c5-95704aebb7e2"`
|
||||
|
@ -124,13 +124,13 @@ func findRequestQuery(req ProjectFindRequest) (*sqlbuilder.SelectBuilder, []inte
|
||||
}
|
||||
|
||||
// Find gets all the projects from the database based on the request params.
|
||||
func Find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req ProjectFindRequest) ([]*Project, error) {
|
||||
func Find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req ProjectFindRequest) (Projects, error) {
|
||||
query, args := findRequestQuery(req)
|
||||
return find(ctx, claims, dbConn, query, args, req.IncludeArchived)
|
||||
}
|
||||
|
||||
// find internal method for getting all the projects from the database using a select query.
|
||||
func find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, query *sqlbuilder.SelectBuilder, args []interface{}, includedArchived bool) ([]*Project, error) {
|
||||
func find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, query *sqlbuilder.SelectBuilder, args []interface{}, includedArchived bool) (Projects, error) {
|
||||
span, ctx := tracer.StartSpanFromContext(ctx, "internal.project.Find")
|
||||
defer span.Finish()
|
||||
|
||||
|
@ -78,6 +78,21 @@ func (m *UserResponse) MarshalBinary() ([]byte, error) {
|
||||
return json.Marshal(m)
|
||||
}
|
||||
|
||||
// Users a list of Users.
|
||||
type Users []*User
|
||||
|
||||
// Response transforms a list of Users to a list of UserResponses.
|
||||
func (m *Users) Response(ctx context.Context) []*UserResponse {
|
||||
var l []*UserResponse
|
||||
if m != nil && len(*m) > 0 {
|
||||
for _, n := range *m {
|
||||
l = append(l, n.Response(ctx))
|
||||
}
|
||||
}
|
||||
|
||||
return l
|
||||
}
|
||||
|
||||
// UserCreateRequest contains information needed to create a new User.
|
||||
type UserCreateRequest struct {
|
||||
FirstName string `json:"first_name" validate:"required" example:"Gabi"`
|
||||
|
@ -202,13 +202,13 @@ func findRequestQuery(req UserFindRequest) (*sqlbuilder.SelectBuilder, []interfa
|
||||
}
|
||||
|
||||
// Find gets all the users from the database based on the request params.
|
||||
func Find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserFindRequest) ([]*User, error) {
|
||||
func Find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserFindRequest) (Users, error) {
|
||||
query, args := findRequestQuery(req)
|
||||
return find(ctx, claims, dbConn, query, args, req.IncludeArchived)
|
||||
}
|
||||
|
||||
// find internal method for getting all the users from the database using a select query.
|
||||
func find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, query *sqlbuilder.SelectBuilder, args []interface{}, includedArchived bool) ([]*User, error) {
|
||||
func find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, query *sqlbuilder.SelectBuilder, args []interface{}, includedArchived bool) (Users, error) {
|
||||
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.Find")
|
||||
defer span.Finish()
|
||||
|
||||
|
@ -66,6 +66,34 @@ func (m *UserAccount) Response(ctx context.Context) *UserAccountResponse {
|
||||
return r
|
||||
}
|
||||
|
||||
// HasRole checks if the entry has a role.
|
||||
func (m *UserAccount) HasRole(role UserAccountRole) bool {
|
||||
if m == nil {
|
||||
return false
|
||||
}
|
||||
for _, r := range m.Roles {
|
||||
if r == role {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// UserAccounts a list of UserAccounts.
|
||||
type UserAccounts []*UserAccount
|
||||
|
||||
// Response transforms a list of UserAccounts to a list of UserAccountResponses.
|
||||
func (m *UserAccounts) Response(ctx context.Context) []*UserAccountResponse {
|
||||
var l []*UserAccountResponse
|
||||
if m != nil && len(*m) > 0 {
|
||||
for _, n := range *m {
|
||||
l = append(l, n.Response(ctx))
|
||||
}
|
||||
}
|
||||
|
||||
return l
|
||||
}
|
||||
|
||||
// UserAccountCreateRequest defines the information is needed to associate a user to an
|
||||
// account. Users are global to the application and each users access can be managed
|
||||
// on an account level. If a current entry exists in the database but is archived,
|
||||
|
@ -131,13 +131,13 @@ func findRequestQuery(req UserAccountFindRequest) (*sqlbuilder.SelectBuilder, []
|
||||
}
|
||||
|
||||
// Find gets all the user accounts from the database based on the request params.
|
||||
func Find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserAccountFindRequest) ([]*UserAccount, error) {
|
||||
func Find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserAccountFindRequest) (UserAccounts, error) {
|
||||
query, args := findRequestQuery(req)
|
||||
return find(ctx, claims, dbConn, query, args, req.IncludeArchived)
|
||||
}
|
||||
|
||||
// Find gets all the user accounts from the database based on the select query
|
||||
func find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, query *sqlbuilder.SelectBuilder, args []interface{}, includedArchived bool) ([]*UserAccount, error) {
|
||||
func find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, query *sqlbuilder.SelectBuilder, args []interface{}, includedArchived bool) (UserAccounts, error) {
|
||||
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user_account.Find")
|
||||
defer span.Finish()
|
||||
|
||||
@ -180,7 +180,7 @@ func find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, query *sqlbu
|
||||
}
|
||||
|
||||
// Retrieve gets the specified user from the database.
|
||||
func FindByUserID(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, userID string, includedArchived bool) ([]*UserAccount, error) {
|
||||
func FindByUserID(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, userID string, includedArchived bool) (UserAccounts, error) {
|
||||
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user_account.FindByUserID")
|
||||
defer span.Finish()
|
||||
|
||||
|
@ -3,6 +3,7 @@ package user_auth
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"geeks-accelerator/oss/saas-starter-kit/internal/user_account"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@ -22,6 +23,9 @@ var (
|
||||
// ErrAuthenticationFailure occurs when a user attempts to authenticate but
|
||||
// anything goes wrong.
|
||||
ErrAuthenticationFailure = errors.New("Authentication failed")
|
||||
|
||||
// ErrForbidden occurs when a user tries to do something that is forbidden to them according to our access control policies.
|
||||
ErrForbidden = errors.New("Attempted action is not allowed")
|
||||
)
|
||||
|
||||
const (
|
||||
@ -65,22 +69,11 @@ func Authenticate(ctx context.Context, dbConn *sqlx.DB, tknGen TokenGenerator, e
|
||||
return generateToken(ctx, dbConn, tknGen, auth.Claims{}, u.ID, "", expires, now, scopes...)
|
||||
}
|
||||
|
||||
// Authenticate finds a user by their email and verifies their password. On success
|
||||
// it returns a Token that can be used to authenticate access to the application in
|
||||
// the future.
|
||||
func SwitchAccount(ctx context.Context, dbConn *sqlx.DB, tknGen TokenGenerator, claims auth.Claims, accountID string, expires time.Duration, now time.Time, scopes ...string) (Token, error) {
|
||||
// SwitchAccount allows users to switch between multiple accounts, this changes the claim audience.
|
||||
func SwitchAccount(ctx context.Context, dbConn *sqlx.DB, tknGen TokenGenerator, claims auth.Claims, req SwitchAccountRequest, expires time.Duration, now time.Time, scopes ...string) (Token, error) {
|
||||
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user_auth.SwitchAccount")
|
||||
defer span.Finish()
|
||||
|
||||
// Defines struct to apply validation for the supplied claims and account ID.
|
||||
req := struct {
|
||||
UserID string `json:"user_id" validate:"required,uuid"`
|
||||
AccountID string `json:"account_id" validate:"required,uuid"`
|
||||
}{
|
||||
UserID: claims.Subject,
|
||||
AccountID: accountID,
|
||||
}
|
||||
|
||||
// Validate the request.
|
||||
v := webcontext.Validator()
|
||||
err := v.Struct(req)
|
||||
@ -88,12 +81,74 @@ func SwitchAccount(ctx context.Context, dbConn *sqlx.DB, tknGen TokenGenerator,
|
||||
return Token{}, err
|
||||
}
|
||||
|
||||
claims.RootAccountID = req.AccountID
|
||||
|
||||
if claims.RootUserID == "" {
|
||||
claims.RootUserID = claims.Subject
|
||||
}
|
||||
|
||||
// Generate a token for the user ID in supplied in claims as the Subject. Pass
|
||||
// in the supplied claims as well to enforce ACLs when finding the current
|
||||
// list of accounts for the user.
|
||||
return generateToken(ctx, dbConn, tknGen, claims, claims.Subject, req.AccountID, expires, now, scopes...)
|
||||
}
|
||||
|
||||
// VirtualLogin allows users to mock being logged in as other users.
|
||||
func VirtualLogin(ctx context.Context, dbConn *sqlx.DB, tknGen TokenGenerator, claims auth.Claims, req VirtualLoginRequest, expires time.Duration, now time.Time, scopes ...string) (Token, error) {
|
||||
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user_auth.VirtualLogin")
|
||||
defer span.Finish()
|
||||
|
||||
// Validate the request.
|
||||
v := webcontext.Validator()
|
||||
err := v.Struct(req)
|
||||
if err != nil {
|
||||
return Token{}, err
|
||||
}
|
||||
|
||||
// Find all the accounts that the current user has access to.
|
||||
usrAccs, err := user_account.FindByUserID(ctx, claims, dbConn, claims.Subject, false)
|
||||
if err != nil {
|
||||
return Token{}, err
|
||||
}
|
||||
|
||||
// The user must have the role of admin to login any other user.
|
||||
var hasAccountAdminRole bool
|
||||
for _, usrAcc := range usrAccs {
|
||||
if usrAcc.HasRole(user_account.UserAccountRole_Admin) {
|
||||
if usrAcc.AccountID == req.AccountID {
|
||||
hasAccountAdminRole = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if !hasAccountAdminRole {
|
||||
return Token{}, errors.WithMessagef(ErrForbidden, "User %s does not have correct access to account %s ", claims.Subject, req.AccountID)
|
||||
}
|
||||
|
||||
if claims.RootAccountID == "" {
|
||||
claims.RootAccountID = claims.Audience
|
||||
}
|
||||
if claims.RootUserID == "" {
|
||||
claims.RootUserID = claims.Subject
|
||||
}
|
||||
|
||||
// Generate a token for the user ID in supplied in claims as the Subject. Pass
|
||||
// in the supplied claims as well to enforce ACLs when finding the current
|
||||
// list of accounts for the user.
|
||||
return generateToken(ctx, dbConn, tknGen, claims, req.UserID, req.AccountID, expires, now, scopes...)
|
||||
}
|
||||
|
||||
// VirtualLogout allows switch back to their root user/account.
|
||||
func VirtualLogout(ctx context.Context, dbConn *sqlx.DB, tknGen TokenGenerator, claims auth.Claims, expires time.Duration, now time.Time, scopes ...string) (Token, error) {
|
||||
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user_auth.VirtualLogout")
|
||||
defer span.Finish()
|
||||
|
||||
// Generate a token for the user ID in supplied in claims as the Subject. Pass
|
||||
// in the supplied claims as well to enforce ACLs when finding the current
|
||||
// list of accounts for the user.
|
||||
return generateToken(ctx, dbConn, tknGen, claims, claims.RootUserID, claims.RootAccountID, expires, now, scopes...)
|
||||
}
|
||||
|
||||
// generateToken generates claims for the supplied user ID and account ID and then
|
||||
// returns the token for the generated claims used for authentication.
|
||||
func generateToken(ctx context.Context, dbConn *sqlx.DB, tknGen TokenGenerator, claims auth.Claims, userID, accountID string, expires time.Duration, now time.Time, scopes ...string) (Token, error) {
|
||||
@ -250,7 +305,7 @@ func generateToken(ctx context.Context, dbConn *sqlx.DB, tknGen TokenGenerator,
|
||||
if scopeValid {
|
||||
roles = append(roles, s)
|
||||
} else {
|
||||
err := errors.Errorf("invalid scope '%s'", s)
|
||||
err := errors.Wrapf(ErrForbidden, "invalid scope '%s'", s)
|
||||
return Token{}, err
|
||||
}
|
||||
}
|
||||
@ -259,7 +314,7 @@ func generateToken(ctx context.Context, dbConn *sqlx.DB, tknGen TokenGenerator,
|
||||
}
|
||||
|
||||
if len(roles) == 0 {
|
||||
err := errors.New("no roles defined for user")
|
||||
err := errors.Wrapf(ErrForbidden, "no roles defined for user")
|
||||
return Token{}, err
|
||||
}
|
||||
}
|
||||
@ -314,14 +369,24 @@ func generateToken(ctx context.Context, dbConn *sqlx.DB, tknGen TokenGenerator,
|
||||
claimPref = auth.NewClaimPreferences(tz, preferenceDatetimeFormat, preferenceDateFormat, preferenceTimeFormat)
|
||||
}
|
||||
|
||||
// Ensure the current claims has the root values set.
|
||||
if (claims.RootAccountID == "" && claims.Audience != "") || (claims.RootUserID == "" && claims.Subject != "") {
|
||||
claims.RootAccountID = claims.Audience
|
||||
claims.RootUserID = claims.Subject
|
||||
}
|
||||
|
||||
// JWT claims requires both an audience and a subject. For this application:
|
||||
// Subject: The ID of the user authenticated.
|
||||
// Audience: The ID of the account the user is accessing. A list of account IDs
|
||||
// will also be included to support the user switching between them.
|
||||
claims = auth.NewClaims(userID, accountID, accountIds, roles, claimPref, now, expires)
|
||||
newClaims := auth.NewClaims(userID, accountID, accountIds, roles, claimPref, now, expires)
|
||||
|
||||
// Copy the original root account/user ID.
|
||||
newClaims.RootAccountID = claims.RootAccountID
|
||||
newClaims.RootUserID = claims.RootUserID
|
||||
|
||||
// Generate a token for the user with the defined claims.
|
||||
tknStr, err := tknGen.GenerateToken(claims)
|
||||
tknStr, err := tknGen.GenerateToken(newClaims)
|
||||
if err != nil {
|
||||
return Token{}, errors.Wrap(err, "generating token")
|
||||
}
|
||||
@ -329,9 +394,9 @@ func generateToken(ctx context.Context, dbConn *sqlx.DB, tknGen TokenGenerator,
|
||||
tkn := Token{
|
||||
AccessToken: tknStr,
|
||||
TokenType: "Bearer",
|
||||
claims: claims,
|
||||
UserID: claims.Subject,
|
||||
AccountID: claims.Audience,
|
||||
claims: newClaims,
|
||||
UserID: newClaims.Subject,
|
||||
AccountID: newClaims.Audience,
|
||||
}
|
||||
|
||||
if expires.Seconds() > 0 {
|
||||
|
@ -2,6 +2,7 @@ package user_auth
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
@ -63,6 +64,9 @@ func TestAuthenticate(t *testing.T) {
|
||||
}
|
||||
t.Logf("\t%s\tCreate user account ok.", tests.Success)
|
||||
|
||||
// Add 30 minutes to now to simulate time passing.
|
||||
now = now.Add(time.Minute * 30)
|
||||
|
||||
acc2, err := account.MockAccount(ctx, test.MasterDB, now)
|
||||
if err != nil {
|
||||
t.Log("\t\tGot :", err)
|
||||
@ -81,7 +85,7 @@ func TestAuthenticate(t *testing.T) {
|
||||
}, now)
|
||||
|
||||
// Add 30 minutes to now to simulate time passing.
|
||||
now = now.Add(time.Minute * 30)
|
||||
now = now.Add(time.Minute * 5)
|
||||
|
||||
// Try to authenticate valid user with invalid password.
|
||||
_, err = Authenticate(ctx, test.MasterDB, tknGen, usrAcc.User.Email, "xy7", time.Hour, now)
|
||||
@ -106,17 +110,20 @@ func TestAuthenticate(t *testing.T) {
|
||||
t.Log("\t\tGot :", err)
|
||||
t.Fatalf("\t%s\tParse claims from token failed.", tests.Failed)
|
||||
}
|
||||
expectClaims := tkn1.claims
|
||||
expectClaims.RootUserID = ""
|
||||
expectClaims.RootAccountID = ""
|
||||
expectClaims.Subject = usrAcc.UserID
|
||||
expectClaims.Audience = usrAcc.AccountID
|
||||
|
||||
// Hack for Unhandled Exception in go-cmp@v0.3.0/cmp/options.go:229
|
||||
resClaims, _ := json.Marshal(claims1)
|
||||
expectClaims, _ := json.Marshal(tkn1.claims)
|
||||
if diff := cmp.Diff(string(resClaims), string(expectClaims)); diff != "" {
|
||||
if diff := cmpClaims(claims1, expectClaims); diff != "" {
|
||||
t.Fatalf("\t%s\tExpected parsed claims to match from token. Diff:\n%s", tests.Failed, diff)
|
||||
}
|
||||
t.Logf("\t%s\tAuthenticate parse claims from token ok.", tests.Success)
|
||||
|
||||
// Try switching to a second account using the first set of claims.
|
||||
tkn2, err := SwitchAccount(ctx, test.MasterDB, tknGen, claims1, acc2.ID, time.Hour, now)
|
||||
tkn2, err := SwitchAccount(ctx, test.MasterDB, tknGen, claims1,
|
||||
SwitchAccountRequest{AccountID: acc2.ID}, time.Hour, now)
|
||||
if err != nil {
|
||||
t.Log("\t\tGot :", err)
|
||||
t.Fatalf("\t%s\tSwitchAccount user failed.", tests.Failed)
|
||||
@ -129,11 +136,13 @@ func TestAuthenticate(t *testing.T) {
|
||||
t.Log("\t\tGot :", err)
|
||||
t.Fatalf("\t%s\tParse claims from token failed.", tests.Failed)
|
||||
}
|
||||
expectClaims = tkn2.claims
|
||||
expectClaims.RootUserID = usrAcc.UserID
|
||||
expectClaims.RootAccountID = acc2.ID
|
||||
expectClaims.Subject = usrAcc.UserID
|
||||
expectClaims.Audience = acc2.ID
|
||||
|
||||
// Hack for Unhandled Exception in go-cmp@v0.3.0/cmp/options.go:229
|
||||
resClaims, _ = json.Marshal(claims2)
|
||||
expectClaims, _ = json.Marshal(tkn2.claims)
|
||||
if diff := cmp.Diff(string(resClaims), string(expectClaims)); diff != "" {
|
||||
if diff := cmpClaims(claims2, expectClaims); diff != "" {
|
||||
t.Fatalf("\t%s\tExpected parsed claims to match from token. Diff:\n%s", tests.Failed, diff)
|
||||
}
|
||||
t.Logf("\t%s\tSwitchAccount parse claims from token ok.", tests.Success)
|
||||
@ -256,3 +265,724 @@ func TestUserResetPassword(t *testing.T) {
|
||||
t.Logf("\t%s\tAuthenticate ok.", tests.Success)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSwitchAccount validates the behavior around allowing users to switch between their accounts.
|
||||
func TestSwitchAccount(t *testing.T) {
|
||||
defer tests.Recover(t)
|
||||
|
||||
// Auth tokens are valid for an our and is verified against current time.
|
||||
// Issue the token one hour ago.
|
||||
now := time.Now().Add(time.Hour * -1)
|
||||
|
||||
ctx := tests.Context()
|
||||
|
||||
type authTest struct {
|
||||
name string
|
||||
root *user_account.MockUserAccountResponse
|
||||
switch1Req SwitchAccountRequest
|
||||
switch1Roles []user_account.UserAccountRole
|
||||
switch1Scopes []string
|
||||
switch1Err error
|
||||
switch2Req SwitchAccountRequest
|
||||
switch2Roles []user_account.UserAccountRole
|
||||
switch2Scopes []string
|
||||
switch2Err error
|
||||
}
|
||||
|
||||
var authTests []authTest
|
||||
|
||||
// Test all the combinations there the user has access to all three accounts.
|
||||
if true {
|
||||
for _, roles := range [][]user_account.UserAccountRole{
|
||||
[]user_account.UserAccountRole{user_account.UserAccountRole_Admin, user_account.UserAccountRole_Admin, user_account.UserAccountRole_Admin},
|
||||
[]user_account.UserAccountRole{user_account.UserAccountRole_User, user_account.UserAccountRole_User, user_account.UserAccountRole_User},
|
||||
[]user_account.UserAccountRole{user_account.UserAccountRole_Admin, user_account.UserAccountRole_User, user_account.UserAccountRole_Admin},
|
||||
[]user_account.UserAccountRole{user_account.UserAccountRole_User, user_account.UserAccountRole_Admin, user_account.UserAccountRole_User},
|
||||
} {
|
||||
// Create a new user for testing.
|
||||
usrAcc, err := user_account.MockUserAccount(ctx, test.MasterDB, now, roles[0])
|
||||
if err != nil {
|
||||
t.Log("\t\tGot :", err)
|
||||
t.Fatalf("\t%s\tCreate user account failed.", tests.Failed)
|
||||
}
|
||||
|
||||
// Create the second account.
|
||||
now = now.Add(time.Minute)
|
||||
acc2, err := account.MockAccount(ctx, test.MasterDB, now)
|
||||
if err != nil {
|
||||
t.Log("\t\tGot :", err)
|
||||
t.Fatalf("\t%s\tCreate second account failed.", tests.Failed)
|
||||
}
|
||||
|
||||
// Associate the second account with root user.
|
||||
usrAcc2, err := user_account.Create(ctx, auth.Claims{}, test.MasterDB, user_account.UserAccountCreateRequest{
|
||||
UserID: usrAcc.UserID,
|
||||
AccountID: acc2.ID,
|
||||
Roles: []user_account.UserAccountRole{user_account.UserAccountRole(roles[1])},
|
||||
}, now)
|
||||
if err != nil {
|
||||
t.Log("\t\tGot :", err)
|
||||
t.Fatalf("\t%s\tLinking second account to user failed.", tests.Failed)
|
||||
}
|
||||
|
||||
// Create the third account.
|
||||
now = now.Add(time.Minute)
|
||||
acc3, err := account.MockAccount(ctx, test.MasterDB, now)
|
||||
if err != nil {
|
||||
t.Log("\t\tGot :", err)
|
||||
t.Fatalf("\t%s\tCreate third account failed.", tests.Failed)
|
||||
}
|
||||
|
||||
// Associate the third account with root user.
|
||||
usrAcc3, err := user_account.Create(ctx, auth.Claims{}, test.MasterDB, user_account.UserAccountCreateRequest{
|
||||
UserID: usrAcc.UserID,
|
||||
AccountID: acc3.ID,
|
||||
Roles: []user_account.UserAccountRole{user_account.UserAccountRole(roles[2])},
|
||||
}, now)
|
||||
if err != nil {
|
||||
t.Log("\t\tGot :", err)
|
||||
t.Fatalf("\t%s\tLinking third account to user failed.", tests.Failed)
|
||||
}
|
||||
|
||||
authTests = append(authTests, authTest{
|
||||
name: fmt.Sprintf("Root account role %s -> role %s account 2 -> role %s account 3.",
|
||||
roles[0], roles[1], roles[2]),
|
||||
root: usrAcc,
|
||||
switch1Req: SwitchAccountRequest{AccountID: acc2.ID},
|
||||
switch1Roles: usrAcc2.Roles,
|
||||
switch1Err: nil,
|
||||
switch2Req: SwitchAccountRequest{AccountID: acc3.ID},
|
||||
switch2Err: nil,
|
||||
switch2Roles: usrAcc3.Roles,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Root account 1 -> invalid account 2
|
||||
if true {
|
||||
// Create a new user for testing.
|
||||
usrAcc, err := user_account.MockUserAccount(ctx, test.MasterDB, now, user_account.UserAccountRole_Admin)
|
||||
if err != nil {
|
||||
t.Log("\t\tGot :", err)
|
||||
t.Fatalf("\t%s\tCreate user account failed.", tests.Failed)
|
||||
}
|
||||
|
||||
// Create the second account and don't associate it with the root user.
|
||||
now = now.Add(time.Minute)
|
||||
acc2, err := account.MockAccount(ctx, test.MasterDB, now)
|
||||
if err != nil {
|
||||
t.Log("\t\tGot :", err)
|
||||
t.Fatalf("\t%s\tCreate second account failed.", tests.Failed)
|
||||
}
|
||||
|
||||
authTests = append(authTests, authTest{
|
||||
name: "Root account 1 -> invalid account 2.",
|
||||
root: usrAcc,
|
||||
switch1Req: SwitchAccountRequest{AccountID: acc2.ID},
|
||||
switch1Err: ErrAuthenticationFailure,
|
||||
})
|
||||
}
|
||||
|
||||
// Root account 1 -> valid account 2 with scopes -> valid account 3 with invalid scope.
|
||||
if true {
|
||||
// Create a new user for testing.
|
||||
usrAcc, err := user_account.MockUserAccount(ctx, test.MasterDB, now, user_account.UserAccountRole_Admin)
|
||||
if err != nil {
|
||||
t.Log("\t\tGot :", err)
|
||||
t.Fatalf("\t%s\tCreate user account failed.", tests.Failed)
|
||||
}
|
||||
|
||||
// Create the second account.
|
||||
now = now.Add(time.Minute)
|
||||
acc2, err := account.MockAccount(ctx, test.MasterDB, now)
|
||||
if err != nil {
|
||||
t.Log("\t\tGot :", err)
|
||||
t.Fatalf("\t%s\tCreate second account failed.", tests.Failed)
|
||||
}
|
||||
|
||||
// Associate the second account with root user.
|
||||
usrAcc2, err := user_account.Create(ctx, auth.Claims{}, test.MasterDB, user_account.UserAccountCreateRequest{
|
||||
UserID: usrAcc.UserID,
|
||||
AccountID: acc2.ID,
|
||||
Roles: []user_account.UserAccountRole{user_account.UserAccountRole_Admin},
|
||||
}, now)
|
||||
if err != nil {
|
||||
t.Log("\t\tGot :", err)
|
||||
t.Fatalf("\t%s\tLinking second account to user failed.", tests.Failed)
|
||||
}
|
||||
|
||||
// Create the third account.
|
||||
now = now.Add(time.Minute)
|
||||
acc3, err := account.MockAccount(ctx, test.MasterDB, now)
|
||||
if err != nil {
|
||||
t.Log("\t\tGot :", err)
|
||||
t.Fatalf("\t%s\tCreate third account failed.", tests.Failed)
|
||||
}
|
||||
|
||||
// Associate the third account with root user.
|
||||
usrAcc3, err := user_account.Create(ctx, auth.Claims{}, test.MasterDB, user_account.UserAccountCreateRequest{
|
||||
UserID: usrAcc.UserID,
|
||||
AccountID: acc3.ID,
|
||||
Roles: []user_account.UserAccountRole{user_account.UserAccountRole_User},
|
||||
}, now)
|
||||
if err != nil {
|
||||
t.Log("\t\tGot :", err)
|
||||
t.Fatalf("\t%s\tLinking third account to user failed.", tests.Failed)
|
||||
}
|
||||
|
||||
authTests = append(authTests, authTest{
|
||||
name: "Root account 1 -> valid account 2 with scopes -> valid account 3 with invalid scope.",
|
||||
root: usrAcc,
|
||||
switch1Req: SwitchAccountRequest{AccountID: acc2.ID},
|
||||
switch1Roles: usrAcc2.Roles,
|
||||
switch1Scopes: []string{user_account.UserAccountRole_User.String()},
|
||||
switch1Err: nil,
|
||||
switch2Req: SwitchAccountRequest{AccountID: acc3.ID},
|
||||
switch2Roles: usrAcc3.Roles,
|
||||
switch2Scopes: []string{user_account.UserAccountRole_Admin.String()},
|
||||
switch2Err: ErrForbidden,
|
||||
})
|
||||
}
|
||||
|
||||
// Add 30 minutes to now to simulate time passing.
|
||||
now = now.Add(time.Minute * 5)
|
||||
|
||||
tknGen := &auth.MockTokenGenerator{}
|
||||
|
||||
t.Log("Given the need to switch accounts.")
|
||||
{
|
||||
for i, authTest := range authTests {
|
||||
t.Logf("\tTest: %d\tWhen running test: %s", i, authTest.name)
|
||||
{
|
||||
// Verify that the user can be authenticated with the created user.
|
||||
var claims1 auth.Claims
|
||||
tkn1, err := Authenticate(ctx, test.MasterDB, tknGen, authTest.root.User.Email, authTest.root.User.Password, time.Hour, now)
|
||||
if err != nil {
|
||||
t.Log("\t\tGot :", err)
|
||||
t.Fatalf("\t%s\tAuthenticate user failed.", tests.Failed)
|
||||
} else {
|
||||
// Ensure the token string was correctly generated.
|
||||
claims1, err = tknGen.ParseClaims(tkn1.AccessToken)
|
||||
if err != nil {
|
||||
t.Log("\t\tGot :", err)
|
||||
t.Fatalf("\t%s\tParse claims from token failed.", tests.Failed)
|
||||
}
|
||||
expectClaims := tkn1.claims
|
||||
expectClaims.RootUserID = ""
|
||||
expectClaims.RootAccountID = ""
|
||||
expectClaims.Subject = authTest.root.UserID
|
||||
expectClaims.Audience = authTest.root.AccountID
|
||||
expectClaims.Roles = rolesStringSlice(authTest.root.Roles)
|
||||
|
||||
if diff := cmpClaims(claims1, expectClaims); diff != "" {
|
||||
t.Fatalf("\t%s\tExpected parsed claims to match from token. Diff:\n%s", tests.Failed, diff)
|
||||
}
|
||||
}
|
||||
t.Logf("\t%s\tAuthenticate root user with role %v ok.", tests.Success, authTest.root.Roles)
|
||||
|
||||
// Try to switch to account 2.
|
||||
var claims2 auth.Claims
|
||||
tkn2, err := SwitchAccount(ctx, test.MasterDB, tknGen, claims1, authTest.switch1Req, time.Hour, now, authTest.switch1Scopes...)
|
||||
if err != authTest.switch1Err {
|
||||
if errors.Cause(err) != authTest.switch1Err {
|
||||
t.Log("\t\tExpected :", authTest.switch1Err)
|
||||
t.Log("\t\tGot :", err)
|
||||
t.Fatalf("\t%s\tSwitchAccount account 1 with role %v failed.", tests.Failed, authTest.switch1Roles)
|
||||
}
|
||||
} else {
|
||||
// Ensure the token string was correctly generated.
|
||||
claims2, err = tknGen.ParseClaims(tkn2.AccessToken)
|
||||
if err != nil {
|
||||
t.Log("\t\tGot :", err)
|
||||
t.Fatalf("\t%s\tParse claims from token failed.", tests.Failed)
|
||||
}
|
||||
expectClaims := tkn2.claims
|
||||
expectClaims.RootUserID = authTest.root.UserID
|
||||
expectClaims.RootAccountID = authTest.switch1Req.AccountID
|
||||
expectClaims.Subject = authTest.root.UserID
|
||||
expectClaims.Audience = authTest.switch1Req.AccountID
|
||||
|
||||
if len(authTest.switch1Scopes) > 0 {
|
||||
expectClaims.Roles = authTest.switch1Scopes
|
||||
} else {
|
||||
expectClaims.Roles = rolesStringSlice(authTest.switch1Roles)
|
||||
}
|
||||
|
||||
if diff := cmpClaims(claims2, expectClaims); diff != "" {
|
||||
t.Fatalf("\t%s\tExpected parsed claims to match from token. Diff:\n%s", tests.Failed, diff)
|
||||
}
|
||||
}
|
||||
t.Logf("\t%s\tSwitchAccount account 1 with role %v ok.", tests.Success, authTest.switch1Roles)
|
||||
|
||||
// If the user can't login, don't need to test any further.
|
||||
if authTest.switch1Err != nil || authTest.switch2Req.AccountID == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Try to switch to account 3.
|
||||
tkn3, err := SwitchAccount(ctx, test.MasterDB, tknGen, claims2, authTest.switch2Req, time.Hour, now, authTest.switch2Scopes...)
|
||||
if err != authTest.switch2Err {
|
||||
if errors.Cause(err) != authTest.switch2Err {
|
||||
t.Log("\t\tExpected :", authTest.switch2Err)
|
||||
t.Log("\t\tGot :", err)
|
||||
t.Fatalf("\t%s\tSwitchAccount account 2 with role %v failed.", tests.Failed, authTest.switch2Roles)
|
||||
}
|
||||
} else {
|
||||
// Ensure the token string was correctly generated.
|
||||
claims3, err := tknGen.ParseClaims(tkn3.AccessToken)
|
||||
if err != nil {
|
||||
t.Log("\t\tGot :", err)
|
||||
t.Fatalf("\t%s\tParse claims from token failed.", tests.Failed)
|
||||
}
|
||||
expectClaims := tkn3.claims
|
||||
expectClaims.RootUserID = authTest.root.UserID
|
||||
expectClaims.RootAccountID = authTest.switch2Req.AccountID
|
||||
expectClaims.Subject = authTest.root.UserID
|
||||
expectClaims.Audience = authTest.switch2Req.AccountID
|
||||
|
||||
if len(authTest.switch2Scopes) > 0 {
|
||||
expectClaims.Roles = authTest.switch2Scopes
|
||||
} else {
|
||||
expectClaims.Roles = rolesStringSlice(authTest.switch2Roles)
|
||||
}
|
||||
|
||||
if diff := cmpClaims(claims3, expectClaims); diff != "" {
|
||||
t.Fatalf("\t%s\tExpected parsed claims to match from token. Diff:\n%s", tests.Failed, diff)
|
||||
}
|
||||
}
|
||||
t.Logf("\t%s\tSwitchAccount account 2 with role %v ok.", tests.Success, authTest.switch2Roles)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestVirtualLogin validates the behavior around allowing users to virtual login users.
|
||||
func TestVirtualLogin(t *testing.T) {
|
||||
defer tests.Recover(t)
|
||||
|
||||
// Auth tokens are valid for an our and is verified against current time.
|
||||
// Issue the token one hour ago.
|
||||
now := time.Now().Add(time.Hour * -1)
|
||||
|
||||
ctx := tests.Context()
|
||||
|
||||
type authTest struct {
|
||||
name string
|
||||
root *user_account.MockUserAccountResponse
|
||||
login1Req VirtualLoginRequest
|
||||
login1Err error
|
||||
login1Role user_account.UserAccountRole
|
||||
login2Req VirtualLoginRequest
|
||||
login2Err error
|
||||
login2Role user_account.UserAccountRole
|
||||
login2Logout bool
|
||||
}
|
||||
|
||||
var authTests []authTest
|
||||
|
||||
// Root admin -> role admin -> role admin
|
||||
if true {
|
||||
// Create a new user for testing.
|
||||
usrAcc, err := user_account.MockUserAccount(ctx, test.MasterDB, now, user_account.UserAccountRole_Admin)
|
||||
if err != nil {
|
||||
t.Log("\t\tGot :", err)
|
||||
t.Fatalf("\t%s\tCreate user account failed.", tests.Failed)
|
||||
}
|
||||
|
||||
usr2, err := user.MockUser(ctx, test.MasterDB, now)
|
||||
if err != nil {
|
||||
t.Log("\t\tGot :", err)
|
||||
t.Fatalf("\t%s\tCreate second account failed.", tests.Failed)
|
||||
}
|
||||
|
||||
// Associate second user with basic role associated with the same account.
|
||||
usrAcc2, err := user_account.Create(ctx, auth.Claims{}, test.MasterDB, user_account.UserAccountCreateRequest{
|
||||
UserID: usr2.ID,
|
||||
AccountID: usrAcc.AccountID,
|
||||
Roles: []user_account.UserAccountRole{user_account.UserAccountRole(user_account.UserAccountRole_Admin)},
|
||||
}, now)
|
||||
if err != nil {
|
||||
t.Log("\t\tGot :", err)
|
||||
t.Fatalf("\t%s\tLinking second user to account failed.", tests.Failed)
|
||||
}
|
||||
|
||||
usr3, err := user.MockUser(ctx, test.MasterDB, now)
|
||||
if err != nil {
|
||||
t.Log("\t\tGot :", err)
|
||||
t.Fatalf("\t%s\tCreate second account failed.", tests.Failed)
|
||||
}
|
||||
|
||||
// Associate second user with basic role associated with the same account.
|
||||
usrAcc3, err := user_account.Create(ctx, auth.Claims{}, test.MasterDB, user_account.UserAccountCreateRequest{
|
||||
UserID: usr3.ID,
|
||||
AccountID: usrAcc.AccountID,
|
||||
Roles: []user_account.UserAccountRole{user_account.UserAccountRole(user_account.UserAccountRole_Admin)},
|
||||
}, now)
|
||||
if err != nil {
|
||||
t.Log("\t\tGot :", err)
|
||||
t.Fatalf("\t%s\tLinking third user to account failed.", tests.Failed)
|
||||
}
|
||||
|
||||
authTests = append(authTests, authTest{
|
||||
name: "Root admin -> role admin -> role admin",
|
||||
root: usrAcc,
|
||||
login1Req: VirtualLoginRequest{
|
||||
UserID: usr2.ID,
|
||||
AccountID: usrAcc.AccountID,
|
||||
},
|
||||
login1Role: usrAcc2.Roles[0],
|
||||
login1Err: nil,
|
||||
login2Req: VirtualLoginRequest{
|
||||
UserID: usr3.ID,
|
||||
AccountID: usrAcc.AccountID,
|
||||
},
|
||||
login2Err: nil,
|
||||
login2Role: usrAcc3.Roles[0],
|
||||
login2Logout: true,
|
||||
})
|
||||
}
|
||||
|
||||
// Root admin -> role admin -> role user
|
||||
if true {
|
||||
// Create a new user for testing.
|
||||
usrAcc, err := user_account.MockUserAccount(ctx, test.MasterDB, now, user_account.UserAccountRole_Admin)
|
||||
if err != nil {
|
||||
t.Log("\t\tGot :", err)
|
||||
t.Fatalf("\t%s\tCreate user account failed.", tests.Failed)
|
||||
}
|
||||
|
||||
usr2, err := user.MockUser(ctx, test.MasterDB, now)
|
||||
if err != nil {
|
||||
t.Log("\t\tGot :", err)
|
||||
t.Fatalf("\t%s\tCreate second account failed.", tests.Failed)
|
||||
}
|
||||
|
||||
// Associate second user with basic role associated with the same account.
|
||||
usrAcc2, err := user_account.Create(ctx, auth.Claims{}, test.MasterDB, user_account.UserAccountCreateRequest{
|
||||
UserID: usr2.ID,
|
||||
AccountID: usrAcc.AccountID,
|
||||
Roles: []user_account.UserAccountRole{user_account.UserAccountRole(user_account.UserAccountRole_Admin)},
|
||||
}, now)
|
||||
if err != nil {
|
||||
t.Log("\t\tGot :", err)
|
||||
t.Fatalf("\t%s\tLinking second user to account failed.", tests.Failed)
|
||||
}
|
||||
|
||||
usr3, err := user.MockUser(ctx, test.MasterDB, now)
|
||||
if err != nil {
|
||||
t.Log("\t\tGot :", err)
|
||||
t.Fatalf("\t%s\tCreate second account failed.", tests.Failed)
|
||||
}
|
||||
|
||||
// Associate second user with basic role associated with the same account.
|
||||
usrAcc3, err := user_account.Create(ctx, auth.Claims{}, test.MasterDB, user_account.UserAccountCreateRequest{
|
||||
UserID: usr3.ID,
|
||||
AccountID: usrAcc.AccountID,
|
||||
Roles: []user_account.UserAccountRole{user_account.UserAccountRole(user_account.UserAccountRole_User)},
|
||||
}, now)
|
||||
if err != nil {
|
||||
t.Log("\t\tGot :", err)
|
||||
t.Fatalf("\t%s\tLinking third user to account failed.", tests.Failed)
|
||||
}
|
||||
|
||||
authTests = append(authTests, authTest{
|
||||
name: "Root admin -> role admin -> role user",
|
||||
root: usrAcc,
|
||||
login1Req: VirtualLoginRequest{
|
||||
UserID: usr2.ID,
|
||||
AccountID: usrAcc.AccountID,
|
||||
},
|
||||
login1Err: nil,
|
||||
login1Role: usrAcc2.Roles[0],
|
||||
login2Req: VirtualLoginRequest{
|
||||
UserID: usr3.ID,
|
||||
AccountID: usrAcc.AccountID,
|
||||
},
|
||||
login2Err: nil,
|
||||
login2Role: usrAcc3.Roles[0],
|
||||
login2Logout: true,
|
||||
})
|
||||
}
|
||||
|
||||
// Root admin -> role user -> role admin
|
||||
if true {
|
||||
// Create a new user for testing.
|
||||
usrAcc, err := user_account.MockUserAccount(ctx, test.MasterDB, now, user_account.UserAccountRole_Admin)
|
||||
if err != nil {
|
||||
t.Log("\t\tGot :", err)
|
||||
t.Fatalf("\t%s\tCreate user account failed.", tests.Failed)
|
||||
}
|
||||
|
||||
usr2, err := user.MockUser(ctx, test.MasterDB, now)
|
||||
if err != nil {
|
||||
t.Log("\t\tGot :", err)
|
||||
t.Fatalf("\t%s\tCreate second account failed.", tests.Failed)
|
||||
}
|
||||
|
||||
// Associate second user with basic role associated with the same account.
|
||||
usrAcc2, err := user_account.Create(ctx, auth.Claims{}, test.MasterDB, user_account.UserAccountCreateRequest{
|
||||
UserID: usr2.ID,
|
||||
AccountID: usrAcc.AccountID,
|
||||
Roles: []user_account.UserAccountRole{user_account.UserAccountRole(user_account.UserAccountRole_User)},
|
||||
}, now)
|
||||
if err != nil {
|
||||
t.Log("\t\tGot :", err)
|
||||
t.Fatalf("\t%s\tLinking second user to account failed.", tests.Failed)
|
||||
}
|
||||
|
||||
usr3, err := user.MockUser(ctx, test.MasterDB, now)
|
||||
if err != nil {
|
||||
t.Log("\t\tGot :", err)
|
||||
t.Fatalf("\t%s\tCreate second account failed.", tests.Failed)
|
||||
}
|
||||
|
||||
// Associate second user with basic role associated with the same account.
|
||||
usrAcc3, err := user_account.Create(ctx, auth.Claims{}, test.MasterDB, user_account.UserAccountCreateRequest{
|
||||
UserID: usr3.ID,
|
||||
AccountID: usrAcc.AccountID,
|
||||
Roles: []user_account.UserAccountRole{user_account.UserAccountRole(user_account.UserAccountRole_Admin)},
|
||||
}, now)
|
||||
if err != nil {
|
||||
t.Log("\t\tGot :", err)
|
||||
t.Fatalf("\t%s\tLinking third user to account failed.", tests.Failed)
|
||||
}
|
||||
|
||||
authTests = append(authTests, authTest{
|
||||
name: "Root admin -> role user -> role admin",
|
||||
root: usrAcc,
|
||||
login1Req: VirtualLoginRequest{
|
||||
UserID: usr2.ID,
|
||||
AccountID: usrAcc.AccountID,
|
||||
},
|
||||
login1Err: nil,
|
||||
login1Role: usrAcc2.Roles[0],
|
||||
login2Req: VirtualLoginRequest{
|
||||
UserID: usr3.ID,
|
||||
AccountID: usrAcc.AccountID,
|
||||
},
|
||||
login2Err: ErrForbidden,
|
||||
login2Role: usrAcc3.Roles[0],
|
||||
login2Logout: true,
|
||||
})
|
||||
}
|
||||
|
||||
// Root user -> role admin
|
||||
if true {
|
||||
// Create a new user for testing.
|
||||
usrAcc, err := user_account.MockUserAccount(ctx, test.MasterDB, now, user_account.UserAccountRole_User)
|
||||
if err != nil {
|
||||
t.Log("\t\tGot :", err)
|
||||
t.Fatalf("\t%s\tCreate user account failed.", tests.Failed)
|
||||
}
|
||||
|
||||
usr2, err := user.MockUser(ctx, test.MasterDB, now)
|
||||
if err != nil {
|
||||
t.Log("\t\tGot :", err)
|
||||
t.Fatalf("\t%s\tCreate second account failed.", tests.Failed)
|
||||
}
|
||||
|
||||
// Associate second user with basic role associated with the same account.
|
||||
usrAcc2, err := user_account.Create(ctx, auth.Claims{}, test.MasterDB, user_account.UserAccountCreateRequest{
|
||||
UserID: usr2.ID,
|
||||
AccountID: usrAcc.AccountID,
|
||||
Roles: []user_account.UserAccountRole{user_account.UserAccountRole(user_account.UserAccountRole_Admin)},
|
||||
}, now)
|
||||
if err != nil {
|
||||
t.Log("\t\tGot :", err)
|
||||
t.Fatalf("\t%s\tLinking second user to account failed.", tests.Failed)
|
||||
}
|
||||
|
||||
authTests = append(authTests, authTest{
|
||||
name: "Root user -> role admin",
|
||||
root: usrAcc,
|
||||
login1Req: VirtualLoginRequest{
|
||||
UserID: usr2.ID,
|
||||
AccountID: usrAcc.AccountID,
|
||||
},
|
||||
login1Err: ErrForbidden,
|
||||
login1Role: usrAcc2.Roles[0],
|
||||
login2Logout: true,
|
||||
})
|
||||
}
|
||||
|
||||
// Root user -> role user
|
||||
if true {
|
||||
// Create a new user for testing.
|
||||
usrAcc, err := user_account.MockUserAccount(ctx, test.MasterDB, now, user_account.UserAccountRole_User)
|
||||
if err != nil {
|
||||
t.Log("\t\tGot :", err)
|
||||
t.Fatalf("\t%s\tCreate user account failed.", tests.Failed)
|
||||
}
|
||||
|
||||
usr2, err := user.MockUser(ctx, test.MasterDB, now)
|
||||
if err != nil {
|
||||
t.Log("\t\tGot :", err)
|
||||
t.Fatalf("\t%s\tCreate second account failed.", tests.Failed)
|
||||
}
|
||||
|
||||
// Associate second user with basic role associated with the same account.
|
||||
usrAcc2, err := user_account.Create(ctx, auth.Claims{}, test.MasterDB, user_account.UserAccountCreateRequest{
|
||||
UserID: usr2.ID,
|
||||
AccountID: usrAcc.AccountID,
|
||||
Roles: []user_account.UserAccountRole{user_account.UserAccountRole(user_account.UserAccountRole_User)},
|
||||
}, now)
|
||||
if err != nil {
|
||||
t.Log("\t\tGot :", err)
|
||||
t.Fatalf("\t%s\tLinking second user to account failed.", tests.Failed)
|
||||
}
|
||||
|
||||
authTests = append(authTests, authTest{
|
||||
name: "Root user -> role admin",
|
||||
root: usrAcc,
|
||||
login1Req: VirtualLoginRequest{
|
||||
UserID: usr2.ID,
|
||||
AccountID: usrAcc.AccountID,
|
||||
},
|
||||
login1Err: ErrForbidden,
|
||||
login1Role: usrAcc2.Roles[0],
|
||||
login2Logout: true,
|
||||
})
|
||||
}
|
||||
|
||||
// Add 30 minutes to now to simulate time passing.
|
||||
now = now.Add(time.Minute * 5)
|
||||
|
||||
tknGen := &auth.MockTokenGenerator{}
|
||||
|
||||
t.Log("Given the need to virtual login.")
|
||||
{
|
||||
for i, authTest := range authTests {
|
||||
t.Logf("\tTest: %d\tWhen running test: %s", i, authTest.name)
|
||||
{
|
||||
// Verify that the user can be authenticated with the created user.
|
||||
var claims1 auth.Claims
|
||||
tkn1, err := Authenticate(ctx, test.MasterDB, tknGen, authTest.root.User.Email, authTest.root.User.Password, time.Hour, now)
|
||||
if err != nil {
|
||||
t.Log("\t\tGot :", err)
|
||||
t.Fatalf("\t%s\tAuthenticate user failed.", tests.Failed)
|
||||
} else {
|
||||
// Ensure the token string was correctly generated.
|
||||
claims1, err = tknGen.ParseClaims(tkn1.AccessToken)
|
||||
if err != nil {
|
||||
t.Log("\t\tGot :", err)
|
||||
t.Fatalf("\t%s\tParse claims from token failed.", tests.Failed)
|
||||
}
|
||||
expectClaims := tkn1.claims
|
||||
expectClaims.RootUserID = ""
|
||||
expectClaims.RootAccountID = ""
|
||||
expectClaims.Subject = authTest.root.UserID
|
||||
expectClaims.Audience = authTest.root.AccountID
|
||||
|
||||
// Hack for Unhandled Exception in go-cmp@v0.3.0/cmp/options.go:229
|
||||
if diff := cmpClaims(claims1, expectClaims); diff != "" {
|
||||
t.Fatalf("\t%s\tExpected parsed claims to match from token. Diff:\n%s", tests.Failed, diff)
|
||||
}
|
||||
}
|
||||
t.Logf("\t%s\tAuthenticate root user with role %s ok.", tests.Success, authTest.root.Roles[0])
|
||||
|
||||
// Try virtual login to user 2.
|
||||
var claims2 auth.Claims
|
||||
tkn2, err := VirtualLogin(ctx, test.MasterDB, tknGen, claims1, authTest.login1Req, time.Hour, now)
|
||||
if err != authTest.login1Err {
|
||||
if errors.Cause(err) != authTest.login1Err {
|
||||
t.Log("\t\tExpected :", authTest.login1Err)
|
||||
t.Log("\t\tGot :", err)
|
||||
t.Fatalf("\t%s\tVirtualLogin user 1 with role %s failed.", tests.Failed, authTest.login1Role)
|
||||
}
|
||||
} else {
|
||||
// Ensure the token string was correctly generated.
|
||||
claims2, err = tknGen.ParseClaims(tkn2.AccessToken)
|
||||
if err != nil {
|
||||
t.Log("\t\tGot :", err)
|
||||
t.Fatalf("\t%s\tParse claims from token failed.", tests.Failed)
|
||||
}
|
||||
expectClaims := tkn2.claims
|
||||
expectClaims.RootUserID = authTest.root.UserID
|
||||
expectClaims.RootAccountID = authTest.root.AccountID
|
||||
expectClaims.Subject = authTest.login1Req.UserID
|
||||
expectClaims.Audience = authTest.login1Req.AccountID
|
||||
|
||||
// Hack for Unhandled Exception in go-cmp@v0.3.0/cmp/options.go:229
|
||||
if diff := cmpClaims(claims2, expectClaims); diff != "" {
|
||||
t.Fatalf("\t%s\tExpected parsed claims to match from token. Diff:\n%s", tests.Failed, diff)
|
||||
}
|
||||
}
|
||||
t.Logf("\t%s\tVirtualLogin user 1 with role %s ok.", tests.Success, authTest.login1Role)
|
||||
|
||||
// If the user can't login, don't need to test any further.
|
||||
if authTest.login1Err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Try virtual login to user 3.
|
||||
tkn3, err := VirtualLogin(ctx, test.MasterDB, tknGen, claims2, authTest.login2Req, time.Hour, now)
|
||||
if err != authTest.login2Err {
|
||||
if errors.Cause(err) != authTest.login2Err {
|
||||
t.Log("\t\tExpected :", authTest.login2Err)
|
||||
t.Log("\t\tGot :", err)
|
||||
t.Fatalf("\t%s\tVirtualLogin user 2 with role %s failed.", tests.Failed, authTest.login2Role)
|
||||
}
|
||||
} else {
|
||||
// Ensure the token string was correctly generated.
|
||||
claims3, err := tknGen.ParseClaims(tkn3.AccessToken)
|
||||
if err != nil {
|
||||
t.Log("\t\tGot :", err)
|
||||
t.Fatalf("\t%s\tParse claims from token failed.", tests.Failed)
|
||||
}
|
||||
expectClaims := tkn3.claims
|
||||
expectClaims.RootUserID = authTest.root.UserID
|
||||
expectClaims.RootAccountID = authTest.root.AccountID
|
||||
expectClaims.Subject = authTest.login2Req.UserID
|
||||
expectClaims.Audience = authTest.login2Req.AccountID
|
||||
|
||||
// Hack for Unhandled Exception in go-cmp@v0.3.0/cmp/options.go:229
|
||||
if diff := cmpClaims(claims3, expectClaims); diff != "" {
|
||||
t.Fatalf("\t%s\tExpected parsed claims to match from token. Diff:\n%s", tests.Failed, diff)
|
||||
}
|
||||
}
|
||||
t.Logf("\t%s\tVirtualLogin user 2 with role %s ok.", tests.Success, authTest.login2Role)
|
||||
|
||||
if authTest.login2Logout {
|
||||
tknOut, err := VirtualLogout(ctx, test.MasterDB, tknGen, claims2, time.Hour, now)
|
||||
if err != nil {
|
||||
t.Log("\t\tGot :", err)
|
||||
t.Fatalf("\t%s\tVirtualLogout user 2 failed.", tests.Failed)
|
||||
}
|
||||
|
||||
// Ensure the token string was correctly generated.
|
||||
claimsOut, err := tknGen.ParseClaims(tknOut.AccessToken)
|
||||
if err != nil {
|
||||
t.Log("\t\tGot :", err)
|
||||
t.Fatalf("\t%s\tParse claims from token failed.", tests.Failed)
|
||||
}
|
||||
expectClaims := tknOut.claims
|
||||
expectClaims.RootUserID = authTest.root.UserID
|
||||
expectClaims.RootAccountID = authTest.root.AccountID
|
||||
expectClaims.Subject = authTest.root.UserID
|
||||
expectClaims.Audience = authTest.root.AccountID
|
||||
|
||||
if diff := cmpClaims(claimsOut, expectClaims); diff != "" {
|
||||
t.Fatalf("\t%s\tExpected parsed claims to match from token. Diff:\n%s", tests.Failed, diff)
|
||||
}
|
||||
t.Logf("\t%s\tVirtualLogout user 2 with role %s ok.", tests.Success, authTest.login2Role)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// rolesStringSlice converts a list of roles to a string slice.
|
||||
func rolesStringSlice(roles []user_account.UserAccountRole) []string {
|
||||
var l []string
|
||||
for _, r := range roles {
|
||||
l = append(l, string(r))
|
||||
}
|
||||
return l
|
||||
}
|
||||
|
||||
// cmpClaims is a hack for Unhandled Exception in go-cmp@v0.3.0/cmp/options.go:229
|
||||
func cmpClaims(actualClaims, expectedclaims auth.Claims) string {
|
||||
dat1, _ := json.Marshal(actualClaims)
|
||||
dat2, _ := json.Marshal(expectedclaims)
|
||||
return cmp.Diff(string(dat1), string(dat2))
|
||||
}
|
||||
|
@ -35,6 +35,17 @@ type Token struct {
|
||||
AccountID string `json:"account_id"example:"c4653bf9-5978-48b7-89c5-95704aebb7e2"`
|
||||
}
|
||||
|
||||
// SwitchAccountRequest defines the information for the current user to switch between their accounts
|
||||
type SwitchAccountRequest struct {
|
||||
AccountID string `json:"account_id" validate:"required,uuid" example:"c4653bf9-5978-48b7-89c5-95704aebb7e2"`
|
||||
}
|
||||
|
||||
// VirtualLoginRequest defines the information virtual login to a user / account.
|
||||
type VirtualLoginRequest struct {
|
||||
UserID string `json:"user_id" validate:"required,uuid" example:"d69bdef7-173f-4d29-b52c-3edc60baf6a2"`
|
||||
AccountID string `json:"account_id" validate:"required,uuid" example:"c4653bf9-5978-48b7-89c5-95704aebb7e2"`
|
||||
}
|
||||
|
||||
// AuthorizationHeader returns the header authorization value.
|
||||
func (t Token) AuthorizationHeader() string {
|
||||
return "Bearer " + t.AccessToken
|
||||
|
Reference in New Issue
Block a user