1
0
mirror of https://github.com/raseels-repos/golang-saas-starter-kit.git synced 2025-06-15 00:15:15 +02:00

Completed virtual user login and switch accounts

This commit is contained in:
Lee Brown
2019-08-04 21:28:02 -08:00
parent bb9820ffcc
commit ba96e8b367
23 changed files with 1541 additions and 176 deletions

View File

@ -4,7 +4,6 @@ import (
"context"
"encoding/json"
"fmt"
"geeks-accelerator/oss/saas-starter-kit/internal/user_auth"
"net/http"
"net/http/httptest"
"testing"
@ -17,6 +16,7 @@ import (
"geeks-accelerator/oss/saas-starter-kit/internal/platform/web/weberror"
"geeks-accelerator/oss/saas-starter-kit/internal/user"
"geeks-accelerator/oss/saas-starter-kit/internal/user_account"
"geeks-accelerator/oss/saas-starter-kit/internal/user_auth"
"github.com/pborman/uuid"
)

View File

@ -19,8 +19,9 @@ import (
// Account represents the Account API method handler set.
type Account struct {
MasterDB *sqlx.DB
Renderer web.Renderer
MasterDB *sqlx.DB
Renderer web.Renderer
Authenticator *auth.Authenticator
}
// View handles displaying the current account profile.
@ -34,7 +35,7 @@ func (h *Account) View(ctx context.Context, w http.ResponseWriter, r *http.Reque
return err
}
acc, err := account.Read(ctx, claims, h.MasterDB, claims.Audience, false)
acc, err := account.ReadByID(ctx, claims, h.MasterDB, claims.Audience)
if err != nil {
return err
}
@ -127,7 +128,11 @@ func (h *Account) Update(ctx context.Context, w http.ResponseWriter, r *http.Req
}
}
sess := webcontext.ContextSession(ctx)
var updateClaims bool
if req.Timezone != nil && claims.Preferences.Timezone != *req.Timezone {
claims.Preferences.Timezone = *req.Timezone
updateClaims = true
}
if preferenceDatetimeFormat != req.PreferenceDatetimeFormat {
err = account_preference.Set(ctx, claims, h.MasterDB, account_preference.AccountPreferenceSetRequest{
@ -144,7 +149,10 @@ func (h *Account) Update(ctx context.Context, w http.ResponseWriter, r *http.Req
}
}
sess.Values[webcontext.SessionKeyPreferenceDatetimeFormat] = req.PreferenceDatetimeFormat
if claims.Preferences.DatetimeFormat != req.PreferenceDatetimeFormat {
claims.Preferences.DatetimeFormat = req.PreferenceDatetimeFormat
updateClaims = true
}
}
if preferenceDateFormat != req.PreferenceDateFormat {
@ -162,7 +170,10 @@ func (h *Account) Update(ctx context.Context, w http.ResponseWriter, r *http.Req
}
}
sess.Values[webcontext.SessionKeyPreferenceDateFormat] = req.PreferenceDateFormat
if claims.Preferences.DateFormat != req.PreferenceDateFormat {
claims.Preferences.DateFormat = req.PreferenceDateFormat
updateClaims = true
}
}
if preferenceTimeFormat != req.PreferenceTimeFormat {
@ -180,7 +191,18 @@ func (h *Account) Update(ctx context.Context, w http.ResponseWriter, r *http.Req
}
}
sess.Values[webcontext.SessionKeyPreferenceTimeFormat] = req.PreferenceTimeFormat
if claims.Preferences.TimeFormat != req.PreferenceTimeFormat {
claims.Preferences.TimeFormat = req.PreferenceTimeFormat
updateClaims = true
}
}
// Update the access token to include the updated claims.
if updateClaims {
ctx, err = updateContextClaims(ctx, h.Authenticator, claims)
if err != nil {
return false, err
}
}
// Display a success message to the user.
@ -196,7 +218,7 @@ func (h *Account) Update(ctx context.Context, w http.ResponseWriter, r *http.Req
return true, nil
}
acc, err := account.Read(ctx, claims, h.MasterDB, claims.Audience, false)
acc, err := account.ReadByID(ctx, claims, h.MasterDB, claims.Audience)
if err != nil {
return false, err
}

View File

@ -66,13 +66,21 @@ func APP(shutdown chan os.Signal, log *log.Logger, env webcontext.Env, staticDir
app.Handle("POST", "/user/update", u.Update, mid.AuthenticateSessionRequired(authenticator), mid.HasAuth())
app.Handle("GET", "/user/update", u.Update, mid.AuthenticateSessionRequired(authenticator), mid.HasAuth())
app.Handle("GET", "/user/account", u.Account, mid.AuthenticateSessionRequired(authenticator), mid.HasAuth())
app.Handle("GET", "/user/virtual-login/:user_id", u.VirtualLogin, mid.AuthenticateSessionRequired(authenticator), mid.HasRole(auth.RoleAdmin))
app.Handle("POST", "/user/virtual-login", u.VirtualLogin, mid.AuthenticateSessionRequired(authenticator), mid.HasRole(auth.RoleAdmin))
app.Handle("GET", "/user/virtual-login", u.VirtualLogin, mid.AuthenticateSessionRequired(authenticator), mid.HasRole(auth.RoleAdmin))
app.Handle("GET", "/user/virtual-logout", u.VirtualLogout, mid.AuthenticateSessionRequired(authenticator), mid.HasAuth())
app.Handle("GET", "/user/switch-account/:account_id", u.SwitchAccount, mid.AuthenticateSessionRequired(authenticator), mid.HasAuth())
app.Handle("POST", "/user/switch-account", u.SwitchAccount, mid.AuthenticateSessionRequired(authenticator), mid.HasAuth())
app.Handle("GET", "/user/switch-account", u.SwitchAccount, mid.AuthenticateSessionRequired(authenticator), mid.HasAuth())
app.Handle("POST", "/user", u.View, mid.AuthenticateSessionRequired(authenticator), mid.HasAuth())
app.Handle("GET", "/user", u.View, mid.AuthenticateSessionRequired(authenticator), mid.HasAuth())
// Register account management endpoints.
acc := Account{
MasterDB: masterDB,
Renderer: renderer,
MasterDB: masterDB,
Renderer: renderer,
Authenticator: authenticator,
}
app.Handle("POST", "/account/update", acc.Update, mid.AuthenticateSessionRequired(authenticator), mid.HasRole(auth.RoleAdmin))
app.Handle("GET", "/account/update", acc.Update, mid.AuthenticateSessionRequired(authenticator), mid.HasRole(auth.RoleAdmin))

View File

@ -12,7 +12,7 @@ import (
"geeks-accelerator/oss/saas-starter-kit/internal/platform/web/webcontext"
"geeks-accelerator/oss/saas-starter-kit/internal/platform/web/weberror"
"geeks-accelerator/oss/saas-starter-kit/internal/signup"
"geeks-accelerator/oss/saas-starter-kit/internal/user"
"geeks-accelerator/oss/saas-starter-kit/internal/user_auth"
"github.com/gorilla/schema"
"github.com/jmoiron/sqlx"
"github.com/pkg/errors"
@ -68,7 +68,7 @@ func (h *Signup) Step1(ctx context.Context, w http.ResponseWriter, r *http.Reque
}
// Authenticated the new user.
token, err := user.Authenticate(ctx, h.MasterDB, h.Authenticator, req.User.Email, req.User.Password, time.Hour, ctxValues.Now)
token, err := user_auth.Authenticate(ctx, h.MasterDB, h.Authenticator, req.User.Email, req.User.Password, time.Hour, ctxValues.Now)
if err != nil {
return err
}
@ -93,13 +93,6 @@ func (h *Signup) Step1(ctx context.Context, w http.ResponseWriter, r *http.Reque
return nil
}
data["geonameCountries"] = geonames.ValidGeonameCountries
data["countries"], err = geonames.FindCountries(ctx, h.MasterDB, "name", "")
if err != nil {
return err
}
return nil
}
@ -107,6 +100,13 @@ func (h *Signup) Step1(ctx context.Context, w http.ResponseWriter, r *http.Reque
return web.RenderError(ctx, w, r, err, h.Renderer, TmplLayoutBase, TmplContentErrorGeneric, web.MIMETextHTMLCharsetUTF8)
}
data["geonameCountries"] = geonames.ValidGeonameCountries
data["countries"], err = geonames.FindCountries(ctx, h.MasterDB, "name", "")
if err != nil {
return err
}
data["form"] = req
if verr, ok := weberror.NewValidationError(ctx, webcontext.Validator().Struct(signup.SignupRequest{})); ok {

View File

@ -4,6 +4,8 @@ import (
"context"
"fmt"
"net/http"
"net/url"
"strings"
"time"
"geeks-accelerator/oss/saas-starter-kit/internal/account"
@ -16,6 +18,7 @@ import (
project_routes "geeks-accelerator/oss/saas-starter-kit/internal/project-routes"
"geeks-accelerator/oss/saas-starter-kit/internal/user"
"geeks-accelerator/oss/saas-starter-kit/internal/user_account"
"geeks-accelerator/oss/saas-starter-kit/internal/user_auth"
"github.com/gorilla/schema"
"github.com/gorilla/sessions"
"github.com/jmoiron/sqlx"
@ -34,7 +37,7 @@ type User struct {
}
type UserLoginRequest struct {
user.AuthenticateRequest
user_auth.AuthenticateRequest
RememberMe bool
}
@ -68,7 +71,7 @@ func (h *User) Login(ctx context.Context, w http.ResponseWriter, r *http.Request
}
// Authenticated the user.
token, err := user.Authenticate(ctx, h.MasterDB, h.Authenticator, req.Email, req.Password, sessionTTL, ctxValues.Now)
token, err := user_auth.Authenticate(ctx, h.MasterDB, h.Authenticator, req.Email, req.Password, sessionTTL, ctxValues.Now)
if err != nil {
switch errors.Cause(err) {
case user.ErrForbidden:
@ -89,8 +92,16 @@ func (h *User) Login(ctx context.Context, w http.ResponseWriter, r *http.Request
return err
}
redirectUri := "/"
if qv := r.URL.Query().Get("redirect"); qv != "" {
redirectUri, err = url.QueryUnescape(qv)
if err != nil {
return err
}
}
// Redirect the user to the dashboard.
http.Redirect(w, r, "/", http.StatusFound)
http.Redirect(w, r, redirectUri, http.StatusFound)
}
return nil
@ -110,7 +121,7 @@ func (h *User) Login(ctx context.Context, w http.ResponseWriter, r *http.Request
}
// handleSessionToken persists the access token to the session for request authentication.
func handleSessionToken(ctx context.Context, db *sqlx.DB, w http.ResponseWriter, r *http.Request, token user.Token) error {
func handleSessionToken(ctx context.Context, db *sqlx.DB, w http.ResponseWriter, r *http.Request, token user_auth.Token) error {
if token.AccessToken == "" {
return errors.New("accessToken is required.")
}
@ -252,7 +263,7 @@ func (h *User) ResetConfirm(ctx context.Context, w http.ResponseWriter, r *http.
}
// Authenticated the user. Probably should use the default session TTL from UserLogin.
token, err := user.Authenticate(ctx, h.MasterDB, h.Authenticator, u.Email, req.Password, time.Hour, ctxValues.Now)
token, err := user_auth.Authenticate(ctx, h.MasterDB, h.Authenticator, u.Email, req.Password, time.Hour, ctxValues.Now)
if err != nil {
switch errors.Cause(err) {
case account.ErrForbidden:
@ -306,7 +317,7 @@ func (h *User) View(ctx context.Context, w http.ResponseWriter, r *http.Request,
return err
}
usr, err := user.Read(ctx, claims, h.MasterDB, claims.Subject, false)
usr, err := user.ReadByID(ctx, claims, h.MasterDB, claims.Subject)
if err != nil {
return err
}
@ -343,16 +354,15 @@ func (h *User) Update(ctx context.Context, w http.ResponseWriter, r *http.Reques
return err
}
claims, err := auth.ClaimsFromContext(ctx)
if err != nil {
return err
}
//
req := new(user.UserUpdateRequest)
data := make(map[string]interface{})
f := func() (bool, error) {
claims, err := auth.ClaimsFromContext(ctx)
if err != nil {
return false, err
}
if r.Method == http.MethodPost {
err := r.ParseForm()
if err != nil {
@ -415,25 +425,6 @@ func (h *User) Update(ctx context.Context, w http.ResponseWriter, r *http.Reques
return true, nil
}
usr, err := user.Read(ctx, claims, h.MasterDB, claims.Subject, false)
if err != nil {
return false, err
}
if req.ID == "" {
req.FirstName = &usr.FirstName
req.LastName = &usr.LastName
req.Email = &usr.Email
req.Timezone = &usr.Timezone
}
data["user"] = usr.Response(ctx)
data["timezones"], err = geonames.ListTimezones(ctx, h.MasterDB)
if err != nil {
return false, err
}
return false, nil
}
@ -444,6 +435,25 @@ func (h *User) Update(ctx context.Context, w http.ResponseWriter, r *http.Reques
return nil
}
usr, err := user.ReadByID(ctx, claims, h.MasterDB, claims.Subject)
if err != nil {
return err
}
if req.ID == "" {
req.FirstName = &usr.FirstName
req.LastName = &usr.LastName
req.Email = &usr.Email
req.Timezone = &usr.Timezone
}
data["user"] = usr.Response(ctx)
data["timezones"], err = geonames.ListTimezones(ctx, h.MasterDB)
if err != nil {
return err
}
data["form"] = req
if verr, ok := weberror.NewValidationError(ctx, webcontext.Validator().Struct(user.UserUpdateRequest{})); ok {
@ -468,7 +478,7 @@ func (h *User) Account(ctx context.Context, w http.ResponseWriter, r *http.Reque
return err
}
acc, err := account.Read(ctx, claims, h.MasterDB, claims.Audience, false)
acc, err := account.ReadByID(ctx, claims, h.MasterDB, claims.Audience)
if err != nil {
return err
}
@ -483,3 +493,352 @@ func (h *User) Account(ctx context.Context, w http.ResponseWriter, r *http.Reque
return h.Renderer.Render(ctx, w, r, TmplLayoutBase, "user-account.gohtml", web.MIMETextHTMLCharsetUTF8, http.StatusOK, data)
}
// VirtualLogin handles switching the scope of the context to another user.
func (h *User) VirtualLogin(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error {
ctxValues, err := webcontext.ContextValues(ctx)
if err != nil {
return err
}
claims, err := auth.ClaimsFromContext(ctx)
if err != nil {
return err
}
//
req := new(user_auth.VirtualLoginRequest)
data := make(map[string]interface{})
f := func() (bool, error) {
if r.Method == http.MethodPost {
err := r.ParseForm()
if err != nil {
return false, err
}
decoder := schema.NewDecoder()
decoder.IgnoreUnknownKeys(true)
if err := decoder.Decode(req, r.PostForm); err != nil {
return false, err
}
} else {
if pv, ok := params["user_id"]; ok && pv != "" {
req.UserID = pv
}
}
if qv := r.URL.Query().Get("account_id"); qv != "" {
req.AccountID = qv
} else {
req.AccountID = claims.Audience
}
if req.UserID != "" {
sess := webcontext.ContextSession(ctx)
var expires time.Duration
if sess != nil && sess.Options != nil {
expires = time.Second * time.Duration(sess.Options.MaxAge)
} else {
expires = time.Hour
}
// Perform the account switch.
tkn, err := user_auth.VirtualLogin(ctx, h.MasterDB, h.Authenticator, claims, *req, expires, ctxValues.Now)
if err != nil {
if verr, ok := weberror.NewValidationError(ctx, err); ok {
data["validationErrors"] = verr.(*weberror.Error)
return false, nil
} else {
return false, err
}
}
// Update the access token in the session.
sess = webcontext.SessionUpdateAccessToken(sess, tkn.AccessToken)
// Read the account for a flash message.
usr, err := user.ReadByID(ctx, claims, h.MasterDB, tkn.UserID)
if err != nil {
return false, err
}
webcontext.SessionFlashSuccess(ctx,
"User Switched",
fmt.Sprintf("You are now virtually logged into user %s.",
usr.Response(ctx).Name))
// Write the session to the client.
err = webcontext.ContextSession(ctx).Save(r, w)
if err != nil {
return false, err
}
// Redirect the user to the dashboard with the new credentials.
http.Redirect(w, r, "/", http.StatusFound)
return true, nil
}
return false, nil
}
end, err := f()
if err != nil {
return web.RenderError(ctx, w, r, err, h.Renderer, TmplLayoutBase, TmplContentErrorGeneric, web.MIMETextHTMLCharsetUTF8)
} else if end {
return nil
}
usrAccFilter := "account_id = ?"
usrAccs, err := user_account.Find(ctx, claims, h.MasterDB, user_account.UserAccountFindRequest{
Where: &usrAccFilter,
Args: []interface{}{claims.Audience},
})
if err != nil {
return err
}
var userIDs []interface{}
var userPhs []string
for _, usrAcc := range usrAccs {
if usrAcc.UserID == claims.Subject {
// Skip the current authenticated user.
continue
}
userIDs = append(userIDs, usrAcc.UserID)
userPhs = append(userPhs, "?")
}
if len(userIDs) == 0 {
userIDs = append(userIDs, "")
userPhs = append(userPhs, "?")
}
usrFilter := fmt.Sprintf("id IN (%s)", strings.Join(userPhs, ", "))
users, err := user.Find(ctx, claims, h.MasterDB, user.UserFindRequest{
Where: &usrFilter,
Args: userIDs,
})
if err != nil {
return err
}
data["users"] = users.Response(ctx)
if req.AccountID == "" {
req.AccountID = claims.Audience
}
data["form"] = req
if verr, ok := weberror.NewValidationError(ctx, webcontext.Validator().Struct(user_auth.VirtualLoginRequest{})); ok {
data["validationDefaults"] = verr.(*weberror.Error)
}
return h.Renderer.Render(ctx, w, r, TmplLayoutBase, "user-virtual-login.gohtml", web.MIMETextHTMLCharsetUTF8, http.StatusOK, data)
}
// VirtualLogout handles switching the scope back to the user who initiated the virtual login.
func (h *User) VirtualLogout(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error {
ctxValues, err := webcontext.ContextValues(ctx)
if err != nil {
return err
}
claims, err := auth.ClaimsFromContext(ctx)
if err != nil {
return err
}
sess := webcontext.ContextSession(ctx)
var expires time.Duration
if sess != nil && sess.Options != nil {
expires = time.Second * time.Duration(sess.Options.MaxAge)
} else {
expires = time.Hour
}
tkn, err := user_auth.VirtualLogout(ctx, h.MasterDB, h.Authenticator, claims, expires, ctxValues.Now)
if err != nil {
return err
}
// Update the access token in the session.
sess = webcontext.SessionUpdateAccessToken(sess, tkn.AccessToken)
// Display a success message to verify the user has switched contexts.
if claims.Subject != tkn.UserID && claims.Audience != tkn.AccountID {
usr, err := user.ReadByID(ctx, claims, h.MasterDB, tkn.UserID)
if err != nil {
return err
}
acc, err := account.ReadByID(ctx, claims, h.MasterDB, tkn.AccountID)
if err != nil {
return err
}
webcontext.SessionFlashSuccess(ctx,
"Context Switched",
fmt.Sprintf("You are now virtually logged back into account %s user %s.",
acc.Response(ctx).Name, usr.Response(ctx).Name))
} else if claims.Audience != tkn.AccountID {
acc, err := account.ReadByID(ctx, claims, h.MasterDB, tkn.AccountID)
if err != nil {
return err
}
webcontext.SessionFlashSuccess(ctx,
"Context Switched",
fmt.Sprintf("You are now virtually logged back into account %s.",
acc.Response(ctx).Name))
} else {
usr, err := user.ReadByID(ctx, claims, h.MasterDB, tkn.UserID)
if err != nil {
return err
}
webcontext.SessionFlashSuccess(ctx,
"Context Switched",
fmt.Sprintf("You are now virtually logged back into user %s.",
usr.Response(ctx).Name))
}
// Write the session to the client.
err = webcontext.ContextSession(ctx).Save(r, w)
if err != nil {
return err
}
// Redirect the user to the dashboard with the new credentials.
http.Redirect(w, r, "/", http.StatusFound)
return nil
}
// VirtualLogin handles switching the scope of the context to another user.
func (h *User) SwitchAccount(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error {
ctxValues, err := webcontext.ContextValues(ctx)
if err != nil {
return err
}
claims, err := auth.ClaimsFromContext(ctx)
if err != nil {
return err
}
//
req := new(user_auth.SwitchAccountRequest)
data := make(map[string]interface{})
f := func() (bool, error) {
if r.Method == http.MethodPost {
err := r.ParseForm()
if err != nil {
return false, err
}
decoder := schema.NewDecoder()
decoder.IgnoreUnknownKeys(true)
if err := decoder.Decode(req, r.PostForm); err != nil {
return false, err
}
} else {
if pv, ok := params["account_id"]; ok && pv != "" {
req.AccountID = pv
} else if qv := r.URL.Query().Get("account_id"); qv != "" {
req.AccountID = qv
}
}
if req.AccountID != "" {
sess := webcontext.ContextSession(ctx)
var expires time.Duration
if sess != nil && sess.Options != nil {
expires = time.Second * time.Duration(sess.Options.MaxAge)
} else {
expires = time.Hour
}
// Perform the account switch.
tkn, err := user_auth.SwitchAccount(ctx, h.MasterDB, h.Authenticator, claims, *req, expires, ctxValues.Now)
if err != nil {
if verr, ok := weberror.NewValidationError(ctx, err); ok {
data["validationErrors"] = verr.(*weberror.Error)
return false, nil
} else {
return false, err
}
}
// Update the access token in the session.
sess = webcontext.SessionUpdateAccessToken(sess, tkn.AccessToken)
// Read the account for a flash message.
acc, err := account.ReadByID(ctx, claims, h.MasterDB, tkn.AccountID)
if err != nil {
return false, err
}
webcontext.SessionFlashSuccess(ctx,
"Account Switched",
fmt.Sprintf("You are now logged into account %s.",
acc.Response(ctx).Name))
// Write the session to the client.
err = webcontext.ContextSession(ctx).Save(r, w)
if err != nil {
return false, err
}
// Redirect the user to the dashboard with the new credentials.
http.Redirect(w, r, "/", http.StatusFound)
return true, nil
}
return false, nil
}
end, err := f()
if err != nil {
return web.RenderError(ctx, w, r, err, h.Renderer, TmplLayoutBase, TmplContentErrorGeneric, web.MIMETextHTMLCharsetUTF8)
} else if end {
return nil
}
accounts, err := account.Find(ctx, claims, h.MasterDB, account.AccountFindRequest{
Order: []string{"name"},
})
if err != nil {
return err
}
data["accounts"] = accounts.Response(ctx)
if req.AccountID == "" {
req.AccountID = claims.Audience
}
data["form"] = req
if verr, ok := weberror.NewValidationError(ctx, webcontext.Validator().Struct(user_auth.SwitchAccountRequest{})); ok {
data["validationDefaults"] = verr.(*weberror.Error)
}
return h.Renderer.Render(ctx, w, r, TmplLayoutBase, "user-switch-account.gohtml", web.MIMETextHTMLCharsetUTF8, http.StatusOK, data)
}
// updateContextClaims updates the claims in the context.
func updateContextClaims(ctx context.Context, authenticator *auth.Authenticator, claims auth.Claims) (context.Context, error) {
tkn, err := authenticator.GenerateToken(claims)
if err != nil {
return ctx, err
}
sess := webcontext.ContextSession(ctx)
sess = webcontext.SessionUpdateAccessToken(sess, tkn)
ctx = context.WithValue(ctx, auth.Key, claims)
return ctx, nil
}

View File

@ -698,7 +698,7 @@ func main() {
return nil
}
usr, err := user.Read(ctx, auth.Claims{}, masterDb, claims.Subject, false)
usr, err := user.ReadByID(ctx, auth.Claims{}, masterDb, claims.Subject)
if err != nil {
return nil
}
@ -733,7 +733,7 @@ func main() {
return nil
}
acc, err := account.Read(ctx, auth.Claims{}, masterDb, claims.Audience, false)
acc, err := account.ReadByID(ctx, auth.Claims{}, masterDb, claims.Audience)
if err != nil {
return nil
}
@ -746,6 +746,26 @@ func main() {
return a
},
"ContextCanSwitchAccount": func(ctx context.Context) bool {
claims, err := auth.ClaimsFromContext(ctx)
if err != nil || len(claims.AccountIDs) < 2 {
return false
}
return true
},
"ContextIsVirtualSession": func(ctx context.Context) bool {
claims, err := auth.ClaimsFromContext(ctx)
if err != nil {
return false
}
if claims.RootUserID != "" && claims.RootUserID != claims.Subject {
return true
}
if claims.RootAccountID != "" && claims.RootAccountID != claims.Audience {
return true
}
return false
},
}
imgUrlFormatter := staticUrlFormatter
@ -827,11 +847,8 @@ func main() {
switch statusCode {
case http.StatusUnauthorized:
// Handle expired sessions that are returned from the auth middleware.
if strings.Contains(errors.Cause(er).Error(), "token is expired") {
http.Redirect(w, r, "/user/login", http.StatusFound)
return nil
}
http.Redirect(w, r, "/user/login?redirect="+url.QueryEscape(r.RequestURI), http.StatusFound)
return nil
}
return web.RenderError(ctx, w, r, er, renderer, handlers.TmplLayoutBase, handlers.TmplContentErrorGeneric, web.MIMETextHTMLCharsetUTF8)

Binary file not shown.

After

Width:  |  Height:  |  Size: 40 KiB

View File

@ -0,0 +1,60 @@
{{define "title"}}Switch Account{{end}}
{{define "style"}}
{{end}}
{{ define "partials/page-wrapper" }}
<div class="container" id="page-content">
<!-- Outer Row -->
<div class="row justify-content-center">
<div class="col-xl-10 col-lg-12 col-md-9">
<div class="card o-hidden border-0 shadow-lg my-5">
<div class="card-body p-0">
<!-- Nested Row within Card Body -->
<div class="row">
<div class="col-lg-6 d-none d-lg-block bg-login-image"></div>
<div class="col-lg-6">
<div class="p-5">
{{ template "app-flashes" . }}
<div class="text-center">
<h1 class="h4 text-gray-900 mb-4">Switch Account</h1>
</div>
{{ template "validation-error" . }}
<form class="user" method="post" novalidate>
<div class="form-group">
<select class="form-control form-control-select-box {{ ValidationFieldClass $.validationErrors "AccountID" }}" name="AccountID" placeholder="AccountID" required>
{{ range $i := $.accounts }}
<option value="{{ $i.ID }}" {{ if eq $.form.AccountID $i.ID }}selected="selected"{{ end }}>{{ $i.Name }}</option>
{{ end }}
</select>
{{template "invalid-feedback" dict "validationDefaults" $.validationDefaults "validationErrors" $.validationErrors "fieldName" "AccountID" }}
</div>
<button class="btn btn-primary btn-user btn-block">
Login
</button>
<hr>
</form>
</div>
</div>
</div>
</div>
</div>
</div>
</div>
</div>
{{end}}
{{define "js"}}
<script>
$(document).ready(function() {
$(document).find('body').addClass('bg-gradient-primary');
});
</script>
{{end}}

View File

@ -0,0 +1,60 @@
{{define "title"}}Switch User{{end}}
{{define "style"}}
{{end}}
{{ define "partials/page-wrapper" }}
<div class="container" id="page-content">
<!-- Outer Row -->
<div class="row justify-content-center">
<div class="col-xl-10 col-lg-12 col-md-9">
<div class="card o-hidden border-0 shadow-lg my-5">
<div class="card-body p-0">
<!-- Nested Row within Card Body -->
<div class="row">
<div class="col-lg-6 d-none d-lg-block bg-login-image"></div>
<div class="col-lg-6">
<div class="p-5">
{{ template "app-flashes" . }}
<div class="text-center">
<h1 class="h4 text-gray-900 mb-4">Switch User</h1>
</div>
{{ template "validation-error" . }}
<form class="user" method="post" novalidate>
<div class="form-group">
<select class="form-control form-control-select-box {{ ValidationFieldClass $.validationErrors "User" }}" name="UserID" placeholder="UserID" required>
{{ range $i := $.users }}
<option value="{{ $i.ID }}" {{ if eq $.form.UserID $i.ID }}selected="selected"{{ end }}>{{ $i.Name }}</option>
{{ end }}
</select>
{{template "invalid-feedback" dict "validationDefaults" $.validationDefaults "validationErrors" $.validationErrors "fieldName" "UserID" }}
</div>
<button class="btn btn-primary btn-user btn-block">
Login
</button>
<hr>
</form>
</div>
</div>
</div>
</div>
</div>
</div>
</div>
</div>
{{end}}
{{define "js"}}
<script>
$(document).ready(function() {
$(document).find('body').addClass('bg-gradient-primary');
});
</script>
{{end}}

View File

@ -177,7 +177,6 @@
<i class="fas fa-cogs fa-sm fa-fw mr-2 text-gray-400"></i>
Account Settings
</a>
<a class="dropdown-item" href="/users">
<i class="fas fa-cogs fa-sm fa-fw mr-2 text-gray-400"></i>
Manage Users
@ -197,8 +196,22 @@
<i class="fas fa-cogs fa-sm fa-fw mr-2 text-gray-400"></i>
Support
</a>
<div class="dropdown-divider"></div>
{{ if ContextCanSwitchAccount $._Ctx }}
<a class="dropdown-item" href="/user/switch-account?modal=1">
<i class="far fa-sign-in fa-sm fa-fw mr-2 text-gray-400"></i>
Switch Account
</a>
{{ end }}
{{ if ContextIsVirtualSession $._Ctx }}
<a class="dropdown-item" href="/user/virtual-logout">
<i class="far fa-sign-out fa-sm fa-fw mr-2 text-gray-400"></i>
Switch Back
</a>
{{ end }}
<a class="dropdown-item" href="/user/logout" data-toggle="modal" data-target="#logoutModal">
<i class="fas fa-sign-out-alt fa-sm fa-fw mr-2 text-gray-400"></i>
Logout

View File

@ -150,7 +150,7 @@ func selectQuery() *sqlbuilder.SelectBuilder {
// Find gets all the accounts from the database based on the request params.
// TODO: Need to figure out why can't parse the args when appending the where
// to the query.
func Find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req AccountFindRequest) ([]*Account, error) {
func Find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req AccountFindRequest) (Accounts, error) {
query := selectQuery()
if req.Where != nil {
@ -170,7 +170,7 @@ func Find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req AccountF
}
// find internal method for getting all the accounts from the database using a select query.
func find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, query *sqlbuilder.SelectBuilder, args []interface{}, includedArchived bool) ([]*Account, error) {
func find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, query *sqlbuilder.SelectBuilder, args []interface{}, includedArchived bool) (Accounts, error) {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.account.Find")
defer span.Finish()

View File

@ -99,6 +99,21 @@ func (m *AccountResponse) MarshalBinary() ([]byte, error) {
return json.Marshal(m)
}
// Accounts a list of Accounts.
type Accounts []*Account
// Response transforms a list of Accounts to a list of AccountResponses.
func (m *Accounts) Response(ctx context.Context) []*AccountResponse {
var l []*AccountResponse
if m != nil && len(*m) > 0 {
for _, n := range *m {
l = append(l, n.Response(ctx))
}
}
return l
}
// AccountCreateRequest contains information needed to create a new Account.
type AccountCreateRequest struct {
Name string `json:"name" validate:"required,unique" example:"Company Name"`

View File

@ -23,9 +23,11 @@ const Key ctxKey = 1
// Claims represents the authorization claims transmitted via a JWT.
type Claims struct {
AccountIds []string `json:"accounts"`
Roles []string `json:"roles"`
Preferences ClaimPreferences `json:"prefs"`
RootUserID string `json:"root_user_id"`
RootAccountID string `json:"root_account_id"`
AccountIDs []string `json:"accounts"`
Roles []string `json:"roles"`
Preferences ClaimPreferences `json:"prefs"`
jwt.StandardClaims
}
@ -41,14 +43,16 @@ type ClaimPreferences struct {
// NewClaims constructs a Claims value for the identified user. The Claims
// expire within a specified duration of the provided time. Additional fields
// of the Claims can be set after calling NewClaims is desired.
func NewClaims(userId, accountId string, accountIds []string, roles []string, prefs ClaimPreferences, now time.Time, expires time.Duration) Claims {
func NewClaims(userID, accountID string, accountIDs []string, roles []string, prefs ClaimPreferences, now time.Time, expires time.Duration) Claims {
c := Claims{
AccountIds: accountIds,
Roles: roles,
Preferences: prefs,
AccountIDs: accountIDs,
RootAccountID: accountID,
RootUserID: userID,
Roles: roles,
Preferences: prefs,
StandardClaims: jwt.StandardClaims{
Subject: userId,
Audience: accountId,
Subject: userID,
Audience: accountID,
IssuedAt: now.Unix(),
ExpiresAt: now.Add(expires).Unix(),
},

View File

@ -14,39 +14,23 @@ const KeySession ctxKeySession = 1
// Session keys used to store values.
const (
SessionKeyAccessToken = iota
//SessionKeyPreferenceDatetimeFormat
//SessionKeyPreferenceDateFormat
//SessionKeyPreferenceTimeFormat
//SessionKeyTimezone
)
func init() {
//gob.Register(&Session{})
}
// Session represents a user with authentication.
type Session struct {
*sessions.Session
}
// ContextWithSession appends a universal translator to a context.
func ContextWithSession(ctx context.Context, session *sessions.Session) context.Context {
return context.WithValue(ctx, KeySession, session)
}
// ContextSession returns the session from a context.
func ContextSession(ctx context.Context) *Session {
if s, ok := ctx.Value(KeySession).(*Session); ok {
func ContextSession(ctx context.Context) *sessions.Session {
if s, ok := ctx.Value(KeySession).(*sessions.Session); ok {
return s
}
return nil
}
func ContextAccessToken(ctx context.Context) (string, bool) {
return ContextSession(ctx).AccessToken()
}
func (sess *Session) AccessToken() (string, bool) {
sess := ContextSession(ctx)
if sess == nil {
return "", false
}
@ -56,60 +40,19 @@ func (sess *Session) AccessToken() (string, bool) {
return "", false
}
/*
func(sess *Session) PreferenceDatetimeFormat() (string, bool) {
if sess == nil {
return "", false
}
if sv, ok := sess.Values[SessionKeyPreferenceDatetimeFormat].(string); ok {
return sv, true
}
return "", false
}
func(sess *Session) PreferenceDateFormat() (string, bool) {
if sess == nil {
return "", false
}
if sv, ok := sess.Values[SessionKeyPreferenceDateFormat].(string); ok {
return sv, true
}
return "", false
}
func(sess *Session) PreferenceTimeFormat() (string, bool) {
if sess == nil {
return "", false
}
if sv, ok := sess.Values[SessionKeyPreferenceTimeFormat].(string); ok {
return sv, true
}
return "", false
}
func(sess *Session) Timezone() (*time.Location, bool) {
if sess != nil {
if sv, ok := sess.Values[SessionKeyTimezone].(*time.Location); ok {
return sv, true
}
}
return nil, false
}
*/
func SessionInit(session *Session, accessToken string) *Session {
func SessionInit(session *sessions.Session, accessToken string) *sessions.Session {
session.Values[SessionKeyAccessToken] = accessToken
//session.Values[SessionKeyPreferenceDatetimeFormat] = datetimeFormat
//session.Values[SessionKeyPreferenceDateFormat] = dateFormat
//session.Values[SessionKeyPreferenceTimeFormat] = timeFormat
//session.Values[SessionKeyTimezone] = timezone
return session
}
func SessionDestroy(session *Session) *Session {
func SessionUpdateAccessToken(session *sessions.Session, accessToken string) *sessions.Session {
session.Values[SessionKeyAccessToken] = accessToken
return session
}
func SessionDestroy(session *sessions.Session) *sessions.Session {
delete(session.Values, SessionKeyAccessToken)

View File

@ -56,6 +56,21 @@ func (m *Project) Response(ctx context.Context) *ProjectResponse {
return r
}
// Projects a list of Projects.
type Projects []*Project
// Response transforms a list of Projects to a list of ProjectResponses.
func (m *Projects) Response(ctx context.Context) []*ProjectResponse {
var l []*ProjectResponse
if m != nil && len(*m) > 0 {
for _, n := range *m {
l = append(l, n.Response(ctx))
}
}
return l
}
// ProjectCreateRequest contains information needed to create a new Project.
type ProjectCreateRequest struct {
AccountID string `json:"account_id" validate:"required,uuid" example:"c4653bf9-5978-48b7-89c5-95704aebb7e2"`

View File

@ -124,13 +124,13 @@ func findRequestQuery(req ProjectFindRequest) (*sqlbuilder.SelectBuilder, []inte
}
// Find gets all the projects from the database based on the request params.
func Find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req ProjectFindRequest) ([]*Project, error) {
func Find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req ProjectFindRequest) (Projects, error) {
query, args := findRequestQuery(req)
return find(ctx, claims, dbConn, query, args, req.IncludeArchived)
}
// find internal method for getting all the projects from the database using a select query.
func find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, query *sqlbuilder.SelectBuilder, args []interface{}, includedArchived bool) ([]*Project, error) {
func find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, query *sqlbuilder.SelectBuilder, args []interface{}, includedArchived bool) (Projects, error) {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.project.Find")
defer span.Finish()

View File

@ -78,6 +78,21 @@ func (m *UserResponse) MarshalBinary() ([]byte, error) {
return json.Marshal(m)
}
// Users a list of Users.
type Users []*User
// Response transforms a list of Users to a list of UserResponses.
func (m *Users) Response(ctx context.Context) []*UserResponse {
var l []*UserResponse
if m != nil && len(*m) > 0 {
for _, n := range *m {
l = append(l, n.Response(ctx))
}
}
return l
}
// UserCreateRequest contains information needed to create a new User.
type UserCreateRequest struct {
FirstName string `json:"first_name" validate:"required" example:"Gabi"`

View File

@ -202,13 +202,13 @@ func findRequestQuery(req UserFindRequest) (*sqlbuilder.SelectBuilder, []interfa
}
// Find gets all the users from the database based on the request params.
func Find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserFindRequest) ([]*User, error) {
func Find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserFindRequest) (Users, error) {
query, args := findRequestQuery(req)
return find(ctx, claims, dbConn, query, args, req.IncludeArchived)
}
// find internal method for getting all the users from the database using a select query.
func find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, query *sqlbuilder.SelectBuilder, args []interface{}, includedArchived bool) ([]*User, error) {
func find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, query *sqlbuilder.SelectBuilder, args []interface{}, includedArchived bool) (Users, error) {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.Find")
defer span.Finish()

View File

@ -66,6 +66,34 @@ func (m *UserAccount) Response(ctx context.Context) *UserAccountResponse {
return r
}
// HasRole checks if the entry has a role.
func (m *UserAccount) HasRole(role UserAccountRole) bool {
if m == nil {
return false
}
for _, r := range m.Roles {
if r == role {
return true
}
}
return false
}
// UserAccounts a list of UserAccounts.
type UserAccounts []*UserAccount
// Response transforms a list of UserAccounts to a list of UserAccountResponses.
func (m *UserAccounts) Response(ctx context.Context) []*UserAccountResponse {
var l []*UserAccountResponse
if m != nil && len(*m) > 0 {
for _, n := range *m {
l = append(l, n.Response(ctx))
}
}
return l
}
// UserAccountCreateRequest defines the information is needed to associate a user to an
// account. Users are global to the application and each users access can be managed
// on an account level. If a current entry exists in the database but is archived,

View File

@ -131,13 +131,13 @@ func findRequestQuery(req UserAccountFindRequest) (*sqlbuilder.SelectBuilder, []
}
// Find gets all the user accounts from the database based on the request params.
func Find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserAccountFindRequest) ([]*UserAccount, error) {
func Find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserAccountFindRequest) (UserAccounts, error) {
query, args := findRequestQuery(req)
return find(ctx, claims, dbConn, query, args, req.IncludeArchived)
}
// Find gets all the user accounts from the database based on the select query
func find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, query *sqlbuilder.SelectBuilder, args []interface{}, includedArchived bool) ([]*UserAccount, error) {
func find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, query *sqlbuilder.SelectBuilder, args []interface{}, includedArchived bool) (UserAccounts, error) {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user_account.Find")
defer span.Finish()
@ -180,7 +180,7 @@ func find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, query *sqlbu
}
// Retrieve gets the specified user from the database.
func FindByUserID(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, userID string, includedArchived bool) ([]*UserAccount, error) {
func FindByUserID(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, userID string, includedArchived bool) (UserAccounts, error) {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user_account.FindByUserID")
defer span.Finish()

View File

@ -3,6 +3,7 @@ package user_auth
import (
"context"
"database/sql"
"geeks-accelerator/oss/saas-starter-kit/internal/user_account"
"strings"
"time"
@ -22,6 +23,9 @@ var (
// ErrAuthenticationFailure occurs when a user attempts to authenticate but
// anything goes wrong.
ErrAuthenticationFailure = errors.New("Authentication failed")
// ErrForbidden occurs when a user tries to do something that is forbidden to them according to our access control policies.
ErrForbidden = errors.New("Attempted action is not allowed")
)
const (
@ -65,22 +69,11 @@ func Authenticate(ctx context.Context, dbConn *sqlx.DB, tknGen TokenGenerator, e
return generateToken(ctx, dbConn, tknGen, auth.Claims{}, u.ID, "", expires, now, scopes...)
}
// Authenticate finds a user by their email and verifies their password. On success
// it returns a Token that can be used to authenticate access to the application in
// the future.
func SwitchAccount(ctx context.Context, dbConn *sqlx.DB, tknGen TokenGenerator, claims auth.Claims, accountID string, expires time.Duration, now time.Time, scopes ...string) (Token, error) {
// SwitchAccount allows users to switch between multiple accounts, this changes the claim audience.
func SwitchAccount(ctx context.Context, dbConn *sqlx.DB, tknGen TokenGenerator, claims auth.Claims, req SwitchAccountRequest, expires time.Duration, now time.Time, scopes ...string) (Token, error) {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user_auth.SwitchAccount")
defer span.Finish()
// Defines struct to apply validation for the supplied claims and account ID.
req := struct {
UserID string `json:"user_id" validate:"required,uuid"`
AccountID string `json:"account_id" validate:"required,uuid"`
}{
UserID: claims.Subject,
AccountID: accountID,
}
// Validate the request.
v := webcontext.Validator()
err := v.Struct(req)
@ -88,12 +81,74 @@ func SwitchAccount(ctx context.Context, dbConn *sqlx.DB, tknGen TokenGenerator,
return Token{}, err
}
claims.RootAccountID = req.AccountID
if claims.RootUserID == "" {
claims.RootUserID = claims.Subject
}
// Generate a token for the user ID in supplied in claims as the Subject. Pass
// in the supplied claims as well to enforce ACLs when finding the current
// list of accounts for the user.
return generateToken(ctx, dbConn, tknGen, claims, claims.Subject, req.AccountID, expires, now, scopes...)
}
// VirtualLogin allows users to mock being logged in as other users.
func VirtualLogin(ctx context.Context, dbConn *sqlx.DB, tknGen TokenGenerator, claims auth.Claims, req VirtualLoginRequest, expires time.Duration, now time.Time, scopes ...string) (Token, error) {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user_auth.VirtualLogin")
defer span.Finish()
// Validate the request.
v := webcontext.Validator()
err := v.Struct(req)
if err != nil {
return Token{}, err
}
// Find all the accounts that the current user has access to.
usrAccs, err := user_account.FindByUserID(ctx, claims, dbConn, claims.Subject, false)
if err != nil {
return Token{}, err
}
// The user must have the role of admin to login any other user.
var hasAccountAdminRole bool
for _, usrAcc := range usrAccs {
if usrAcc.HasRole(user_account.UserAccountRole_Admin) {
if usrAcc.AccountID == req.AccountID {
hasAccountAdminRole = true
break
}
}
}
if !hasAccountAdminRole {
return Token{}, errors.WithMessagef(ErrForbidden, "User %s does not have correct access to account %s ", claims.Subject, req.AccountID)
}
if claims.RootAccountID == "" {
claims.RootAccountID = claims.Audience
}
if claims.RootUserID == "" {
claims.RootUserID = claims.Subject
}
// Generate a token for the user ID in supplied in claims as the Subject. Pass
// in the supplied claims as well to enforce ACLs when finding the current
// list of accounts for the user.
return generateToken(ctx, dbConn, tknGen, claims, req.UserID, req.AccountID, expires, now, scopes...)
}
// VirtualLogout allows switch back to their root user/account.
func VirtualLogout(ctx context.Context, dbConn *sqlx.DB, tknGen TokenGenerator, claims auth.Claims, expires time.Duration, now time.Time, scopes ...string) (Token, error) {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user_auth.VirtualLogout")
defer span.Finish()
// Generate a token for the user ID in supplied in claims as the Subject. Pass
// in the supplied claims as well to enforce ACLs when finding the current
// list of accounts for the user.
return generateToken(ctx, dbConn, tknGen, claims, claims.RootUserID, claims.RootAccountID, expires, now, scopes...)
}
// generateToken generates claims for the supplied user ID and account ID and then
// returns the token for the generated claims used for authentication.
func generateToken(ctx context.Context, dbConn *sqlx.DB, tknGen TokenGenerator, claims auth.Claims, userID, accountID string, expires time.Duration, now time.Time, scopes ...string) (Token, error) {
@ -250,7 +305,7 @@ func generateToken(ctx context.Context, dbConn *sqlx.DB, tknGen TokenGenerator,
if scopeValid {
roles = append(roles, s)
} else {
err := errors.Errorf("invalid scope '%s'", s)
err := errors.Wrapf(ErrForbidden, "invalid scope '%s'", s)
return Token{}, err
}
}
@ -259,7 +314,7 @@ func generateToken(ctx context.Context, dbConn *sqlx.DB, tknGen TokenGenerator,
}
if len(roles) == 0 {
err := errors.New("no roles defined for user")
err := errors.Wrapf(ErrForbidden, "no roles defined for user")
return Token{}, err
}
}
@ -314,14 +369,24 @@ func generateToken(ctx context.Context, dbConn *sqlx.DB, tknGen TokenGenerator,
claimPref = auth.NewClaimPreferences(tz, preferenceDatetimeFormat, preferenceDateFormat, preferenceTimeFormat)
}
// Ensure the current claims has the root values set.
if (claims.RootAccountID == "" && claims.Audience != "") || (claims.RootUserID == "" && claims.Subject != "") {
claims.RootAccountID = claims.Audience
claims.RootUserID = claims.Subject
}
// JWT claims requires both an audience and a subject. For this application:
// Subject: The ID of the user authenticated.
// Audience: The ID of the account the user is accessing. A list of account IDs
// will also be included to support the user switching between them.
claims = auth.NewClaims(userID, accountID, accountIds, roles, claimPref, now, expires)
newClaims := auth.NewClaims(userID, accountID, accountIds, roles, claimPref, now, expires)
// Copy the original root account/user ID.
newClaims.RootAccountID = claims.RootAccountID
newClaims.RootUserID = claims.RootUserID
// Generate a token for the user with the defined claims.
tknStr, err := tknGen.GenerateToken(claims)
tknStr, err := tknGen.GenerateToken(newClaims)
if err != nil {
return Token{}, errors.Wrap(err, "generating token")
}
@ -329,9 +394,9 @@ func generateToken(ctx context.Context, dbConn *sqlx.DB, tknGen TokenGenerator,
tkn := Token{
AccessToken: tknStr,
TokenType: "Bearer",
claims: claims,
UserID: claims.Subject,
AccountID: claims.Audience,
claims: newClaims,
UserID: newClaims.Subject,
AccountID: newClaims.Audience,
}
if expires.Seconds() > 0 {

View File

@ -2,6 +2,7 @@ package user_auth
import (
"encoding/json"
"fmt"
"os"
"testing"
"time"
@ -63,6 +64,9 @@ func TestAuthenticate(t *testing.T) {
}
t.Logf("\t%s\tCreate user account ok.", tests.Success)
// Add 30 minutes to now to simulate time passing.
now = now.Add(time.Minute * 30)
acc2, err := account.MockAccount(ctx, test.MasterDB, now)
if err != nil {
t.Log("\t\tGot :", err)
@ -81,7 +85,7 @@ func TestAuthenticate(t *testing.T) {
}, now)
// Add 30 minutes to now to simulate time passing.
now = now.Add(time.Minute * 30)
now = now.Add(time.Minute * 5)
// Try to authenticate valid user with invalid password.
_, err = Authenticate(ctx, test.MasterDB, tknGen, usrAcc.User.Email, "xy7", time.Hour, now)
@ -106,17 +110,20 @@ func TestAuthenticate(t *testing.T) {
t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tParse claims from token failed.", tests.Failed)
}
expectClaims := tkn1.claims
expectClaims.RootUserID = ""
expectClaims.RootAccountID = ""
expectClaims.Subject = usrAcc.UserID
expectClaims.Audience = usrAcc.AccountID
// Hack for Unhandled Exception in go-cmp@v0.3.0/cmp/options.go:229
resClaims, _ := json.Marshal(claims1)
expectClaims, _ := json.Marshal(tkn1.claims)
if diff := cmp.Diff(string(resClaims), string(expectClaims)); diff != "" {
if diff := cmpClaims(claims1, expectClaims); diff != "" {
t.Fatalf("\t%s\tExpected parsed claims to match from token. Diff:\n%s", tests.Failed, diff)
}
t.Logf("\t%s\tAuthenticate parse claims from token ok.", tests.Success)
// Try switching to a second account using the first set of claims.
tkn2, err := SwitchAccount(ctx, test.MasterDB, tknGen, claims1, acc2.ID, time.Hour, now)
tkn2, err := SwitchAccount(ctx, test.MasterDB, tknGen, claims1,
SwitchAccountRequest{AccountID: acc2.ID}, time.Hour, now)
if err != nil {
t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tSwitchAccount user failed.", tests.Failed)
@ -129,11 +136,13 @@ func TestAuthenticate(t *testing.T) {
t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tParse claims from token failed.", tests.Failed)
}
expectClaims = tkn2.claims
expectClaims.RootUserID = usrAcc.UserID
expectClaims.RootAccountID = acc2.ID
expectClaims.Subject = usrAcc.UserID
expectClaims.Audience = acc2.ID
// Hack for Unhandled Exception in go-cmp@v0.3.0/cmp/options.go:229
resClaims, _ = json.Marshal(claims2)
expectClaims, _ = json.Marshal(tkn2.claims)
if diff := cmp.Diff(string(resClaims), string(expectClaims)); diff != "" {
if diff := cmpClaims(claims2, expectClaims); diff != "" {
t.Fatalf("\t%s\tExpected parsed claims to match from token. Diff:\n%s", tests.Failed, diff)
}
t.Logf("\t%s\tSwitchAccount parse claims from token ok.", tests.Success)
@ -256,3 +265,724 @@ func TestUserResetPassword(t *testing.T) {
t.Logf("\t%s\tAuthenticate ok.", tests.Success)
}
}
// TestSwitchAccount validates the behavior around allowing users to switch between their accounts.
func TestSwitchAccount(t *testing.T) {
defer tests.Recover(t)
// Auth tokens are valid for an our and is verified against current time.
// Issue the token one hour ago.
now := time.Now().Add(time.Hour * -1)
ctx := tests.Context()
type authTest struct {
name string
root *user_account.MockUserAccountResponse
switch1Req SwitchAccountRequest
switch1Roles []user_account.UserAccountRole
switch1Scopes []string
switch1Err error
switch2Req SwitchAccountRequest
switch2Roles []user_account.UserAccountRole
switch2Scopes []string
switch2Err error
}
var authTests []authTest
// Test all the combinations there the user has access to all three accounts.
if true {
for _, roles := range [][]user_account.UserAccountRole{
[]user_account.UserAccountRole{user_account.UserAccountRole_Admin, user_account.UserAccountRole_Admin, user_account.UserAccountRole_Admin},
[]user_account.UserAccountRole{user_account.UserAccountRole_User, user_account.UserAccountRole_User, user_account.UserAccountRole_User},
[]user_account.UserAccountRole{user_account.UserAccountRole_Admin, user_account.UserAccountRole_User, user_account.UserAccountRole_Admin},
[]user_account.UserAccountRole{user_account.UserAccountRole_User, user_account.UserAccountRole_Admin, user_account.UserAccountRole_User},
} {
// Create a new user for testing.
usrAcc, err := user_account.MockUserAccount(ctx, test.MasterDB, now, roles[0])
if err != nil {
t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tCreate user account failed.", tests.Failed)
}
// Create the second account.
now = now.Add(time.Minute)
acc2, err := account.MockAccount(ctx, test.MasterDB, now)
if err != nil {
t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tCreate second account failed.", tests.Failed)
}
// Associate the second account with root user.
usrAcc2, err := user_account.Create(ctx, auth.Claims{}, test.MasterDB, user_account.UserAccountCreateRequest{
UserID: usrAcc.UserID,
AccountID: acc2.ID,
Roles: []user_account.UserAccountRole{user_account.UserAccountRole(roles[1])},
}, now)
if err != nil {
t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tLinking second account to user failed.", tests.Failed)
}
// Create the third account.
now = now.Add(time.Minute)
acc3, err := account.MockAccount(ctx, test.MasterDB, now)
if err != nil {
t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tCreate third account failed.", tests.Failed)
}
// Associate the third account with root user.
usrAcc3, err := user_account.Create(ctx, auth.Claims{}, test.MasterDB, user_account.UserAccountCreateRequest{
UserID: usrAcc.UserID,
AccountID: acc3.ID,
Roles: []user_account.UserAccountRole{user_account.UserAccountRole(roles[2])},
}, now)
if err != nil {
t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tLinking third account to user failed.", tests.Failed)
}
authTests = append(authTests, authTest{
name: fmt.Sprintf("Root account role %s -> role %s account 2 -> role %s account 3.",
roles[0], roles[1], roles[2]),
root: usrAcc,
switch1Req: SwitchAccountRequest{AccountID: acc2.ID},
switch1Roles: usrAcc2.Roles,
switch1Err: nil,
switch2Req: SwitchAccountRequest{AccountID: acc3.ID},
switch2Err: nil,
switch2Roles: usrAcc3.Roles,
})
}
}
// Root account 1 -> invalid account 2
if true {
// Create a new user for testing.
usrAcc, err := user_account.MockUserAccount(ctx, test.MasterDB, now, user_account.UserAccountRole_Admin)
if err != nil {
t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tCreate user account failed.", tests.Failed)
}
// Create the second account and don't associate it with the root user.
now = now.Add(time.Minute)
acc2, err := account.MockAccount(ctx, test.MasterDB, now)
if err != nil {
t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tCreate second account failed.", tests.Failed)
}
authTests = append(authTests, authTest{
name: "Root account 1 -> invalid account 2.",
root: usrAcc,
switch1Req: SwitchAccountRequest{AccountID: acc2.ID},
switch1Err: ErrAuthenticationFailure,
})
}
// Root account 1 -> valid account 2 with scopes -> valid account 3 with invalid scope.
if true {
// Create a new user for testing.
usrAcc, err := user_account.MockUserAccount(ctx, test.MasterDB, now, user_account.UserAccountRole_Admin)
if err != nil {
t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tCreate user account failed.", tests.Failed)
}
// Create the second account.
now = now.Add(time.Minute)
acc2, err := account.MockAccount(ctx, test.MasterDB, now)
if err != nil {
t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tCreate second account failed.", tests.Failed)
}
// Associate the second account with root user.
usrAcc2, err := user_account.Create(ctx, auth.Claims{}, test.MasterDB, user_account.UserAccountCreateRequest{
UserID: usrAcc.UserID,
AccountID: acc2.ID,
Roles: []user_account.UserAccountRole{user_account.UserAccountRole_Admin},
}, now)
if err != nil {
t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tLinking second account to user failed.", tests.Failed)
}
// Create the third account.
now = now.Add(time.Minute)
acc3, err := account.MockAccount(ctx, test.MasterDB, now)
if err != nil {
t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tCreate third account failed.", tests.Failed)
}
// Associate the third account with root user.
usrAcc3, err := user_account.Create(ctx, auth.Claims{}, test.MasterDB, user_account.UserAccountCreateRequest{
UserID: usrAcc.UserID,
AccountID: acc3.ID,
Roles: []user_account.UserAccountRole{user_account.UserAccountRole_User},
}, now)
if err != nil {
t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tLinking third account to user failed.", tests.Failed)
}
authTests = append(authTests, authTest{
name: "Root account 1 -> valid account 2 with scopes -> valid account 3 with invalid scope.",
root: usrAcc,
switch1Req: SwitchAccountRequest{AccountID: acc2.ID},
switch1Roles: usrAcc2.Roles,
switch1Scopes: []string{user_account.UserAccountRole_User.String()},
switch1Err: nil,
switch2Req: SwitchAccountRequest{AccountID: acc3.ID},
switch2Roles: usrAcc3.Roles,
switch2Scopes: []string{user_account.UserAccountRole_Admin.String()},
switch2Err: ErrForbidden,
})
}
// Add 30 minutes to now to simulate time passing.
now = now.Add(time.Minute * 5)
tknGen := &auth.MockTokenGenerator{}
t.Log("Given the need to switch accounts.")
{
for i, authTest := range authTests {
t.Logf("\tTest: %d\tWhen running test: %s", i, authTest.name)
{
// Verify that the user can be authenticated with the created user.
var claims1 auth.Claims
tkn1, err := Authenticate(ctx, test.MasterDB, tknGen, authTest.root.User.Email, authTest.root.User.Password, time.Hour, now)
if err != nil {
t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tAuthenticate user failed.", tests.Failed)
} else {
// Ensure the token string was correctly generated.
claims1, err = tknGen.ParseClaims(tkn1.AccessToken)
if err != nil {
t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tParse claims from token failed.", tests.Failed)
}
expectClaims := tkn1.claims
expectClaims.RootUserID = ""
expectClaims.RootAccountID = ""
expectClaims.Subject = authTest.root.UserID
expectClaims.Audience = authTest.root.AccountID
expectClaims.Roles = rolesStringSlice(authTest.root.Roles)
if diff := cmpClaims(claims1, expectClaims); diff != "" {
t.Fatalf("\t%s\tExpected parsed claims to match from token. Diff:\n%s", tests.Failed, diff)
}
}
t.Logf("\t%s\tAuthenticate root user with role %v ok.", tests.Success, authTest.root.Roles)
// Try to switch to account 2.
var claims2 auth.Claims
tkn2, err := SwitchAccount(ctx, test.MasterDB, tknGen, claims1, authTest.switch1Req, time.Hour, now, authTest.switch1Scopes...)
if err != authTest.switch1Err {
if errors.Cause(err) != authTest.switch1Err {
t.Log("\t\tExpected :", authTest.switch1Err)
t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tSwitchAccount account 1 with role %v failed.", tests.Failed, authTest.switch1Roles)
}
} else {
// Ensure the token string was correctly generated.
claims2, err = tknGen.ParseClaims(tkn2.AccessToken)
if err != nil {
t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tParse claims from token failed.", tests.Failed)
}
expectClaims := tkn2.claims
expectClaims.RootUserID = authTest.root.UserID
expectClaims.RootAccountID = authTest.switch1Req.AccountID
expectClaims.Subject = authTest.root.UserID
expectClaims.Audience = authTest.switch1Req.AccountID
if len(authTest.switch1Scopes) > 0 {
expectClaims.Roles = authTest.switch1Scopes
} else {
expectClaims.Roles = rolesStringSlice(authTest.switch1Roles)
}
if diff := cmpClaims(claims2, expectClaims); diff != "" {
t.Fatalf("\t%s\tExpected parsed claims to match from token. Diff:\n%s", tests.Failed, diff)
}
}
t.Logf("\t%s\tSwitchAccount account 1 with role %v ok.", tests.Success, authTest.switch1Roles)
// If the user can't login, don't need to test any further.
if authTest.switch1Err != nil || authTest.switch2Req.AccountID == "" {
continue
}
// Try to switch to account 3.
tkn3, err := SwitchAccount(ctx, test.MasterDB, tknGen, claims2, authTest.switch2Req, time.Hour, now, authTest.switch2Scopes...)
if err != authTest.switch2Err {
if errors.Cause(err) != authTest.switch2Err {
t.Log("\t\tExpected :", authTest.switch2Err)
t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tSwitchAccount account 2 with role %v failed.", tests.Failed, authTest.switch2Roles)
}
} else {
// Ensure the token string was correctly generated.
claims3, err := tknGen.ParseClaims(tkn3.AccessToken)
if err != nil {
t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tParse claims from token failed.", tests.Failed)
}
expectClaims := tkn3.claims
expectClaims.RootUserID = authTest.root.UserID
expectClaims.RootAccountID = authTest.switch2Req.AccountID
expectClaims.Subject = authTest.root.UserID
expectClaims.Audience = authTest.switch2Req.AccountID
if len(authTest.switch2Scopes) > 0 {
expectClaims.Roles = authTest.switch2Scopes
} else {
expectClaims.Roles = rolesStringSlice(authTest.switch2Roles)
}
if diff := cmpClaims(claims3, expectClaims); diff != "" {
t.Fatalf("\t%s\tExpected parsed claims to match from token. Diff:\n%s", tests.Failed, diff)
}
}
t.Logf("\t%s\tSwitchAccount account 2 with role %v ok.", tests.Success, authTest.switch2Roles)
}
}
}
}
// TestVirtualLogin validates the behavior around allowing users to virtual login users.
func TestVirtualLogin(t *testing.T) {
defer tests.Recover(t)
// Auth tokens are valid for an our and is verified against current time.
// Issue the token one hour ago.
now := time.Now().Add(time.Hour * -1)
ctx := tests.Context()
type authTest struct {
name string
root *user_account.MockUserAccountResponse
login1Req VirtualLoginRequest
login1Err error
login1Role user_account.UserAccountRole
login2Req VirtualLoginRequest
login2Err error
login2Role user_account.UserAccountRole
login2Logout bool
}
var authTests []authTest
// Root admin -> role admin -> role admin
if true {
// Create a new user for testing.
usrAcc, err := user_account.MockUserAccount(ctx, test.MasterDB, now, user_account.UserAccountRole_Admin)
if err != nil {
t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tCreate user account failed.", tests.Failed)
}
usr2, err := user.MockUser(ctx, test.MasterDB, now)
if err != nil {
t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tCreate second account failed.", tests.Failed)
}
// Associate second user with basic role associated with the same account.
usrAcc2, err := user_account.Create(ctx, auth.Claims{}, test.MasterDB, user_account.UserAccountCreateRequest{
UserID: usr2.ID,
AccountID: usrAcc.AccountID,
Roles: []user_account.UserAccountRole{user_account.UserAccountRole(user_account.UserAccountRole_Admin)},
}, now)
if err != nil {
t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tLinking second user to account failed.", tests.Failed)
}
usr3, err := user.MockUser(ctx, test.MasterDB, now)
if err != nil {
t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tCreate second account failed.", tests.Failed)
}
// Associate second user with basic role associated with the same account.
usrAcc3, err := user_account.Create(ctx, auth.Claims{}, test.MasterDB, user_account.UserAccountCreateRequest{
UserID: usr3.ID,
AccountID: usrAcc.AccountID,
Roles: []user_account.UserAccountRole{user_account.UserAccountRole(user_account.UserAccountRole_Admin)},
}, now)
if err != nil {
t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tLinking third user to account failed.", tests.Failed)
}
authTests = append(authTests, authTest{
name: "Root admin -> role admin -> role admin",
root: usrAcc,
login1Req: VirtualLoginRequest{
UserID: usr2.ID,
AccountID: usrAcc.AccountID,
},
login1Role: usrAcc2.Roles[0],
login1Err: nil,
login2Req: VirtualLoginRequest{
UserID: usr3.ID,
AccountID: usrAcc.AccountID,
},
login2Err: nil,
login2Role: usrAcc3.Roles[0],
login2Logout: true,
})
}
// Root admin -> role admin -> role user
if true {
// Create a new user for testing.
usrAcc, err := user_account.MockUserAccount(ctx, test.MasterDB, now, user_account.UserAccountRole_Admin)
if err != nil {
t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tCreate user account failed.", tests.Failed)
}
usr2, err := user.MockUser(ctx, test.MasterDB, now)
if err != nil {
t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tCreate second account failed.", tests.Failed)
}
// Associate second user with basic role associated with the same account.
usrAcc2, err := user_account.Create(ctx, auth.Claims{}, test.MasterDB, user_account.UserAccountCreateRequest{
UserID: usr2.ID,
AccountID: usrAcc.AccountID,
Roles: []user_account.UserAccountRole{user_account.UserAccountRole(user_account.UserAccountRole_Admin)},
}, now)
if err != nil {
t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tLinking second user to account failed.", tests.Failed)
}
usr3, err := user.MockUser(ctx, test.MasterDB, now)
if err != nil {
t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tCreate second account failed.", tests.Failed)
}
// Associate second user with basic role associated with the same account.
usrAcc3, err := user_account.Create(ctx, auth.Claims{}, test.MasterDB, user_account.UserAccountCreateRequest{
UserID: usr3.ID,
AccountID: usrAcc.AccountID,
Roles: []user_account.UserAccountRole{user_account.UserAccountRole(user_account.UserAccountRole_User)},
}, now)
if err != nil {
t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tLinking third user to account failed.", tests.Failed)
}
authTests = append(authTests, authTest{
name: "Root admin -> role admin -> role user",
root: usrAcc,
login1Req: VirtualLoginRequest{
UserID: usr2.ID,
AccountID: usrAcc.AccountID,
},
login1Err: nil,
login1Role: usrAcc2.Roles[0],
login2Req: VirtualLoginRequest{
UserID: usr3.ID,
AccountID: usrAcc.AccountID,
},
login2Err: nil,
login2Role: usrAcc3.Roles[0],
login2Logout: true,
})
}
// Root admin -> role user -> role admin
if true {
// Create a new user for testing.
usrAcc, err := user_account.MockUserAccount(ctx, test.MasterDB, now, user_account.UserAccountRole_Admin)
if err != nil {
t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tCreate user account failed.", tests.Failed)
}
usr2, err := user.MockUser(ctx, test.MasterDB, now)
if err != nil {
t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tCreate second account failed.", tests.Failed)
}
// Associate second user with basic role associated with the same account.
usrAcc2, err := user_account.Create(ctx, auth.Claims{}, test.MasterDB, user_account.UserAccountCreateRequest{
UserID: usr2.ID,
AccountID: usrAcc.AccountID,
Roles: []user_account.UserAccountRole{user_account.UserAccountRole(user_account.UserAccountRole_User)},
}, now)
if err != nil {
t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tLinking second user to account failed.", tests.Failed)
}
usr3, err := user.MockUser(ctx, test.MasterDB, now)
if err != nil {
t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tCreate second account failed.", tests.Failed)
}
// Associate second user with basic role associated with the same account.
usrAcc3, err := user_account.Create(ctx, auth.Claims{}, test.MasterDB, user_account.UserAccountCreateRequest{
UserID: usr3.ID,
AccountID: usrAcc.AccountID,
Roles: []user_account.UserAccountRole{user_account.UserAccountRole(user_account.UserAccountRole_Admin)},
}, now)
if err != nil {
t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tLinking third user to account failed.", tests.Failed)
}
authTests = append(authTests, authTest{
name: "Root admin -> role user -> role admin",
root: usrAcc,
login1Req: VirtualLoginRequest{
UserID: usr2.ID,
AccountID: usrAcc.AccountID,
},
login1Err: nil,
login1Role: usrAcc2.Roles[0],
login2Req: VirtualLoginRequest{
UserID: usr3.ID,
AccountID: usrAcc.AccountID,
},
login2Err: ErrForbidden,
login2Role: usrAcc3.Roles[0],
login2Logout: true,
})
}
// Root user -> role admin
if true {
// Create a new user for testing.
usrAcc, err := user_account.MockUserAccount(ctx, test.MasterDB, now, user_account.UserAccountRole_User)
if err != nil {
t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tCreate user account failed.", tests.Failed)
}
usr2, err := user.MockUser(ctx, test.MasterDB, now)
if err != nil {
t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tCreate second account failed.", tests.Failed)
}
// Associate second user with basic role associated with the same account.
usrAcc2, err := user_account.Create(ctx, auth.Claims{}, test.MasterDB, user_account.UserAccountCreateRequest{
UserID: usr2.ID,
AccountID: usrAcc.AccountID,
Roles: []user_account.UserAccountRole{user_account.UserAccountRole(user_account.UserAccountRole_Admin)},
}, now)
if err != nil {
t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tLinking second user to account failed.", tests.Failed)
}
authTests = append(authTests, authTest{
name: "Root user -> role admin",
root: usrAcc,
login1Req: VirtualLoginRequest{
UserID: usr2.ID,
AccountID: usrAcc.AccountID,
},
login1Err: ErrForbidden,
login1Role: usrAcc2.Roles[0],
login2Logout: true,
})
}
// Root user -> role user
if true {
// Create a new user for testing.
usrAcc, err := user_account.MockUserAccount(ctx, test.MasterDB, now, user_account.UserAccountRole_User)
if err != nil {
t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tCreate user account failed.", tests.Failed)
}
usr2, err := user.MockUser(ctx, test.MasterDB, now)
if err != nil {
t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tCreate second account failed.", tests.Failed)
}
// Associate second user with basic role associated with the same account.
usrAcc2, err := user_account.Create(ctx, auth.Claims{}, test.MasterDB, user_account.UserAccountCreateRequest{
UserID: usr2.ID,
AccountID: usrAcc.AccountID,
Roles: []user_account.UserAccountRole{user_account.UserAccountRole(user_account.UserAccountRole_User)},
}, now)
if err != nil {
t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tLinking second user to account failed.", tests.Failed)
}
authTests = append(authTests, authTest{
name: "Root user -> role admin",
root: usrAcc,
login1Req: VirtualLoginRequest{
UserID: usr2.ID,
AccountID: usrAcc.AccountID,
},
login1Err: ErrForbidden,
login1Role: usrAcc2.Roles[0],
login2Logout: true,
})
}
// Add 30 minutes to now to simulate time passing.
now = now.Add(time.Minute * 5)
tknGen := &auth.MockTokenGenerator{}
t.Log("Given the need to virtual login.")
{
for i, authTest := range authTests {
t.Logf("\tTest: %d\tWhen running test: %s", i, authTest.name)
{
// Verify that the user can be authenticated with the created user.
var claims1 auth.Claims
tkn1, err := Authenticate(ctx, test.MasterDB, tknGen, authTest.root.User.Email, authTest.root.User.Password, time.Hour, now)
if err != nil {
t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tAuthenticate user failed.", tests.Failed)
} else {
// Ensure the token string was correctly generated.
claims1, err = tknGen.ParseClaims(tkn1.AccessToken)
if err != nil {
t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tParse claims from token failed.", tests.Failed)
}
expectClaims := tkn1.claims
expectClaims.RootUserID = ""
expectClaims.RootAccountID = ""
expectClaims.Subject = authTest.root.UserID
expectClaims.Audience = authTest.root.AccountID
// Hack for Unhandled Exception in go-cmp@v0.3.0/cmp/options.go:229
if diff := cmpClaims(claims1, expectClaims); diff != "" {
t.Fatalf("\t%s\tExpected parsed claims to match from token. Diff:\n%s", tests.Failed, diff)
}
}
t.Logf("\t%s\tAuthenticate root user with role %s ok.", tests.Success, authTest.root.Roles[0])
// Try virtual login to user 2.
var claims2 auth.Claims
tkn2, err := VirtualLogin(ctx, test.MasterDB, tknGen, claims1, authTest.login1Req, time.Hour, now)
if err != authTest.login1Err {
if errors.Cause(err) != authTest.login1Err {
t.Log("\t\tExpected :", authTest.login1Err)
t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tVirtualLogin user 1 with role %s failed.", tests.Failed, authTest.login1Role)
}
} else {
// Ensure the token string was correctly generated.
claims2, err = tknGen.ParseClaims(tkn2.AccessToken)
if err != nil {
t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tParse claims from token failed.", tests.Failed)
}
expectClaims := tkn2.claims
expectClaims.RootUserID = authTest.root.UserID
expectClaims.RootAccountID = authTest.root.AccountID
expectClaims.Subject = authTest.login1Req.UserID
expectClaims.Audience = authTest.login1Req.AccountID
// Hack for Unhandled Exception in go-cmp@v0.3.0/cmp/options.go:229
if diff := cmpClaims(claims2, expectClaims); diff != "" {
t.Fatalf("\t%s\tExpected parsed claims to match from token. Diff:\n%s", tests.Failed, diff)
}
}
t.Logf("\t%s\tVirtualLogin user 1 with role %s ok.", tests.Success, authTest.login1Role)
// If the user can't login, don't need to test any further.
if authTest.login1Err != nil {
continue
}
// Try virtual login to user 3.
tkn3, err := VirtualLogin(ctx, test.MasterDB, tknGen, claims2, authTest.login2Req, time.Hour, now)
if err != authTest.login2Err {
if errors.Cause(err) != authTest.login2Err {
t.Log("\t\tExpected :", authTest.login2Err)
t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tVirtualLogin user 2 with role %s failed.", tests.Failed, authTest.login2Role)
}
} else {
// Ensure the token string was correctly generated.
claims3, err := tknGen.ParseClaims(tkn3.AccessToken)
if err != nil {
t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tParse claims from token failed.", tests.Failed)
}
expectClaims := tkn3.claims
expectClaims.RootUserID = authTest.root.UserID
expectClaims.RootAccountID = authTest.root.AccountID
expectClaims.Subject = authTest.login2Req.UserID
expectClaims.Audience = authTest.login2Req.AccountID
// Hack for Unhandled Exception in go-cmp@v0.3.0/cmp/options.go:229
if diff := cmpClaims(claims3, expectClaims); diff != "" {
t.Fatalf("\t%s\tExpected parsed claims to match from token. Diff:\n%s", tests.Failed, diff)
}
}
t.Logf("\t%s\tVirtualLogin user 2 with role %s ok.", tests.Success, authTest.login2Role)
if authTest.login2Logout {
tknOut, err := VirtualLogout(ctx, test.MasterDB, tknGen, claims2, time.Hour, now)
if err != nil {
t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tVirtualLogout user 2 failed.", tests.Failed)
}
// Ensure the token string was correctly generated.
claimsOut, err := tknGen.ParseClaims(tknOut.AccessToken)
if err != nil {
t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tParse claims from token failed.", tests.Failed)
}
expectClaims := tknOut.claims
expectClaims.RootUserID = authTest.root.UserID
expectClaims.RootAccountID = authTest.root.AccountID
expectClaims.Subject = authTest.root.UserID
expectClaims.Audience = authTest.root.AccountID
if diff := cmpClaims(claimsOut, expectClaims); diff != "" {
t.Fatalf("\t%s\tExpected parsed claims to match from token. Diff:\n%s", tests.Failed, diff)
}
t.Logf("\t%s\tVirtualLogout user 2 with role %s ok.", tests.Success, authTest.login2Role)
}
}
}
}
}
// rolesStringSlice converts a list of roles to a string slice.
func rolesStringSlice(roles []user_account.UserAccountRole) []string {
var l []string
for _, r := range roles {
l = append(l, string(r))
}
return l
}
// cmpClaims is a hack for Unhandled Exception in go-cmp@v0.3.0/cmp/options.go:229
func cmpClaims(actualClaims, expectedclaims auth.Claims) string {
dat1, _ := json.Marshal(actualClaims)
dat2, _ := json.Marshal(expectedclaims)
return cmp.Diff(string(dat1), string(dat2))
}

View File

@ -35,6 +35,17 @@ type Token struct {
AccountID string `json:"account_id"example:"c4653bf9-5978-48b7-89c5-95704aebb7e2"`
}
// SwitchAccountRequest defines the information for the current user to switch between their accounts
type SwitchAccountRequest struct {
AccountID string `json:"account_id" validate:"required,uuid" example:"c4653bf9-5978-48b7-89c5-95704aebb7e2"`
}
// VirtualLoginRequest defines the information virtual login to a user / account.
type VirtualLoginRequest struct {
UserID string `json:"user_id" validate:"required,uuid" example:"d69bdef7-173f-4d29-b52c-3edc60baf6a2"`
AccountID string `json:"account_id" validate:"required,uuid" example:"c4653bf9-5978-48b7-89c5-95704aebb7e2"`
}
// AuthorizationHeader returns the header authorization value.
func (t Token) AuthorizationHeader() string {
return "Bearer " + t.AccessToken