You've already forked golang-saas-starter-kit
mirror of
https://github.com/raseels-repos/golang-saas-starter-kit.git
synced 2025-06-15 00:15:15 +02:00
Completed virtual user login and switch accounts
This commit is contained in:
@ -4,7 +4,6 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"geeks-accelerator/oss/saas-starter-kit/internal/user_auth"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"testing"
|
"testing"
|
||||||
@ -17,6 +16,7 @@ import (
|
|||||||
"geeks-accelerator/oss/saas-starter-kit/internal/platform/web/weberror"
|
"geeks-accelerator/oss/saas-starter-kit/internal/platform/web/weberror"
|
||||||
"geeks-accelerator/oss/saas-starter-kit/internal/user"
|
"geeks-accelerator/oss/saas-starter-kit/internal/user"
|
||||||
"geeks-accelerator/oss/saas-starter-kit/internal/user_account"
|
"geeks-accelerator/oss/saas-starter-kit/internal/user_account"
|
||||||
|
"geeks-accelerator/oss/saas-starter-kit/internal/user_auth"
|
||||||
"github.com/pborman/uuid"
|
"github.com/pborman/uuid"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -19,8 +19,9 @@ import (
|
|||||||
|
|
||||||
// Account represents the Account API method handler set.
|
// Account represents the Account API method handler set.
|
||||||
type Account struct {
|
type Account struct {
|
||||||
MasterDB *sqlx.DB
|
MasterDB *sqlx.DB
|
||||||
Renderer web.Renderer
|
Renderer web.Renderer
|
||||||
|
Authenticator *auth.Authenticator
|
||||||
}
|
}
|
||||||
|
|
||||||
// View handles displaying the current account profile.
|
// View handles displaying the current account profile.
|
||||||
@ -34,7 +35,7 @@ func (h *Account) View(ctx context.Context, w http.ResponseWriter, r *http.Reque
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
acc, err := account.Read(ctx, claims, h.MasterDB, claims.Audience, false)
|
acc, err := account.ReadByID(ctx, claims, h.MasterDB, claims.Audience)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -127,7 +128,11 @@ func (h *Account) Update(ctx context.Context, w http.ResponseWriter, r *http.Req
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
sess := webcontext.ContextSession(ctx)
|
var updateClaims bool
|
||||||
|
if req.Timezone != nil && claims.Preferences.Timezone != *req.Timezone {
|
||||||
|
claims.Preferences.Timezone = *req.Timezone
|
||||||
|
updateClaims = true
|
||||||
|
}
|
||||||
|
|
||||||
if preferenceDatetimeFormat != req.PreferenceDatetimeFormat {
|
if preferenceDatetimeFormat != req.PreferenceDatetimeFormat {
|
||||||
err = account_preference.Set(ctx, claims, h.MasterDB, account_preference.AccountPreferenceSetRequest{
|
err = account_preference.Set(ctx, claims, h.MasterDB, account_preference.AccountPreferenceSetRequest{
|
||||||
@ -144,7 +149,10 @@ func (h *Account) Update(ctx context.Context, w http.ResponseWriter, r *http.Req
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
sess.Values[webcontext.SessionKeyPreferenceDatetimeFormat] = req.PreferenceDatetimeFormat
|
if claims.Preferences.DatetimeFormat != req.PreferenceDatetimeFormat {
|
||||||
|
claims.Preferences.DatetimeFormat = req.PreferenceDatetimeFormat
|
||||||
|
updateClaims = true
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if preferenceDateFormat != req.PreferenceDateFormat {
|
if preferenceDateFormat != req.PreferenceDateFormat {
|
||||||
@ -162,7 +170,10 @@ func (h *Account) Update(ctx context.Context, w http.ResponseWriter, r *http.Req
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
sess.Values[webcontext.SessionKeyPreferenceDateFormat] = req.PreferenceDateFormat
|
if claims.Preferences.DateFormat != req.PreferenceDateFormat {
|
||||||
|
claims.Preferences.DateFormat = req.PreferenceDateFormat
|
||||||
|
updateClaims = true
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if preferenceTimeFormat != req.PreferenceTimeFormat {
|
if preferenceTimeFormat != req.PreferenceTimeFormat {
|
||||||
@ -180,7 +191,18 @@ func (h *Account) Update(ctx context.Context, w http.ResponseWriter, r *http.Req
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
sess.Values[webcontext.SessionKeyPreferenceTimeFormat] = req.PreferenceTimeFormat
|
if claims.Preferences.TimeFormat != req.PreferenceTimeFormat {
|
||||||
|
claims.Preferences.TimeFormat = req.PreferenceTimeFormat
|
||||||
|
updateClaims = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update the access token to include the updated claims.
|
||||||
|
if updateClaims {
|
||||||
|
ctx, err = updateContextClaims(ctx, h.Authenticator, claims)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Display a success message to the user.
|
// Display a success message to the user.
|
||||||
@ -196,7 +218,7 @@ func (h *Account) Update(ctx context.Context, w http.ResponseWriter, r *http.Req
|
|||||||
return true, nil
|
return true, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
acc, err := account.Read(ctx, claims, h.MasterDB, claims.Audience, false)
|
acc, err := account.ReadByID(ctx, claims, h.MasterDB, claims.Audience)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
|
@ -66,13 +66,21 @@ func APP(shutdown chan os.Signal, log *log.Logger, env webcontext.Env, staticDir
|
|||||||
app.Handle("POST", "/user/update", u.Update, mid.AuthenticateSessionRequired(authenticator), mid.HasAuth())
|
app.Handle("POST", "/user/update", u.Update, mid.AuthenticateSessionRequired(authenticator), mid.HasAuth())
|
||||||
app.Handle("GET", "/user/update", u.Update, mid.AuthenticateSessionRequired(authenticator), mid.HasAuth())
|
app.Handle("GET", "/user/update", u.Update, mid.AuthenticateSessionRequired(authenticator), mid.HasAuth())
|
||||||
app.Handle("GET", "/user/account", u.Account, mid.AuthenticateSessionRequired(authenticator), mid.HasAuth())
|
app.Handle("GET", "/user/account", u.Account, mid.AuthenticateSessionRequired(authenticator), mid.HasAuth())
|
||||||
|
app.Handle("GET", "/user/virtual-login/:user_id", u.VirtualLogin, mid.AuthenticateSessionRequired(authenticator), mid.HasRole(auth.RoleAdmin))
|
||||||
|
app.Handle("POST", "/user/virtual-login", u.VirtualLogin, mid.AuthenticateSessionRequired(authenticator), mid.HasRole(auth.RoleAdmin))
|
||||||
|
app.Handle("GET", "/user/virtual-login", u.VirtualLogin, mid.AuthenticateSessionRequired(authenticator), mid.HasRole(auth.RoleAdmin))
|
||||||
|
app.Handle("GET", "/user/virtual-logout", u.VirtualLogout, mid.AuthenticateSessionRequired(authenticator), mid.HasAuth())
|
||||||
|
app.Handle("GET", "/user/switch-account/:account_id", u.SwitchAccount, mid.AuthenticateSessionRequired(authenticator), mid.HasAuth())
|
||||||
|
app.Handle("POST", "/user/switch-account", u.SwitchAccount, mid.AuthenticateSessionRequired(authenticator), mid.HasAuth())
|
||||||
|
app.Handle("GET", "/user/switch-account", u.SwitchAccount, mid.AuthenticateSessionRequired(authenticator), mid.HasAuth())
|
||||||
app.Handle("POST", "/user", u.View, mid.AuthenticateSessionRequired(authenticator), mid.HasAuth())
|
app.Handle("POST", "/user", u.View, mid.AuthenticateSessionRequired(authenticator), mid.HasAuth())
|
||||||
app.Handle("GET", "/user", u.View, mid.AuthenticateSessionRequired(authenticator), mid.HasAuth())
|
app.Handle("GET", "/user", u.View, mid.AuthenticateSessionRequired(authenticator), mid.HasAuth())
|
||||||
|
|
||||||
// Register account management endpoints.
|
// Register account management endpoints.
|
||||||
acc := Account{
|
acc := Account{
|
||||||
MasterDB: masterDB,
|
MasterDB: masterDB,
|
||||||
Renderer: renderer,
|
Renderer: renderer,
|
||||||
|
Authenticator: authenticator,
|
||||||
}
|
}
|
||||||
app.Handle("POST", "/account/update", acc.Update, mid.AuthenticateSessionRequired(authenticator), mid.HasRole(auth.RoleAdmin))
|
app.Handle("POST", "/account/update", acc.Update, mid.AuthenticateSessionRequired(authenticator), mid.HasRole(auth.RoleAdmin))
|
||||||
app.Handle("GET", "/account/update", acc.Update, mid.AuthenticateSessionRequired(authenticator), mid.HasRole(auth.RoleAdmin))
|
app.Handle("GET", "/account/update", acc.Update, mid.AuthenticateSessionRequired(authenticator), mid.HasRole(auth.RoleAdmin))
|
||||||
|
@ -12,7 +12,7 @@ import (
|
|||||||
"geeks-accelerator/oss/saas-starter-kit/internal/platform/web/webcontext"
|
"geeks-accelerator/oss/saas-starter-kit/internal/platform/web/webcontext"
|
||||||
"geeks-accelerator/oss/saas-starter-kit/internal/platform/web/weberror"
|
"geeks-accelerator/oss/saas-starter-kit/internal/platform/web/weberror"
|
||||||
"geeks-accelerator/oss/saas-starter-kit/internal/signup"
|
"geeks-accelerator/oss/saas-starter-kit/internal/signup"
|
||||||
"geeks-accelerator/oss/saas-starter-kit/internal/user"
|
"geeks-accelerator/oss/saas-starter-kit/internal/user_auth"
|
||||||
"github.com/gorilla/schema"
|
"github.com/gorilla/schema"
|
||||||
"github.com/jmoiron/sqlx"
|
"github.com/jmoiron/sqlx"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
@ -68,7 +68,7 @@ func (h *Signup) Step1(ctx context.Context, w http.ResponseWriter, r *http.Reque
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Authenticated the new user.
|
// Authenticated the new user.
|
||||||
token, err := user.Authenticate(ctx, h.MasterDB, h.Authenticator, req.User.Email, req.User.Password, time.Hour, ctxValues.Now)
|
token, err := user_auth.Authenticate(ctx, h.MasterDB, h.Authenticator, req.User.Email, req.User.Password, time.Hour, ctxValues.Now)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -93,13 +93,6 @@ func (h *Signup) Step1(ctx context.Context, w http.ResponseWriter, r *http.Reque
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
data["geonameCountries"] = geonames.ValidGeonameCountries
|
|
||||||
|
|
||||||
data["countries"], err = geonames.FindCountries(ctx, h.MasterDB, "name", "")
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -107,6 +100,13 @@ func (h *Signup) Step1(ctx context.Context, w http.ResponseWriter, r *http.Reque
|
|||||||
return web.RenderError(ctx, w, r, err, h.Renderer, TmplLayoutBase, TmplContentErrorGeneric, web.MIMETextHTMLCharsetUTF8)
|
return web.RenderError(ctx, w, r, err, h.Renderer, TmplLayoutBase, TmplContentErrorGeneric, web.MIMETextHTMLCharsetUTF8)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
data["geonameCountries"] = geonames.ValidGeonameCountries
|
||||||
|
|
||||||
|
data["countries"], err = geonames.FindCountries(ctx, h.MasterDB, "name", "")
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
data["form"] = req
|
data["form"] = req
|
||||||
|
|
||||||
if verr, ok := weberror.NewValidationError(ctx, webcontext.Validator().Struct(signup.SignupRequest{})); ok {
|
if verr, ok := weberror.NewValidationError(ctx, webcontext.Validator().Struct(signup.SignupRequest{})); ok {
|
||||||
|
@ -4,6 +4,8 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"geeks-accelerator/oss/saas-starter-kit/internal/account"
|
"geeks-accelerator/oss/saas-starter-kit/internal/account"
|
||||||
@ -16,6 +18,7 @@ import (
|
|||||||
project_routes "geeks-accelerator/oss/saas-starter-kit/internal/project-routes"
|
project_routes "geeks-accelerator/oss/saas-starter-kit/internal/project-routes"
|
||||||
"geeks-accelerator/oss/saas-starter-kit/internal/user"
|
"geeks-accelerator/oss/saas-starter-kit/internal/user"
|
||||||
"geeks-accelerator/oss/saas-starter-kit/internal/user_account"
|
"geeks-accelerator/oss/saas-starter-kit/internal/user_account"
|
||||||
|
"geeks-accelerator/oss/saas-starter-kit/internal/user_auth"
|
||||||
"github.com/gorilla/schema"
|
"github.com/gorilla/schema"
|
||||||
"github.com/gorilla/sessions"
|
"github.com/gorilla/sessions"
|
||||||
"github.com/jmoiron/sqlx"
|
"github.com/jmoiron/sqlx"
|
||||||
@ -34,7 +37,7 @@ type User struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type UserLoginRequest struct {
|
type UserLoginRequest struct {
|
||||||
user.AuthenticateRequest
|
user_auth.AuthenticateRequest
|
||||||
RememberMe bool
|
RememberMe bool
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -68,7 +71,7 @@ func (h *User) Login(ctx context.Context, w http.ResponseWriter, r *http.Request
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Authenticated the user.
|
// Authenticated the user.
|
||||||
token, err := user.Authenticate(ctx, h.MasterDB, h.Authenticator, req.Email, req.Password, sessionTTL, ctxValues.Now)
|
token, err := user_auth.Authenticate(ctx, h.MasterDB, h.Authenticator, req.Email, req.Password, sessionTTL, ctxValues.Now)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
switch errors.Cause(err) {
|
switch errors.Cause(err) {
|
||||||
case user.ErrForbidden:
|
case user.ErrForbidden:
|
||||||
@ -89,8 +92,16 @@ func (h *User) Login(ctx context.Context, w http.ResponseWriter, r *http.Request
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
redirectUri := "/"
|
||||||
|
if qv := r.URL.Query().Get("redirect"); qv != "" {
|
||||||
|
redirectUri, err = url.QueryUnescape(qv)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Redirect the user to the dashboard.
|
// Redirect the user to the dashboard.
|
||||||
http.Redirect(w, r, "/", http.StatusFound)
|
http.Redirect(w, r, redirectUri, http.StatusFound)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@ -110,7 +121,7 @@ func (h *User) Login(ctx context.Context, w http.ResponseWriter, r *http.Request
|
|||||||
}
|
}
|
||||||
|
|
||||||
// handleSessionToken persists the access token to the session for request authentication.
|
// handleSessionToken persists the access token to the session for request authentication.
|
||||||
func handleSessionToken(ctx context.Context, db *sqlx.DB, w http.ResponseWriter, r *http.Request, token user.Token) error {
|
func handleSessionToken(ctx context.Context, db *sqlx.DB, w http.ResponseWriter, r *http.Request, token user_auth.Token) error {
|
||||||
if token.AccessToken == "" {
|
if token.AccessToken == "" {
|
||||||
return errors.New("accessToken is required.")
|
return errors.New("accessToken is required.")
|
||||||
}
|
}
|
||||||
@ -252,7 +263,7 @@ func (h *User) ResetConfirm(ctx context.Context, w http.ResponseWriter, r *http.
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Authenticated the user. Probably should use the default session TTL from UserLogin.
|
// Authenticated the user. Probably should use the default session TTL from UserLogin.
|
||||||
token, err := user.Authenticate(ctx, h.MasterDB, h.Authenticator, u.Email, req.Password, time.Hour, ctxValues.Now)
|
token, err := user_auth.Authenticate(ctx, h.MasterDB, h.Authenticator, u.Email, req.Password, time.Hour, ctxValues.Now)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
switch errors.Cause(err) {
|
switch errors.Cause(err) {
|
||||||
case account.ErrForbidden:
|
case account.ErrForbidden:
|
||||||
@ -306,7 +317,7 @@ func (h *User) View(ctx context.Context, w http.ResponseWriter, r *http.Request,
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
usr, err := user.Read(ctx, claims, h.MasterDB, claims.Subject, false)
|
usr, err := user.ReadByID(ctx, claims, h.MasterDB, claims.Subject)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -343,16 +354,15 @@ func (h *User) Update(ctx context.Context, w http.ResponseWriter, r *http.Reques
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
claims, err := auth.ClaimsFromContext(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
req := new(user.UserUpdateRequest)
|
req := new(user.UserUpdateRequest)
|
||||||
data := make(map[string]interface{})
|
data := make(map[string]interface{})
|
||||||
f := func() (bool, error) {
|
f := func() (bool, error) {
|
||||||
|
|
||||||
claims, err := auth.ClaimsFromContext(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if r.Method == http.MethodPost {
|
if r.Method == http.MethodPost {
|
||||||
err := r.ParseForm()
|
err := r.ParseForm()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -415,25 +425,6 @@ func (h *User) Update(ctx context.Context, w http.ResponseWriter, r *http.Reques
|
|||||||
return true, nil
|
return true, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
usr, err := user.Read(ctx, claims, h.MasterDB, claims.Subject, false)
|
|
||||||
if err != nil {
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if req.ID == "" {
|
|
||||||
req.FirstName = &usr.FirstName
|
|
||||||
req.LastName = &usr.LastName
|
|
||||||
req.Email = &usr.Email
|
|
||||||
req.Timezone = &usr.Timezone
|
|
||||||
}
|
|
||||||
|
|
||||||
data["user"] = usr.Response(ctx)
|
|
||||||
|
|
||||||
data["timezones"], err = geonames.ListTimezones(ctx, h.MasterDB)
|
|
||||||
if err != nil {
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -444,6 +435,25 @@ func (h *User) Update(ctx context.Context, w http.ResponseWriter, r *http.Reques
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
usr, err := user.ReadByID(ctx, claims, h.MasterDB, claims.Subject)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.ID == "" {
|
||||||
|
req.FirstName = &usr.FirstName
|
||||||
|
req.LastName = &usr.LastName
|
||||||
|
req.Email = &usr.Email
|
||||||
|
req.Timezone = &usr.Timezone
|
||||||
|
}
|
||||||
|
|
||||||
|
data["user"] = usr.Response(ctx)
|
||||||
|
|
||||||
|
data["timezones"], err = geonames.ListTimezones(ctx, h.MasterDB)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
data["form"] = req
|
data["form"] = req
|
||||||
|
|
||||||
if verr, ok := weberror.NewValidationError(ctx, webcontext.Validator().Struct(user.UserUpdateRequest{})); ok {
|
if verr, ok := weberror.NewValidationError(ctx, webcontext.Validator().Struct(user.UserUpdateRequest{})); ok {
|
||||||
@ -468,7 +478,7 @@ func (h *User) Account(ctx context.Context, w http.ResponseWriter, r *http.Reque
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
acc, err := account.Read(ctx, claims, h.MasterDB, claims.Audience, false)
|
acc, err := account.ReadByID(ctx, claims, h.MasterDB, claims.Audience)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -483,3 +493,352 @@ func (h *User) Account(ctx context.Context, w http.ResponseWriter, r *http.Reque
|
|||||||
|
|
||||||
return h.Renderer.Render(ctx, w, r, TmplLayoutBase, "user-account.gohtml", web.MIMETextHTMLCharsetUTF8, http.StatusOK, data)
|
return h.Renderer.Render(ctx, w, r, TmplLayoutBase, "user-account.gohtml", web.MIMETextHTMLCharsetUTF8, http.StatusOK, data)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// VirtualLogin handles switching the scope of the context to another user.
|
||||||
|
func (h *User) VirtualLogin(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error {
|
||||||
|
|
||||||
|
ctxValues, err := webcontext.ContextValues(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
claims, err := auth.ClaimsFromContext(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
req := new(user_auth.VirtualLoginRequest)
|
||||||
|
data := make(map[string]interface{})
|
||||||
|
f := func() (bool, error) {
|
||||||
|
if r.Method == http.MethodPost {
|
||||||
|
err := r.ParseForm()
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
decoder := schema.NewDecoder()
|
||||||
|
decoder.IgnoreUnknownKeys(true)
|
||||||
|
|
||||||
|
if err := decoder.Decode(req, r.PostForm); err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if pv, ok := params["user_id"]; ok && pv != "" {
|
||||||
|
req.UserID = pv
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if qv := r.URL.Query().Get("account_id"); qv != "" {
|
||||||
|
req.AccountID = qv
|
||||||
|
} else {
|
||||||
|
req.AccountID = claims.Audience
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.UserID != "" {
|
||||||
|
sess := webcontext.ContextSession(ctx)
|
||||||
|
var expires time.Duration
|
||||||
|
if sess != nil && sess.Options != nil {
|
||||||
|
expires = time.Second * time.Duration(sess.Options.MaxAge)
|
||||||
|
} else {
|
||||||
|
expires = time.Hour
|
||||||
|
}
|
||||||
|
|
||||||
|
// Perform the account switch.
|
||||||
|
tkn, err := user_auth.VirtualLogin(ctx, h.MasterDB, h.Authenticator, claims, *req, expires, ctxValues.Now)
|
||||||
|
if err != nil {
|
||||||
|
if verr, ok := weberror.NewValidationError(ctx, err); ok {
|
||||||
|
data["validationErrors"] = verr.(*weberror.Error)
|
||||||
|
return false, nil
|
||||||
|
} else {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update the access token in the session.
|
||||||
|
sess = webcontext.SessionUpdateAccessToken(sess, tkn.AccessToken)
|
||||||
|
|
||||||
|
// Read the account for a flash message.
|
||||||
|
usr, err := user.ReadByID(ctx, claims, h.MasterDB, tkn.UserID)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
webcontext.SessionFlashSuccess(ctx,
|
||||||
|
"User Switched",
|
||||||
|
fmt.Sprintf("You are now virtually logged into user %s.",
|
||||||
|
usr.Response(ctx).Name))
|
||||||
|
|
||||||
|
// Write the session to the client.
|
||||||
|
err = webcontext.ContextSession(ctx).Save(r, w)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Redirect the user to the dashboard with the new credentials.
|
||||||
|
http.Redirect(w, r, "/", http.StatusFound)
|
||||||
|
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
end, err := f()
|
||||||
|
if err != nil {
|
||||||
|
return web.RenderError(ctx, w, r, err, h.Renderer, TmplLayoutBase, TmplContentErrorGeneric, web.MIMETextHTMLCharsetUTF8)
|
||||||
|
} else if end {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
usrAccFilter := "account_id = ?"
|
||||||
|
usrAccs, err := user_account.Find(ctx, claims, h.MasterDB, user_account.UserAccountFindRequest{
|
||||||
|
Where: &usrAccFilter,
|
||||||
|
Args: []interface{}{claims.Audience},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
var userIDs []interface{}
|
||||||
|
var userPhs []string
|
||||||
|
for _, usrAcc := range usrAccs {
|
||||||
|
if usrAcc.UserID == claims.Subject {
|
||||||
|
// Skip the current authenticated user.
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
userIDs = append(userIDs, usrAcc.UserID)
|
||||||
|
userPhs = append(userPhs, "?")
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(userIDs) == 0 {
|
||||||
|
userIDs = append(userIDs, "")
|
||||||
|
userPhs = append(userPhs, "?")
|
||||||
|
}
|
||||||
|
|
||||||
|
usrFilter := fmt.Sprintf("id IN (%s)", strings.Join(userPhs, ", "))
|
||||||
|
users, err := user.Find(ctx, claims, h.MasterDB, user.UserFindRequest{
|
||||||
|
Where: &usrFilter,
|
||||||
|
Args: userIDs,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
data["users"] = users.Response(ctx)
|
||||||
|
|
||||||
|
if req.AccountID == "" {
|
||||||
|
req.AccountID = claims.Audience
|
||||||
|
}
|
||||||
|
|
||||||
|
data["form"] = req
|
||||||
|
|
||||||
|
if verr, ok := weberror.NewValidationError(ctx, webcontext.Validator().Struct(user_auth.VirtualLoginRequest{})); ok {
|
||||||
|
data["validationDefaults"] = verr.(*weberror.Error)
|
||||||
|
}
|
||||||
|
|
||||||
|
return h.Renderer.Render(ctx, w, r, TmplLayoutBase, "user-virtual-login.gohtml", web.MIMETextHTMLCharsetUTF8, http.StatusOK, data)
|
||||||
|
}
|
||||||
|
|
||||||
|
// VirtualLogout handles switching the scope back to the user who initiated the virtual login.
|
||||||
|
func (h *User) VirtualLogout(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error {
|
||||||
|
|
||||||
|
ctxValues, err := webcontext.ContextValues(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
claims, err := auth.ClaimsFromContext(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
sess := webcontext.ContextSession(ctx)
|
||||||
|
|
||||||
|
var expires time.Duration
|
||||||
|
if sess != nil && sess.Options != nil {
|
||||||
|
expires = time.Second * time.Duration(sess.Options.MaxAge)
|
||||||
|
} else {
|
||||||
|
expires = time.Hour
|
||||||
|
}
|
||||||
|
|
||||||
|
tkn, err := user_auth.VirtualLogout(ctx, h.MasterDB, h.Authenticator, claims, expires, ctxValues.Now)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update the access token in the session.
|
||||||
|
sess = webcontext.SessionUpdateAccessToken(sess, tkn.AccessToken)
|
||||||
|
|
||||||
|
// Display a success message to verify the user has switched contexts.
|
||||||
|
if claims.Subject != tkn.UserID && claims.Audience != tkn.AccountID {
|
||||||
|
usr, err := user.ReadByID(ctx, claims, h.MasterDB, tkn.UserID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
acc, err := account.ReadByID(ctx, claims, h.MasterDB, tkn.AccountID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
webcontext.SessionFlashSuccess(ctx,
|
||||||
|
"Context Switched",
|
||||||
|
fmt.Sprintf("You are now virtually logged back into account %s user %s.",
|
||||||
|
acc.Response(ctx).Name, usr.Response(ctx).Name))
|
||||||
|
} else if claims.Audience != tkn.AccountID {
|
||||||
|
acc, err := account.ReadByID(ctx, claims, h.MasterDB, tkn.AccountID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
webcontext.SessionFlashSuccess(ctx,
|
||||||
|
"Context Switched",
|
||||||
|
fmt.Sprintf("You are now virtually logged back into account %s.",
|
||||||
|
acc.Response(ctx).Name))
|
||||||
|
} else {
|
||||||
|
usr, err := user.ReadByID(ctx, claims, h.MasterDB, tkn.UserID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
webcontext.SessionFlashSuccess(ctx,
|
||||||
|
"Context Switched",
|
||||||
|
fmt.Sprintf("You are now virtually logged back into user %s.",
|
||||||
|
usr.Response(ctx).Name))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write the session to the client.
|
||||||
|
err = webcontext.ContextSession(ctx).Save(r, w)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Redirect the user to the dashboard with the new credentials.
|
||||||
|
http.Redirect(w, r, "/", http.StatusFound)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// VirtualLogin handles switching the scope of the context to another user.
|
||||||
|
func (h *User) SwitchAccount(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error {
|
||||||
|
|
||||||
|
ctxValues, err := webcontext.ContextValues(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
claims, err := auth.ClaimsFromContext(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
req := new(user_auth.SwitchAccountRequest)
|
||||||
|
data := make(map[string]interface{})
|
||||||
|
f := func() (bool, error) {
|
||||||
|
|
||||||
|
if r.Method == http.MethodPost {
|
||||||
|
err := r.ParseForm()
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
decoder := schema.NewDecoder()
|
||||||
|
decoder.IgnoreUnknownKeys(true)
|
||||||
|
|
||||||
|
if err := decoder.Decode(req, r.PostForm); err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if pv, ok := params["account_id"]; ok && pv != "" {
|
||||||
|
req.AccountID = pv
|
||||||
|
} else if qv := r.URL.Query().Get("account_id"); qv != "" {
|
||||||
|
req.AccountID = qv
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.AccountID != "" {
|
||||||
|
sess := webcontext.ContextSession(ctx)
|
||||||
|
var expires time.Duration
|
||||||
|
if sess != nil && sess.Options != nil {
|
||||||
|
expires = time.Second * time.Duration(sess.Options.MaxAge)
|
||||||
|
} else {
|
||||||
|
expires = time.Hour
|
||||||
|
}
|
||||||
|
|
||||||
|
// Perform the account switch.
|
||||||
|
tkn, err := user_auth.SwitchAccount(ctx, h.MasterDB, h.Authenticator, claims, *req, expires, ctxValues.Now)
|
||||||
|
if err != nil {
|
||||||
|
if verr, ok := weberror.NewValidationError(ctx, err); ok {
|
||||||
|
data["validationErrors"] = verr.(*weberror.Error)
|
||||||
|
return false, nil
|
||||||
|
} else {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update the access token in the session.
|
||||||
|
sess = webcontext.SessionUpdateAccessToken(sess, tkn.AccessToken)
|
||||||
|
|
||||||
|
// Read the account for a flash message.
|
||||||
|
acc, err := account.ReadByID(ctx, claims, h.MasterDB, tkn.AccountID)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
webcontext.SessionFlashSuccess(ctx,
|
||||||
|
"Account Switched",
|
||||||
|
fmt.Sprintf("You are now logged into account %s.",
|
||||||
|
acc.Response(ctx).Name))
|
||||||
|
|
||||||
|
// Write the session to the client.
|
||||||
|
err = webcontext.ContextSession(ctx).Save(r, w)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Redirect the user to the dashboard with the new credentials.
|
||||||
|
http.Redirect(w, r, "/", http.StatusFound)
|
||||||
|
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
end, err := f()
|
||||||
|
if err != nil {
|
||||||
|
return web.RenderError(ctx, w, r, err, h.Renderer, TmplLayoutBase, TmplContentErrorGeneric, web.MIMETextHTMLCharsetUTF8)
|
||||||
|
} else if end {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
accounts, err := account.Find(ctx, claims, h.MasterDB, account.AccountFindRequest{
|
||||||
|
Order: []string{"name"},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
data["accounts"] = accounts.Response(ctx)
|
||||||
|
|
||||||
|
if req.AccountID == "" {
|
||||||
|
req.AccountID = claims.Audience
|
||||||
|
}
|
||||||
|
|
||||||
|
data["form"] = req
|
||||||
|
|
||||||
|
if verr, ok := weberror.NewValidationError(ctx, webcontext.Validator().Struct(user_auth.SwitchAccountRequest{})); ok {
|
||||||
|
data["validationDefaults"] = verr.(*weberror.Error)
|
||||||
|
}
|
||||||
|
|
||||||
|
return h.Renderer.Render(ctx, w, r, TmplLayoutBase, "user-switch-account.gohtml", web.MIMETextHTMLCharsetUTF8, http.StatusOK, data)
|
||||||
|
}
|
||||||
|
|
||||||
|
// updateContextClaims updates the claims in the context.
|
||||||
|
func updateContextClaims(ctx context.Context, authenticator *auth.Authenticator, claims auth.Claims) (context.Context, error) {
|
||||||
|
tkn, err := authenticator.GenerateToken(claims)
|
||||||
|
if err != nil {
|
||||||
|
return ctx, err
|
||||||
|
}
|
||||||
|
|
||||||
|
sess := webcontext.ContextSession(ctx)
|
||||||
|
sess = webcontext.SessionUpdateAccessToken(sess, tkn)
|
||||||
|
|
||||||
|
ctx = context.WithValue(ctx, auth.Key, claims)
|
||||||
|
|
||||||
|
return ctx, nil
|
||||||
|
}
|
||||||
|
@ -698,7 +698,7 @@ func main() {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
usr, err := user.Read(ctx, auth.Claims{}, masterDb, claims.Subject, false)
|
usr, err := user.ReadByID(ctx, auth.Claims{}, masterDb, claims.Subject)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -733,7 +733,7 @@ func main() {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
acc, err := account.Read(ctx, auth.Claims{}, masterDb, claims.Audience, false)
|
acc, err := account.ReadByID(ctx, auth.Claims{}, masterDb, claims.Audience)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -746,6 +746,26 @@ func main() {
|
|||||||
|
|
||||||
return a
|
return a
|
||||||
},
|
},
|
||||||
|
"ContextCanSwitchAccount": func(ctx context.Context) bool {
|
||||||
|
claims, err := auth.ClaimsFromContext(ctx)
|
||||||
|
if err != nil || len(claims.AccountIDs) < 2 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
},
|
||||||
|
"ContextIsVirtualSession": func(ctx context.Context) bool {
|
||||||
|
claims, err := auth.ClaimsFromContext(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if claims.RootUserID != "" && claims.RootUserID != claims.Subject {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if claims.RootAccountID != "" && claims.RootAccountID != claims.Audience {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
imgUrlFormatter := staticUrlFormatter
|
imgUrlFormatter := staticUrlFormatter
|
||||||
@ -827,11 +847,8 @@ func main() {
|
|||||||
|
|
||||||
switch statusCode {
|
switch statusCode {
|
||||||
case http.StatusUnauthorized:
|
case http.StatusUnauthorized:
|
||||||
// Handle expired sessions that are returned from the auth middleware.
|
http.Redirect(w, r, "/user/login?redirect="+url.QueryEscape(r.RequestURI), http.StatusFound)
|
||||||
if strings.Contains(errors.Cause(er).Error(), "token is expired") {
|
return nil
|
||||||
http.Redirect(w, r, "/user/login", http.StatusFound)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return web.RenderError(ctx, w, r, er, renderer, handlers.TmplLayoutBase, handlers.TmplContentErrorGeneric, web.MIMETextHTMLCharsetUTF8)
|
return web.RenderError(ctx, w, r, er, renderer, handlers.TmplLayoutBase, handlers.TmplContentErrorGeneric, web.MIMETextHTMLCharsetUTF8)
|
||||||
|
BIN
cmd/web-app/static/assets/images/user-default.jpg
Normal file
BIN
cmd/web-app/static/assets/images/user-default.jpg
Normal file
Binary file not shown.
After Width: | Height: | Size: 40 KiB |
60
cmd/web-app/templates/content/user-switch-account.gohtml
Normal file
60
cmd/web-app/templates/content/user-switch-account.gohtml
Normal file
@ -0,0 +1,60 @@
|
|||||||
|
{{define "title"}}Switch Account{{end}}
|
||||||
|
{{define "style"}}
|
||||||
|
|
||||||
|
{{end}}
|
||||||
|
{{ define "partials/page-wrapper" }}
|
||||||
|
<div class="container" id="page-content">
|
||||||
|
|
||||||
|
<!-- Outer Row -->
|
||||||
|
<div class="row justify-content-center">
|
||||||
|
|
||||||
|
<div class="col-xl-10 col-lg-12 col-md-9">
|
||||||
|
|
||||||
|
<div class="card o-hidden border-0 shadow-lg my-5">
|
||||||
|
<div class="card-body p-0">
|
||||||
|
<!-- Nested Row within Card Body -->
|
||||||
|
<div class="row">
|
||||||
|
<div class="col-lg-6 d-none d-lg-block bg-login-image"></div>
|
||||||
|
<div class="col-lg-6">
|
||||||
|
<div class="p-5">
|
||||||
|
{{ template "app-flashes" . }}
|
||||||
|
|
||||||
|
<div class="text-center">
|
||||||
|
<h1 class="h4 text-gray-900 mb-4">Switch Account</h1>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{{ template "validation-error" . }}
|
||||||
|
|
||||||
|
<form class="user" method="post" novalidate>
|
||||||
|
<div class="form-group">
|
||||||
|
<select class="form-control form-control-select-box {{ ValidationFieldClass $.validationErrors "AccountID" }}" name="AccountID" placeholder="AccountID" required>
|
||||||
|
{{ range $i := $.accounts }}
|
||||||
|
<option value="{{ $i.ID }}" {{ if eq $.form.AccountID $i.ID }}selected="selected"{{ end }}>{{ $i.Name }}</option>
|
||||||
|
{{ end }}
|
||||||
|
</select>
|
||||||
|
{{template "invalid-feedback" dict "validationDefaults" $.validationDefaults "validationErrors" $.validationErrors "fieldName" "AccountID" }}
|
||||||
|
</div>
|
||||||
|
<button class="btn btn-primary btn-user btn-block">
|
||||||
|
Login
|
||||||
|
</button>
|
||||||
|
<hr>
|
||||||
|
</form>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
</div>
|
||||||
|
|
||||||
|
</div>
|
||||||
|
|
||||||
|
</div>
|
||||||
|
{{end}}
|
||||||
|
{{define "js"}}
|
||||||
|
<script>
|
||||||
|
$(document).ready(function() {
|
||||||
|
$(document).find('body').addClass('bg-gradient-primary');
|
||||||
|
});
|
||||||
|
</script>
|
||||||
|
{{end}}
|
60
cmd/web-app/templates/content/user-virtual-login.gohtml
Normal file
60
cmd/web-app/templates/content/user-virtual-login.gohtml
Normal file
@ -0,0 +1,60 @@
|
|||||||
|
{{define "title"}}Switch User{{end}}
|
||||||
|
{{define "style"}}
|
||||||
|
|
||||||
|
{{end}}
|
||||||
|
{{ define "partials/page-wrapper" }}
|
||||||
|
<div class="container" id="page-content">
|
||||||
|
|
||||||
|
<!-- Outer Row -->
|
||||||
|
<div class="row justify-content-center">
|
||||||
|
|
||||||
|
<div class="col-xl-10 col-lg-12 col-md-9">
|
||||||
|
|
||||||
|
<div class="card o-hidden border-0 shadow-lg my-5">
|
||||||
|
<div class="card-body p-0">
|
||||||
|
<!-- Nested Row within Card Body -->
|
||||||
|
<div class="row">
|
||||||
|
<div class="col-lg-6 d-none d-lg-block bg-login-image"></div>
|
||||||
|
<div class="col-lg-6">
|
||||||
|
<div class="p-5">
|
||||||
|
{{ template "app-flashes" . }}
|
||||||
|
|
||||||
|
<div class="text-center">
|
||||||
|
<h1 class="h4 text-gray-900 mb-4">Switch User</h1>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{{ template "validation-error" . }}
|
||||||
|
|
||||||
|
<form class="user" method="post" novalidate>
|
||||||
|
<div class="form-group">
|
||||||
|
<select class="form-control form-control-select-box {{ ValidationFieldClass $.validationErrors "User" }}" name="UserID" placeholder="UserID" required>
|
||||||
|
{{ range $i := $.users }}
|
||||||
|
<option value="{{ $i.ID }}" {{ if eq $.form.UserID $i.ID }}selected="selected"{{ end }}>{{ $i.Name }}</option>
|
||||||
|
{{ end }}
|
||||||
|
</select>
|
||||||
|
{{template "invalid-feedback" dict "validationDefaults" $.validationDefaults "validationErrors" $.validationErrors "fieldName" "UserID" }}
|
||||||
|
</div>
|
||||||
|
<button class="btn btn-primary btn-user btn-block">
|
||||||
|
Login
|
||||||
|
</button>
|
||||||
|
<hr>
|
||||||
|
</form>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
</div>
|
||||||
|
|
||||||
|
</div>
|
||||||
|
|
||||||
|
</div>
|
||||||
|
{{end}}
|
||||||
|
{{define "js"}}
|
||||||
|
<script>
|
||||||
|
$(document).ready(function() {
|
||||||
|
$(document).find('body').addClass('bg-gradient-primary');
|
||||||
|
});
|
||||||
|
</script>
|
||||||
|
{{end}}
|
@ -177,7 +177,6 @@
|
|||||||
<i class="fas fa-cogs fa-sm fa-fw mr-2 text-gray-400"></i>
|
<i class="fas fa-cogs fa-sm fa-fw mr-2 text-gray-400"></i>
|
||||||
Account Settings
|
Account Settings
|
||||||
</a>
|
</a>
|
||||||
|
|
||||||
<a class="dropdown-item" href="/users">
|
<a class="dropdown-item" href="/users">
|
||||||
<i class="fas fa-cogs fa-sm fa-fw mr-2 text-gray-400"></i>
|
<i class="fas fa-cogs fa-sm fa-fw mr-2 text-gray-400"></i>
|
||||||
Manage Users
|
Manage Users
|
||||||
@ -197,8 +196,22 @@
|
|||||||
<i class="fas fa-cogs fa-sm fa-fw mr-2 text-gray-400"></i>
|
<i class="fas fa-cogs fa-sm fa-fw mr-2 text-gray-400"></i>
|
||||||
Support
|
Support
|
||||||
</a>
|
</a>
|
||||||
|
|
||||||
<div class="dropdown-divider"></div>
|
<div class="dropdown-divider"></div>
|
||||||
|
|
||||||
|
{{ if ContextCanSwitchAccount $._Ctx }}
|
||||||
|
<a class="dropdown-item" href="/user/switch-account?modal=1">
|
||||||
|
<i class="far fa-sign-in fa-sm fa-fw mr-2 text-gray-400"></i>
|
||||||
|
Switch Account
|
||||||
|
</a>
|
||||||
|
{{ end }}
|
||||||
|
|
||||||
|
{{ if ContextIsVirtualSession $._Ctx }}
|
||||||
|
<a class="dropdown-item" href="/user/virtual-logout">
|
||||||
|
<i class="far fa-sign-out fa-sm fa-fw mr-2 text-gray-400"></i>
|
||||||
|
Switch Back
|
||||||
|
</a>
|
||||||
|
{{ end }}
|
||||||
|
|
||||||
<a class="dropdown-item" href="/user/logout" data-toggle="modal" data-target="#logoutModal">
|
<a class="dropdown-item" href="/user/logout" data-toggle="modal" data-target="#logoutModal">
|
||||||
<i class="fas fa-sign-out-alt fa-sm fa-fw mr-2 text-gray-400"></i>
|
<i class="fas fa-sign-out-alt fa-sm fa-fw mr-2 text-gray-400"></i>
|
||||||
Logout
|
Logout
|
||||||
|
@ -150,7 +150,7 @@ func selectQuery() *sqlbuilder.SelectBuilder {
|
|||||||
// Find gets all the accounts from the database based on the request params.
|
// Find gets all the accounts from the database based on the request params.
|
||||||
// TODO: Need to figure out why can't parse the args when appending the where
|
// TODO: Need to figure out why can't parse the args when appending the where
|
||||||
// to the query.
|
// to the query.
|
||||||
func Find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req AccountFindRequest) ([]*Account, error) {
|
func Find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req AccountFindRequest) (Accounts, error) {
|
||||||
query := selectQuery()
|
query := selectQuery()
|
||||||
|
|
||||||
if req.Where != nil {
|
if req.Where != nil {
|
||||||
@ -170,7 +170,7 @@ func Find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req AccountF
|
|||||||
}
|
}
|
||||||
|
|
||||||
// find internal method for getting all the accounts from the database using a select query.
|
// find internal method for getting all the accounts from the database using a select query.
|
||||||
func find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, query *sqlbuilder.SelectBuilder, args []interface{}, includedArchived bool) ([]*Account, error) {
|
func find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, query *sqlbuilder.SelectBuilder, args []interface{}, includedArchived bool) (Accounts, error) {
|
||||||
span, ctx := tracer.StartSpanFromContext(ctx, "internal.account.Find")
|
span, ctx := tracer.StartSpanFromContext(ctx, "internal.account.Find")
|
||||||
defer span.Finish()
|
defer span.Finish()
|
||||||
|
|
||||||
|
@ -99,6 +99,21 @@ func (m *AccountResponse) MarshalBinary() ([]byte, error) {
|
|||||||
return json.Marshal(m)
|
return json.Marshal(m)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Accounts a list of Accounts.
|
||||||
|
type Accounts []*Account
|
||||||
|
|
||||||
|
// Response transforms a list of Accounts to a list of AccountResponses.
|
||||||
|
func (m *Accounts) Response(ctx context.Context) []*AccountResponse {
|
||||||
|
var l []*AccountResponse
|
||||||
|
if m != nil && len(*m) > 0 {
|
||||||
|
for _, n := range *m {
|
||||||
|
l = append(l, n.Response(ctx))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return l
|
||||||
|
}
|
||||||
|
|
||||||
// AccountCreateRequest contains information needed to create a new Account.
|
// AccountCreateRequest contains information needed to create a new Account.
|
||||||
type AccountCreateRequest struct {
|
type AccountCreateRequest struct {
|
||||||
Name string `json:"name" validate:"required,unique" example:"Company Name"`
|
Name string `json:"name" validate:"required,unique" example:"Company Name"`
|
||||||
|
@ -23,9 +23,11 @@ const Key ctxKey = 1
|
|||||||
|
|
||||||
// Claims represents the authorization claims transmitted via a JWT.
|
// Claims represents the authorization claims transmitted via a JWT.
|
||||||
type Claims struct {
|
type Claims struct {
|
||||||
AccountIds []string `json:"accounts"`
|
RootUserID string `json:"root_user_id"`
|
||||||
Roles []string `json:"roles"`
|
RootAccountID string `json:"root_account_id"`
|
||||||
Preferences ClaimPreferences `json:"prefs"`
|
AccountIDs []string `json:"accounts"`
|
||||||
|
Roles []string `json:"roles"`
|
||||||
|
Preferences ClaimPreferences `json:"prefs"`
|
||||||
jwt.StandardClaims
|
jwt.StandardClaims
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -41,14 +43,16 @@ type ClaimPreferences struct {
|
|||||||
// NewClaims constructs a Claims value for the identified user. The Claims
|
// NewClaims constructs a Claims value for the identified user. The Claims
|
||||||
// expire within a specified duration of the provided time. Additional fields
|
// expire within a specified duration of the provided time. Additional fields
|
||||||
// of the Claims can be set after calling NewClaims is desired.
|
// of the Claims can be set after calling NewClaims is desired.
|
||||||
func NewClaims(userId, accountId string, accountIds []string, roles []string, prefs ClaimPreferences, now time.Time, expires time.Duration) Claims {
|
func NewClaims(userID, accountID string, accountIDs []string, roles []string, prefs ClaimPreferences, now time.Time, expires time.Duration) Claims {
|
||||||
c := Claims{
|
c := Claims{
|
||||||
AccountIds: accountIds,
|
AccountIDs: accountIDs,
|
||||||
Roles: roles,
|
RootAccountID: accountID,
|
||||||
Preferences: prefs,
|
RootUserID: userID,
|
||||||
|
Roles: roles,
|
||||||
|
Preferences: prefs,
|
||||||
StandardClaims: jwt.StandardClaims{
|
StandardClaims: jwt.StandardClaims{
|
||||||
Subject: userId,
|
Subject: userID,
|
||||||
Audience: accountId,
|
Audience: accountID,
|
||||||
IssuedAt: now.Unix(),
|
IssuedAt: now.Unix(),
|
||||||
ExpiresAt: now.Add(expires).Unix(),
|
ExpiresAt: now.Add(expires).Unix(),
|
||||||
},
|
},
|
||||||
|
@ -14,39 +14,23 @@ const KeySession ctxKeySession = 1
|
|||||||
// Session keys used to store values.
|
// Session keys used to store values.
|
||||||
const (
|
const (
|
||||||
SessionKeyAccessToken = iota
|
SessionKeyAccessToken = iota
|
||||||
//SessionKeyPreferenceDatetimeFormat
|
|
||||||
//SessionKeyPreferenceDateFormat
|
|
||||||
//SessionKeyPreferenceTimeFormat
|
|
||||||
//SessionKeyTimezone
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func init() {
|
|
||||||
//gob.Register(&Session{})
|
|
||||||
}
|
|
||||||
|
|
||||||
// Session represents a user with authentication.
|
|
||||||
type Session struct {
|
|
||||||
*sessions.Session
|
|
||||||
}
|
|
||||||
|
|
||||||
// ContextWithSession appends a universal translator to a context.
|
// ContextWithSession appends a universal translator to a context.
|
||||||
func ContextWithSession(ctx context.Context, session *sessions.Session) context.Context {
|
func ContextWithSession(ctx context.Context, session *sessions.Session) context.Context {
|
||||||
return context.WithValue(ctx, KeySession, session)
|
return context.WithValue(ctx, KeySession, session)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ContextSession returns the session from a context.
|
// ContextSession returns the session from a context.
|
||||||
func ContextSession(ctx context.Context) *Session {
|
func ContextSession(ctx context.Context) *sessions.Session {
|
||||||
if s, ok := ctx.Value(KeySession).(*Session); ok {
|
if s, ok := ctx.Value(KeySession).(*sessions.Session); ok {
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func ContextAccessToken(ctx context.Context) (string, bool) {
|
func ContextAccessToken(ctx context.Context) (string, bool) {
|
||||||
return ContextSession(ctx).AccessToken()
|
sess := ContextSession(ctx)
|
||||||
}
|
|
||||||
|
|
||||||
func (sess *Session) AccessToken() (string, bool) {
|
|
||||||
if sess == nil {
|
if sess == nil {
|
||||||
return "", false
|
return "", false
|
||||||
}
|
}
|
||||||
@ -56,60 +40,19 @@ func (sess *Session) AccessToken() (string, bool) {
|
|||||||
return "", false
|
return "", false
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
func SessionInit(session *sessions.Session, accessToken string) *sessions.Session {
|
||||||
func(sess *Session) PreferenceDatetimeFormat() (string, bool) {
|
|
||||||
if sess == nil {
|
|
||||||
return "", false
|
|
||||||
}
|
|
||||||
if sv, ok := sess.Values[SessionKeyPreferenceDatetimeFormat].(string); ok {
|
|
||||||
return sv, true
|
|
||||||
}
|
|
||||||
return "", false
|
|
||||||
}
|
|
||||||
|
|
||||||
func(sess *Session) PreferenceDateFormat() (string, bool) {
|
|
||||||
if sess == nil {
|
|
||||||
return "", false
|
|
||||||
}
|
|
||||||
if sv, ok := sess.Values[SessionKeyPreferenceDateFormat].(string); ok {
|
|
||||||
return sv, true
|
|
||||||
}
|
|
||||||
return "", false
|
|
||||||
}
|
|
||||||
|
|
||||||
func(sess *Session) PreferenceTimeFormat() (string, bool) {
|
|
||||||
if sess == nil {
|
|
||||||
return "", false
|
|
||||||
}
|
|
||||||
if sv, ok := sess.Values[SessionKeyPreferenceTimeFormat].(string); ok {
|
|
||||||
return sv, true
|
|
||||||
}
|
|
||||||
return "", false
|
|
||||||
}
|
|
||||||
|
|
||||||
func(sess *Session) Timezone() (*time.Location, bool) {
|
|
||||||
if sess != nil {
|
|
||||||
if sv, ok := sess.Values[SessionKeyTimezone].(*time.Location); ok {
|
|
||||||
return sv, true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, false
|
|
||||||
}
|
|
||||||
*/
|
|
||||||
|
|
||||||
func SessionInit(session *Session, accessToken string) *Session {
|
|
||||||
|
|
||||||
session.Values[SessionKeyAccessToken] = accessToken
|
session.Values[SessionKeyAccessToken] = accessToken
|
||||||
//session.Values[SessionKeyPreferenceDatetimeFormat] = datetimeFormat
|
|
||||||
//session.Values[SessionKeyPreferenceDateFormat] = dateFormat
|
|
||||||
//session.Values[SessionKeyPreferenceTimeFormat] = timeFormat
|
|
||||||
//session.Values[SessionKeyTimezone] = timezone
|
|
||||||
|
|
||||||
return session
|
return session
|
||||||
}
|
}
|
||||||
|
|
||||||
func SessionDestroy(session *Session) *Session {
|
func SessionUpdateAccessToken(session *sessions.Session, accessToken string) *sessions.Session {
|
||||||
|
session.Values[SessionKeyAccessToken] = accessToken
|
||||||
|
return session
|
||||||
|
}
|
||||||
|
|
||||||
|
func SessionDestroy(session *sessions.Session) *sessions.Session {
|
||||||
|
|
||||||
delete(session.Values, SessionKeyAccessToken)
|
delete(session.Values, SessionKeyAccessToken)
|
||||||
|
|
||||||
|
@ -56,6 +56,21 @@ func (m *Project) Response(ctx context.Context) *ProjectResponse {
|
|||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Projects a list of Projects.
|
||||||
|
type Projects []*Project
|
||||||
|
|
||||||
|
// Response transforms a list of Projects to a list of ProjectResponses.
|
||||||
|
func (m *Projects) Response(ctx context.Context) []*ProjectResponse {
|
||||||
|
var l []*ProjectResponse
|
||||||
|
if m != nil && len(*m) > 0 {
|
||||||
|
for _, n := range *m {
|
||||||
|
l = append(l, n.Response(ctx))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return l
|
||||||
|
}
|
||||||
|
|
||||||
// ProjectCreateRequest contains information needed to create a new Project.
|
// ProjectCreateRequest contains information needed to create a new Project.
|
||||||
type ProjectCreateRequest struct {
|
type ProjectCreateRequest struct {
|
||||||
AccountID string `json:"account_id" validate:"required,uuid" example:"c4653bf9-5978-48b7-89c5-95704aebb7e2"`
|
AccountID string `json:"account_id" validate:"required,uuid" example:"c4653bf9-5978-48b7-89c5-95704aebb7e2"`
|
||||||
|
@ -124,13 +124,13 @@ func findRequestQuery(req ProjectFindRequest) (*sqlbuilder.SelectBuilder, []inte
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Find gets all the projects from the database based on the request params.
|
// Find gets all the projects from the database based on the request params.
|
||||||
func Find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req ProjectFindRequest) ([]*Project, error) {
|
func Find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req ProjectFindRequest) (Projects, error) {
|
||||||
query, args := findRequestQuery(req)
|
query, args := findRequestQuery(req)
|
||||||
return find(ctx, claims, dbConn, query, args, req.IncludeArchived)
|
return find(ctx, claims, dbConn, query, args, req.IncludeArchived)
|
||||||
}
|
}
|
||||||
|
|
||||||
// find internal method for getting all the projects from the database using a select query.
|
// find internal method for getting all the projects from the database using a select query.
|
||||||
func find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, query *sqlbuilder.SelectBuilder, args []interface{}, includedArchived bool) ([]*Project, error) {
|
func find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, query *sqlbuilder.SelectBuilder, args []interface{}, includedArchived bool) (Projects, error) {
|
||||||
span, ctx := tracer.StartSpanFromContext(ctx, "internal.project.Find")
|
span, ctx := tracer.StartSpanFromContext(ctx, "internal.project.Find")
|
||||||
defer span.Finish()
|
defer span.Finish()
|
||||||
|
|
||||||
|
@ -78,6 +78,21 @@ func (m *UserResponse) MarshalBinary() ([]byte, error) {
|
|||||||
return json.Marshal(m)
|
return json.Marshal(m)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Users a list of Users.
|
||||||
|
type Users []*User
|
||||||
|
|
||||||
|
// Response transforms a list of Users to a list of UserResponses.
|
||||||
|
func (m *Users) Response(ctx context.Context) []*UserResponse {
|
||||||
|
var l []*UserResponse
|
||||||
|
if m != nil && len(*m) > 0 {
|
||||||
|
for _, n := range *m {
|
||||||
|
l = append(l, n.Response(ctx))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return l
|
||||||
|
}
|
||||||
|
|
||||||
// UserCreateRequest contains information needed to create a new User.
|
// UserCreateRequest contains information needed to create a new User.
|
||||||
type UserCreateRequest struct {
|
type UserCreateRequest struct {
|
||||||
FirstName string `json:"first_name" validate:"required" example:"Gabi"`
|
FirstName string `json:"first_name" validate:"required" example:"Gabi"`
|
||||||
|
@ -202,13 +202,13 @@ func findRequestQuery(req UserFindRequest) (*sqlbuilder.SelectBuilder, []interfa
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Find gets all the users from the database based on the request params.
|
// Find gets all the users from the database based on the request params.
|
||||||
func Find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserFindRequest) ([]*User, error) {
|
func Find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserFindRequest) (Users, error) {
|
||||||
query, args := findRequestQuery(req)
|
query, args := findRequestQuery(req)
|
||||||
return find(ctx, claims, dbConn, query, args, req.IncludeArchived)
|
return find(ctx, claims, dbConn, query, args, req.IncludeArchived)
|
||||||
}
|
}
|
||||||
|
|
||||||
// find internal method for getting all the users from the database using a select query.
|
// find internal method for getting all the users from the database using a select query.
|
||||||
func find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, query *sqlbuilder.SelectBuilder, args []interface{}, includedArchived bool) ([]*User, error) {
|
func find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, query *sqlbuilder.SelectBuilder, args []interface{}, includedArchived bool) (Users, error) {
|
||||||
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.Find")
|
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.Find")
|
||||||
defer span.Finish()
|
defer span.Finish()
|
||||||
|
|
||||||
|
@ -66,6 +66,34 @@ func (m *UserAccount) Response(ctx context.Context) *UserAccountResponse {
|
|||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// HasRole checks if the entry has a role.
|
||||||
|
func (m *UserAccount) HasRole(role UserAccountRole) bool {
|
||||||
|
if m == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for _, r := range m.Roles {
|
||||||
|
if r == role {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// UserAccounts a list of UserAccounts.
|
||||||
|
type UserAccounts []*UserAccount
|
||||||
|
|
||||||
|
// Response transforms a list of UserAccounts to a list of UserAccountResponses.
|
||||||
|
func (m *UserAccounts) Response(ctx context.Context) []*UserAccountResponse {
|
||||||
|
var l []*UserAccountResponse
|
||||||
|
if m != nil && len(*m) > 0 {
|
||||||
|
for _, n := range *m {
|
||||||
|
l = append(l, n.Response(ctx))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return l
|
||||||
|
}
|
||||||
|
|
||||||
// UserAccountCreateRequest defines the information is needed to associate a user to an
|
// UserAccountCreateRequest defines the information is needed to associate a user to an
|
||||||
// account. Users are global to the application and each users access can be managed
|
// account. Users are global to the application and each users access can be managed
|
||||||
// on an account level. If a current entry exists in the database but is archived,
|
// on an account level. If a current entry exists in the database but is archived,
|
||||||
|
@ -131,13 +131,13 @@ func findRequestQuery(req UserAccountFindRequest) (*sqlbuilder.SelectBuilder, []
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Find gets all the user accounts from the database based on the request params.
|
// Find gets all the user accounts from the database based on the request params.
|
||||||
func Find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserAccountFindRequest) ([]*UserAccount, error) {
|
func Find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserAccountFindRequest) (UserAccounts, error) {
|
||||||
query, args := findRequestQuery(req)
|
query, args := findRequestQuery(req)
|
||||||
return find(ctx, claims, dbConn, query, args, req.IncludeArchived)
|
return find(ctx, claims, dbConn, query, args, req.IncludeArchived)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Find gets all the user accounts from the database based on the select query
|
// Find gets all the user accounts from the database based on the select query
|
||||||
func find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, query *sqlbuilder.SelectBuilder, args []interface{}, includedArchived bool) ([]*UserAccount, error) {
|
func find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, query *sqlbuilder.SelectBuilder, args []interface{}, includedArchived bool) (UserAccounts, error) {
|
||||||
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user_account.Find")
|
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user_account.Find")
|
||||||
defer span.Finish()
|
defer span.Finish()
|
||||||
|
|
||||||
@ -180,7 +180,7 @@ func find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, query *sqlbu
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Retrieve gets the specified user from the database.
|
// Retrieve gets the specified user from the database.
|
||||||
func FindByUserID(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, userID string, includedArchived bool) ([]*UserAccount, error) {
|
func FindByUserID(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, userID string, includedArchived bool) (UserAccounts, error) {
|
||||||
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user_account.FindByUserID")
|
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user_account.FindByUserID")
|
||||||
defer span.Finish()
|
defer span.Finish()
|
||||||
|
|
||||||
|
@ -3,6 +3,7 @@ package user_auth
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
"geeks-accelerator/oss/saas-starter-kit/internal/user_account"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@ -22,6 +23,9 @@ var (
|
|||||||
// ErrAuthenticationFailure occurs when a user attempts to authenticate but
|
// ErrAuthenticationFailure occurs when a user attempts to authenticate but
|
||||||
// anything goes wrong.
|
// anything goes wrong.
|
||||||
ErrAuthenticationFailure = errors.New("Authentication failed")
|
ErrAuthenticationFailure = errors.New("Authentication failed")
|
||||||
|
|
||||||
|
// ErrForbidden occurs when a user tries to do something that is forbidden to them according to our access control policies.
|
||||||
|
ErrForbidden = errors.New("Attempted action is not allowed")
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@ -65,22 +69,11 @@ func Authenticate(ctx context.Context, dbConn *sqlx.DB, tknGen TokenGenerator, e
|
|||||||
return generateToken(ctx, dbConn, tknGen, auth.Claims{}, u.ID, "", expires, now, scopes...)
|
return generateToken(ctx, dbConn, tknGen, auth.Claims{}, u.ID, "", expires, now, scopes...)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Authenticate finds a user by their email and verifies their password. On success
|
// SwitchAccount allows users to switch between multiple accounts, this changes the claim audience.
|
||||||
// it returns a Token that can be used to authenticate access to the application in
|
func SwitchAccount(ctx context.Context, dbConn *sqlx.DB, tknGen TokenGenerator, claims auth.Claims, req SwitchAccountRequest, expires time.Duration, now time.Time, scopes ...string) (Token, error) {
|
||||||
// the future.
|
|
||||||
func SwitchAccount(ctx context.Context, dbConn *sqlx.DB, tknGen TokenGenerator, claims auth.Claims, accountID string, expires time.Duration, now time.Time, scopes ...string) (Token, error) {
|
|
||||||
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user_auth.SwitchAccount")
|
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user_auth.SwitchAccount")
|
||||||
defer span.Finish()
|
defer span.Finish()
|
||||||
|
|
||||||
// Defines struct to apply validation for the supplied claims and account ID.
|
|
||||||
req := struct {
|
|
||||||
UserID string `json:"user_id" validate:"required,uuid"`
|
|
||||||
AccountID string `json:"account_id" validate:"required,uuid"`
|
|
||||||
}{
|
|
||||||
UserID: claims.Subject,
|
|
||||||
AccountID: accountID,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Validate the request.
|
// Validate the request.
|
||||||
v := webcontext.Validator()
|
v := webcontext.Validator()
|
||||||
err := v.Struct(req)
|
err := v.Struct(req)
|
||||||
@ -88,12 +81,74 @@ func SwitchAccount(ctx context.Context, dbConn *sqlx.DB, tknGen TokenGenerator,
|
|||||||
return Token{}, err
|
return Token{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
claims.RootAccountID = req.AccountID
|
||||||
|
|
||||||
|
if claims.RootUserID == "" {
|
||||||
|
claims.RootUserID = claims.Subject
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate a token for the user ID in supplied in claims as the Subject. Pass
|
||||||
|
// in the supplied claims as well to enforce ACLs when finding the current
|
||||||
|
// list of accounts for the user.
|
||||||
|
return generateToken(ctx, dbConn, tknGen, claims, claims.Subject, req.AccountID, expires, now, scopes...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// VirtualLogin allows users to mock being logged in as other users.
|
||||||
|
func VirtualLogin(ctx context.Context, dbConn *sqlx.DB, tknGen TokenGenerator, claims auth.Claims, req VirtualLoginRequest, expires time.Duration, now time.Time, scopes ...string) (Token, error) {
|
||||||
|
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user_auth.VirtualLogin")
|
||||||
|
defer span.Finish()
|
||||||
|
|
||||||
|
// Validate the request.
|
||||||
|
v := webcontext.Validator()
|
||||||
|
err := v.Struct(req)
|
||||||
|
if err != nil {
|
||||||
|
return Token{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find all the accounts that the current user has access to.
|
||||||
|
usrAccs, err := user_account.FindByUserID(ctx, claims, dbConn, claims.Subject, false)
|
||||||
|
if err != nil {
|
||||||
|
return Token{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// The user must have the role of admin to login any other user.
|
||||||
|
var hasAccountAdminRole bool
|
||||||
|
for _, usrAcc := range usrAccs {
|
||||||
|
if usrAcc.HasRole(user_account.UserAccountRole_Admin) {
|
||||||
|
if usrAcc.AccountID == req.AccountID {
|
||||||
|
hasAccountAdminRole = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !hasAccountAdminRole {
|
||||||
|
return Token{}, errors.WithMessagef(ErrForbidden, "User %s does not have correct access to account %s ", claims.Subject, req.AccountID)
|
||||||
|
}
|
||||||
|
|
||||||
|
if claims.RootAccountID == "" {
|
||||||
|
claims.RootAccountID = claims.Audience
|
||||||
|
}
|
||||||
|
if claims.RootUserID == "" {
|
||||||
|
claims.RootUserID = claims.Subject
|
||||||
|
}
|
||||||
|
|
||||||
// Generate a token for the user ID in supplied in claims as the Subject. Pass
|
// Generate a token for the user ID in supplied in claims as the Subject. Pass
|
||||||
// in the supplied claims as well to enforce ACLs when finding the current
|
// in the supplied claims as well to enforce ACLs when finding the current
|
||||||
// list of accounts for the user.
|
// list of accounts for the user.
|
||||||
return generateToken(ctx, dbConn, tknGen, claims, req.UserID, req.AccountID, expires, now, scopes...)
|
return generateToken(ctx, dbConn, tknGen, claims, req.UserID, req.AccountID, expires, now, scopes...)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// VirtualLogout allows switch back to their root user/account.
|
||||||
|
func VirtualLogout(ctx context.Context, dbConn *sqlx.DB, tknGen TokenGenerator, claims auth.Claims, expires time.Duration, now time.Time, scopes ...string) (Token, error) {
|
||||||
|
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user_auth.VirtualLogout")
|
||||||
|
defer span.Finish()
|
||||||
|
|
||||||
|
// Generate a token for the user ID in supplied in claims as the Subject. Pass
|
||||||
|
// in the supplied claims as well to enforce ACLs when finding the current
|
||||||
|
// list of accounts for the user.
|
||||||
|
return generateToken(ctx, dbConn, tknGen, claims, claims.RootUserID, claims.RootAccountID, expires, now, scopes...)
|
||||||
|
}
|
||||||
|
|
||||||
// generateToken generates claims for the supplied user ID and account ID and then
|
// generateToken generates claims for the supplied user ID and account ID and then
|
||||||
// returns the token for the generated claims used for authentication.
|
// returns the token for the generated claims used for authentication.
|
||||||
func generateToken(ctx context.Context, dbConn *sqlx.DB, tknGen TokenGenerator, claims auth.Claims, userID, accountID string, expires time.Duration, now time.Time, scopes ...string) (Token, error) {
|
func generateToken(ctx context.Context, dbConn *sqlx.DB, tknGen TokenGenerator, claims auth.Claims, userID, accountID string, expires time.Duration, now time.Time, scopes ...string) (Token, error) {
|
||||||
@ -250,7 +305,7 @@ func generateToken(ctx context.Context, dbConn *sqlx.DB, tknGen TokenGenerator,
|
|||||||
if scopeValid {
|
if scopeValid {
|
||||||
roles = append(roles, s)
|
roles = append(roles, s)
|
||||||
} else {
|
} else {
|
||||||
err := errors.Errorf("invalid scope '%s'", s)
|
err := errors.Wrapf(ErrForbidden, "invalid scope '%s'", s)
|
||||||
return Token{}, err
|
return Token{}, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -259,7 +314,7 @@ func generateToken(ctx context.Context, dbConn *sqlx.DB, tknGen TokenGenerator,
|
|||||||
}
|
}
|
||||||
|
|
||||||
if len(roles) == 0 {
|
if len(roles) == 0 {
|
||||||
err := errors.New("no roles defined for user")
|
err := errors.Wrapf(ErrForbidden, "no roles defined for user")
|
||||||
return Token{}, err
|
return Token{}, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -314,14 +369,24 @@ func generateToken(ctx context.Context, dbConn *sqlx.DB, tknGen TokenGenerator,
|
|||||||
claimPref = auth.NewClaimPreferences(tz, preferenceDatetimeFormat, preferenceDateFormat, preferenceTimeFormat)
|
claimPref = auth.NewClaimPreferences(tz, preferenceDatetimeFormat, preferenceDateFormat, preferenceTimeFormat)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Ensure the current claims has the root values set.
|
||||||
|
if (claims.RootAccountID == "" && claims.Audience != "") || (claims.RootUserID == "" && claims.Subject != "") {
|
||||||
|
claims.RootAccountID = claims.Audience
|
||||||
|
claims.RootUserID = claims.Subject
|
||||||
|
}
|
||||||
|
|
||||||
// JWT claims requires both an audience and a subject. For this application:
|
// JWT claims requires both an audience and a subject. For this application:
|
||||||
// Subject: The ID of the user authenticated.
|
// Subject: The ID of the user authenticated.
|
||||||
// Audience: The ID of the account the user is accessing. A list of account IDs
|
// Audience: The ID of the account the user is accessing. A list of account IDs
|
||||||
// will also be included to support the user switching between them.
|
// will also be included to support the user switching between them.
|
||||||
claims = auth.NewClaims(userID, accountID, accountIds, roles, claimPref, now, expires)
|
newClaims := auth.NewClaims(userID, accountID, accountIds, roles, claimPref, now, expires)
|
||||||
|
|
||||||
|
// Copy the original root account/user ID.
|
||||||
|
newClaims.RootAccountID = claims.RootAccountID
|
||||||
|
newClaims.RootUserID = claims.RootUserID
|
||||||
|
|
||||||
// Generate a token for the user with the defined claims.
|
// Generate a token for the user with the defined claims.
|
||||||
tknStr, err := tknGen.GenerateToken(claims)
|
tknStr, err := tknGen.GenerateToken(newClaims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return Token{}, errors.Wrap(err, "generating token")
|
return Token{}, errors.Wrap(err, "generating token")
|
||||||
}
|
}
|
||||||
@ -329,9 +394,9 @@ func generateToken(ctx context.Context, dbConn *sqlx.DB, tknGen TokenGenerator,
|
|||||||
tkn := Token{
|
tkn := Token{
|
||||||
AccessToken: tknStr,
|
AccessToken: tknStr,
|
||||||
TokenType: "Bearer",
|
TokenType: "Bearer",
|
||||||
claims: claims,
|
claims: newClaims,
|
||||||
UserID: claims.Subject,
|
UserID: newClaims.Subject,
|
||||||
AccountID: claims.Audience,
|
AccountID: newClaims.Audience,
|
||||||
}
|
}
|
||||||
|
|
||||||
if expires.Seconds() > 0 {
|
if expires.Seconds() > 0 {
|
||||||
|
@ -2,6 +2,7 @@ package user_auth
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
@ -63,6 +64,9 @@ func TestAuthenticate(t *testing.T) {
|
|||||||
}
|
}
|
||||||
t.Logf("\t%s\tCreate user account ok.", tests.Success)
|
t.Logf("\t%s\tCreate user account ok.", tests.Success)
|
||||||
|
|
||||||
|
// Add 30 minutes to now to simulate time passing.
|
||||||
|
now = now.Add(time.Minute * 30)
|
||||||
|
|
||||||
acc2, err := account.MockAccount(ctx, test.MasterDB, now)
|
acc2, err := account.MockAccount(ctx, test.MasterDB, now)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Log("\t\tGot :", err)
|
t.Log("\t\tGot :", err)
|
||||||
@ -81,7 +85,7 @@ func TestAuthenticate(t *testing.T) {
|
|||||||
}, now)
|
}, now)
|
||||||
|
|
||||||
// Add 30 minutes to now to simulate time passing.
|
// Add 30 minutes to now to simulate time passing.
|
||||||
now = now.Add(time.Minute * 30)
|
now = now.Add(time.Minute * 5)
|
||||||
|
|
||||||
// Try to authenticate valid user with invalid password.
|
// Try to authenticate valid user with invalid password.
|
||||||
_, err = Authenticate(ctx, test.MasterDB, tknGen, usrAcc.User.Email, "xy7", time.Hour, now)
|
_, err = Authenticate(ctx, test.MasterDB, tknGen, usrAcc.User.Email, "xy7", time.Hour, now)
|
||||||
@ -106,17 +110,20 @@ func TestAuthenticate(t *testing.T) {
|
|||||||
t.Log("\t\tGot :", err)
|
t.Log("\t\tGot :", err)
|
||||||
t.Fatalf("\t%s\tParse claims from token failed.", tests.Failed)
|
t.Fatalf("\t%s\tParse claims from token failed.", tests.Failed)
|
||||||
}
|
}
|
||||||
|
expectClaims := tkn1.claims
|
||||||
|
expectClaims.RootUserID = ""
|
||||||
|
expectClaims.RootAccountID = ""
|
||||||
|
expectClaims.Subject = usrAcc.UserID
|
||||||
|
expectClaims.Audience = usrAcc.AccountID
|
||||||
|
|
||||||
// Hack for Unhandled Exception in go-cmp@v0.3.0/cmp/options.go:229
|
if diff := cmpClaims(claims1, expectClaims); diff != "" {
|
||||||
resClaims, _ := json.Marshal(claims1)
|
|
||||||
expectClaims, _ := json.Marshal(tkn1.claims)
|
|
||||||
if diff := cmp.Diff(string(resClaims), string(expectClaims)); diff != "" {
|
|
||||||
t.Fatalf("\t%s\tExpected parsed claims to match from token. Diff:\n%s", tests.Failed, diff)
|
t.Fatalf("\t%s\tExpected parsed claims to match from token. Diff:\n%s", tests.Failed, diff)
|
||||||
}
|
}
|
||||||
t.Logf("\t%s\tAuthenticate parse claims from token ok.", tests.Success)
|
t.Logf("\t%s\tAuthenticate parse claims from token ok.", tests.Success)
|
||||||
|
|
||||||
// Try switching to a second account using the first set of claims.
|
// Try switching to a second account using the first set of claims.
|
||||||
tkn2, err := SwitchAccount(ctx, test.MasterDB, tknGen, claims1, acc2.ID, time.Hour, now)
|
tkn2, err := SwitchAccount(ctx, test.MasterDB, tknGen, claims1,
|
||||||
|
SwitchAccountRequest{AccountID: acc2.ID}, time.Hour, now)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Log("\t\tGot :", err)
|
t.Log("\t\tGot :", err)
|
||||||
t.Fatalf("\t%s\tSwitchAccount user failed.", tests.Failed)
|
t.Fatalf("\t%s\tSwitchAccount user failed.", tests.Failed)
|
||||||
@ -129,11 +136,13 @@ func TestAuthenticate(t *testing.T) {
|
|||||||
t.Log("\t\tGot :", err)
|
t.Log("\t\tGot :", err)
|
||||||
t.Fatalf("\t%s\tParse claims from token failed.", tests.Failed)
|
t.Fatalf("\t%s\tParse claims from token failed.", tests.Failed)
|
||||||
}
|
}
|
||||||
|
expectClaims = tkn2.claims
|
||||||
|
expectClaims.RootUserID = usrAcc.UserID
|
||||||
|
expectClaims.RootAccountID = acc2.ID
|
||||||
|
expectClaims.Subject = usrAcc.UserID
|
||||||
|
expectClaims.Audience = acc2.ID
|
||||||
|
|
||||||
// Hack for Unhandled Exception in go-cmp@v0.3.0/cmp/options.go:229
|
if diff := cmpClaims(claims2, expectClaims); diff != "" {
|
||||||
resClaims, _ = json.Marshal(claims2)
|
|
||||||
expectClaims, _ = json.Marshal(tkn2.claims)
|
|
||||||
if diff := cmp.Diff(string(resClaims), string(expectClaims)); diff != "" {
|
|
||||||
t.Fatalf("\t%s\tExpected parsed claims to match from token. Diff:\n%s", tests.Failed, diff)
|
t.Fatalf("\t%s\tExpected parsed claims to match from token. Diff:\n%s", tests.Failed, diff)
|
||||||
}
|
}
|
||||||
t.Logf("\t%s\tSwitchAccount parse claims from token ok.", tests.Success)
|
t.Logf("\t%s\tSwitchAccount parse claims from token ok.", tests.Success)
|
||||||
@ -256,3 +265,724 @@ func TestUserResetPassword(t *testing.T) {
|
|||||||
t.Logf("\t%s\tAuthenticate ok.", tests.Success)
|
t.Logf("\t%s\tAuthenticate ok.", tests.Success)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestSwitchAccount validates the behavior around allowing users to switch between their accounts.
|
||||||
|
func TestSwitchAccount(t *testing.T) {
|
||||||
|
defer tests.Recover(t)
|
||||||
|
|
||||||
|
// Auth tokens are valid for an our and is verified against current time.
|
||||||
|
// Issue the token one hour ago.
|
||||||
|
now := time.Now().Add(time.Hour * -1)
|
||||||
|
|
||||||
|
ctx := tests.Context()
|
||||||
|
|
||||||
|
type authTest struct {
|
||||||
|
name string
|
||||||
|
root *user_account.MockUserAccountResponse
|
||||||
|
switch1Req SwitchAccountRequest
|
||||||
|
switch1Roles []user_account.UserAccountRole
|
||||||
|
switch1Scopes []string
|
||||||
|
switch1Err error
|
||||||
|
switch2Req SwitchAccountRequest
|
||||||
|
switch2Roles []user_account.UserAccountRole
|
||||||
|
switch2Scopes []string
|
||||||
|
switch2Err error
|
||||||
|
}
|
||||||
|
|
||||||
|
var authTests []authTest
|
||||||
|
|
||||||
|
// Test all the combinations there the user has access to all three accounts.
|
||||||
|
if true {
|
||||||
|
for _, roles := range [][]user_account.UserAccountRole{
|
||||||
|
[]user_account.UserAccountRole{user_account.UserAccountRole_Admin, user_account.UserAccountRole_Admin, user_account.UserAccountRole_Admin},
|
||||||
|
[]user_account.UserAccountRole{user_account.UserAccountRole_User, user_account.UserAccountRole_User, user_account.UserAccountRole_User},
|
||||||
|
[]user_account.UserAccountRole{user_account.UserAccountRole_Admin, user_account.UserAccountRole_User, user_account.UserAccountRole_Admin},
|
||||||
|
[]user_account.UserAccountRole{user_account.UserAccountRole_User, user_account.UserAccountRole_Admin, user_account.UserAccountRole_User},
|
||||||
|
} {
|
||||||
|
// Create a new user for testing.
|
||||||
|
usrAcc, err := user_account.MockUserAccount(ctx, test.MasterDB, now, roles[0])
|
||||||
|
if err != nil {
|
||||||
|
t.Log("\t\tGot :", err)
|
||||||
|
t.Fatalf("\t%s\tCreate user account failed.", tests.Failed)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create the second account.
|
||||||
|
now = now.Add(time.Minute)
|
||||||
|
acc2, err := account.MockAccount(ctx, test.MasterDB, now)
|
||||||
|
if err != nil {
|
||||||
|
t.Log("\t\tGot :", err)
|
||||||
|
t.Fatalf("\t%s\tCreate second account failed.", tests.Failed)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Associate the second account with root user.
|
||||||
|
usrAcc2, err := user_account.Create(ctx, auth.Claims{}, test.MasterDB, user_account.UserAccountCreateRequest{
|
||||||
|
UserID: usrAcc.UserID,
|
||||||
|
AccountID: acc2.ID,
|
||||||
|
Roles: []user_account.UserAccountRole{user_account.UserAccountRole(roles[1])},
|
||||||
|
}, now)
|
||||||
|
if err != nil {
|
||||||
|
t.Log("\t\tGot :", err)
|
||||||
|
t.Fatalf("\t%s\tLinking second account to user failed.", tests.Failed)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create the third account.
|
||||||
|
now = now.Add(time.Minute)
|
||||||
|
acc3, err := account.MockAccount(ctx, test.MasterDB, now)
|
||||||
|
if err != nil {
|
||||||
|
t.Log("\t\tGot :", err)
|
||||||
|
t.Fatalf("\t%s\tCreate third account failed.", tests.Failed)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Associate the third account with root user.
|
||||||
|
usrAcc3, err := user_account.Create(ctx, auth.Claims{}, test.MasterDB, user_account.UserAccountCreateRequest{
|
||||||
|
UserID: usrAcc.UserID,
|
||||||
|
AccountID: acc3.ID,
|
||||||
|
Roles: []user_account.UserAccountRole{user_account.UserAccountRole(roles[2])},
|
||||||
|
}, now)
|
||||||
|
if err != nil {
|
||||||
|
t.Log("\t\tGot :", err)
|
||||||
|
t.Fatalf("\t%s\tLinking third account to user failed.", tests.Failed)
|
||||||
|
}
|
||||||
|
|
||||||
|
authTests = append(authTests, authTest{
|
||||||
|
name: fmt.Sprintf("Root account role %s -> role %s account 2 -> role %s account 3.",
|
||||||
|
roles[0], roles[1], roles[2]),
|
||||||
|
root: usrAcc,
|
||||||
|
switch1Req: SwitchAccountRequest{AccountID: acc2.ID},
|
||||||
|
switch1Roles: usrAcc2.Roles,
|
||||||
|
switch1Err: nil,
|
||||||
|
switch2Req: SwitchAccountRequest{AccountID: acc3.ID},
|
||||||
|
switch2Err: nil,
|
||||||
|
switch2Roles: usrAcc3.Roles,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Root account 1 -> invalid account 2
|
||||||
|
if true {
|
||||||
|
// Create a new user for testing.
|
||||||
|
usrAcc, err := user_account.MockUserAccount(ctx, test.MasterDB, now, user_account.UserAccountRole_Admin)
|
||||||
|
if err != nil {
|
||||||
|
t.Log("\t\tGot :", err)
|
||||||
|
t.Fatalf("\t%s\tCreate user account failed.", tests.Failed)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create the second account and don't associate it with the root user.
|
||||||
|
now = now.Add(time.Minute)
|
||||||
|
acc2, err := account.MockAccount(ctx, test.MasterDB, now)
|
||||||
|
if err != nil {
|
||||||
|
t.Log("\t\tGot :", err)
|
||||||
|
t.Fatalf("\t%s\tCreate second account failed.", tests.Failed)
|
||||||
|
}
|
||||||
|
|
||||||
|
authTests = append(authTests, authTest{
|
||||||
|
name: "Root account 1 -> invalid account 2.",
|
||||||
|
root: usrAcc,
|
||||||
|
switch1Req: SwitchAccountRequest{AccountID: acc2.ID},
|
||||||
|
switch1Err: ErrAuthenticationFailure,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Root account 1 -> valid account 2 with scopes -> valid account 3 with invalid scope.
|
||||||
|
if true {
|
||||||
|
// Create a new user for testing.
|
||||||
|
usrAcc, err := user_account.MockUserAccount(ctx, test.MasterDB, now, user_account.UserAccountRole_Admin)
|
||||||
|
if err != nil {
|
||||||
|
t.Log("\t\tGot :", err)
|
||||||
|
t.Fatalf("\t%s\tCreate user account failed.", tests.Failed)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create the second account.
|
||||||
|
now = now.Add(time.Minute)
|
||||||
|
acc2, err := account.MockAccount(ctx, test.MasterDB, now)
|
||||||
|
if err != nil {
|
||||||
|
t.Log("\t\tGot :", err)
|
||||||
|
t.Fatalf("\t%s\tCreate second account failed.", tests.Failed)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Associate the second account with root user.
|
||||||
|
usrAcc2, err := user_account.Create(ctx, auth.Claims{}, test.MasterDB, user_account.UserAccountCreateRequest{
|
||||||
|
UserID: usrAcc.UserID,
|
||||||
|
AccountID: acc2.ID,
|
||||||
|
Roles: []user_account.UserAccountRole{user_account.UserAccountRole_Admin},
|
||||||
|
}, now)
|
||||||
|
if err != nil {
|
||||||
|
t.Log("\t\tGot :", err)
|
||||||
|
t.Fatalf("\t%s\tLinking second account to user failed.", tests.Failed)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create the third account.
|
||||||
|
now = now.Add(time.Minute)
|
||||||
|
acc3, err := account.MockAccount(ctx, test.MasterDB, now)
|
||||||
|
if err != nil {
|
||||||
|
t.Log("\t\tGot :", err)
|
||||||
|
t.Fatalf("\t%s\tCreate third account failed.", tests.Failed)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Associate the third account with root user.
|
||||||
|
usrAcc3, err := user_account.Create(ctx, auth.Claims{}, test.MasterDB, user_account.UserAccountCreateRequest{
|
||||||
|
UserID: usrAcc.UserID,
|
||||||
|
AccountID: acc3.ID,
|
||||||
|
Roles: []user_account.UserAccountRole{user_account.UserAccountRole_User},
|
||||||
|
}, now)
|
||||||
|
if err != nil {
|
||||||
|
t.Log("\t\tGot :", err)
|
||||||
|
t.Fatalf("\t%s\tLinking third account to user failed.", tests.Failed)
|
||||||
|
}
|
||||||
|
|
||||||
|
authTests = append(authTests, authTest{
|
||||||
|
name: "Root account 1 -> valid account 2 with scopes -> valid account 3 with invalid scope.",
|
||||||
|
root: usrAcc,
|
||||||
|
switch1Req: SwitchAccountRequest{AccountID: acc2.ID},
|
||||||
|
switch1Roles: usrAcc2.Roles,
|
||||||
|
switch1Scopes: []string{user_account.UserAccountRole_User.String()},
|
||||||
|
switch1Err: nil,
|
||||||
|
switch2Req: SwitchAccountRequest{AccountID: acc3.ID},
|
||||||
|
switch2Roles: usrAcc3.Roles,
|
||||||
|
switch2Scopes: []string{user_account.UserAccountRole_Admin.String()},
|
||||||
|
switch2Err: ErrForbidden,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add 30 minutes to now to simulate time passing.
|
||||||
|
now = now.Add(time.Minute * 5)
|
||||||
|
|
||||||
|
tknGen := &auth.MockTokenGenerator{}
|
||||||
|
|
||||||
|
t.Log("Given the need to switch accounts.")
|
||||||
|
{
|
||||||
|
for i, authTest := range authTests {
|
||||||
|
t.Logf("\tTest: %d\tWhen running test: %s", i, authTest.name)
|
||||||
|
{
|
||||||
|
// Verify that the user can be authenticated with the created user.
|
||||||
|
var claims1 auth.Claims
|
||||||
|
tkn1, err := Authenticate(ctx, test.MasterDB, tknGen, authTest.root.User.Email, authTest.root.User.Password, time.Hour, now)
|
||||||
|
if err != nil {
|
||||||
|
t.Log("\t\tGot :", err)
|
||||||
|
t.Fatalf("\t%s\tAuthenticate user failed.", tests.Failed)
|
||||||
|
} else {
|
||||||
|
// Ensure the token string was correctly generated.
|
||||||
|
claims1, err = tknGen.ParseClaims(tkn1.AccessToken)
|
||||||
|
if err != nil {
|
||||||
|
t.Log("\t\tGot :", err)
|
||||||
|
t.Fatalf("\t%s\tParse claims from token failed.", tests.Failed)
|
||||||
|
}
|
||||||
|
expectClaims := tkn1.claims
|
||||||
|
expectClaims.RootUserID = ""
|
||||||
|
expectClaims.RootAccountID = ""
|
||||||
|
expectClaims.Subject = authTest.root.UserID
|
||||||
|
expectClaims.Audience = authTest.root.AccountID
|
||||||
|
expectClaims.Roles = rolesStringSlice(authTest.root.Roles)
|
||||||
|
|
||||||
|
if diff := cmpClaims(claims1, expectClaims); diff != "" {
|
||||||
|
t.Fatalf("\t%s\tExpected parsed claims to match from token. Diff:\n%s", tests.Failed, diff)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
t.Logf("\t%s\tAuthenticate root user with role %v ok.", tests.Success, authTest.root.Roles)
|
||||||
|
|
||||||
|
// Try to switch to account 2.
|
||||||
|
var claims2 auth.Claims
|
||||||
|
tkn2, err := SwitchAccount(ctx, test.MasterDB, tknGen, claims1, authTest.switch1Req, time.Hour, now, authTest.switch1Scopes...)
|
||||||
|
if err != authTest.switch1Err {
|
||||||
|
if errors.Cause(err) != authTest.switch1Err {
|
||||||
|
t.Log("\t\tExpected :", authTest.switch1Err)
|
||||||
|
t.Log("\t\tGot :", err)
|
||||||
|
t.Fatalf("\t%s\tSwitchAccount account 1 with role %v failed.", tests.Failed, authTest.switch1Roles)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Ensure the token string was correctly generated.
|
||||||
|
claims2, err = tknGen.ParseClaims(tkn2.AccessToken)
|
||||||
|
if err != nil {
|
||||||
|
t.Log("\t\tGot :", err)
|
||||||
|
t.Fatalf("\t%s\tParse claims from token failed.", tests.Failed)
|
||||||
|
}
|
||||||
|
expectClaims := tkn2.claims
|
||||||
|
expectClaims.RootUserID = authTest.root.UserID
|
||||||
|
expectClaims.RootAccountID = authTest.switch1Req.AccountID
|
||||||
|
expectClaims.Subject = authTest.root.UserID
|
||||||
|
expectClaims.Audience = authTest.switch1Req.AccountID
|
||||||
|
|
||||||
|
if len(authTest.switch1Scopes) > 0 {
|
||||||
|
expectClaims.Roles = authTest.switch1Scopes
|
||||||
|
} else {
|
||||||
|
expectClaims.Roles = rolesStringSlice(authTest.switch1Roles)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmpClaims(claims2, expectClaims); diff != "" {
|
||||||
|
t.Fatalf("\t%s\tExpected parsed claims to match from token. Diff:\n%s", tests.Failed, diff)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
t.Logf("\t%s\tSwitchAccount account 1 with role %v ok.", tests.Success, authTest.switch1Roles)
|
||||||
|
|
||||||
|
// If the user can't login, don't need to test any further.
|
||||||
|
if authTest.switch1Err != nil || authTest.switch2Req.AccountID == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to switch to account 3.
|
||||||
|
tkn3, err := SwitchAccount(ctx, test.MasterDB, tknGen, claims2, authTest.switch2Req, time.Hour, now, authTest.switch2Scopes...)
|
||||||
|
if err != authTest.switch2Err {
|
||||||
|
if errors.Cause(err) != authTest.switch2Err {
|
||||||
|
t.Log("\t\tExpected :", authTest.switch2Err)
|
||||||
|
t.Log("\t\tGot :", err)
|
||||||
|
t.Fatalf("\t%s\tSwitchAccount account 2 with role %v failed.", tests.Failed, authTest.switch2Roles)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Ensure the token string was correctly generated.
|
||||||
|
claims3, err := tknGen.ParseClaims(tkn3.AccessToken)
|
||||||
|
if err != nil {
|
||||||
|
t.Log("\t\tGot :", err)
|
||||||
|
t.Fatalf("\t%s\tParse claims from token failed.", tests.Failed)
|
||||||
|
}
|
||||||
|
expectClaims := tkn3.claims
|
||||||
|
expectClaims.RootUserID = authTest.root.UserID
|
||||||
|
expectClaims.RootAccountID = authTest.switch2Req.AccountID
|
||||||
|
expectClaims.Subject = authTest.root.UserID
|
||||||
|
expectClaims.Audience = authTest.switch2Req.AccountID
|
||||||
|
|
||||||
|
if len(authTest.switch2Scopes) > 0 {
|
||||||
|
expectClaims.Roles = authTest.switch2Scopes
|
||||||
|
} else {
|
||||||
|
expectClaims.Roles = rolesStringSlice(authTest.switch2Roles)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmpClaims(claims3, expectClaims); diff != "" {
|
||||||
|
t.Fatalf("\t%s\tExpected parsed claims to match from token. Diff:\n%s", tests.Failed, diff)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
t.Logf("\t%s\tSwitchAccount account 2 with role %v ok.", tests.Success, authTest.switch2Roles)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestVirtualLogin validates the behavior around allowing users to virtual login users.
|
||||||
|
func TestVirtualLogin(t *testing.T) {
|
||||||
|
defer tests.Recover(t)
|
||||||
|
|
||||||
|
// Auth tokens are valid for an our and is verified against current time.
|
||||||
|
// Issue the token one hour ago.
|
||||||
|
now := time.Now().Add(time.Hour * -1)
|
||||||
|
|
||||||
|
ctx := tests.Context()
|
||||||
|
|
||||||
|
type authTest struct {
|
||||||
|
name string
|
||||||
|
root *user_account.MockUserAccountResponse
|
||||||
|
login1Req VirtualLoginRequest
|
||||||
|
login1Err error
|
||||||
|
login1Role user_account.UserAccountRole
|
||||||
|
login2Req VirtualLoginRequest
|
||||||
|
login2Err error
|
||||||
|
login2Role user_account.UserAccountRole
|
||||||
|
login2Logout bool
|
||||||
|
}
|
||||||
|
|
||||||
|
var authTests []authTest
|
||||||
|
|
||||||
|
// Root admin -> role admin -> role admin
|
||||||
|
if true {
|
||||||
|
// Create a new user for testing.
|
||||||
|
usrAcc, err := user_account.MockUserAccount(ctx, test.MasterDB, now, user_account.UserAccountRole_Admin)
|
||||||
|
if err != nil {
|
||||||
|
t.Log("\t\tGot :", err)
|
||||||
|
t.Fatalf("\t%s\tCreate user account failed.", tests.Failed)
|
||||||
|
}
|
||||||
|
|
||||||
|
usr2, err := user.MockUser(ctx, test.MasterDB, now)
|
||||||
|
if err != nil {
|
||||||
|
t.Log("\t\tGot :", err)
|
||||||
|
t.Fatalf("\t%s\tCreate second account failed.", tests.Failed)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Associate second user with basic role associated with the same account.
|
||||||
|
usrAcc2, err := user_account.Create(ctx, auth.Claims{}, test.MasterDB, user_account.UserAccountCreateRequest{
|
||||||
|
UserID: usr2.ID,
|
||||||
|
AccountID: usrAcc.AccountID,
|
||||||
|
Roles: []user_account.UserAccountRole{user_account.UserAccountRole(user_account.UserAccountRole_Admin)},
|
||||||
|
}, now)
|
||||||
|
if err != nil {
|
||||||
|
t.Log("\t\tGot :", err)
|
||||||
|
t.Fatalf("\t%s\tLinking second user to account failed.", tests.Failed)
|
||||||
|
}
|
||||||
|
|
||||||
|
usr3, err := user.MockUser(ctx, test.MasterDB, now)
|
||||||
|
if err != nil {
|
||||||
|
t.Log("\t\tGot :", err)
|
||||||
|
t.Fatalf("\t%s\tCreate second account failed.", tests.Failed)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Associate second user with basic role associated with the same account.
|
||||||
|
usrAcc3, err := user_account.Create(ctx, auth.Claims{}, test.MasterDB, user_account.UserAccountCreateRequest{
|
||||||
|
UserID: usr3.ID,
|
||||||
|
AccountID: usrAcc.AccountID,
|
||||||
|
Roles: []user_account.UserAccountRole{user_account.UserAccountRole(user_account.UserAccountRole_Admin)},
|
||||||
|
}, now)
|
||||||
|
if err != nil {
|
||||||
|
t.Log("\t\tGot :", err)
|
||||||
|
t.Fatalf("\t%s\tLinking third user to account failed.", tests.Failed)
|
||||||
|
}
|
||||||
|
|
||||||
|
authTests = append(authTests, authTest{
|
||||||
|
name: "Root admin -> role admin -> role admin",
|
||||||
|
root: usrAcc,
|
||||||
|
login1Req: VirtualLoginRequest{
|
||||||
|
UserID: usr2.ID,
|
||||||
|
AccountID: usrAcc.AccountID,
|
||||||
|
},
|
||||||
|
login1Role: usrAcc2.Roles[0],
|
||||||
|
login1Err: nil,
|
||||||
|
login2Req: VirtualLoginRequest{
|
||||||
|
UserID: usr3.ID,
|
||||||
|
AccountID: usrAcc.AccountID,
|
||||||
|
},
|
||||||
|
login2Err: nil,
|
||||||
|
login2Role: usrAcc3.Roles[0],
|
||||||
|
login2Logout: true,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Root admin -> role admin -> role user
|
||||||
|
if true {
|
||||||
|
// Create a new user for testing.
|
||||||
|
usrAcc, err := user_account.MockUserAccount(ctx, test.MasterDB, now, user_account.UserAccountRole_Admin)
|
||||||
|
if err != nil {
|
||||||
|
t.Log("\t\tGot :", err)
|
||||||
|
t.Fatalf("\t%s\tCreate user account failed.", tests.Failed)
|
||||||
|
}
|
||||||
|
|
||||||
|
usr2, err := user.MockUser(ctx, test.MasterDB, now)
|
||||||
|
if err != nil {
|
||||||
|
t.Log("\t\tGot :", err)
|
||||||
|
t.Fatalf("\t%s\tCreate second account failed.", tests.Failed)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Associate second user with basic role associated with the same account.
|
||||||
|
usrAcc2, err := user_account.Create(ctx, auth.Claims{}, test.MasterDB, user_account.UserAccountCreateRequest{
|
||||||
|
UserID: usr2.ID,
|
||||||
|
AccountID: usrAcc.AccountID,
|
||||||
|
Roles: []user_account.UserAccountRole{user_account.UserAccountRole(user_account.UserAccountRole_Admin)},
|
||||||
|
}, now)
|
||||||
|
if err != nil {
|
||||||
|
t.Log("\t\tGot :", err)
|
||||||
|
t.Fatalf("\t%s\tLinking second user to account failed.", tests.Failed)
|
||||||
|
}
|
||||||
|
|
||||||
|
usr3, err := user.MockUser(ctx, test.MasterDB, now)
|
||||||
|
if err != nil {
|
||||||
|
t.Log("\t\tGot :", err)
|
||||||
|
t.Fatalf("\t%s\tCreate second account failed.", tests.Failed)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Associate second user with basic role associated with the same account.
|
||||||
|
usrAcc3, err := user_account.Create(ctx, auth.Claims{}, test.MasterDB, user_account.UserAccountCreateRequest{
|
||||||
|
UserID: usr3.ID,
|
||||||
|
AccountID: usrAcc.AccountID,
|
||||||
|
Roles: []user_account.UserAccountRole{user_account.UserAccountRole(user_account.UserAccountRole_User)},
|
||||||
|
}, now)
|
||||||
|
if err != nil {
|
||||||
|
t.Log("\t\tGot :", err)
|
||||||
|
t.Fatalf("\t%s\tLinking third user to account failed.", tests.Failed)
|
||||||
|
}
|
||||||
|
|
||||||
|
authTests = append(authTests, authTest{
|
||||||
|
name: "Root admin -> role admin -> role user",
|
||||||
|
root: usrAcc,
|
||||||
|
login1Req: VirtualLoginRequest{
|
||||||
|
UserID: usr2.ID,
|
||||||
|
AccountID: usrAcc.AccountID,
|
||||||
|
},
|
||||||
|
login1Err: nil,
|
||||||
|
login1Role: usrAcc2.Roles[0],
|
||||||
|
login2Req: VirtualLoginRequest{
|
||||||
|
UserID: usr3.ID,
|
||||||
|
AccountID: usrAcc.AccountID,
|
||||||
|
},
|
||||||
|
login2Err: nil,
|
||||||
|
login2Role: usrAcc3.Roles[0],
|
||||||
|
login2Logout: true,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Root admin -> role user -> role admin
|
||||||
|
if true {
|
||||||
|
// Create a new user for testing.
|
||||||
|
usrAcc, err := user_account.MockUserAccount(ctx, test.MasterDB, now, user_account.UserAccountRole_Admin)
|
||||||
|
if err != nil {
|
||||||
|
t.Log("\t\tGot :", err)
|
||||||
|
t.Fatalf("\t%s\tCreate user account failed.", tests.Failed)
|
||||||
|
}
|
||||||
|
|
||||||
|
usr2, err := user.MockUser(ctx, test.MasterDB, now)
|
||||||
|
if err != nil {
|
||||||
|
t.Log("\t\tGot :", err)
|
||||||
|
t.Fatalf("\t%s\tCreate second account failed.", tests.Failed)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Associate second user with basic role associated with the same account.
|
||||||
|
usrAcc2, err := user_account.Create(ctx, auth.Claims{}, test.MasterDB, user_account.UserAccountCreateRequest{
|
||||||
|
UserID: usr2.ID,
|
||||||
|
AccountID: usrAcc.AccountID,
|
||||||
|
Roles: []user_account.UserAccountRole{user_account.UserAccountRole(user_account.UserAccountRole_User)},
|
||||||
|
}, now)
|
||||||
|
if err != nil {
|
||||||
|
t.Log("\t\tGot :", err)
|
||||||
|
t.Fatalf("\t%s\tLinking second user to account failed.", tests.Failed)
|
||||||
|
}
|
||||||
|
|
||||||
|
usr3, err := user.MockUser(ctx, test.MasterDB, now)
|
||||||
|
if err != nil {
|
||||||
|
t.Log("\t\tGot :", err)
|
||||||
|
t.Fatalf("\t%s\tCreate second account failed.", tests.Failed)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Associate second user with basic role associated with the same account.
|
||||||
|
usrAcc3, err := user_account.Create(ctx, auth.Claims{}, test.MasterDB, user_account.UserAccountCreateRequest{
|
||||||
|
UserID: usr3.ID,
|
||||||
|
AccountID: usrAcc.AccountID,
|
||||||
|
Roles: []user_account.UserAccountRole{user_account.UserAccountRole(user_account.UserAccountRole_Admin)},
|
||||||
|
}, now)
|
||||||
|
if err != nil {
|
||||||
|
t.Log("\t\tGot :", err)
|
||||||
|
t.Fatalf("\t%s\tLinking third user to account failed.", tests.Failed)
|
||||||
|
}
|
||||||
|
|
||||||
|
authTests = append(authTests, authTest{
|
||||||
|
name: "Root admin -> role user -> role admin",
|
||||||
|
root: usrAcc,
|
||||||
|
login1Req: VirtualLoginRequest{
|
||||||
|
UserID: usr2.ID,
|
||||||
|
AccountID: usrAcc.AccountID,
|
||||||
|
},
|
||||||
|
login1Err: nil,
|
||||||
|
login1Role: usrAcc2.Roles[0],
|
||||||
|
login2Req: VirtualLoginRequest{
|
||||||
|
UserID: usr3.ID,
|
||||||
|
AccountID: usrAcc.AccountID,
|
||||||
|
},
|
||||||
|
login2Err: ErrForbidden,
|
||||||
|
login2Role: usrAcc3.Roles[0],
|
||||||
|
login2Logout: true,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Root user -> role admin
|
||||||
|
if true {
|
||||||
|
// Create a new user for testing.
|
||||||
|
usrAcc, err := user_account.MockUserAccount(ctx, test.MasterDB, now, user_account.UserAccountRole_User)
|
||||||
|
if err != nil {
|
||||||
|
t.Log("\t\tGot :", err)
|
||||||
|
t.Fatalf("\t%s\tCreate user account failed.", tests.Failed)
|
||||||
|
}
|
||||||
|
|
||||||
|
usr2, err := user.MockUser(ctx, test.MasterDB, now)
|
||||||
|
if err != nil {
|
||||||
|
t.Log("\t\tGot :", err)
|
||||||
|
t.Fatalf("\t%s\tCreate second account failed.", tests.Failed)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Associate second user with basic role associated with the same account.
|
||||||
|
usrAcc2, err := user_account.Create(ctx, auth.Claims{}, test.MasterDB, user_account.UserAccountCreateRequest{
|
||||||
|
UserID: usr2.ID,
|
||||||
|
AccountID: usrAcc.AccountID,
|
||||||
|
Roles: []user_account.UserAccountRole{user_account.UserAccountRole(user_account.UserAccountRole_Admin)},
|
||||||
|
}, now)
|
||||||
|
if err != nil {
|
||||||
|
t.Log("\t\tGot :", err)
|
||||||
|
t.Fatalf("\t%s\tLinking second user to account failed.", tests.Failed)
|
||||||
|
}
|
||||||
|
|
||||||
|
authTests = append(authTests, authTest{
|
||||||
|
name: "Root user -> role admin",
|
||||||
|
root: usrAcc,
|
||||||
|
login1Req: VirtualLoginRequest{
|
||||||
|
UserID: usr2.ID,
|
||||||
|
AccountID: usrAcc.AccountID,
|
||||||
|
},
|
||||||
|
login1Err: ErrForbidden,
|
||||||
|
login1Role: usrAcc2.Roles[0],
|
||||||
|
login2Logout: true,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Root user -> role user
|
||||||
|
if true {
|
||||||
|
// Create a new user for testing.
|
||||||
|
usrAcc, err := user_account.MockUserAccount(ctx, test.MasterDB, now, user_account.UserAccountRole_User)
|
||||||
|
if err != nil {
|
||||||
|
t.Log("\t\tGot :", err)
|
||||||
|
t.Fatalf("\t%s\tCreate user account failed.", tests.Failed)
|
||||||
|
}
|
||||||
|
|
||||||
|
usr2, err := user.MockUser(ctx, test.MasterDB, now)
|
||||||
|
if err != nil {
|
||||||
|
t.Log("\t\tGot :", err)
|
||||||
|
t.Fatalf("\t%s\tCreate second account failed.", tests.Failed)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Associate second user with basic role associated with the same account.
|
||||||
|
usrAcc2, err := user_account.Create(ctx, auth.Claims{}, test.MasterDB, user_account.UserAccountCreateRequest{
|
||||||
|
UserID: usr2.ID,
|
||||||
|
AccountID: usrAcc.AccountID,
|
||||||
|
Roles: []user_account.UserAccountRole{user_account.UserAccountRole(user_account.UserAccountRole_User)},
|
||||||
|
}, now)
|
||||||
|
if err != nil {
|
||||||
|
t.Log("\t\tGot :", err)
|
||||||
|
t.Fatalf("\t%s\tLinking second user to account failed.", tests.Failed)
|
||||||
|
}
|
||||||
|
|
||||||
|
authTests = append(authTests, authTest{
|
||||||
|
name: "Root user -> role admin",
|
||||||
|
root: usrAcc,
|
||||||
|
login1Req: VirtualLoginRequest{
|
||||||
|
UserID: usr2.ID,
|
||||||
|
AccountID: usrAcc.AccountID,
|
||||||
|
},
|
||||||
|
login1Err: ErrForbidden,
|
||||||
|
login1Role: usrAcc2.Roles[0],
|
||||||
|
login2Logout: true,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add 30 minutes to now to simulate time passing.
|
||||||
|
now = now.Add(time.Minute * 5)
|
||||||
|
|
||||||
|
tknGen := &auth.MockTokenGenerator{}
|
||||||
|
|
||||||
|
t.Log("Given the need to virtual login.")
|
||||||
|
{
|
||||||
|
for i, authTest := range authTests {
|
||||||
|
t.Logf("\tTest: %d\tWhen running test: %s", i, authTest.name)
|
||||||
|
{
|
||||||
|
// Verify that the user can be authenticated with the created user.
|
||||||
|
var claims1 auth.Claims
|
||||||
|
tkn1, err := Authenticate(ctx, test.MasterDB, tknGen, authTest.root.User.Email, authTest.root.User.Password, time.Hour, now)
|
||||||
|
if err != nil {
|
||||||
|
t.Log("\t\tGot :", err)
|
||||||
|
t.Fatalf("\t%s\tAuthenticate user failed.", tests.Failed)
|
||||||
|
} else {
|
||||||
|
// Ensure the token string was correctly generated.
|
||||||
|
claims1, err = tknGen.ParseClaims(tkn1.AccessToken)
|
||||||
|
if err != nil {
|
||||||
|
t.Log("\t\tGot :", err)
|
||||||
|
t.Fatalf("\t%s\tParse claims from token failed.", tests.Failed)
|
||||||
|
}
|
||||||
|
expectClaims := tkn1.claims
|
||||||
|
expectClaims.RootUserID = ""
|
||||||
|
expectClaims.RootAccountID = ""
|
||||||
|
expectClaims.Subject = authTest.root.UserID
|
||||||
|
expectClaims.Audience = authTest.root.AccountID
|
||||||
|
|
||||||
|
// Hack for Unhandled Exception in go-cmp@v0.3.0/cmp/options.go:229
|
||||||
|
if diff := cmpClaims(claims1, expectClaims); diff != "" {
|
||||||
|
t.Fatalf("\t%s\tExpected parsed claims to match from token. Diff:\n%s", tests.Failed, diff)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
t.Logf("\t%s\tAuthenticate root user with role %s ok.", tests.Success, authTest.root.Roles[0])
|
||||||
|
|
||||||
|
// Try virtual login to user 2.
|
||||||
|
var claims2 auth.Claims
|
||||||
|
tkn2, err := VirtualLogin(ctx, test.MasterDB, tknGen, claims1, authTest.login1Req, time.Hour, now)
|
||||||
|
if err != authTest.login1Err {
|
||||||
|
if errors.Cause(err) != authTest.login1Err {
|
||||||
|
t.Log("\t\tExpected :", authTest.login1Err)
|
||||||
|
t.Log("\t\tGot :", err)
|
||||||
|
t.Fatalf("\t%s\tVirtualLogin user 1 with role %s failed.", tests.Failed, authTest.login1Role)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Ensure the token string was correctly generated.
|
||||||
|
claims2, err = tknGen.ParseClaims(tkn2.AccessToken)
|
||||||
|
if err != nil {
|
||||||
|
t.Log("\t\tGot :", err)
|
||||||
|
t.Fatalf("\t%s\tParse claims from token failed.", tests.Failed)
|
||||||
|
}
|
||||||
|
expectClaims := tkn2.claims
|
||||||
|
expectClaims.RootUserID = authTest.root.UserID
|
||||||
|
expectClaims.RootAccountID = authTest.root.AccountID
|
||||||
|
expectClaims.Subject = authTest.login1Req.UserID
|
||||||
|
expectClaims.Audience = authTest.login1Req.AccountID
|
||||||
|
|
||||||
|
// Hack for Unhandled Exception in go-cmp@v0.3.0/cmp/options.go:229
|
||||||
|
if diff := cmpClaims(claims2, expectClaims); diff != "" {
|
||||||
|
t.Fatalf("\t%s\tExpected parsed claims to match from token. Diff:\n%s", tests.Failed, diff)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
t.Logf("\t%s\tVirtualLogin user 1 with role %s ok.", tests.Success, authTest.login1Role)
|
||||||
|
|
||||||
|
// If the user can't login, don't need to test any further.
|
||||||
|
if authTest.login1Err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try virtual login to user 3.
|
||||||
|
tkn3, err := VirtualLogin(ctx, test.MasterDB, tknGen, claims2, authTest.login2Req, time.Hour, now)
|
||||||
|
if err != authTest.login2Err {
|
||||||
|
if errors.Cause(err) != authTest.login2Err {
|
||||||
|
t.Log("\t\tExpected :", authTest.login2Err)
|
||||||
|
t.Log("\t\tGot :", err)
|
||||||
|
t.Fatalf("\t%s\tVirtualLogin user 2 with role %s failed.", tests.Failed, authTest.login2Role)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Ensure the token string was correctly generated.
|
||||||
|
claims3, err := tknGen.ParseClaims(tkn3.AccessToken)
|
||||||
|
if err != nil {
|
||||||
|
t.Log("\t\tGot :", err)
|
||||||
|
t.Fatalf("\t%s\tParse claims from token failed.", tests.Failed)
|
||||||
|
}
|
||||||
|
expectClaims := tkn3.claims
|
||||||
|
expectClaims.RootUserID = authTest.root.UserID
|
||||||
|
expectClaims.RootAccountID = authTest.root.AccountID
|
||||||
|
expectClaims.Subject = authTest.login2Req.UserID
|
||||||
|
expectClaims.Audience = authTest.login2Req.AccountID
|
||||||
|
|
||||||
|
// Hack for Unhandled Exception in go-cmp@v0.3.0/cmp/options.go:229
|
||||||
|
if diff := cmpClaims(claims3, expectClaims); diff != "" {
|
||||||
|
t.Fatalf("\t%s\tExpected parsed claims to match from token. Diff:\n%s", tests.Failed, diff)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
t.Logf("\t%s\tVirtualLogin user 2 with role %s ok.", tests.Success, authTest.login2Role)
|
||||||
|
|
||||||
|
if authTest.login2Logout {
|
||||||
|
tknOut, err := VirtualLogout(ctx, test.MasterDB, tknGen, claims2, time.Hour, now)
|
||||||
|
if err != nil {
|
||||||
|
t.Log("\t\tGot :", err)
|
||||||
|
t.Fatalf("\t%s\tVirtualLogout user 2 failed.", tests.Failed)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure the token string was correctly generated.
|
||||||
|
claimsOut, err := tknGen.ParseClaims(tknOut.AccessToken)
|
||||||
|
if err != nil {
|
||||||
|
t.Log("\t\tGot :", err)
|
||||||
|
t.Fatalf("\t%s\tParse claims from token failed.", tests.Failed)
|
||||||
|
}
|
||||||
|
expectClaims := tknOut.claims
|
||||||
|
expectClaims.RootUserID = authTest.root.UserID
|
||||||
|
expectClaims.RootAccountID = authTest.root.AccountID
|
||||||
|
expectClaims.Subject = authTest.root.UserID
|
||||||
|
expectClaims.Audience = authTest.root.AccountID
|
||||||
|
|
||||||
|
if diff := cmpClaims(claimsOut, expectClaims); diff != "" {
|
||||||
|
t.Fatalf("\t%s\tExpected parsed claims to match from token. Diff:\n%s", tests.Failed, diff)
|
||||||
|
}
|
||||||
|
t.Logf("\t%s\tVirtualLogout user 2 with role %s ok.", tests.Success, authTest.login2Role)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// rolesStringSlice converts a list of roles to a string slice.
|
||||||
|
func rolesStringSlice(roles []user_account.UserAccountRole) []string {
|
||||||
|
var l []string
|
||||||
|
for _, r := range roles {
|
||||||
|
l = append(l, string(r))
|
||||||
|
}
|
||||||
|
return l
|
||||||
|
}
|
||||||
|
|
||||||
|
// cmpClaims is a hack for Unhandled Exception in go-cmp@v0.3.0/cmp/options.go:229
|
||||||
|
func cmpClaims(actualClaims, expectedclaims auth.Claims) string {
|
||||||
|
dat1, _ := json.Marshal(actualClaims)
|
||||||
|
dat2, _ := json.Marshal(expectedclaims)
|
||||||
|
return cmp.Diff(string(dat1), string(dat2))
|
||||||
|
}
|
||||||
|
@ -35,6 +35,17 @@ type Token struct {
|
|||||||
AccountID string `json:"account_id"example:"c4653bf9-5978-48b7-89c5-95704aebb7e2"`
|
AccountID string `json:"account_id"example:"c4653bf9-5978-48b7-89c5-95704aebb7e2"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SwitchAccountRequest defines the information for the current user to switch between their accounts
|
||||||
|
type SwitchAccountRequest struct {
|
||||||
|
AccountID string `json:"account_id" validate:"required,uuid" example:"c4653bf9-5978-48b7-89c5-95704aebb7e2"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// VirtualLoginRequest defines the information virtual login to a user / account.
|
||||||
|
type VirtualLoginRequest struct {
|
||||||
|
UserID string `json:"user_id" validate:"required,uuid" example:"d69bdef7-173f-4d29-b52c-3edc60baf6a2"`
|
||||||
|
AccountID string `json:"account_id" validate:"required,uuid" example:"c4653bf9-5978-48b7-89c5-95704aebb7e2"`
|
||||||
|
}
|
||||||
|
|
||||||
// AuthorizationHeader returns the header authorization value.
|
// AuthorizationHeader returns the header authorization value.
|
||||||
func (t Token) AuthorizationHeader() string {
|
func (t Token) AuthorizationHeader() string {
|
||||||
return "Bearer " + t.AccessToken
|
return "Bearer " + t.AccessToken
|
||||||
|
Reference in New Issue
Block a user