1
0
mirror of https://github.com/raseels-repos/golang-saas-starter-kit.git synced 2025-06-17 00:17:59 +02:00

moved auth from user package and added timezone to context values

This commit is contained in:
Lee Brown
2019-08-04 14:48:43 -08:00
parent fad0801379
commit bb9820ffcc
62 changed files with 3740 additions and 1008 deletions

View File

@ -2,15 +2,14 @@ package handlers
import ( import (
"context" "context"
"fmt"
"geeks-accelerator/oss/saas-starter-kit/internal/platform/web/webcontext"
"geeks-accelerator/oss/saas-starter-kit/internal/platform/web/weberror"
"net/http" "net/http"
"strconv" "strconv"
"geeks-accelerator/oss/saas-starter-kit/internal/account" "geeks-accelerator/oss/saas-starter-kit/internal/account"
"geeks-accelerator/oss/saas-starter-kit/internal/platform/auth" "geeks-accelerator/oss/saas-starter-kit/internal/platform/auth"
"geeks-accelerator/oss/saas-starter-kit/internal/platform/web" "geeks-accelerator/oss/saas-starter-kit/internal/platform/web"
"geeks-accelerator/oss/saas-starter-kit/internal/platform/web/webcontext"
"geeks-accelerator/oss/saas-starter-kit/internal/platform/web/weberror"
"github.com/jmoiron/sqlx" "github.com/jmoiron/sqlx"
"github.com/pkg/errors" "github.com/pkg/errors"
"gopkg.in/go-playground/validator.v9" "gopkg.in/go-playground/validator.v9"
@ -42,25 +41,25 @@ func (a *Account) Read(ctx context.Context, w http.ResponseWriter, r *http.Reque
return errors.New("claims missing from context") return errors.New("claims missing from context")
} }
// Handle included-archived query value if set. // Handle include-archived query value if set.
var includeArchived bool var includeArchived bool
if v := r.URL.Query().Get("included-archived"); v != "" { if v := r.URL.Query().Get("include-archived"); v != "" {
b, err := strconv.ParseBool(v) b, err := strconv.ParseBool(v)
if err != nil { if err != nil {
err = errors.WithMessagef(err, "unable to parse %s as boolean for included-archived param", v) err = errors.WithMessagef(err, "unable to parse %s as boolean for include-archived param", v)
return web.RespondJsonError(ctx, w, weberror.NewError(ctx, err, http.StatusBadRequest)) return web.RespondJsonError(ctx, w, weberror.NewError(ctx, err, http.StatusBadRequest))
} }
includeArchived = b includeArchived = b
} }
res, err := account.Read(ctx, claims, a.MasterDB, params["id"], includeArchived) res, err := account.Read(ctx, claims, a.MasterDB, account.AccountReadRequest{
ID: params["id"],
IncludeArchived: includeArchived,
})
if err != nil { if err != nil {
cause := errors.Cause(err) cause := errors.Cause(err)
switch cause { switch cause {
case account.ErrNotFound: case account.ErrNotFound:
fmt.Println("HERE!!!!! account.ErrNotFound")
return web.RespondJsonError(ctx, w, weberror.NewError(ctx, err, http.StatusNotFound)) return web.RespondJsonError(ctx, w, weberror.NewError(ctx, err, http.StatusNotFound))
default: default:
return errors.Wrapf(err, "ID: %s", params["id"]) return errors.Wrapf(err, "ID: %s", params["id"])

View File

@ -2,14 +2,14 @@ package handlers
import ( import (
"context" "context"
"geeks-accelerator/oss/saas-starter-kit/internal/platform/web/webcontext"
"geeks-accelerator/oss/saas-starter-kit/internal/platform/web/weberror"
"net/http" "net/http"
"strconv" "strconv"
"strings" "strings"
"geeks-accelerator/oss/saas-starter-kit/internal/platform/auth" "geeks-accelerator/oss/saas-starter-kit/internal/platform/auth"
"geeks-accelerator/oss/saas-starter-kit/internal/platform/web" "geeks-accelerator/oss/saas-starter-kit/internal/platform/web"
"geeks-accelerator/oss/saas-starter-kit/internal/platform/web/webcontext"
"geeks-accelerator/oss/saas-starter-kit/internal/platform/web/weberror"
"geeks-accelerator/oss/saas-starter-kit/internal/project" "geeks-accelerator/oss/saas-starter-kit/internal/project"
"github.com/jmoiron/sqlx" "github.com/jmoiron/sqlx"
"github.com/pkg/errors" "github.com/pkg/errors"
@ -35,7 +35,7 @@ type Project struct {
// @Param order query string false "Order columns separated by comma, example: created_at desc" // @Param order query string false "Order columns separated by comma, example: created_at desc"
// @Param limit query integer false "Limit, example: 10" // @Param limit query integer false "Limit, example: 10"
// @Param offset query integer false "Offset, example: 20" // @Param offset query integer false "Offset, example: 20"
// @Param included-archived query boolean false "Included Archived, example: false" // @Param include-archived query boolean false "Included Archived, example: false"
// @Success 200 {array} project.ProjectResponse // @Success 200 {array} project.ProjectResponse
// @Failure 400 {object} web.ErrorResponse // @Failure 400 {object} web.ErrorResponse
// @Failure 403 {object} web.ErrorResponse // @Failure 403 {object} web.ErrorResponse
@ -92,13 +92,13 @@ func (p *Project) Find(ctx context.Context, w http.ResponseWriter, r *http.Reque
} }
// Handle include-archive query value if set. // Handle include-archive query value if set.
if v := r.URL.Query().Get("included-archived"); v != "" { if v := r.URL.Query().Get("include-archived"); v != "" {
b, err := strconv.ParseBool(v) b, err := strconv.ParseBool(v)
if err != nil { if err != nil {
err = errors.WithMessagef(err, "unable to parse %s as boolean for included-archived param", v) err = errors.WithMessagef(err, "unable to parse %s as boolean for include-archived param", v)
return web.RespondJsonError(ctx, w, weberror.NewError(ctx, err, http.StatusBadRequest)) return web.RespondJsonError(ctx, w, weberror.NewError(ctx, err, http.StatusBadRequest))
} }
req.IncludedArchived = b req.IncludeArchived = b
} }
//if err := web.Decode(r, &req); err != nil { //if err := web.Decode(r, &req); err != nil {
@ -140,18 +140,21 @@ func (p *Project) Read(ctx context.Context, w http.ResponseWriter, r *http.Reque
return errors.New("claims missing from context") return errors.New("claims missing from context")
} }
// Handle included-archived query value if set. // Handle include-archived query value if set.
var includeArchived bool var includeArchived bool
if v := r.URL.Query().Get("included-archived"); v != "" { if v := r.URL.Query().Get("include-archived"); v != "" {
b, err := strconv.ParseBool(v) b, err := strconv.ParseBool(v)
if err != nil { if err != nil {
err = errors.WithMessagef(err, "unable to parse %s as boolean for included-archived param", v) err = errors.WithMessagef(err, "unable to parse %s as boolean for include-archived param", v)
return web.RespondJsonError(ctx, w, weberror.NewError(ctx, err, http.StatusBadRequest)) return web.RespondJsonError(ctx, w, weberror.NewError(ctx, err, http.StatusBadRequest))
} }
includeArchived = b includeArchived = b
} }
res, err := project.Read(ctx, claims, p.MasterDB, params["id"], includeArchived) res, err := project.Read(ctx, claims, p.MasterDB, project.ProjectReadRequest{
ID: params["id"],
IncludeArchived: includeArchived,
})
if err != nil { if err != nil {
cause := errors.Cause(err) cause := errors.Cause(err)
switch cause { switch cause {
@ -337,7 +340,8 @@ func (p *Project) Delete(ctx context.Context, w http.ResponseWriter, r *http.Req
return err return err
} }
err = project.Delete(ctx, claims, p.MasterDB, params["id"]) err = project.Delete(ctx, claims, p.MasterDB,
project.ProjectDeleteRequest{ID: params["id"]})
if err != nil { if err != nil {
cause := errors.Cause(err) cause := errors.Cause(err)
switch cause { switch cause {

View File

@ -1,7 +1,6 @@
package handlers package handlers
import ( import (
"geeks-accelerator/oss/saas-starter-kit/internal/platform/web/webcontext"
"log" "log"
"net/http" "net/http"
"os" "os"
@ -10,6 +9,7 @@ import (
saasSwagger "geeks-accelerator/oss/saas-starter-kit/internal/mid/saas-swagger" saasSwagger "geeks-accelerator/oss/saas-starter-kit/internal/mid/saas-swagger"
"geeks-accelerator/oss/saas-starter-kit/internal/platform/auth" "geeks-accelerator/oss/saas-starter-kit/internal/platform/auth"
"geeks-accelerator/oss/saas-starter-kit/internal/platform/web" "geeks-accelerator/oss/saas-starter-kit/internal/platform/web"
"geeks-accelerator/oss/saas-starter-kit/internal/platform/web/webcontext"
_ "geeks-accelerator/oss/saas-starter-kit/internal/signup" _ "geeks-accelerator/oss/saas-starter-kit/internal/signup"
"github.com/jmoiron/sqlx" "github.com/jmoiron/sqlx"
"gopkg.in/DataDog/dd-trace-go.v1/contrib/go-redis/redis" "gopkg.in/DataDog/dd-trace-go.v1/contrib/go-redis/redis"
@ -20,7 +20,7 @@ func API(shutdown chan os.Signal, log *log.Logger, env webcontext.Env, masterDB
// Define base middlewares applied to all requests. // Define base middlewares applied to all requests.
middlewares := []web.Middleware{ middlewares := []web.Middleware{
mid.Trace(), mid.Logger(log), mid.Errors(log), mid.Metrics(), mid.Panics(), mid.Trace(), mid.Logger(log), mid.Errors(log, nil), mid.Metrics(), mid.Panics(),
} }
// Append any global middlewares if they were included. // Append any global middlewares if they were included.
@ -62,7 +62,7 @@ func API(shutdown chan os.Signal, log *log.Logger, env webcontext.Env, masterDB
} }
app.Handle("GET", "/v1/user_accounts", ua.Find, mid.AuthenticateHeader(authenticator)) app.Handle("GET", "/v1/user_accounts", ua.Find, mid.AuthenticateHeader(authenticator))
app.Handle("POST", "/v1/user_accounts", ua.Create, mid.AuthenticateHeader(authenticator), mid.HasRole(auth.RoleAdmin)) app.Handle("POST", "/v1/user_accounts", ua.Create, mid.AuthenticateHeader(authenticator), mid.HasRole(auth.RoleAdmin))
app.Handle("GET", "/v1/user_accounts/:id", ua.Read, mid.AuthenticateHeader(authenticator)) app.Handle("GET", "/v1/user_accounts/:user_id/:account_id", ua.Read, mid.AuthenticateHeader(authenticator))
app.Handle("PATCH", "/v1/user_accounts", ua.Update, mid.AuthenticateHeader(authenticator)) app.Handle("PATCH", "/v1/user_accounts", ua.Update, mid.AuthenticateHeader(authenticator))
app.Handle("PATCH", "/v1/user_accounts/archive", ua.Archive, mid.AuthenticateHeader(authenticator), mid.HasRole(auth.RoleAdmin)) app.Handle("PATCH", "/v1/user_accounts/archive", ua.Archive, mid.AuthenticateHeader(authenticator), mid.HasRole(auth.RoleAdmin))
app.Handle("DELETE", "/v1/user_accounts", ua.Delete, mid.AuthenticateHeader(authenticator), mid.HasRole(auth.RoleAdmin)) app.Handle("DELETE", "/v1/user_accounts", ua.Delete, mid.AuthenticateHeader(authenticator), mid.HasRole(auth.RoleAdmin))

View File

@ -2,13 +2,13 @@ package handlers
import ( import (
"context" "context"
"geeks-accelerator/oss/saas-starter-kit/internal/platform/web/webcontext"
"geeks-accelerator/oss/saas-starter-kit/internal/platform/web/weberror"
"net/http" "net/http"
"geeks-accelerator/oss/saas-starter-kit/internal/account" "geeks-accelerator/oss/saas-starter-kit/internal/account"
"geeks-accelerator/oss/saas-starter-kit/internal/platform/auth" "geeks-accelerator/oss/saas-starter-kit/internal/platform/auth"
"geeks-accelerator/oss/saas-starter-kit/internal/platform/web" "geeks-accelerator/oss/saas-starter-kit/internal/platform/web"
"geeks-accelerator/oss/saas-starter-kit/internal/platform/web/webcontext"
"geeks-accelerator/oss/saas-starter-kit/internal/platform/web/weberror"
"geeks-accelerator/oss/saas-starter-kit/internal/signup" "geeks-accelerator/oss/saas-starter-kit/internal/signup"
"github.com/jmoiron/sqlx" "github.com/jmoiron/sqlx"
"github.com/pkg/errors" "github.com/pkg/errors"

View File

@ -2,8 +2,6 @@ package handlers
import ( import (
"context" "context"
"geeks-accelerator/oss/saas-starter-kit/internal/platform/web/webcontext"
"geeks-accelerator/oss/saas-starter-kit/internal/platform/web/weberror"
"net/http" "net/http"
"strconv" "strconv"
"strings" "strings"
@ -11,7 +9,10 @@ import (
"geeks-accelerator/oss/saas-starter-kit/internal/platform/auth" "geeks-accelerator/oss/saas-starter-kit/internal/platform/auth"
"geeks-accelerator/oss/saas-starter-kit/internal/platform/web" "geeks-accelerator/oss/saas-starter-kit/internal/platform/web"
"geeks-accelerator/oss/saas-starter-kit/internal/platform/web/webcontext"
"geeks-accelerator/oss/saas-starter-kit/internal/platform/web/weberror"
"geeks-accelerator/oss/saas-starter-kit/internal/user" "geeks-accelerator/oss/saas-starter-kit/internal/user"
"geeks-accelerator/oss/saas-starter-kit/internal/user_auth"
"github.com/jmoiron/sqlx" "github.com/jmoiron/sqlx"
"github.com/pkg/errors" "github.com/pkg/errors"
"gopkg.in/go-playground/validator.v9" "gopkg.in/go-playground/validator.v9"
@ -23,7 +24,7 @@ var sessionTtl = time.Hour * 24
// User represents the User API method handler set. // User represents the User API method handler set.
type User struct { type User struct {
MasterDB *sqlx.DB MasterDB *sqlx.DB
TokenGenerator user.TokenGenerator TokenGenerator user_auth.TokenGenerator
// ADD OTHER STATE LIKE THE LOGGER AND CONFIG HERE. // ADD OTHER STATE LIKE THE LOGGER AND CONFIG HERE.
} }
@ -40,7 +41,7 @@ type User struct {
// @Param order query string false "Order columns separated by comma, example: created_at desc" // @Param order query string false "Order columns separated by comma, example: created_at desc"
// @Param limit query integer false "Limit, example: 10" // @Param limit query integer false "Limit, example: 10"
// @Param offset query integer false "Offset, example: 20" // @Param offset query integer false "Offset, example: 20"
// @Param included-archived query boolean false "Included Archived, example: false" // @Param include-archived query boolean false "Included Archived, example: false"
// @Success 200 {array} user.UserResponse // @Success 200 {array} user.UserResponse
// @Failure 400 {object} web.ErrorResponse // @Failure 400 {object} web.ErrorResponse
// @Failure 500 {object} web.ErrorResponse // @Failure 500 {object} web.ErrorResponse
@ -95,14 +96,14 @@ func (u *User) Find(ctx context.Context, w http.ResponseWriter, r *http.Request,
req.Limit = &ul req.Limit = &ul
} }
// Handle included-archived query value if set. // Handle include-archived query value if set.
if v := r.URL.Query().Get("included-archived"); v != "" { if v := r.URL.Query().Get("include-archived"); v != "" {
b, err := strconv.ParseBool(v) b, err := strconv.ParseBool(v)
if err != nil { if err != nil {
err = errors.WithMessagef(err, "unable to parse %s as boolean for included-archived param", v) err = errors.WithMessagef(err, "unable to parse %s as boolean for include-archived param", v)
return web.RespondJsonError(ctx, w, weberror.NewError(ctx, err, http.StatusBadRequest)) return web.RespondJsonError(ctx, w, weberror.NewError(ctx, err, http.StatusBadRequest))
} }
req.IncludedArchived = b req.IncludeArchived = b
} }
//if err := web.Decode(r, &req); err != nil { //if err := web.Decode(r, &req); err != nil {
@ -144,18 +145,21 @@ func (u *User) Read(ctx context.Context, w http.ResponseWriter, r *http.Request,
return errors.New("claims missing from context") return errors.New("claims missing from context")
} }
// Handle included-archived query value if set. // Handle include-archived query value if set.
var includeArchived bool var includeArchived bool
if v := r.URL.Query().Get("included-archived"); v != "" { if v := r.URL.Query().Get("include-archived"); v != "" {
b, err := strconv.ParseBool(v) b, err := strconv.ParseBool(v)
if err != nil { if err != nil {
err = errors.WithMessagef(err, "unable to parse %s as boolean for included-archived param", v) err = errors.WithMessagef(err, "unable to parse %s as boolean for include-archived param", v)
return web.RespondJsonError(ctx, w, weberror.NewError(ctx, err, http.StatusBadRequest)) return web.RespondJsonError(ctx, w, weberror.NewError(ctx, err, http.StatusBadRequest))
} }
includeArchived = b includeArchived = b
} }
res, err := user.Read(ctx, claims, u.MasterDB, params["id"], includeArchived) res, err := user.Read(ctx, claims, u.MasterDB, user.UserReadRequest{
ID: params["id"],
IncludeArchived: includeArchived,
})
if err != nil { if err != nil {
cause := errors.Cause(err) cause := errors.Cause(err)
switch cause { switch cause {
@ -394,7 +398,8 @@ func (u *User) Delete(ctx context.Context, w http.ResponseWriter, r *http.Reques
return err return err
} }
err = user.Delete(ctx, claims, u.MasterDB, params["id"]) err = user.Delete(ctx, claims, u.MasterDB,
user.UserDeleteRequest{ID: params["id"]})
if err != nil { if err != nil {
cause := errors.Cause(err) cause := errors.Cause(err)
switch cause { switch cause {
@ -437,11 +442,11 @@ func (u *User) SwitchAccount(ctx context.Context, w http.ResponseWriter, r *http
return err return err
} }
tkn, err := user.SwitchAccount(ctx, u.MasterDB, u.TokenGenerator, claims, params["account_id"], sessionTtl, v.Now) tkn, err := user_auth.SwitchAccount(ctx, u.MasterDB, u.TokenGenerator, claims, params["account_id"], sessionTtl, v.Now)
if err != nil { if err != nil {
cause := errors.Cause(err) cause := errors.Cause(err)
switch cause { switch cause {
case user.ErrAuthenticationFailure: case user_auth.ErrAuthenticationFailure:
return web.RespondJsonError(ctx, w, weberror.NewError(ctx, err, http.StatusUnauthorized)) return web.RespondJsonError(ctx, w, weberror.NewError(ctx, err, http.StatusUnauthorized))
default: default:
_, ok := cause.(validator.ValidationErrors) _, ok := cause.(validator.ValidationErrors)
@ -484,11 +489,11 @@ func (u *User) Token(ctx context.Context, w http.ResponseWriter, r *http.Request
// Optional to include scope. // Optional to include scope.
scope := r.URL.Query().Get("scope") scope := r.URL.Query().Get("scope")
tkn, err := user.Authenticate(ctx, u.MasterDB, u.TokenGenerator, email, pass, sessionTtl, v.Now, scope) tkn, err := user_auth.Authenticate(ctx, u.MasterDB, u.TokenGenerator, email, pass, sessionTtl, v.Now, scope)
if err != nil { if err != nil {
cause := errors.Cause(err) cause := errors.Cause(err)
switch cause { switch cause {
case user.ErrAuthenticationFailure: case user_auth.ErrAuthenticationFailure:
return web.RespondJsonError(ctx, w, weberror.NewError(ctx, err, http.StatusUnauthorized)) return web.RespondJsonError(ctx, w, weberror.NewError(ctx, err, http.StatusUnauthorized))
default: default:
_, ok := cause.(validator.ValidationErrors) _, ok := cause.(validator.ValidationErrors)
@ -500,5 +505,30 @@ func (u *User) Token(ctx context.Context, w http.ResponseWriter, r *http.Request
} }
} }
accountID := r.URL.Query().Get("account_id")
if accountID != "" && accountID != tkn.AccountID {
claims, err := u.TokenGenerator.ParseClaims(tkn.AccessToken)
if err != nil {
return err
}
tkn, err = user_auth.SwitchAccount(ctx, u.MasterDB, u.TokenGenerator, claims, accountID, sessionTtl, v.Now)
if err != nil {
cause := errors.Cause(err)
switch cause {
case user_auth.ErrAuthenticationFailure:
return web.RespondJsonError(ctx, w, weberror.NewError(ctx, err, http.StatusUnauthorized))
default:
_, ok := cause.(validator.ValidationErrors)
if ok {
return web.RespondJsonError(ctx, w, weberror.NewError(ctx, err, http.StatusBadRequest))
}
return errors.Wrap(err, "switch account")
}
}
}
return web.RespondJson(ctx, w, tkn, http.StatusOK) return web.RespondJson(ctx, w, tkn, http.StatusOK)
} }

View File

@ -2,6 +2,10 @@ package handlers
import ( import (
"context" "context"
"net/http"
"strconv"
"strings"
"geeks-accelerator/oss/saas-starter-kit/internal/platform/auth" "geeks-accelerator/oss/saas-starter-kit/internal/platform/auth"
"geeks-accelerator/oss/saas-starter-kit/internal/platform/web" "geeks-accelerator/oss/saas-starter-kit/internal/platform/web"
"geeks-accelerator/oss/saas-starter-kit/internal/platform/web/webcontext" "geeks-accelerator/oss/saas-starter-kit/internal/platform/web/webcontext"
@ -10,9 +14,6 @@ import (
"github.com/jmoiron/sqlx" "github.com/jmoiron/sqlx"
"github.com/pkg/errors" "github.com/pkg/errors"
"gopkg.in/go-playground/validator.v9" "gopkg.in/go-playground/validator.v9"
"net/http"
"strconv"
"strings"
) )
// UserAccount represents the UserAccount API method handler set. // UserAccount represents the UserAccount API method handler set.
@ -34,7 +35,7 @@ type UserAccount struct {
// @Param order query string false "Order columns separated by comma, example: created_at desc" // @Param order query string false "Order columns separated by comma, example: created_at desc"
// @Param limit query integer false "Limit, example: 10" // @Param limit query integer false "Limit, example: 10"
// @Param offset query integer false "Offset, example: 20" // @Param offset query integer false "Offset, example: 20"
// @Param included-archived query boolean false "Included Archived, example: false" // @Param include-archived query boolean false "Included Archived, example: false"
// @Success 200 {array} user_account.UserAccountResponse // @Success 200 {array} user_account.UserAccountResponse
// @Failure 400 {object} web.ErrorResponse // @Failure 400 {object} web.ErrorResponse
// @Failure 403 {object} web.ErrorResponse // @Failure 403 {object} web.ErrorResponse
@ -91,13 +92,13 @@ func (u *UserAccount) Find(ctx context.Context, w http.ResponseWriter, r *http.R
} }
// Handle order query value if set. // Handle order query value if set.
if v := r.URL.Query().Get("included-archived"); v != "" { if v := r.URL.Query().Get("include-archived"); v != "" {
b, err := strconv.ParseBool(v) b, err := strconv.ParseBool(v)
if err != nil { if err != nil {
err = errors.WithMessagef(err, "unable to parse %s as boolean for included-archived param", v) err = errors.WithMessagef(err, "unable to parse %s as boolean for include-archived param", v)
return web.RespondJsonError(ctx, w, weberror.NewError(ctx, err, http.StatusBadRequest)) return web.RespondJsonError(ctx, w, weberror.NewError(ctx, err, http.StatusBadRequest))
} }
req.IncludedArchived = b req.IncludeArchived = b
} }
//if err := web.Decode(r, &req); err != nil { //if err := web.Decode(r, &req); err != nil {
@ -132,25 +133,29 @@ func (u *UserAccount) Find(ctx context.Context, w http.ResponseWriter, r *http.R
// @Failure 400 {object} web.ErrorResponse // @Failure 400 {object} web.ErrorResponse
// @Failure 404 {object} web.ErrorResponse // @Failure 404 {object} web.ErrorResponse
// @Failure 500 {object} web.ErrorResponse // @Failure 500 {object} web.ErrorResponse
// @Router /user_accounts/{id} [get] // @Router /user_accounts/{user_id}/{account_id} [get]
func (u *UserAccount) Read(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { func (u *UserAccount) Read(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error {
claims, ok := ctx.Value(auth.Key).(auth.Claims) claims, ok := ctx.Value(auth.Key).(auth.Claims)
if !ok { if !ok {
return errors.New("claims missing from context") return errors.New("claims missing from context")
} }
// Handle included-archived query value if set. // Handle include-archived query value if set.
var includeArchived bool var includeArchived bool
if v := r.URL.Query().Get("included-archived"); v != "" { if v := r.URL.Query().Get("include-archived"); v != "" {
b, err := strconv.ParseBool(v) b, err := strconv.ParseBool(v)
if err != nil { if err != nil {
err = errors.WithMessagef(err, "unable to parse %s as boolean for included-archived param", v) err = errors.WithMessagef(err, "unable to parse %s as boolean for include-archived param", v)
return web.RespondJsonError(ctx, w, weberror.NewError(ctx, err, http.StatusBadRequest)) return web.RespondJsonError(ctx, w, weberror.NewError(ctx, err, http.StatusBadRequest))
} }
includeArchived = b includeArchived = b
} }
res, err := user_account.Read(ctx, claims, u.MasterDB, params["id"], includeArchived) res, err := user_account.Read(ctx, claims, u.MasterDB, user_account.UserAccountReadRequest{
UserID: params["user_id"],
AccountID: params["account_id"],
IncludeArchived: includeArchived,
})
if err != nil { if err != nil {
cause := errors.Cause(err) cause := errors.Cause(err)
switch cause { switch cause {

View File

@ -4,7 +4,6 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"geeks-accelerator/oss/saas-starter-kit/internal/platform/web/weberror"
"net/http" "net/http"
"testing" "testing"
@ -13,6 +12,7 @@ import (
"geeks-accelerator/oss/saas-starter-kit/internal/platform/auth" "geeks-accelerator/oss/saas-starter-kit/internal/platform/auth"
"geeks-accelerator/oss/saas-starter-kit/internal/platform/tests" "geeks-accelerator/oss/saas-starter-kit/internal/platform/tests"
"geeks-accelerator/oss/saas-starter-kit/internal/platform/web" "geeks-accelerator/oss/saas-starter-kit/internal/platform/web"
"geeks-accelerator/oss/saas-starter-kit/internal/platform/web/weberror"
"github.com/pborman/uuid" "github.com/pborman/uuid"
) )
@ -152,7 +152,10 @@ func TestAccountCRUDAdmin(t *testing.T) {
} }
expected := weberror.ErrorResponse{ expected := weberror.ErrorResponse{
Error: fmt.Sprintf("account %s not found: Entity not found", randID), StatusCode: expectedStatus,
Error: http.StatusText(expectedStatus),
Details: fmt.Sprintf("account %s not found: Entity not found", randID),
StackTrace: actual.StackTrace,
} }
if diff := cmpDiff(t, expected, actual); diff { if diff := cmpDiff(t, expected, actual); diff {
@ -190,7 +193,10 @@ func TestAccountCRUDAdmin(t *testing.T) {
} }
expected := weberror.ErrorResponse{ expected := weberror.ErrorResponse{
Error: fmt.Sprintf("account %s not found: Entity not found", tr.ForbiddenAccount.ID), StatusCode: expectedStatus,
Error: http.StatusText(expectedStatus),
Details: fmt.Sprintf("account %s not found: Entity not found", tr.ForbiddenAccount.ID),
StackTrace: actual.StackTrace,
} }
if diff := cmpDiff(t, expected, actual); diff { if diff := cmpDiff(t, expected, actual); diff {
@ -366,7 +372,10 @@ func TestAccountCRUDUser(t *testing.T) {
} }
expected := weberror.ErrorResponse{ expected := weberror.ErrorResponse{
Error: fmt.Sprintf("account %s not found: Entity not found", randID), StatusCode: expectedStatus,
Error: http.StatusText(expectedStatus),
Details: fmt.Sprintf("account %s not found: Entity not found", randID),
StackTrace: actual.StackTrace,
} }
if diff := cmpDiff(t, expected, actual); diff { if diff := cmpDiff(t, expected, actual); diff {
@ -404,7 +413,10 @@ func TestAccountCRUDUser(t *testing.T) {
} }
expected := weberror.ErrorResponse{ expected := weberror.ErrorResponse{
Error: fmt.Sprintf("account %s not found: Entity not found", tr.ForbiddenAccount.ID), StatusCode: expectedStatus,
Error: http.StatusText(expectedStatus),
Details: fmt.Sprintf("account %s not found: Entity not found", tr.ForbiddenAccount.ID),
StackTrace: actual.StackTrace,
} }
if diff := cmpDiff(t, expected, actual); diff { if diff := cmpDiff(t, expected, actual); diff {
@ -445,7 +457,8 @@ func TestAccountCRUDUser(t *testing.T) {
t.Fatalf("\t%s\tDecode response body failed.", tests.Failed) t.Fatalf("\t%s\tDecode response body failed.", tests.Failed)
} }
expected := mid.ErrorForbidden(ctx).(*weberror.Error).Display(ctx) expected := mid.ErrorForbidden(ctx).(*weberror.Error).Response(ctx, false)
expected.StackTrace = actual.StackTrace
if diff := cmpDiff(t, expected, actual); diff { if diff := cmpDiff(t, expected, actual); diff {
t.Fatalf("\t%s\tReceived expected error.", tests.Failed) t.Fatalf("\t%s\tReceived expected error.", tests.Failed)
} }
@ -495,7 +508,8 @@ func TestAccountUpdate(t *testing.T) {
} }
expected := weberror.ErrorResponse{ expected := weberror.ErrorResponse{
Error: "Field validation error", StatusCode: http.StatusBadRequest,
Error: "Field validation error",
Fields: []weberror.FieldError{ Fields: []weberror.FieldError{
//{Field: "status", Error: "Key: 'AccountUpdateRequest.status' Error:Field validation for 'status' failed on the 'oneof' tag"}, //{Field: "status", Error: "Key: 'AccountUpdateRequest.status' Error:Field validation for 'status' failed on the 'oneof' tag"},
{ {
@ -506,6 +520,8 @@ func TestAccountUpdate(t *testing.T) {
Display: "status must be one of [active pending disabled]", Display: "status must be one of [active pending disabled]",
}, },
}, },
Details: actual.Details,
StackTrace: actual.StackTrace,
} }
if diff := cmpDiff(t, expected, actual); diff { if diff := cmpDiff(t, expected, actual); diff {

View File

@ -164,7 +164,10 @@ func TestProjectCRUDAdmin(t *testing.T) {
} }
expected := weberror.ErrorResponse{ expected := weberror.ErrorResponse{
Error: fmt.Sprintf("project %s not found: Entity not found", randID), StatusCode: expectedStatus,
Error: http.StatusText(expectedStatus),
Details: fmt.Sprintf("project %s not found: Entity not found", randID),
StackTrace: actual.StackTrace,
} }
if diff := cmpDiff(t, expected, actual); diff { if diff := cmpDiff(t, expected, actual); diff {
@ -203,7 +206,10 @@ func TestProjectCRUDAdmin(t *testing.T) {
} }
expected := weberror.ErrorResponse{ expected := weberror.ErrorResponse{
Error: fmt.Sprintf("project %s not found: Entity not found", forbiddenProject.ID), StatusCode: expectedStatus,
Error: http.StatusText(expectedStatus),
Details: fmt.Sprintf("project %s not found: Entity not found", forbiddenProject.ID),
StackTrace: actual.StackTrace,
} }
if diff := cmpDiff(t, expected, actual); diff { if diff := cmpDiff(t, expected, actual); diff {
@ -347,7 +353,8 @@ func TestProjectCRUDUser(t *testing.T) {
t.Fatalf("\t%s\tDecode response body failed.", tests.Failed) t.Fatalf("\t%s\tDecode response body failed.", tests.Failed)
} }
expected := mid.ErrorForbidden(ctx).(*weberror.Error).Display(ctx) expected := mid.ErrorForbidden(ctx).(*weberror.Error).Response(ctx, false)
expected.StackTrace = actual.StackTrace
if diff := cmpDiff(t, expected, actual); diff { if diff := cmpDiff(t, expected, actual); diff {
t.Fatalf("\t%s\tReceived expected error.", tests.Failed) t.Fatalf("\t%s\tReceived expected error.", tests.Failed)
@ -422,7 +429,10 @@ func TestProjectCRUDUser(t *testing.T) {
} }
expected := weberror.ErrorResponse{ expected := weberror.ErrorResponse{
Error: fmt.Sprintf("project %s not found: Entity not found", randID), StatusCode: expectedStatus,
Error: http.StatusText(expectedStatus),
Details: fmt.Sprintf("project %s not found: Entity not found", randID),
StackTrace: actual.StackTrace,
} }
if diff := cmpDiff(t, expected, actual); diff { if diff := cmpDiff(t, expected, actual); diff {
@ -461,7 +471,10 @@ func TestProjectCRUDUser(t *testing.T) {
} }
expected := weberror.ErrorResponse{ expected := weberror.ErrorResponse{
Error: fmt.Sprintf("project %s not found: Entity not found", forbiddenProject.ID), StatusCode: expectedStatus,
Error: http.StatusText(expectedStatus),
Details: fmt.Sprintf("project %s not found: Entity not found", forbiddenProject.ID),
StackTrace: actual.StackTrace,
} }
if diff := cmpDiff(t, expected, actual); diff { if diff := cmpDiff(t, expected, actual); diff {
@ -502,7 +515,8 @@ func TestProjectCRUDUser(t *testing.T) {
t.Fatalf("\t%s\tDecode response body failed.", tests.Failed) t.Fatalf("\t%s\tDecode response body failed.", tests.Failed)
} }
expected := mid.ErrorForbidden(ctx).(*weberror.Error).Display(ctx) expected := mid.ErrorForbidden(ctx).(*weberror.Error).Response(ctx, false)
expected.StackTrace = actual.StackTrace
if diff := cmpDiff(t, expected, actual); diff { if diff := cmpDiff(t, expected, actual); diff {
t.Fatalf("\t%s\tReceived expected error.", tests.Failed) t.Fatalf("\t%s\tReceived expected error.", tests.Failed)
@ -540,7 +554,8 @@ func TestProjectCRUDUser(t *testing.T) {
t.Fatalf("\t%s\tDecode response body failed.", tests.Failed) t.Fatalf("\t%s\tDecode response body failed.", tests.Failed)
} }
expected := mid.ErrorForbidden(ctx).(*weberror.Error).Display(ctx) expected := mid.ErrorForbidden(ctx).(*weberror.Error).Response(ctx, false)
expected.StackTrace = actual.StackTrace
if diff := cmpDiff(t, expected, actual); diff { if diff := cmpDiff(t, expected, actual); diff {
t.Fatalf("\t%s\tReceived expected error.", tests.Failed) t.Fatalf("\t%s\tReceived expected error.", tests.Failed)
@ -576,7 +591,8 @@ func TestProjectCRUDUser(t *testing.T) {
t.Fatalf("\t%s\tDecode response body failed.", tests.Failed) t.Fatalf("\t%s\tDecode response body failed.", tests.Failed)
} }
expected := mid.ErrorForbidden(ctx).(*weberror.Error).Display(ctx) expected := mid.ErrorForbidden(ctx).(*weberror.Error).Response(ctx, false)
expected.StackTrace = actual.StackTrace
if diff := cmpDiff(t, expected, actual); diff { if diff := cmpDiff(t, expected, actual); diff {
t.Fatalf("\t%s\tReceived expected error.", tests.Failed) t.Fatalf("\t%s\tReceived expected error.", tests.Failed)
@ -626,7 +642,9 @@ func TestProjectCreate(t *testing.T) {
} }
expected := weberror.ErrorResponse{ expected := weberror.ErrorResponse{
Error: "Field validation error", StatusCode: expectedStatus,
Details: actual.Details,
Error: "Field validation error",
Fields: []weberror.FieldError{ Fields: []weberror.FieldError{
//{Field: "status", Error: "Key: 'ProjectCreateRequest.status' Error:Field validation for 'status' failed on the 'oneof' tag"}, //{Field: "status", Error: "Key: 'ProjectCreateRequest.status' Error:Field validation for 'status' failed on the 'oneof' tag"},
{ {
@ -637,6 +655,7 @@ func TestProjectCreate(t *testing.T) {
Display: "status must be one of [active disabled]", Display: "status must be one of [active disabled]",
}, },
}, },
StackTrace: actual.StackTrace,
} }
if diff := cmpDiff(t, expected, actual); diff { if diff := cmpDiff(t, expected, actual); diff {
@ -688,7 +707,9 @@ func TestProjectUpdate(t *testing.T) {
} }
expected := weberror.ErrorResponse{ expected := weberror.ErrorResponse{
Error: "Field validation error", StatusCode: expectedStatus,
Details: actual.Details,
Error: "Field validation error",
Fields: []weberror.FieldError{ Fields: []weberror.FieldError{
//{Field: "status", Error: "Key: 'ProjectUpdateRequest.status' Error:Field validation for 'status' failed on the 'oneof' tag"}, //{Field: "status", Error: "Key: 'ProjectUpdateRequest.status' Error:Field validation for 'status' failed on the 'oneof' tag"},
{ {
@ -699,6 +720,7 @@ func TestProjectUpdate(t *testing.T) {
Display: "status must be one of [active disabled]", Display: "status must be one of [active disabled]",
}, },
}, },
StackTrace: actual.StackTrace,
} }
if diff := cmpDiff(t, expected, actual); diff { if diff := cmpDiff(t, expected, actual); diff {
@ -752,7 +774,9 @@ func TestProjectArchive(t *testing.T) {
} }
expected := weberror.ErrorResponse{ expected := weberror.ErrorResponse{
Error: "Field validation error", StatusCode: expectedStatus,
Details: actual.Details,
Error: "Field validation error",
Fields: []weberror.FieldError{ Fields: []weberror.FieldError{
//{Field: "id", Error: "Key: 'ProjectArchiveRequest.id' Error:Field validation for 'id' failed on the 'uuid' tag"}, //{Field: "id", Error: "Key: 'ProjectArchiveRequest.id' Error:Field validation for 'id' failed on the 'uuid' tag"},
{ {
@ -763,6 +787,7 @@ func TestProjectArchive(t *testing.T) {
Display: "id must be a valid UUID", Display: "id must be a valid UUID",
}, },
}, },
StackTrace: actual.StackTrace,
} }
if diff := cmpDiff(t, expected, actual); diff { if diff := cmpDiff(t, expected, actual); diff {
@ -802,7 +827,10 @@ func TestProjectArchive(t *testing.T) {
} }
expected := weberror.ErrorResponse{ expected := weberror.ErrorResponse{
Error: project.ErrForbidden.Error(), StatusCode: expectedStatus,
Error: http.StatusText(expectedStatus),
Details: project.ErrForbidden.Error(),
StackTrace: actual.StackTrace,
} }
if diff := cmpDiff(t, expected, actual); diff { if diff := cmpDiff(t, expected, actual); diff {
@ -854,7 +882,9 @@ func TestProjectDelete(t *testing.T) {
} }
expected := weberror.ErrorResponse{ expected := weberror.ErrorResponse{
Error: "Field validation error", StatusCode: expectedStatus,
Details: actual.Details,
Error: "Field validation error",
Fields: []weberror.FieldError{ Fields: []weberror.FieldError{
//{Field: "id", Error: "Key: 'id' Error:Field validation for 'id' failed on the 'uuid' tag"}, //{Field: "id", Error: "Key: 'id' Error:Field validation for 'id' failed on the 'uuid' tag"},
{ {
@ -865,6 +895,7 @@ func TestProjectDelete(t *testing.T) {
Display: "id must be a valid UUID", Display: "id must be a valid UUID",
}, },
}, },
StackTrace: actual.StackTrace,
} }
if diff := cmpDiff(t, expected, actual); diff { if diff := cmpDiff(t, expected, actual); diff {
@ -902,7 +933,10 @@ func TestProjectDelete(t *testing.T) {
} }
expected := weberror.ErrorResponse{ expected := weberror.ErrorResponse{
Error: project.ErrForbidden.Error(), StatusCode: expectedStatus,
Error: http.StatusText(expectedStatus),
Details: project.ErrForbidden.Error(),
StackTrace: actual.StackTrace,
} }
if diff := cmpDiff(t, expected, actual); diff { if diff := cmpDiff(t, expected, actual); diff {

View File

@ -14,14 +14,14 @@ import (
"geeks-accelerator/oss/saas-starter-kit/internal/platform/web" "geeks-accelerator/oss/saas-starter-kit/internal/platform/web"
"geeks-accelerator/oss/saas-starter-kit/internal/platform/web/weberror" "geeks-accelerator/oss/saas-starter-kit/internal/platform/web/weberror"
"geeks-accelerator/oss/saas-starter-kit/internal/signup" "geeks-accelerator/oss/saas-starter-kit/internal/signup"
"geeks-accelerator/oss/saas-starter-kit/internal/user" "geeks-accelerator/oss/saas-starter-kit/internal/user_auth"
"github.com/pborman/uuid" "github.com/pborman/uuid"
) )
type mockSignup struct { type mockSignup struct {
account *account.Account account *account.Account
user mockUser user mockUser
token user.Token token user_auth.Token
claims auth.Claims claims auth.Claims
context context.Context context context.Context
} }
@ -56,7 +56,7 @@ func newMockSignup() mockSignup {
} }
expires := time.Now().UTC().Sub(s.User.CreatedAt) + time.Hour expires := time.Now().UTC().Sub(s.User.CreatedAt) + time.Hour
tkn, err := user.Authenticate(tests.Context(), test.MasterDB, authenticator, req.User.Email, req.User.Password, expires, now) tkn, err := user_auth.Authenticate(tests.Context(), test.MasterDB, authenticator, req.User.Email, req.User.Password, expires, now)
if err != nil { if err != nil {
panic(err) panic(err)
} }
@ -94,7 +94,7 @@ func TestSignup(t *testing.T) {
http.MethodPost, http.MethodPost,
"/v1/signup", "/v1/signup",
req, req,
user.Token{}, user_auth.Token{},
auth.Claims{}, auth.Claims{},
expectedStatus, expectedStatus,
nil, nil,
@ -116,12 +116,14 @@ func TestSignup(t *testing.T) {
expectedMap := map[string]interface{}{ expectedMap := map[string]interface{}{
"user": map[string]interface{}{ "user": map[string]interface{}{
"id": actual.User.ID, "id": actual.User.ID,
"name": req.User.FirstName + " " + req.User.LastName,
"first_name": req.User.FirstName, "first_name": req.User.FirstName,
"last_name": req.User.LastName, "last_name": req.User.LastName,
"email": req.User.Email, "email": req.User.Email,
"timezone": actual.User.Timezone, "timezone": actual.User.Timezone,
"created_at": web.NewTimeResponse(ctx, actual.User.CreatedAt.Value), "created_at": web.NewTimeResponse(ctx, actual.User.CreatedAt.Value),
"updated_at": web.NewTimeResponse(ctx, actual.User.UpdatedAt.Value), "updated_at": web.NewTimeResponse(ctx, actual.User.UpdatedAt.Value),
"gravatar": web.NewGravatarResponse(ctx, actual.User.Email),
}, },
"account": map[string]interface{}{ "account": map[string]interface{}{
"updated_at": web.NewTimeResponse(ctx, actual.Account.UpdatedAt.Value), "updated_at": web.NewTimeResponse(ctx, actual.Account.UpdatedAt.Value),
@ -170,7 +172,7 @@ func TestSignup(t *testing.T) {
http.MethodPost, http.MethodPost,
"/v1/signup", "/v1/signup",
nil, nil,
user.Token{}, user_auth.Token{},
auth.Claims{}, auth.Claims{},
expectedStatus, expectedStatus,
nil, nil,
@ -190,7 +192,10 @@ func TestSignup(t *testing.T) {
} }
expected := weberror.ErrorResponse{ expected := weberror.ErrorResponse{
Error: "decode request body failed", StatusCode: expectedStatus,
Error: "decode request body failed",
Details: "EOF",
StackTrace: actual.StackTrace,
} }
if diff := cmpDiff(t, expected, actual); diff { if diff := cmpDiff(t, expected, actual); diff {
@ -211,7 +216,7 @@ func TestSignup(t *testing.T) {
http.MethodPost, http.MethodPost,
"/v1/signup", "/v1/signup",
req, req,
user.Token{}, user_auth.Token{},
auth.Claims{}, auth.Claims{},
expectedStatus, expectedStatus,
nil, nil,
@ -231,7 +236,9 @@ func TestSignup(t *testing.T) {
} }
expected := weberror.ErrorResponse{ expected := weberror.ErrorResponse{
Error: "Field validation error", StatusCode: expectedStatus,
Details: actual.Details,
Error: "Field validation error",
Fields: []weberror.FieldError{ Fields: []weberror.FieldError{
//{Field: "name", Error: "Key: 'SignupRequest.account.name' Error:Field validation for 'name' failed on the 'required' tag"}, //{Field: "name", Error: "Key: 'SignupRequest.account.name' Error:Field validation for 'name' failed on the 'required' tag"},
//{Field: "email", Error: "Key: 'SignupRequest.user.email' Error:Field validation for 'email' failed on the 'required' tag"}, //{Field: "email", Error: "Key: 'SignupRequest.user.email' Error:Field validation for 'email' failed on the 'required' tag"},
@ -251,6 +258,7 @@ func TestSignup(t *testing.T) {
Display: "email is a required field", Display: "email is a required field",
}, },
}, },
StackTrace: actual.StackTrace,
} }
if diff := cmpDiff(t, expected, actual); diff { if diff := cmpDiff(t, expected, actual); diff {

View File

@ -5,7 +5,6 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"geeks-accelerator/oss/saas-starter-kit/internal/platform/web/webcontext"
"io" "io"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
@ -20,10 +19,12 @@ import (
"geeks-accelerator/oss/saas-starter-kit/internal/platform/auth" "geeks-accelerator/oss/saas-starter-kit/internal/platform/auth"
"geeks-accelerator/oss/saas-starter-kit/internal/platform/tests" "geeks-accelerator/oss/saas-starter-kit/internal/platform/tests"
"geeks-accelerator/oss/saas-starter-kit/internal/platform/web" "geeks-accelerator/oss/saas-starter-kit/internal/platform/web"
"geeks-accelerator/oss/saas-starter-kit/internal/platform/web/webcontext"
"geeks-accelerator/oss/saas-starter-kit/internal/platform/web/weberror" "geeks-accelerator/oss/saas-starter-kit/internal/platform/web/weberror"
"geeks-accelerator/oss/saas-starter-kit/internal/signup" "geeks-accelerator/oss/saas-starter-kit/internal/signup"
"geeks-accelerator/oss/saas-starter-kit/internal/user" "geeks-accelerator/oss/saas-starter-kit/internal/user"
"geeks-accelerator/oss/saas-starter-kit/internal/user_account" "geeks-accelerator/oss/saas-starter-kit/internal/user_account"
"geeks-accelerator/oss/saas-starter-kit/internal/user_auth"
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
"github.com/iancoleman/strcase" "github.com/iancoleman/strcase"
"github.com/pborman/uuid" "github.com/pborman/uuid"
@ -37,7 +38,7 @@ var authenticator *auth.Authenticator
// Information about the users we have created for testing. // Information about the users we have created for testing.
type roleTest struct { type roleTest struct {
Role string Role string
Token user.Token Token user_auth.Token
Claims auth.Claims Claims auth.Claims
User mockUser User mockUser
Account *account.Account Account *account.Account
@ -50,7 +51,7 @@ type requestTest struct {
method string method string
url string url string
request interface{} request interface{}
token user.Token token user_auth.Token
claims auth.Claims claims auth.Claims
statusCode int statusCode int
error interface{} error interface{}
@ -94,7 +95,7 @@ func testMain(m *testing.M) int {
} }
expires := time.Now().UTC().Sub(signup1.User.CreatedAt) + time.Hour expires := time.Now().UTC().Sub(signup1.User.CreatedAt) + time.Hour
adminTkn, err := user.Authenticate(tests.Context(), test.MasterDB, authenticator, signupReq1.User.Email, signupReq1.User.Password, expires, now) adminTkn, err := user_auth.Authenticate(tests.Context(), test.MasterDB, authenticator, signupReq1.User.Email, signupReq1.User.Password, expires, now)
if err != nil { if err != nil {
panic(err) panic(err)
} }
@ -145,7 +146,7 @@ func testMain(m *testing.M) int {
panic(err) panic(err)
} }
userTkn, err := user.Authenticate(tests.Context(), test.MasterDB, authenticator, usr.Email, userReq.Password, expires, now) userTkn, err := user_auth.Authenticate(tests.Context(), test.MasterDB, authenticator, usr.Email, userReq.Password, expires, now)
if err != nil { if err != nil {
panic(err) panic(err)
} }

View File

@ -91,7 +91,7 @@ func TestUserAccountCRUDAdmin(t *testing.T) {
expectedMap := map[string]interface{}{ expectedMap := map[string]interface{}{
"updated_at": web.NewTimeResponse(ctx, actual.UpdatedAt.Value), "updated_at": web.NewTimeResponse(ctx, actual.UpdatedAt.Value),
"id": actual.ID, //"id": actual.ID,
"account_id": req.AccountID, "account_id": req.AccountID,
"user_id": req.UserID, "user_id": req.UserID,
"status": web.NewEnumResponse(ctx, "active", user_account.UserAccountStatus_Values), "status": web.NewEnumResponse(ctx, "active", user_account.UserAccountStatus_Values),
@ -122,7 +122,7 @@ func TestUserAccountCRUDAdmin(t *testing.T) {
rt := requestTest{ rt := requestTest{
fmt.Sprintf("Read %d w/role %s", expectedStatus, tr.Role), fmt.Sprintf("Read %d w/role %s", expectedStatus, tr.Role),
http.MethodGet, http.MethodGet,
fmt.Sprintf("/v1/user_accounts/%s", created.ID), fmt.Sprintf("/v1/user_accounts/%s/%s", created.UserID, created.AccountID),
nil, nil,
tr.Token, tr.Token,
tr.Claims, tr.Claims,
@ -157,7 +157,7 @@ func TestUserAccountCRUDAdmin(t *testing.T) {
rt := requestTest{ rt := requestTest{
fmt.Sprintf("Read %d w/role %s using random ID", expectedStatus, tr.Role), fmt.Sprintf("Read %d w/role %s using random ID", expectedStatus, tr.Role),
http.MethodGet, http.MethodGet,
fmt.Sprintf("/v1/user_accounts/%s", randID), fmt.Sprintf("/v1/user_accounts/%s/%s", randID, randID),
nil, nil,
tr.Token, tr.Token,
tr.Claims, tr.Claims,
@ -179,7 +179,10 @@ func TestUserAccountCRUDAdmin(t *testing.T) {
} }
expected := weberror.ErrorResponse{ expected := weberror.ErrorResponse{
Error: fmt.Sprintf("user account %s not found: Entity not found", randID), StatusCode: expectedStatus,
Error: http.StatusText(expectedStatus),
Details: fmt.Sprintf("entry for user %s account %s not found: Entity not found", randID, randID),
StackTrace: actual.StackTrace,
} }
if diff := cmpDiff(t, expected, actual); diff { if diff := cmpDiff(t, expected, actual); diff {
@ -196,7 +199,7 @@ func TestUserAccountCRUDAdmin(t *testing.T) {
rt := requestTest{ rt := requestTest{
fmt.Sprintf("Read %d w/role %s using forbidden ID", expectedStatus, tr.Role), fmt.Sprintf("Read %d w/role %s using forbidden ID", expectedStatus, tr.Role),
http.MethodGet, http.MethodGet,
fmt.Sprintf("/v1/user_accounts/%s", forbiddenUserAccount.ID), fmt.Sprintf("/v1/user_accounts/%s/%s", forbiddenUserAccount.UserID, forbiddenUserAccount.AccountID),
nil, nil,
tr.Token, tr.Token,
tr.Claims, tr.Claims,
@ -218,7 +221,10 @@ func TestUserAccountCRUDAdmin(t *testing.T) {
} }
expected := weberror.ErrorResponse{ expected := weberror.ErrorResponse{
Error: fmt.Sprintf("user account %s not found: Entity not found", forbiddenUserAccount.ID), StatusCode: expectedStatus,
Error: http.StatusText(expectedStatus),
Details: fmt.Sprintf("entry for user %s account %s not found: Entity not found", forbiddenUserAccount.UserID, forbiddenUserAccount.AccountID),
StackTrace: actual.StackTrace,
} }
if diff := cmpDiff(t, expected, actual); diff { if diff := cmpDiff(t, expected, actual); diff {
@ -370,7 +376,8 @@ func TestUserAccountCRUDUser(t *testing.T) {
t.Fatalf("\t%s\tDecode response body failed.", tests.Failed) t.Fatalf("\t%s\tDecode response body failed.", tests.Failed)
} }
expected := mid.ErrorForbidden(ctx).(*weberror.Error).Display(ctx) expected := mid.ErrorForbidden(ctx).(*weberror.Error).Response(ctx, false)
expected.StackTrace = actual.StackTrace
if diff := cmpDiff(t, expected, actual); diff { if diff := cmpDiff(t, expected, actual); diff {
t.Fatalf("\t%s\tReceived expected error.", tests.Failed) t.Fatalf("\t%s\tReceived expected error.", tests.Failed)
@ -388,7 +395,7 @@ func TestUserAccountCRUDUser(t *testing.T) {
rt := requestTest{ rt := requestTest{
fmt.Sprintf("Read %d w/role %s", expectedStatus, tr.Role), fmt.Sprintf("Read %d w/role %s", expectedStatus, tr.Role),
http.MethodGet, http.MethodGet,
fmt.Sprintf("/v1/user_accounts/%s", created.ID), fmt.Sprintf("/v1/user_accounts/%s/%s", created.UserID, created.AccountID),
nil, nil,
tr.Token, tr.Token,
tr.Claims, tr.Claims,
@ -423,7 +430,7 @@ func TestUserAccountCRUDUser(t *testing.T) {
rt := requestTest{ rt := requestTest{
fmt.Sprintf("Read %d w/role %s using random ID", expectedStatus, tr.Role), fmt.Sprintf("Read %d w/role %s using random ID", expectedStatus, tr.Role),
http.MethodGet, http.MethodGet,
fmt.Sprintf("/v1/user_accounts/%s", randID), fmt.Sprintf("/v1/user_accounts/%s/%s", randID, randID),
nil, nil,
tr.Token, tr.Token,
tr.Claims, tr.Claims,
@ -445,7 +452,10 @@ func TestUserAccountCRUDUser(t *testing.T) {
} }
expected := weberror.ErrorResponse{ expected := weberror.ErrorResponse{
Error: fmt.Sprintf("user account %s not found: Entity not found", randID), StatusCode: expectedStatus,
Error: http.StatusText(expectedStatus),
Details: fmt.Sprintf("entry for user %s account %s not found: Entity not found", randID, randID),
StackTrace: actual.StackTrace,
} }
if diff := cmpDiff(t, expected, actual); diff { if diff := cmpDiff(t, expected, actual); diff {
@ -462,7 +472,7 @@ func TestUserAccountCRUDUser(t *testing.T) {
rt := requestTest{ rt := requestTest{
fmt.Sprintf("Read %d w/role %s using forbidden ID", expectedStatus, tr.Role), fmt.Sprintf("Read %d w/role %s using forbidden ID", expectedStatus, tr.Role),
http.MethodGet, http.MethodGet,
fmt.Sprintf("/v1/user_accounts/%s", forbiddenUserAccount.ID), fmt.Sprintf("/v1/user_accounts/%s/%s", forbiddenUserAccount.UserID, forbiddenUserAccount.AccountID),
nil, nil,
tr.Token, tr.Token,
tr.Claims, tr.Claims,
@ -484,7 +494,10 @@ func TestUserAccountCRUDUser(t *testing.T) {
} }
expected := weberror.ErrorResponse{ expected := weberror.ErrorResponse{
Error: fmt.Sprintf("user account %s not found: Entity not found", forbiddenUserAccount.ID), StatusCode: expectedStatus,
Error: http.StatusText(expectedStatus),
Details: fmt.Sprintf("entry for user %s account %s not found: Entity not found", forbiddenUserAccount.UserID, forbiddenUserAccount.AccountID),
StackTrace: actual.StackTrace,
} }
if diff := cmpDiff(t, expected, actual); diff { if diff := cmpDiff(t, expected, actual); diff {
@ -527,7 +540,10 @@ func TestUserAccountCRUDUser(t *testing.T) {
} }
expected := weberror.ErrorResponse{ expected := weberror.ErrorResponse{
Error: account.ErrForbidden.Error(), StatusCode: expectedStatus,
Error: http.StatusText(expectedStatus),
Details: account.ErrForbidden.Error(),
StackTrace: actual.StackTrace,
} }
if diff := cmpDiff(t, expected, actual); diff { if diff := cmpDiff(t, expected, actual); diff {
@ -567,7 +583,8 @@ func TestUserAccountCRUDUser(t *testing.T) {
t.Fatalf("\t%s\tDecode response body failed.", tests.Failed) t.Fatalf("\t%s\tDecode response body failed.", tests.Failed)
} }
expected := mid.ErrorForbidden(ctx).(*weberror.Error).Display(ctx) expected := mid.ErrorForbidden(ctx).(*weberror.Error).Response(ctx, false)
expected.StackTrace = actual.StackTrace
if diff := cmpDiff(t, expected, actual); diff { if diff := cmpDiff(t, expected, actual); diff {
t.Fatalf("\t%s\tReceived expected error.", tests.Failed) t.Fatalf("\t%s\tReceived expected error.", tests.Failed)
@ -606,7 +623,8 @@ func TestUserAccountCRUDUser(t *testing.T) {
t.Fatalf("\t%s\tDecode response body failed.", tests.Failed) t.Fatalf("\t%s\tDecode response body failed.", tests.Failed)
} }
expected := mid.ErrorForbidden(ctx).(*weberror.Error).Display(ctx) expected := mid.ErrorForbidden(ctx).(*weberror.Error).Response(ctx, false)
expected.StackTrace = actual.StackTrace
if diff := cmpDiff(t, expected, actual); diff { if diff := cmpDiff(t, expected, actual); diff {
t.Fatalf("\t%s\tReceived expected error.", tests.Failed) t.Fatalf("\t%s\tReceived expected error.", tests.Failed)
@ -659,7 +677,8 @@ func TestUserAccountCreate(t *testing.T) {
} }
expected := weberror.ErrorResponse{ expected := weberror.ErrorResponse{
Error: "Field validation error", StatusCode: expectedStatus,
Error: "Field validation error",
Fields: []weberror.FieldError{ Fields: []weberror.FieldError{
//{Field: "status", Error: "Key: 'UserAccountCreateRequest.status' Error:Field validation for 'status' failed on the 'oneof' tag"}, //{Field: "status", Error: "Key: 'UserAccountCreateRequest.status' Error:Field validation for 'status' failed on the 'oneof' tag"},
{ {
@ -670,6 +689,8 @@ func TestUserAccountCreate(t *testing.T) {
Display: "status must be one of [active invited disabled]", Display: "status must be one of [active invited disabled]",
}, },
}, },
Details: actual.Details,
StackTrace: actual.StackTrace,
} }
if diff := cmpDiff(t, expected, actual); diff { if diff := cmpDiff(t, expected, actual); diff {
@ -722,7 +743,8 @@ func TestUserAccountUpdate(t *testing.T) {
} }
expected := weberror.ErrorResponse{ expected := weberror.ErrorResponse{
Error: "Field validation error", StatusCode: expectedStatus,
Error: "Field validation error",
Fields: []weberror.FieldError{ Fields: []weberror.FieldError{
//{Field: "status", Error: "Key: 'UserAccountUpdateRequest.status' Error:Field validation for 'status' failed on the 'oneof' tag"}, //{Field: "status", Error: "Key: 'UserAccountUpdateRequest.status' Error:Field validation for 'status' failed on the 'oneof' tag"},
{ {
@ -733,6 +755,8 @@ func TestUserAccountUpdate(t *testing.T) {
Display: "status must be one of [active invited disabled]", Display: "status must be one of [active invited disabled]",
}, },
}, },
Details: actual.Details,
StackTrace: actual.StackTrace,
} }
if diff := cmpDiff(t, expected, actual); diff { if diff := cmpDiff(t, expected, actual); diff {
@ -783,7 +807,8 @@ func TestUserAccountArchive(t *testing.T) {
} }
expected := weberror.ErrorResponse{ expected := weberror.ErrorResponse{
Error: "Field validation error", StatusCode: expectedStatus,
Error: "Field validation error",
Fields: []weberror.FieldError{ Fields: []weberror.FieldError{
//{Field: "user_id", Error: "Key: 'UserAccountArchiveRequest.user_id' Error:Field validation for 'user_id' failed on the 'uuid' tag"}, //{Field: "user_id", Error: "Key: 'UserAccountArchiveRequest.user_id' Error:Field validation for 'user_id' failed on the 'uuid' tag"},
//{Field: "account_id", Error: "Key: 'UserAccountArchiveRequest.account_id' Error:Field validation for 'account_id' failed on the 'uuid' tag"}, //{Field: "account_id", Error: "Key: 'UserAccountArchiveRequest.account_id' Error:Field validation for 'account_id' failed on the 'uuid' tag"},
@ -802,6 +827,8 @@ func TestUserAccountArchive(t *testing.T) {
Display: "account_id must be a valid UUID", Display: "account_id must be a valid UUID",
}, },
}, },
Details: actual.Details,
StackTrace: actual.StackTrace,
} }
if diff := cmpDiff(t, expected, actual); diff { if diff := cmpDiff(t, expected, actual); diff {
@ -843,7 +870,10 @@ func TestUserAccountArchive(t *testing.T) {
} }
expected := weberror.ErrorResponse{ expected := weberror.ErrorResponse{
Error: user_account.ErrForbidden.Error(), StatusCode: expectedStatus,
Error: http.StatusText(expectedStatus),
Details: user_account.ErrForbidden.Error(),
StackTrace: actual.StackTrace,
} }
if diff := cmpDiff(t, expected, actual); diff { if diff := cmpDiff(t, expected, actual); diff {
@ -896,7 +926,8 @@ func TestUserAccountDelete(t *testing.T) {
} }
expected := weberror.ErrorResponse{ expected := weberror.ErrorResponse{
Error: "Field validation error", StatusCode: expectedStatus,
Error: "Field validation error",
Fields: []weberror.FieldError{ Fields: []weberror.FieldError{
//{Field: "user_id", Error: "Key: 'UserAccountDeleteRequest.user_id' Error:Field validation for 'user_id' failed on the 'uuid' tag"}, //{Field: "user_id", Error: "Key: 'UserAccountDeleteRequest.user_id' Error:Field validation for 'user_id' failed on the 'uuid' tag"},
//{Field: "account_id", Error: "Key: 'UserAccountDeleteRequest.account_id' Error:Field validation for 'account_id' failed on the 'uuid' tag"}, //{Field: "account_id", Error: "Key: 'UserAccountDeleteRequest.account_id' Error:Field validation for 'account_id' failed on the 'uuid' tag"},
@ -915,6 +946,8 @@ func TestUserAccountDelete(t *testing.T) {
Display: "account_id must be a valid UUID", Display: "account_id must be a valid UUID",
}, },
}, },
Details: actual.Details,
StackTrace: actual.StackTrace,
} }
if diff := cmpDiff(t, expected, actual); diff { if diff := cmpDiff(t, expected, actual); diff {
@ -956,7 +989,10 @@ func TestUserAccountDelete(t *testing.T) {
} }
expected := weberror.ErrorResponse{ expected := weberror.ErrorResponse{
Error: user_account.ErrForbidden.Error(), StatusCode: expectedStatus,
Error: http.StatusText(expectedStatus),
Details: user_account.ErrForbidden.Error(),
StackTrace: actual.StackTrace,
} }
if diff := cmpDiff(t, expected, actual); diff { if diff := cmpDiff(t, expected, actual); diff {

View File

@ -4,6 +4,7 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"geeks-accelerator/oss/saas-starter-kit/internal/user_auth"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"testing" "testing"
@ -105,6 +106,8 @@ func TestUserCRUDAdmin(t *testing.T) {
"created_at": web.NewTimeResponse(ctx, actual.CreatedAt.Value), "created_at": web.NewTimeResponse(ctx, actual.CreatedAt.Value),
"first_name": req.FirstName, "first_name": req.FirstName,
"last_name": req.LastName, "last_name": req.LastName,
"name": req.FirstName + " " + req.LastName,
"gravatar": web.NewGravatarResponse(ctx, actual.Email),
} }
var expected user.UserResponse var expected user.UserResponse
@ -197,7 +200,10 @@ func TestUserCRUDAdmin(t *testing.T) {
} }
expected := weberror.ErrorResponse{ expected := weberror.ErrorResponse{
Error: fmt.Sprintf("user %s not found: Entity not found", randID), StatusCode: expectedStatus,
Error: http.StatusText(expectedStatus),
Details: fmt.Sprintf("user %s not found: Entity not found", randID),
StackTrace: actual.StackTrace,
} }
if diff := cmpDiff(t, actual, expected); diff { if diff := cmpDiff(t, actual, expected); diff {
@ -235,7 +241,10 @@ func TestUserCRUDAdmin(t *testing.T) {
} }
expected := weberror.ErrorResponse{ expected := weberror.ErrorResponse{
Error: fmt.Sprintf("user %s not found: Entity not found", tr.ForbiddenUser.ID), StatusCode: expectedStatus,
Error: http.StatusText(expectedStatus),
Details: fmt.Sprintf("user %s not found: Entity not found", tr.ForbiddenUser.ID),
StackTrace: actual.StackTrace,
} }
if diff := cmpDiff(t, actual, expected); diff { if diff := cmpDiff(t, actual, expected); diff {
@ -419,6 +428,8 @@ func TestUserCRUDAdmin(t *testing.T) {
"token_type": actual["token_type"], "token_type": actual["token_type"],
"expiry": actual["expiry"], "expiry": actual["expiry"],
"ttl": actual["ttl"], "ttl": actual["ttl"],
"user_id": tr.User.ID,
"account_id": newAccount.ID,
} }
if diff := cmpDiff(t, actual, expected); diff { if diff := cmpDiff(t, actual, expected); diff {
@ -481,7 +492,8 @@ func TestUserCRUDUser(t *testing.T) {
t.Fatalf("\t%s\tDecode response body failed.", tests.Failed) t.Fatalf("\t%s\tDecode response body failed.", tests.Failed)
} }
expected := mid.ErrorForbidden(ctx).(*weberror.Error).Display(ctx) expected := mid.ErrorForbidden(ctx).(*weberror.Error).Response(ctx, false)
expected.StackTrace = actual.StackTrace
if diff := cmpDiff(t, actual, expected); diff { if diff := cmpDiff(t, actual, expected); diff {
t.Fatalf("\t%s\tReceived expected error.", tests.Failed) t.Fatalf("\t%s\tReceived expected error.", tests.Failed)
@ -556,7 +568,10 @@ func TestUserCRUDUser(t *testing.T) {
} }
expected := weberror.ErrorResponse{ expected := weberror.ErrorResponse{
Error: fmt.Sprintf("user %s not found: Entity not found", randID), StatusCode: expectedStatus,
Error: http.StatusText(expectedStatus),
Details: fmt.Sprintf("user %s not found: Entity not found", randID),
StackTrace: actual.StackTrace,
} }
if diff := cmpDiff(t, actual, expected); diff { if diff := cmpDiff(t, actual, expected); diff {
@ -594,7 +609,10 @@ func TestUserCRUDUser(t *testing.T) {
} }
expected := weberror.ErrorResponse{ expected := weberror.ErrorResponse{
Error: fmt.Sprintf("user %s not found: Entity not found", tr.ForbiddenUser.ID), StatusCode: expectedStatus,
Error: http.StatusText(expectedStatus),
Details: fmt.Sprintf("user %s not found: Entity not found", tr.ForbiddenUser.ID),
StackTrace: actual.StackTrace,
} }
if diff := cmpDiff(t, actual, expected); diff { if diff := cmpDiff(t, actual, expected); diff {
@ -636,7 +654,10 @@ func TestUserCRUDUser(t *testing.T) {
} }
expected := weberror.ErrorResponse{ expected := weberror.ErrorResponse{
Error: user.ErrForbidden.Error(), StatusCode: expectedStatus,
Error: http.StatusText(expectedStatus),
Details: user.ErrForbidden.Error(),
StackTrace: actual.StackTrace,
} }
if diff := cmpDiff(t, actual, expected); diff { if diff := cmpDiff(t, actual, expected); diff {
@ -679,7 +700,10 @@ func TestUserCRUDUser(t *testing.T) {
} }
expected := weberror.ErrorResponse{ expected := weberror.ErrorResponse{
Error: user.ErrForbidden.Error(), StatusCode: expectedStatus,
Error: http.StatusText(expectedStatus),
Details: user.ErrForbidden.Error(),
StackTrace: actual.StackTrace,
} }
if diff := cmpDiff(t, actual, expected); diff { if diff := cmpDiff(t, actual, expected); diff {
@ -718,7 +742,8 @@ func TestUserCRUDUser(t *testing.T) {
t.Fatalf("\t%s\tDecode response body failed.", tests.Failed) t.Fatalf("\t%s\tDecode response body failed.", tests.Failed)
} }
expected := mid.ErrorForbidden(ctx).(*weberror.Error).Display(ctx) expected := mid.ErrorForbidden(ctx).(*weberror.Error).Response(ctx, false)
expected.StackTrace = actual.StackTrace
if diff := cmpDiff(t, actual, expected); diff { if diff := cmpDiff(t, actual, expected); diff {
t.Fatalf("\t%s\tReceived expected error.", tests.Failed) t.Fatalf("\t%s\tReceived expected error.", tests.Failed)
@ -754,7 +779,8 @@ func TestUserCRUDUser(t *testing.T) {
t.Fatalf("\t%s\tDecode response body failed.", tests.Failed) t.Fatalf("\t%s\tDecode response body failed.", tests.Failed)
} }
expected := mid.ErrorForbidden(ctx).(*weberror.Error).Display(ctx) expected := mid.ErrorForbidden(ctx).(*weberror.Error).Response(ctx, false)
expected.StackTrace = actual.StackTrace
if diff := cmpDiff(t, actual, expected); diff { if diff := cmpDiff(t, actual, expected); diff {
t.Fatalf("\t%s\tReceived expected error.", tests.Failed) t.Fatalf("\t%s\tReceived expected error.", tests.Failed)
@ -806,6 +832,8 @@ func TestUserCRUDUser(t *testing.T) {
"token_type": actual["token_type"], "token_type": actual["token_type"],
"expiry": actual["expiry"], "expiry": actual["expiry"],
"ttl": actual["ttl"], "ttl": actual["ttl"],
"user_id": tr.User.ID,
"account_id": newAccount.ID,
} }
if diff := cmpDiff(t, actual, expected); diff { if diff := cmpDiff(t, actual, expected); diff {
@ -870,7 +898,8 @@ func TestUserCreate(t *testing.T) {
} }
expected := weberror.ErrorResponse{ expected := weberror.ErrorResponse{
Error: "Field validation error", StatusCode: expectedStatus,
Error: "Field validation error",
Fields: []weberror.FieldError{ Fields: []weberror.FieldError{
//{Field: "email", Error: "Key: 'UserCreateRequest.email' Error:Field validation for 'email' failed on the 'email' tag"}, //{Field: "email", Error: "Key: 'UserCreateRequest.email' Error:Field validation for 'email' failed on the 'email' tag"},
{ {
@ -881,6 +910,8 @@ func TestUserCreate(t *testing.T) {
Display: "email must be a valid email address", Display: "email must be a valid email address",
}, },
}, },
Details: actual.Details,
StackTrace: actual.StackTrace,
} }
if diff := cmpDiff(t, actual, expected); diff { if diff := cmpDiff(t, actual, expected); diff {
@ -932,7 +963,8 @@ func TestUserUpdate(t *testing.T) {
} }
expected := weberror.ErrorResponse{ expected := weberror.ErrorResponse{
Error: "Field validation error", StatusCode: expectedStatus,
Error: "Field validation error",
Fields: []weberror.FieldError{ Fields: []weberror.FieldError{
//{Field: "email", Error: "Key: 'UserUpdateRequest.email' Error:Field validation for 'email' failed on the 'email' tag"}, //{Field: "email", Error: "Key: 'UserUpdateRequest.email' Error:Field validation for 'email' failed on the 'email' tag"},
{ {
@ -943,6 +975,8 @@ func TestUserUpdate(t *testing.T) {
Display: "email must be a valid email address", Display: "email must be a valid email address",
}, },
}, },
Details: actual.Details,
StackTrace: actual.StackTrace,
} }
if diff := cmpDiff(t, actual, expected); diff { if diff := cmpDiff(t, actual, expected); diff {
@ -1000,7 +1034,8 @@ func TestUserUpdatePassword(t *testing.T) {
} }
expected := weberror.ErrorResponse{ expected := weberror.ErrorResponse{
Error: "Field validation error", StatusCode: expectedStatus,
Error: "Field validation error",
Fields: []weberror.FieldError{ Fields: []weberror.FieldError{
//{Field: "password_confirm", Error: "Key: 'UserUpdatePasswordRequest.password_confirm' Error:Field validation for 'password_confirm' failed on the 'eqfield' tag"}, //{Field: "password_confirm", Error: "Key: 'UserUpdatePasswordRequest.password_confirm' Error:Field validation for 'password_confirm' failed on the 'eqfield' tag"},
{ {
@ -1011,6 +1046,8 @@ func TestUserUpdatePassword(t *testing.T) {
Display: "password_confirm must be equal to Password", Display: "password_confirm must be equal to Password",
}, },
}, },
Details: actual.Details,
StackTrace: actual.StackTrace,
} }
if diff := cmpDiff(t, actual, expected); diff { if diff := cmpDiff(t, actual, expected); diff {
@ -1062,7 +1099,8 @@ func TestUserArchive(t *testing.T) {
} }
expected := weberror.ErrorResponse{ expected := weberror.ErrorResponse{
Error: "Field validation error", StatusCode: expectedStatus,
Error: "Field validation error",
Fields: []weberror.FieldError{ Fields: []weberror.FieldError{
//{Field: "id", Error: "Key: 'UserArchiveRequest.id' Error:Field validation for 'id' failed on the 'uuid' tag"}, //{Field: "id", Error: "Key: 'UserArchiveRequest.id' Error:Field validation for 'id' failed on the 'uuid' tag"},
{ {
@ -1073,6 +1111,8 @@ func TestUserArchive(t *testing.T) {
Display: "id must be a valid UUID", Display: "id must be a valid UUID",
}, },
}, },
Details: actual.Details,
StackTrace: actual.StackTrace,
} }
if diff := cmpDiff(t, actual, expected); diff { if diff := cmpDiff(t, actual, expected); diff {
@ -1112,7 +1152,10 @@ func TestUserArchive(t *testing.T) {
} }
expected := weberror.ErrorResponse{ expected := weberror.ErrorResponse{
Error: user.ErrForbidden.Error(), StatusCode: expectedStatus,
Error: http.StatusText(expectedStatus),
Details: user.ErrForbidden.Error(),
StackTrace: actual.StackTrace,
} }
if diff := cmpDiff(t, actual, expected); diff { if diff := cmpDiff(t, actual, expected); diff {
@ -1162,7 +1205,8 @@ func TestUserDelete(t *testing.T) {
} }
expected := weberror.ErrorResponse{ expected := weberror.ErrorResponse{
Error: "Field validation error", StatusCode: expectedStatus,
Error: "Field validation error",
Fields: []weberror.FieldError{ Fields: []weberror.FieldError{
//{Field: "id", Error: "Key: 'id' Error:Field validation for 'id' failed on the 'uuid' tag"}, //{Field: "id", Error: "Key: 'id' Error:Field validation for 'id' failed on the 'uuid' tag"},
{ {
@ -1173,6 +1217,8 @@ func TestUserDelete(t *testing.T) {
Display: "id must be a valid UUID", Display: "id must be a valid UUID",
}, },
}, },
Details: actual.Details,
StackTrace: actual.StackTrace,
} }
if diff := cmpDiff(t, actual, expected); diff { if diff := cmpDiff(t, actual, expected); diff {
@ -1210,7 +1256,10 @@ func TestUserDelete(t *testing.T) {
} }
expected := weberror.ErrorResponse{ expected := weberror.ErrorResponse{
Error: user.ErrForbidden.Error(), StatusCode: expectedStatus,
Error: http.StatusText(expectedStatus),
Details: user.ErrForbidden.Error(),
StackTrace: actual.StackTrace,
} }
if diff := cmpDiff(t, actual, expected); diff { if diff := cmpDiff(t, actual, expected); diff {
@ -1260,7 +1309,8 @@ func TestUserSwitchAccount(t *testing.T) {
} }
expected := weberror.ErrorResponse{ expected := weberror.ErrorResponse{
Error: "Field validation error", StatusCode: expectedStatus,
Error: "Field validation error",
Fields: []weberror.FieldError{ Fields: []weberror.FieldError{
{ {
Field: "account_id", Field: "account_id",
@ -1270,6 +1320,8 @@ func TestUserSwitchAccount(t *testing.T) {
Display: "account_id must be a valid UUID", Display: "account_id must be a valid UUID",
}, },
}, },
Details: actual.Details,
StackTrace: actual.StackTrace,
} }
if diff := cmpDiff(t, expected, actual); diff { if diff := cmpDiff(t, expected, actual); diff {
@ -1307,7 +1359,10 @@ func TestUserSwitchAccount(t *testing.T) {
} }
expected := weberror.ErrorResponse{ expected := weberror.ErrorResponse{
Error: user.ErrAuthenticationFailure.Error(), StatusCode: expectedStatus,
Error: http.StatusText(expectedStatus),
Details: user_auth.ErrAuthenticationFailure.Error(),
StackTrace: actual.StackTrace,
} }
if diff := cmpDiff(t, actual, expected); diff { if diff := cmpDiff(t, actual, expected); diff {
@ -1330,7 +1385,7 @@ func TestUserToken(t *testing.T) {
http.MethodPost, http.MethodPost,
"/v1/oauth/token", "/v1/oauth/token",
nil, nil,
user.Token{}, user_auth.Token{},
auth.Claims{}, auth.Claims{},
expectedStatus, expectedStatus,
nil, nil,
@ -1350,7 +1405,10 @@ func TestUserToken(t *testing.T) {
} }
expected := weberror.ErrorResponse{ expected := weberror.ErrorResponse{
Error: "must provide email and password in Basic auth", StatusCode: expectedStatus,
Error: http.StatusText(expectedStatus),
Details: "must provide email and password in Basic auth",
StackTrace: actual.StackTrace,
} }
if diff := cmpDiff(t, actual, expected); diff { if diff := cmpDiff(t, actual, expected); diff {
@ -1368,7 +1426,7 @@ func TestUserToken(t *testing.T) {
http.MethodPost, http.MethodPost,
"/v1/oauth/token", "/v1/oauth/token",
nil, nil,
user.Token{}, user_auth.Token{},
auth.Claims{}, auth.Claims{},
expectedStatus, expectedStatus,
nil, nil,
@ -1397,7 +1455,10 @@ func TestUserToken(t *testing.T) {
} }
expected := weberror.ErrorResponse{ expected := weberror.ErrorResponse{
Error: user.ErrAuthenticationFailure.Error(), StatusCode: expectedStatus,
Error: http.StatusText(expectedStatus),
Details: user_auth.ErrAuthenticationFailure.Error(),
StackTrace: actual.StackTrace,
} }
if diff := cmpDiff(t, actual, expected); diff { if diff := cmpDiff(t, actual, expected); diff {
@ -1416,7 +1477,7 @@ func TestUserToken(t *testing.T) {
http.MethodPost, http.MethodPost,
"/v1/oauth/token", "/v1/oauth/token",
nil, nil,
user.Token{}, user_auth.Token{},
auth.Claims{}, auth.Claims{},
expectedStatus, expectedStatus,
nil, nil,
@ -1445,7 +1506,10 @@ func TestUserToken(t *testing.T) {
} }
expected := weberror.ErrorResponse{ expected := weberror.ErrorResponse{
Error: user.ErrAuthenticationFailure.Error(), StatusCode: expectedStatus,
Error: http.StatusText(expectedStatus),
Details: user_auth.ErrAuthenticationFailure.Error(),
StackTrace: actual.StackTrace,
} }
if diff := cmpDiff(t, actual, expected); diff { if diff := cmpDiff(t, actual, expected); diff {
@ -1463,9 +1527,9 @@ func TestUserToken(t *testing.T) {
rt := requestTest{ rt := requestTest{
fmt.Sprintf("Token %d w/role %s using valid credentials", expectedStatus, tr.Role), fmt.Sprintf("Token %d w/role %s using valid credentials", expectedStatus, tr.Role),
http.MethodPost, http.MethodPost,
"/v1/oauth/token", "/v1/oauth/token?account_id=" + tr.Account.ID,
nil, nil,
user.Token{}, user_auth.Token{},
auth.Claims{}, auth.Claims{},
expectedStatus, expectedStatus,
nil, nil,
@ -1499,6 +1563,8 @@ func TestUserToken(t *testing.T) {
"token_type": actual["token_type"], "token_type": actual["token_type"],
"expiry": actual["expiry"], "expiry": actual["expiry"],
"ttl": actual["ttl"], "ttl": actual["ttl"],
"user_id": tr.User.ID,
"account_id": tr.Account.ID,
} }
if diff := cmpDiff(t, actual, expected); diff { if diff := cmpDiff(t, actual, expected); diff {

View File

@ -0,0 +1,261 @@
package handlers
import (
"context"
"net/http"
"time"
"geeks-accelerator/oss/saas-starter-kit/internal/account"
"geeks-accelerator/oss/saas-starter-kit/internal/account/account_preference"
"geeks-accelerator/oss/saas-starter-kit/internal/geonames"
"geeks-accelerator/oss/saas-starter-kit/internal/platform/auth"
"geeks-accelerator/oss/saas-starter-kit/internal/platform/web"
"geeks-accelerator/oss/saas-starter-kit/internal/platform/web/webcontext"
"geeks-accelerator/oss/saas-starter-kit/internal/platform/web/weberror"
"github.com/gorilla/schema"
"github.com/jmoiron/sqlx"
"github.com/pkg/errors"
)
// Account represents the Account API method handler set.
type Account struct {
MasterDB *sqlx.DB
Renderer web.Renderer
}
// View handles displaying the current account profile.
func (h *Account) View(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error {
data := make(map[string]interface{})
f := func() error {
claims, err := auth.ClaimsFromContext(ctx)
if err != nil {
return err
}
acc, err := account.Read(ctx, claims, h.MasterDB, claims.Audience, false)
if err != nil {
return err
}
data["account"] = acc.Response(ctx)
return nil
}
if err := f(); err != nil {
return web.RenderError(ctx, w, r, err, h.Renderer, TmplLayoutBase, TmplContentErrorGeneric, web.MIMETextHTMLCharsetUTF8)
}
return h.Renderer.Render(ctx, w, r, TmplLayoutBase, "account-view.gohtml", web.MIMETextHTMLCharsetUTF8, http.StatusOK, data)
}
type AccountUpdateRequest struct {
account.AccountUpdateRequest
PreferenceDatetimeFormat string
PreferenceDateFormat string
PreferenceTimeFormat string
}
// Update handles allowing the current user to update their account.
func (h *Account) Update(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error {
ctxValues, err := webcontext.ContextValues(ctx)
if err != nil {
return err
}
//
req := new(AccountUpdateRequest)
data := make(map[string]interface{})
f := func() (bool, error) {
claims, err := auth.ClaimsFromContext(ctx)
if err != nil {
return false, err
}
prefs, err := account_preference.FindByAccountID(ctx, claims, h.MasterDB, account_preference.AccountPreferenceFindByAccountIDRequest{
AccountID: claims.Audience,
})
if err != nil {
return false, err
}
var (
preferenceDatetimeFormat string
preferenceDateFormat string
preferenceTimeFormat string
)
for _, pref := range prefs {
switch pref.Name {
case account_preference.AccountPreference_Datetime_Format:
preferenceDatetimeFormat = pref.Value
case account_preference.AccountPreference_Date_Format:
preferenceDateFormat = pref.Value
case account_preference.AccountPreference_Time_Format:
preferenceTimeFormat = pref.Value
}
}
if r.Method == http.MethodPost {
err := r.ParseForm()
if err != nil {
return false, err
}
decoder := schema.NewDecoder()
decoder.IgnoreUnknownKeys(true)
if err := decoder.Decode(req, r.PostForm); err != nil {
return false, err
}
req.ID = claims.Audience
err = account.Update(ctx, claims, h.MasterDB, req.AccountUpdateRequest, ctxValues.Now)
if err != nil {
switch errors.Cause(err) {
default:
if verr, ok := weberror.NewValidationError(ctx, err); ok {
data["validationErrors"] = verr.(*weberror.Error)
return false, nil
} else {
return false, err
}
}
}
sess := webcontext.ContextSession(ctx)
if preferenceDatetimeFormat != req.PreferenceDatetimeFormat {
err = account_preference.Set(ctx, claims, h.MasterDB, account_preference.AccountPreferenceSetRequest{
AccountID: claims.Audience,
Name: account_preference.AccountPreference_Datetime_Format,
Value: req.PreferenceDatetimeFormat,
}, ctxValues.Now)
if err != nil {
if verr, ok := weberror.NewValidationError(ctx, err); ok {
data["validationErrors"] = verr.(*weberror.Error)
return false, nil
} else {
return false, err
}
}
sess.Values[webcontext.SessionKeyPreferenceDatetimeFormat] = req.PreferenceDatetimeFormat
}
if preferenceDateFormat != req.PreferenceDateFormat {
err = account_preference.Set(ctx, claims, h.MasterDB, account_preference.AccountPreferenceSetRequest{
AccountID: claims.Audience,
Name: account_preference.AccountPreference_Date_Format,
Value: req.PreferenceDateFormat,
}, ctxValues.Now)
if err != nil {
if verr, ok := weberror.NewValidationError(ctx, err); ok {
data["validationErrors"] = verr.(*weberror.Error)
return false, nil
} else {
return false, err
}
}
sess.Values[webcontext.SessionKeyPreferenceDateFormat] = req.PreferenceDateFormat
}
if preferenceTimeFormat != req.PreferenceTimeFormat {
err = account_preference.Set(ctx, claims, h.MasterDB, account_preference.AccountPreferenceSetRequest{
AccountID: claims.Audience,
Name: account_preference.AccountPreference_Time_Format,
Value: req.PreferenceTimeFormat,
}, ctxValues.Now)
if err != nil {
if verr, ok := weberror.NewValidationError(ctx, err); ok {
data["validationErrors"] = verr.(*weberror.Error)
return false, nil
} else {
return false, err
}
}
sess.Values[webcontext.SessionKeyPreferenceTimeFormat] = req.PreferenceTimeFormat
}
// Display a success message to the user.
webcontext.SessionFlashSuccess(ctx,
"Account Updated",
"Account profile successfully updated.")
err = webcontext.ContextSession(ctx).Save(r, w)
if err != nil {
return false, err
}
http.Redirect(w, r, "/account", http.StatusFound)
return true, nil
}
acc, err := account.Read(ctx, claims, h.MasterDB, claims.Audience, false)
if err != nil {
return false, err
}
if preferenceDatetimeFormat == "" {
preferenceDatetimeFormat = account_preference.AccountPreference_Datetime_Format_Default
}
if preferenceDateFormat == "" {
preferenceDateFormat = account_preference.AccountPreference_Date_Format_Default
}
if preferenceTimeFormat == "" {
preferenceTimeFormat = account_preference.AccountPreference_Time_Format_Default
}
if req.ID == "" {
req.Name = &acc.Name
req.Address1 = &acc.Address1
req.Address2 = &acc.Address2
req.City = &acc.City
req.Region = &acc.Region
req.Country = &acc.Country
req.Zipcode = &acc.Zipcode
req.Timezone = &acc.Timezone
req.PreferenceDatetimeFormat = preferenceDatetimeFormat
req.PreferenceDateFormat = preferenceDateFormat
req.PreferenceTimeFormat = preferenceTimeFormat
}
data["account"] = acc.Response(ctx)
data["timezones"], err = geonames.ListTimezones(ctx, h.MasterDB)
if err != nil {
return false, err
}
data["geonameCountries"] = geonames.ValidGeonameCountries
data["countries"], err = geonames.FindCountries(ctx, h.MasterDB, "name", "")
if err != nil {
return false, err
}
return false, nil
}
end, err := f()
if err != nil {
return web.RenderError(ctx, w, r, err, h.Renderer, TmplLayoutBase, TmplContentErrorGeneric, web.MIMETextHTMLCharsetUTF8)
} else if end {
return nil
}
data["form"] = req
data["exampleDisplayTime"] = web.NewTimeResponse(ctx, time.Now().UTC())
if verr, ok := weberror.NewValidationError(ctx, webcontext.Validator().Struct(account.AccountUpdateRequest{})); ok {
data["validationDefaults"] = verr.(*weberror.Error)
}
return h.Renderer.Render(ctx, w, r, TmplLayoutBase, "account-update.gohtml", web.MIMETextHTMLCharsetUTF8, http.StatusOK, data)
}

View File

@ -110,7 +110,7 @@ func (h *Examples) FlashMessages(ctx context.Context, w http.ResponseWriter, r *
} }
if err := f(); err != nil { if err := f(); err != nil {
return web.RenderError(ctx, w, r, err, h.Renderer, tmplLayoutBase, tmplContentErrorGeneric, web.MIMETextHTMLCharsetUTF8) return web.RenderError(ctx, w, r, err, h.Renderer, TmplLayoutBase, TmplContentErrorGeneric, web.MIMETextHTMLCharsetUTF8)
} }
data["form"] = req data["form"] = req
@ -120,7 +120,7 @@ func (h *Examples) FlashMessages(ctx context.Context, w http.ResponseWriter, r *
} }
} }
return h.Renderer.Render(ctx, w, r, tmplLayoutBase, "examples-flash-messages.gohtml", web.MIMETextHTMLCharsetUTF8, http.StatusOK, data) return h.Renderer.Render(ctx, w, r, TmplLayoutBase, "examples-flash-messages.gohtml", web.MIMETextHTMLCharsetUTF8, http.StatusOK, data)
} }
// Images provides examples for responsive images that are auto re-sized. // Images provides examples for responsive images that are auto re-sized.
@ -132,5 +132,5 @@ func (h *Examples) Images(ctx context.Context, w http.ResponseWriter, r *http.Re
"imgSizes": []int{100, 200, 300, 400, 500}, "imgSizes": []int{100, 200, 300, 400, 500},
} }
return h.Renderer.Render(ctx, w, r, tmplLayoutBase, "examples-images.gohtml", web.MIMETextHTMLCharsetUTF8, http.StatusOK, data) return h.Renderer.Render(ctx, w, r, TmplLayoutBase, "examples-images.gohtml", web.MIMETextHTMLCharsetUTF8, http.StatusOK, data)
} }

View File

@ -17,5 +17,5 @@ type Projects struct {
// List returns all the existing users in the system. // List returns all the existing users in the system.
func (p *Projects) Index(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { func (p *Projects) Index(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error {
return p.Renderer.Render(ctx, w, r, tmplLayoutBase, "projects-index.tmpl", web.MIMETextHTMLCharsetUTF8, http.StatusOK, nil) return p.Renderer.Render(ctx, w, r, TmplLayoutBase, "projects-index.tmpl", web.MIMETextHTMLCharsetUTF8, http.StatusOK, nil)
} }

View File

@ -35,7 +35,7 @@ func (h *Root) indexDashboard(ctx context.Context, w http.ResponseWriter, r *htt
"imgSizes": []int{100, 200, 300, 400, 500}, "imgSizes": []int{100, 200, 300, 400, 500},
} }
return h.Renderer.Render(ctx, w, r, tmplLayoutBase, "root-dashboard.gohtml", web.MIMETextHTMLCharsetUTF8, http.StatusOK, data) return h.Renderer.Render(ctx, w, r, TmplLayoutBase, "root-dashboard.gohtml", web.MIMETextHTMLCharsetUTF8, http.StatusOK, data)
} }
// indexDefault loads the root index page when a user has no authentication. // indexDefault loads the root index page when a user has no authentication.
@ -50,21 +50,21 @@ func (u *Root) SitePage(ctx context.Context, w http.ResponseWriter, r *http.Requ
var tmpName string var tmpName string
switch r.RequestURI { switch r.RequestURI {
case "/": case "/":
tmpName = "site-index.gohtml" tmpName = "site-index.gohtml"
case "/api": case "/api":
tmpName = "site-api.gohtml" tmpName = "site-api.gohtml"
case "/features": case "/features":
tmpName = "site-features.gohtml" tmpName = "site-features.gohtml"
case "/support": case "/support":
tmpName = "site-support.gohtml" tmpName = "site-support.gohtml"
case "/legal/privacy": case "/legal/privacy":
tmpName = "legal-privacy.gohtml" tmpName = "legal-privacy.gohtml"
case "/legal/terms": case "/legal/terms":
tmpName = "legal-terms.gohtml" tmpName = "legal-terms.gohtml"
default: default:
http.Redirect(w, r, "/", http.StatusFound) http.Redirect(w, r, "/", http.StatusFound)
return nil return nil
} }
return u.Renderer.Render(ctx, w, r, tmplLayoutSite, tmpName, web.MIMETextHTMLCharsetUTF8, http.StatusOK, nil) return u.Renderer.Render(ctx, w, r, tmplLayoutSite, tmpName, web.MIMETextHTMLCharsetUTF8, http.StatusOK, nil)

View File

@ -19,9 +19,9 @@ import (
) )
const ( const (
tmplLayoutBase = "base.gohtml" TmplLayoutBase = "base.gohtml"
tmplLayoutSite = "site.gohtml" tmplLayoutSite = "site.gohtml"
tmplContentErrorGeneric = "error-generic.gohtml" TmplContentErrorGeneric = "error-generic.gohtml"
) )
// API returns a handler for a set of routes. // API returns a handler for a set of routes.
@ -29,7 +29,7 @@ func APP(shutdown chan os.Signal, log *log.Logger, env webcontext.Env, staticDir
// Define base middlewares applied to all requests. // Define base middlewares applied to all requests.
middlewares := []web.Middleware{ middlewares := []web.Middleware{
mid.Trace(), mid.Logger(log), mid.Errors(log), mid.Metrics(), mid.Panics(), mid.Trace(), mid.Logger(log), mid.Errors(log, renderer), mid.Metrics(), mid.Panics(),
} }
// Append any global middlewares if they were included. // Append any global middlewares if they were included.
@ -56,7 +56,6 @@ func APP(shutdown chan os.Signal, log *log.Logger, env webcontext.Env, staticDir
NotifyEmail: notifyEmail, NotifyEmail: notifyEmail,
SecretKey: secretKey, SecretKey: secretKey,
} }
// This route is not authenticated
app.Handle("POST", "/user/login", u.Login) app.Handle("POST", "/user/login", u.Login)
app.Handle("GET", "/user/login", u.Login) app.Handle("GET", "/user/login", u.Login)
app.Handle("GET", "/user/logout", u.Logout) app.Handle("GET", "/user/logout", u.Logout)
@ -64,6 +63,21 @@ func APP(shutdown chan os.Signal, log *log.Logger, env webcontext.Env, staticDir
app.Handle("GET", "/user/reset-password/:hash", u.ResetConfirm) app.Handle("GET", "/user/reset-password/:hash", u.ResetConfirm)
app.Handle("POST", "/user/reset-password", u.ResetPassword) app.Handle("POST", "/user/reset-password", u.ResetPassword)
app.Handle("GET", "/user/reset-password", u.ResetPassword) app.Handle("GET", "/user/reset-password", u.ResetPassword)
app.Handle("POST", "/user/update", u.Update, mid.AuthenticateSessionRequired(authenticator), mid.HasAuth())
app.Handle("GET", "/user/update", u.Update, mid.AuthenticateSessionRequired(authenticator), mid.HasAuth())
app.Handle("GET", "/user/account", u.Account, mid.AuthenticateSessionRequired(authenticator), mid.HasAuth())
app.Handle("POST", "/user", u.View, mid.AuthenticateSessionRequired(authenticator), mid.HasAuth())
app.Handle("GET", "/user", u.View, mid.AuthenticateSessionRequired(authenticator), mid.HasAuth())
// Register account management endpoints.
acc := Account{
MasterDB: masterDB,
Renderer: renderer,
}
app.Handle("POST", "/account/update", acc.Update, mid.AuthenticateSessionRequired(authenticator), mid.HasRole(auth.RoleAdmin))
app.Handle("GET", "/account/update", acc.Update, mid.AuthenticateSessionRequired(authenticator), mid.HasRole(auth.RoleAdmin))
app.Handle("POST", "/account", acc.View, mid.AuthenticateSessionRequired(authenticator), mid.HasRole(auth.RoleAdmin))
app.Handle("GET", "/account", acc.View, mid.AuthenticateSessionRequired(authenticator), mid.HasRole(auth.RoleAdmin))
// Register user management and authentication endpoints. // Register user management and authentication endpoints.
s := Signup{ s := Signup{
@ -79,7 +93,6 @@ func APP(shutdown chan os.Signal, log *log.Logger, env webcontext.Env, staticDir
ex := Examples{ ex := Examples{
Renderer: renderer, Renderer: renderer,
} }
// This route is not authenticated
app.Handle("POST", "/examples/flash-messages", ex.FlashMessages) app.Handle("POST", "/examples/flash-messages", ex.FlashMessages)
app.Handle("GET", "/examples/flash-messages", ex.FlashMessages) app.Handle("GET", "/examples/flash-messages", ex.FlashMessages)
app.Handle("GET", "/examples/images", ex.Images) app.Handle("GET", "/examples/images", ex.Images)
@ -89,7 +102,6 @@ func APP(shutdown chan os.Signal, log *log.Logger, env webcontext.Env, staticDir
MasterDB: masterDB, MasterDB: masterDB,
Redis: redis, Redis: redis,
} }
// These routes are not authenticated
app.Handle("GET", "/geo/regions/autocomplete", g.RegionsAutocomplete) app.Handle("GET", "/geo/regions/autocomplete", g.RegionsAutocomplete)
app.Handle("GET", "/geo/postal_codes/autocomplete", g.PostalCodesAutocomplete) app.Handle("GET", "/geo/postal_codes/autocomplete", g.PostalCodesAutocomplete)
app.Handle("GET", "/geo/geonames/postal_code/:postalCode", g.GeonameByPostalCode) app.Handle("GET", "/geo/geonames/postal_code/:postalCode", g.GeonameByPostalCode)
@ -101,8 +113,6 @@ func APP(shutdown chan os.Signal, log *log.Logger, env webcontext.Env, staticDir
Renderer: renderer, Renderer: renderer,
ProjectRoutes: projectRoutes, ProjectRoutes: projectRoutes,
} }
// These routes is not authenticated
app.Handle("GET", "/api", r.SitePage) app.Handle("GET", "/api", r.SitePage)
app.Handle("GET", "/features", r.SitePage) app.Handle("GET", "/features", r.SitePage)
app.Handle("GET", "/support", r.SitePage) app.Handle("GET", "/support", r.SitePage)
@ -131,7 +141,7 @@ func APP(shutdown chan os.Signal, log *log.Logger, env webcontext.Env, staticDir
err = weberror.NewError(ctx, err, http.StatusInternalServerError) err = weberror.NewError(ctx, err, http.StatusInternalServerError)
} }
return web.RenderError(ctx, w, r, err, renderer, tmplLayoutBase, tmplContentErrorGeneric, web.MIMETextHTMLCharsetUTF8) return web.RenderError(ctx, w, r, err, renderer, TmplLayoutBase, TmplContentErrorGeneric, web.MIMETextHTMLCharsetUTF8)
} }
return nil return nil

View File

@ -83,6 +83,10 @@ func (h *Signup) Step1(ctx context.Context, w http.ResponseWriter, r *http.Reque
webcontext.SessionFlashSuccess(ctx, webcontext.SessionFlashSuccess(ctx,
"Thank you for Joining", "Thank you for Joining",
"You workflow will be a breeze starting today.") "You workflow will be a breeze starting today.")
err = webcontext.ContextSession(ctx).Save(r, w)
if err != nil {
return err
}
// Redirect the user to the dashboard. // Redirect the user to the dashboard.
http.Redirect(w, r, "/", http.StatusFound) http.Redirect(w, r, "/", http.StatusFound)
@ -100,7 +104,7 @@ func (h *Signup) Step1(ctx context.Context, w http.ResponseWriter, r *http.Reque
} }
if err := f(); err != nil { if err := f(); err != nil {
return web.RenderError(ctx, w, r, err, h.Renderer, tmplLayoutBase, tmplContentErrorGeneric, web.MIMETextHTMLCharsetUTF8) return web.RenderError(ctx, w, r, err, h.Renderer, TmplLayoutBase, TmplContentErrorGeneric, web.MIMETextHTMLCharsetUTF8)
} }
data["form"] = req data["form"] = req
@ -109,5 +113,5 @@ func (h *Signup) Step1(ctx context.Context, w http.ResponseWriter, r *http.Reque
data["validationDefaults"] = verr.(*weberror.Error) data["validationDefaults"] = verr.(*weberror.Error)
} }
return h.Renderer.Render(ctx, w, r, tmplLayoutBase, "signup-step1.gohtml", web.MIMETextHTMLCharsetUTF8, http.StatusOK, data) return h.Renderer.Render(ctx, w, r, TmplLayoutBase, "signup-step1.gohtml", web.MIMETextHTMLCharsetUTF8, http.StatusOK, data)
} }

View File

@ -3,17 +3,19 @@ package handlers
import ( import (
"context" "context"
"fmt" "fmt"
"geeks-accelerator/oss/saas-starter-kit/internal/platform/notify"
project_routes "geeks-accelerator/oss/saas-starter-kit/internal/project-routes"
"net/http" "net/http"
"time" "time"
"geeks-accelerator/oss/saas-starter-kit/internal/account" "geeks-accelerator/oss/saas-starter-kit/internal/account"
"geeks-accelerator/oss/saas-starter-kit/internal/geonames"
"geeks-accelerator/oss/saas-starter-kit/internal/platform/auth" "geeks-accelerator/oss/saas-starter-kit/internal/platform/auth"
"geeks-accelerator/oss/saas-starter-kit/internal/platform/notify"
"geeks-accelerator/oss/saas-starter-kit/internal/platform/web" "geeks-accelerator/oss/saas-starter-kit/internal/platform/web"
"geeks-accelerator/oss/saas-starter-kit/internal/platform/web/webcontext" "geeks-accelerator/oss/saas-starter-kit/internal/platform/web/webcontext"
"geeks-accelerator/oss/saas-starter-kit/internal/platform/web/weberror" "geeks-accelerator/oss/saas-starter-kit/internal/platform/web/weberror"
project_routes "geeks-accelerator/oss/saas-starter-kit/internal/project-routes"
"geeks-accelerator/oss/saas-starter-kit/internal/user" "geeks-accelerator/oss/saas-starter-kit/internal/user"
"geeks-accelerator/oss/saas-starter-kit/internal/user_account"
"github.com/gorilla/schema" "github.com/gorilla/schema"
"github.com/gorilla/sessions" "github.com/gorilla/sessions"
"github.com/jmoiron/sqlx" "github.com/jmoiron/sqlx"
@ -36,7 +38,7 @@ type UserLoginRequest struct {
RememberMe bool RememberMe bool
} }
// List returns all the existing users in the system. // Login handles authenticating a user into the system.
func (h *User) Login(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { func (h *User) Login(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error {
ctxValues, err := webcontext.ContextValues(ctx) ctxValues, err := webcontext.ContextValues(ctx)
@ -60,15 +62,6 @@ func (h *User) Login(ctx context.Context, w http.ResponseWriter, r *http.Request
return err return err
} }
if err := webcontext.Validator().Struct(req); err != nil {
if ne, ok := weberror.NewValidationError(ctx, err); ok {
data["validationErrors"] = ne.(*weberror.Error)
return nil
} else {
return err
}
}
sessionTTL := time.Hour sessionTTL := time.Hour
if req.RememberMe { if req.RememberMe {
sessionTTL = time.Hour * 36 sessionTTL = time.Hour * 36
@ -104,7 +97,7 @@ func (h *User) Login(ctx context.Context, w http.ResponseWriter, r *http.Request
} }
if err := f(); err != nil { if err := f(); err != nil {
return web.RenderError(ctx, w, r, err, h.Renderer, tmplLayoutBase, tmplContentErrorGeneric, web.MIMETextHTMLCharsetUTF8) return web.RenderError(ctx, w, r, err, h.Renderer, TmplLayoutBase, TmplContentErrorGeneric, web.MIMETextHTMLCharsetUTF8)
} }
data["form"] = req data["form"] = req
@ -113,7 +106,7 @@ func (h *User) Login(ctx context.Context, w http.ResponseWriter, r *http.Request
data["validationDefaults"] = verr.(*weberror.Error) data["validationDefaults"] = verr.(*weberror.Error)
} }
return h.Renderer.Render(ctx, w, r, tmplLayoutBase, "user-login.gohtml", web.MIMETextHTMLCharsetUTF8, http.StatusOK, data) return h.Renderer.Render(ctx, w, r, TmplLayoutBase, "user-login.gohtml", web.MIMETextHTMLCharsetUTF8, http.StatusOK, data)
} }
// handleSessionToken persists the access token to the session for request authentication. // handleSessionToken persists the access token to the session for request authentication.
@ -122,16 +115,6 @@ func handleSessionToken(ctx context.Context, db *sqlx.DB, w http.ResponseWriter,
return errors.New("accessToken is required.") return errors.New("accessToken is required.")
} }
usr, err := user.Read(ctx, auth.Claims{}, db, token.UserID, false )
if err != nil {
return err
}
acc, err := account.Read(ctx, auth.Claims{},db, token.AccountID, false )
if err != nil {
return err
}
sess := webcontext.ContextSession(ctx) sess := webcontext.ContextSession(ctx)
if sess.IsNew { if sess.IsNew {
@ -144,8 +127,8 @@ func handleSessionToken(ctx context.Context, db *sqlx.DB, w http.ResponseWriter,
HttpOnly: false, HttpOnly: false,
} }
sess = webcontext.SessionInit(sess, token.AccessToken, usr.Response(ctx), acc.Response(ctx)) sess = webcontext.SessionInit(sess,
token.AccessToken)
if err := sess.Save(r, w); err != nil { if err := sess.Save(r, w); err != nil {
return err return err
} }
@ -171,7 +154,7 @@ func (h *User) Logout(ctx context.Context, w http.ResponseWriter, r *http.Reques
return nil return nil
} }
// List returns all the existing users in the system. // ResetPassword allows a user to perform forgot password.
func (h *User) ResetPassword(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { func (h *User) ResetPassword(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error {
ctxValues, err := webcontext.ContextValues(ctx) ctxValues, err := webcontext.ContextValues(ctx)
@ -195,15 +178,6 @@ func (h *User) ResetPassword(ctx context.Context, w http.ResponseWriter, r *http
return err return err
} }
if err := webcontext.Validator().Struct(req); err != nil {
if ne, ok := weberror.NewValidationError(ctx, err); ok {
data["validationErrors"] = ne.(*weberror.Error)
return nil
} else {
return err
}
}
_, err = user.ResetPassword(ctx, h.MasterDB, h.ProjectRoutes.UserResetPassword, h.NotifyEmail, *req, h.SecretKey, ctxValues.Now) _, err = user.ResetPassword(ctx, h.MasterDB, h.ProjectRoutes.UserResetPassword, h.NotifyEmail, *req, h.SecretKey, ctxValues.Now)
if err != nil { if err != nil {
switch errors.Cause(err) { switch errors.Cause(err) {
@ -228,7 +202,7 @@ func (h *User) ResetPassword(ctx context.Context, w http.ResponseWriter, r *http
} }
if err := f(); err != nil { if err := f(); err != nil {
return web.RenderError(ctx, w, r, err, h.Renderer, tmplLayoutBase, tmplContentErrorGeneric, web.MIMETextHTMLCharsetUTF8) return web.RenderError(ctx, w, r, err, h.Renderer, TmplLayoutBase, TmplContentErrorGeneric, web.MIMETextHTMLCharsetUTF8)
} }
data["form"] = req data["form"] = req
@ -237,10 +211,10 @@ func (h *User) ResetPassword(ctx context.Context, w http.ResponseWriter, r *http
data["validationDefaults"] = verr.(*weberror.Error) data["validationDefaults"] = verr.(*weberror.Error)
} }
return h.Renderer.Render(ctx, w, r, tmplLayoutBase, "user-reset-password.gohtml", web.MIMETextHTMLCharsetUTF8, http.StatusOK, data) return h.Renderer.Render(ctx, w, r, TmplLayoutBase, "user-reset-password.gohtml", web.MIMETextHTMLCharsetUTF8, http.StatusOK, data)
} }
// List returns all the existing users in the system. // ResetConfirm handles changing a users password after they have clicked on the link emailed.
func (h *User) ResetConfirm(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { func (h *User) ResetConfirm(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error {
ctxValues, err := webcontext.ContextValues(ctx) ctxValues, err := webcontext.ContextValues(ctx)
@ -264,15 +238,6 @@ func (h *User) ResetConfirm(ctx context.Context, w http.ResponseWriter, r *http.
return err return err
} }
if err := webcontext.Validator().Struct(req); err != nil {
if ne, ok := weberror.NewValidationError(ctx, err); ok {
data["validationErrors"] = ne.(*weberror.Error)
return nil
} else {
return err
}
}
u, err := user.ResetConfirm(ctx, h.MasterDB, *req, h.SecretKey, ctxValues.Now) u, err := user.ResetConfirm(ctx, h.MasterDB, *req, h.SecretKey, ctxValues.Now)
if err != nil { if err != nil {
switch errors.Cause(err) { switch errors.Cause(err) {
@ -318,7 +283,7 @@ func (h *User) ResetConfirm(ctx context.Context, w http.ResponseWriter, r *http.
} }
if err := f(); err != nil { if err := f(); err != nil {
return web.RenderError(ctx, w, r, err, h.Renderer, tmplLayoutBase, tmplContentErrorGeneric, web.MIMETextHTMLCharsetUTF8) return web.RenderError(ctx, w, r, err, h.Renderer, TmplLayoutBase, TmplContentErrorGeneric, web.MIMETextHTMLCharsetUTF8)
} }
data["form"] = req data["form"] = req
@ -327,5 +292,194 @@ func (h *User) ResetConfirm(ctx context.Context, w http.ResponseWriter, r *http.
data["validationDefaults"] = verr.(*weberror.Error) data["validationDefaults"] = verr.(*weberror.Error)
} }
return h.Renderer.Render(ctx, w, r, tmplLayoutBase, "user-reset-confirm.gohtml", web.MIMETextHTMLCharsetUTF8, http.StatusOK, data) return h.Renderer.Render(ctx, w, r, TmplLayoutBase, "user-reset-confirm.gohtml", web.MIMETextHTMLCharsetUTF8, http.StatusOK, data)
}
// View handles displaying the current user profile.
func (h *User) View(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error {
data := make(map[string]interface{})
f := func() error {
claims, err := auth.ClaimsFromContext(ctx)
if err != nil {
return err
}
usr, err := user.Read(ctx, claims, h.MasterDB, claims.Subject, false)
if err != nil {
return err
}
data["user"] = usr.Response(ctx)
usrAccs, err := user_account.FindByUserID(ctx, claims, h.MasterDB, claims.Subject, false)
if err != nil {
return err
}
for _, usrAcc := range usrAccs {
if usrAcc.AccountID == claims.Audience {
data["userAccount"] = usrAcc.Response(ctx)
break
}
}
return nil
}
if err := f(); err != nil {
return web.RenderError(ctx, w, r, err, h.Renderer, TmplLayoutBase, TmplContentErrorGeneric, web.MIMETextHTMLCharsetUTF8)
}
return h.Renderer.Render(ctx, w, r, TmplLayoutBase, "user-view.gohtml", web.MIMETextHTMLCharsetUTF8, http.StatusOK, data)
}
// Update handles allowing the current user to update their profile.
func (h *User) Update(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error {
ctxValues, err := webcontext.ContextValues(ctx)
if err != nil {
return err
}
//
req := new(user.UserUpdateRequest)
data := make(map[string]interface{})
f := func() (bool, error) {
claims, err := auth.ClaimsFromContext(ctx)
if err != nil {
return false, err
}
if r.Method == http.MethodPost {
err := r.ParseForm()
if err != nil {
return false, err
}
decoder := schema.NewDecoder()
decoder.IgnoreUnknownKeys(true)
if err := decoder.Decode(req, r.PostForm); err != nil {
return false, err
}
req.ID = claims.Subject
err = user.Update(ctx, claims, h.MasterDB, *req, ctxValues.Now)
if err != nil {
switch errors.Cause(err) {
default:
if verr, ok := weberror.NewValidationError(ctx, err); ok {
data["validationErrors"] = verr.(*weberror.Error)
return false, nil
} else {
return false, err
}
}
}
if r.PostForm.Get("Password") != "" {
pwdReq := new(user.UserUpdatePasswordRequest)
if err := decoder.Decode(pwdReq, r.PostForm); err != nil {
return false, err
}
pwdReq.ID = claims.Subject
err = user.UpdatePassword(ctx, claims, h.MasterDB, *pwdReq, ctxValues.Now)
if err != nil {
switch errors.Cause(err) {
default:
if verr, ok := weberror.NewValidationError(ctx, err); ok {
data["validationErrors"] = verr.(*weberror.Error)
return false, nil
} else {
return false, err
}
}
}
}
// Display a success message to the user.
webcontext.SessionFlashSuccess(ctx,
"Profile Updated",
"User profile successfully updated.")
err = webcontext.ContextSession(ctx).Save(r, w)
if err != nil {
return false, err
}
http.Redirect(w, r, "/user", http.StatusFound)
return true, nil
}
usr, err := user.Read(ctx, claims, h.MasterDB, claims.Subject, false)
if err != nil {
return false, err
}
if req.ID == "" {
req.FirstName = &usr.FirstName
req.LastName = &usr.LastName
req.Email = &usr.Email
req.Timezone = &usr.Timezone
}
data["user"] = usr.Response(ctx)
data["timezones"], err = geonames.ListTimezones(ctx, h.MasterDB)
if err != nil {
return false, err
}
return false, nil
}
end, err := f()
if err != nil {
return web.RenderError(ctx, w, r, err, h.Renderer, TmplLayoutBase, TmplContentErrorGeneric, web.MIMETextHTMLCharsetUTF8)
} else if end {
return nil
}
data["form"] = req
if verr, ok := weberror.NewValidationError(ctx, webcontext.Validator().Struct(user.UserUpdateRequest{})); ok {
data["userValidationDefaults"] = verr.(*weberror.Error)
}
if verr, ok := weberror.NewValidationError(ctx, webcontext.Validator().Struct(user.UserUpdatePasswordRequest{})); ok {
data["passwordValidationDefaults"] = verr.(*weberror.Error)
}
return h.Renderer.Render(ctx, w, r, TmplLayoutBase, "user-update.gohtml", web.MIMETextHTMLCharsetUTF8, http.StatusOK, data)
}
// Account handles displaying the Account for the current user.
func (h *User) Account(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error {
data := make(map[string]interface{})
f := func() error {
claims, err := auth.ClaimsFromContext(ctx)
if err != nil {
return err
}
acc, err := account.Read(ctx, claims, h.MasterDB, claims.Audience, false)
if err != nil {
return err
}
data["account"] = acc.Response(ctx)
return nil
}
if err := f(); err != nil {
return web.RenderError(ctx, w, r, err, h.Renderer, TmplLayoutBase, TmplContentErrorGeneric, web.MIMETextHTMLCharsetUTF8)
}
return h.Renderer.Render(ctx, w, r, TmplLayoutBase, "user-account.gohtml", web.MIMETextHTMLCharsetUTF8, http.StatusOK, data)
} }

View File

@ -6,10 +6,6 @@ import (
"encoding/json" "encoding/json"
"expvar" "expvar"
"fmt" "fmt"
"geeks-accelerator/oss/saas-starter-kit/internal/account"
"geeks-accelerator/oss/saas-starter-kit/internal/platform/notify"
"geeks-accelerator/oss/saas-starter-kit/internal/user"
"gopkg.in/gomail.v2"
"html/template" "html/template"
"log" "log"
"net" "net"
@ -25,16 +21,19 @@ import (
"time" "time"
"geeks-accelerator/oss/saas-starter-kit/cmd/web-app/handlers" "geeks-accelerator/oss/saas-starter-kit/cmd/web-app/handlers"
"geeks-accelerator/oss/saas-starter-kit/internal/account"
"geeks-accelerator/oss/saas-starter-kit/internal/mid" "geeks-accelerator/oss/saas-starter-kit/internal/mid"
"geeks-accelerator/oss/saas-starter-kit/internal/platform/auth" "geeks-accelerator/oss/saas-starter-kit/internal/platform/auth"
"geeks-accelerator/oss/saas-starter-kit/internal/platform/devops" "geeks-accelerator/oss/saas-starter-kit/internal/platform/devops"
"geeks-accelerator/oss/saas-starter-kit/internal/platform/flag" "geeks-accelerator/oss/saas-starter-kit/internal/platform/flag"
img_resize "geeks-accelerator/oss/saas-starter-kit/internal/platform/img-resize" img_resize "geeks-accelerator/oss/saas-starter-kit/internal/platform/img-resize"
"geeks-accelerator/oss/saas-starter-kit/internal/platform/notify"
"geeks-accelerator/oss/saas-starter-kit/internal/platform/web" "geeks-accelerator/oss/saas-starter-kit/internal/platform/web"
template_renderer "geeks-accelerator/oss/saas-starter-kit/internal/platform/web/template-renderer" template_renderer "geeks-accelerator/oss/saas-starter-kit/internal/platform/web/template-renderer"
"geeks-accelerator/oss/saas-starter-kit/internal/platform/web/webcontext" "geeks-accelerator/oss/saas-starter-kit/internal/platform/web/webcontext"
"geeks-accelerator/oss/saas-starter-kit/internal/platform/web/weberror" "geeks-accelerator/oss/saas-starter-kit/internal/platform/web/weberror"
project_routes "geeks-accelerator/oss/saas-starter-kit/internal/project-routes" project_routes "geeks-accelerator/oss/saas-starter-kit/internal/project-routes"
"geeks-accelerator/oss/saas-starter-kit/internal/user"
"github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials" "github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/ec2metadata" "github.com/aws/aws-sdk-go/aws/ec2metadata"
@ -52,6 +51,7 @@ import (
redistrace "gopkg.in/DataDog/dd-trace-go.v1/contrib/go-redis/redis" redistrace "gopkg.in/DataDog/dd-trace-go.v1/contrib/go-redis/redis"
sqlxtrace "gopkg.in/DataDog/dd-trace-go.v1/contrib/jmoiron/sqlx" sqlxtrace "gopkg.in/DataDog/dd-trace-go.v1/contrib/jmoiron/sqlx"
"gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer" "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer"
"gopkg.in/gomail.v2"
) )
// build is the git version of this program. It is set using build flags in the makefile. // build is the git version of this program. It is set using build flags in the makefile.
@ -676,22 +676,75 @@ func main() {
return fmt.Sprintf("%+v", err) return fmt.Sprintf("%+v", err)
}, },
// Returns the current user from the session.
// @TODO: Need to add logging for the errors.
"ContextUser": func(ctx context.Context) *user.UserResponse { "ContextUser": func(ctx context.Context) *user.UserResponse {
sess := webcontext.ContextSession(ctx) sess := webcontext.ContextSession(ctx)
v, _ := webcontext.SessionUser(sess)
if u, ok := v.(*user.UserResponse); ok { cacheKey := "ContextUser" + sess.ID
u := &user.UserResponse{}
if err := redisClient.Get(cacheKey).Scan(u); err != nil && err != redis.Nil {
return nil
}
// Return if found in cache.
if u != nil && u.ID != "" {
return u return u
} }
return nil
claims, err := auth.ClaimsFromContext(ctx)
if err != nil {
return nil
}
usr, err := user.Read(ctx, auth.Claims{}, masterDb, claims.Subject, false)
if err != nil {
return nil
}
u = usr.Response(ctx)
err = redisClient.Set(cacheKey, u, time.Hour).Err()
if err != nil {
return nil
}
return u
}, },
// Returns the current account from the session.
// @TODO: Need to add logging for the errors.
"ContextAccount": func(ctx context.Context) *account.AccountResponse { "ContextAccount": func(ctx context.Context) *account.AccountResponse {
sess := webcontext.ContextSession(ctx) sess := webcontext.ContextSession(ctx)
v, _ := webcontext.SessionAccount(sess)
if acc, ok := v.(*account.AccountResponse); ok { cacheKey := "ContextAccount" + sess.ID
return acc
a := &account.AccountResponse{}
if err := redisClient.Get(cacheKey).Scan(a); err != nil && err != redis.Nil {
return nil
} }
return nil
// Return if found in cache.
if a != nil && a.ID != "" {
return a
}
claims, err := auth.ClaimsFromContext(ctx)
if err != nil {
return nil
}
acc, err := account.Read(ctx, auth.Claims{}, masterDb, claims.Audience, false)
if err != nil {
return nil
}
a = acc.Response(ctx)
err = redisClient.Set(cacheKey, a, time.Hour).Err()
if err != nil {
return nil
}
return a
}, },
} }
@ -766,15 +819,22 @@ func main() {
// Custom error handler to support rendering user friendly error page for improved web experience. // Custom error handler to support rendering user friendly error page for improved web experience.
eh := func(ctx context.Context, w http.ResponseWriter, r *http.Request, renderer web.Renderer, statusCode int, er error) error { eh := func(ctx context.Context, w http.ResponseWriter, r *http.Request, renderer web.Renderer, statusCode int, er error) error {
data := map[string]interface{}{} if statusCode == 0 {
if webErr, ok := er.(*weberror.Error); ok {
statusCode = webErr.Status
}
}
return renderer.Render(ctx, w, r, switch statusCode {
"base.tmpl", // base layout file to be used for rendering of errors case http.StatusUnauthorized:
"error.tmpl", // generic format for errors, could select based on status code // Handle expired sessions that are returned from the auth middleware.
web.MIMETextHTMLCharsetUTF8, if strings.Contains(errors.Cause(er).Error(), "token is expired") {
http.StatusOK, http.Redirect(w, r, "/user/login", http.StatusFound)
data, return nil
) }
}
return web.RenderError(ctx, w, r, er, renderer, handlers.TmplLayoutBase, handlers.TmplContentErrorGeneric, web.MIMETextHTMLCharsetUTF8)
} }
// Enable template renderer to reload and parse template files when generating a response of dev // Enable template renderer to reload and parse template files when generating a response of dev

View File

@ -1,13 +1,13 @@
$(document).ready(function() { $(document).ready(function() {
hideDuplicateValidationFieldErrors(); hideDuplicateValidationFieldErrors();
}); });
// Prevent duplicate validation messages. When the validation error is displayed inline // Prevent duplicate validation messages. When the validation error is displayed inline
// when the form value, don't display the form error message at the top of the page. // when the form value, don't display the form error message at the top of the page.
function hideDuplicateValidationFieldErrors() { function hideDuplicateValidationFieldErrors() {
var fieldErrors = 0;
$(document).find('#page-content form').find('input, select, textarea').each(function(index){ $(document).find('#page-content form').find('input, select, textarea').each(function(index){
var fname = $(this).attr('name'); var fname = $(this).attr('name');
if (fname === undefined) { if (fname === undefined) {
@ -19,22 +19,29 @@ function hideDuplicateValidationFieldErrors() {
vnode = $(this).parent().parent().find('div.invalid-feedback'); vnode = $(this).parent().parent().find('div.invalid-feedback');
} }
var feedback_count = 0;
var formField = $(vnode).attr('data-field'); var formField = $(vnode).attr('data-field');
var foundMatch = false;
$(document).find('div.validation-error').find('li').each(function(){ $(document).find('div.validation-error').find('li').each(function(){
if ($(this).attr('data-form-field') == formField) { if ($(this).attr('data-form-field') == formField) {
foundMatch = true ;
if ($(vnode).is(":visible") || $(vnode).css('display') === 'none') { if ($(vnode).is(":visible") || $(vnode).css('display') === 'none') {
$(this).hide(); $(this).hide();
feedback_count++; fieldErrors++;
} else { } else {
console.log('form validation feedback for '+fname+' is not visable, display main.'); console.log('form validation feedback for '+fname+' is not visable, display main.');
} }
} }
}); });
if (feedback_count == 0) { // If there was no matching inline validation message, then still need to display the error.
$(document).find('div.validation-error').find('ul').hide(); if (!foundMatch) {
fieldErrors++;
} }
}); });
if (fieldErrors == 0) {
$(document).find('div.validation-error').find('ul').hide();
}
} }

View File

@ -0,0 +1,303 @@
{{define "title"}}Update Account{{end}}
{{define "style"}}
{{end}}
{{define "content"}}
<form class="user" method="post" novalidate>
<div class="row">
<div class="col-md-6">
<h3>Account Details</h3>
<div class="spacer-15"></div>
<div class="form-group row">
<div class="col-sm-6 mb-3 mb-sm-0">
<input type="text" class="form-control form-control-user {{ ValidationFieldClass $.validationErrors "Name" }}" name="Name" value="{{ $.form.Name }}" placeholder="Company Name" required>
{{template "invalid-feedback" dict "validationDefaults" $.validationDefaults "validationErrors" $.validationErrors "fieldName" "Name" }}
</div>
</div>
<div class="form-group row">
<div class="col-sm-6 mb-3 mb-sm-0">
<input type="text" class="form-control form-control-user {{ ValidationFieldClass $.validationErrors "Address1" }}" name="Address1" value="{{ $.form.Address1 }}" placeholder="Address Line 1" required>
{{template "invalid-feedback" dict "validationDefaults" $.validationDefaults "validationErrors" $.validationErrors "fieldName" "Address1" }}
</div>
<div class="col-sm-6">
<input type="text" class="form-control form-control-user {{ ValidationFieldClass $.validationErrors "Address2" }}" name="Address2" value="{{ $.form.Address2 }}" placeholder="Address Line 2">
{{template "invalid-feedback" dict "validationDefaults" $.validationDefaults "validationErrors" $.validationErrors "fieldName" "Address2" }}
</div>
</div>
<div class="form-group row">
<div class="col-sm-6 mb-3 mb-sm-0">
<div class="form-control-select-wrapper">
<select class="form-control form-control-select-box {{ ValidationFieldClass $.validationErrors "Country" }}" id="selectAccountCountry" name="Country" placeholder="Country" required>
{{ range $i := $.countries }}
{{ $hasGeonames := false }}
{{ range $c := $.geonameCountries }}
{{ if eq $c $i.Code }}{{ $hasGeonames = true }}{{ end }}
{{ end }}
<option value="{{ $i.Code }}" data-geonames="{{ if $hasGeonames }}1{{ else }}0{{ end }}" {{ if CmpString $.form.Country $i.Code }}selected="selected"{{ end }}>{{ $i.Name }}</option>
{{ end }}
</select>
{{template "invalid-feedback" dict "validationDefaults" $.validationDefaults "validationErrors" $.validationErrors "fieldName" "Country" }}
</div>
</div>
</div>
<div class="form-group row">
<div class="col-sm-6 mb-3 mb-sm-0">
<div id="divAccountZipcode"></div>
{{template "invalid-feedback" dict "validationDefaults" $.validationDefaults "validationErrors" $.validationErrors "fieldName" "Zipcode" }}
</div>
<div class="col-sm-6 mb-3 mb-sm-0">
<div id="divAccountRegion"></div>
{{template "invalid-feedback" dict "validationDefaults" $.validationDefaults "validationErrors" $.validationErrors "fieldName" "Region" }}
</div>
</div>
<div class="form-group row mb-4">
<div class="col-sm-6 mb-3 mb-sm-0">
<input type="text" class="form-control form-control-user {{ ValidationFieldClass $.validationErrors "Account.City" }}" id="inputAccountCity" name="City" value="{{ $.form.City }}" placeholder="City" required>
{{template "invalid-feedback" dict "validationDefaults" $.validationDefaults "validationErrors" $.validationErrors "fieldName" "City" }}
</div>
</div>
</div>
<div class="col-md-6">
<h3>Account Settings</h3>
<div class="spacer-15"></div>
<div class="form-group">
<label for="inputTimezone">Timezone</label>
<select class="form-control {{ ValidationFieldClass $.validationErrors "Timezone" }}" name="Timezone">
<option value="">Not set</option>
{{ range $idx, $t := .timezones }}
<option value="{{ $t }}" {{ if CmpString $t $.form.Timezone }}selected="selected"{{ end }}>{{ $t }}</option>
{{ end }}
</select>
{{template "invalid-feedback" dict "validationDefaults" $.validationDefaults "validationErrors" $.validationErrors "fieldName" "Timezone" }}
</div>
<div class="form-group">
<label for="inputDatetimeFormat">Datetime Format</label>
<select style="display: none;" id="selectDatetimeFormat">
<option>2006-01-02 at 3:04PM MST</option>
<option>Mon Jan _2 15:04:05 2006</option>
<option>Mon Jan _2 15:04:05 MST 2006</option>
<option>Mon Jan 02 15:04:05 -0700 2006</option>
<option>02 Jan 06 15:04 MST</option>
<option>02 Jan 06 15:04 -0700</option>
<option>Monday, 02-Jan-06 15:04:05 MST</option>
<option>Mon, 02 Jan 2006 15:04:05 MST</option>
<option>Mon, 02 Jan 2006 15:04:05 -0700</option>
<option>Jan _2 15:04:05</option>
<option value="custom">Custom</option>
</select>
<input type="text" class="form-control" id="inputDatetimeFormat" placeholder="enter datetime format" name="PreferenceDatetimeFormat" value="{{ .form.PreferenceDatetimeFormat }}">
<label class="form-check-label" for="inputDatetimeFormat"><small>Current Datetime {{ .exampleDisplayTime.Local }}</small></label>
</div>
<div class="form-group">
<label for="inputDateFormat">Date Format</label>
<select style="display: none;" id="selectDateFormat">
<option>2006-01-02</option>
<option>Mon Jan _2 2006</option>
<option>Mon Jan 02 2006</option>
<option>02 Jan 06</option>
<option>02 Jan 06</option>
<option>Monday, 02-Jan-06</option>
<option>Mon, 02 Jan 2006</option>
<option>Mon, 02 Jan 2006</option>
<option>Jan _2</option>
<option value="custom">Custom</option>
</select>
<input type="text" class="form-control" id="inputDateFormat" placeholder="enter date format" name="PreferenceDateFormat" value="{{ .form.PreferenceDateFormat }}">
<label class="form-check-label" for="inputDateFormat"><small>Current Date {{ .exampleDisplayTime.LocalDate }}</small></label>
</div>
<div class="form-group">
<label for="inputTimeFormat">Time Format</label>
<select style="display: none;" id="selectTimeFormat">
<option>3:04PM</option>
<option>3:04PM MST</option>
<option>3:04PM -0700</option>
<option>15:04:05</option>
<option>15:04:05 MST</option>
<option>15:04:05 -0700</option>
<option value="custom">Custom</option>
</select>
<input type="text" class="form-control" id="inputTimeFormat" placeholder="enter time format" name="PreferenceTimeFormat" value="{{ .form.PreferenceTimeFormat }}">
<label class="form-check-label" for="inputDatetimeFormat"><small>Current Time {{ .exampleDisplayTime.LocalTime }}</small></label>
</div>
</div>
</div>
<div class="spacer-30"></div>
<div class="row">
<div class="col">
<input id="btnSubmit" type="submit" name="action" value="Save" class="btn btn-primary"/>
</div>
</div>
</form>
{{end}}
{{define "js"}}
<script src="https://cdn.jsdelivr.net/gh/xcash/bootstrap-autocomplete@v2.2.2/dist/latest/bootstrap-autocomplete.min.js"></script>
<script>
$(document).ready(function() {
var selectInit = false;
$('#selectAccountCountry').on('change', function () {
// When a country has data-geonames, then we can perform autocomplete on zipcode and
// populate a list of valid regions.
if ($(this).find('option:selected').attr('data-geonames') == 1) {
// Replace the existing region with an empty dropdown.
$('#divAccountRegion').html('<div class="form-control-select-wrapper"><select class="form-control form-control-select-box {{ ValidationFieldClass $.validationErrors "Region" }}" id="inputAccountRegion" name="Region" placeholder="Region" required></select></div>');
// Query the API for a list of regions for the selected
// country and populate the region dropdown.
$.ajax({
type: 'GET',
contentType: 'application/json',
url: '/geo/regions/autocomplete',
data: {country_code: $(this).val(), select: true},
dataType: 'json'
}).done(function (res) {
if (res !== undefined && res !== null) {
for (var c in res) {
var optSelected = '';
if (res[c].value == '{{ $.form.Region }}') {
optSelected = ' selected="selected"';
}
$('#inputAccountRegion').append('<option value="'+res[c].value+'"'+optSelected+'>'+res[c].text+'</option>');
}
}
});
// Replace the existing zipcode text input with a new one that will supports autocomplete.
$('#divAccountZipcode').html('<input class="form-control form-control-user {{ ValidationFieldClass $.validationErrors "Account.Zipcode" }}" id="inputAccountZipcode" name="Zipcode" value="{{ $.form.Zipcode }}" placeholder="Zipcode" required>');
$('#inputAccountZipcode').autoComplete({
minLength: 2,
events: {
search: function (qry, callback) {
$.ajax({
type: 'GET',
contentType: 'application/json',
url: '/geo/postal_codes/autocomplete',
data: {query: qry, country_code: $('#selectAccountCountry').val()},
dataType: 'json'
}).done(function (res) {
callback(res)
});
}
}
});
// When the value of zipcode changes, try to find an exact match for the zipcode and
// can therefore set the correct region and city.
$('#inputAccountZipcode').on('change', function() {
$.ajax({
type: 'GET',
contentType: 'application/json',
url: '/geo/geonames/postal_code/'+$(this).val(),
data: {country_code: $('#selectAccountCountry').val()},
dataType: 'json'
}).done(function (res) {
if (res !== undefined && res !== null && res.PostalCode !== undefined) {
$('#inputAccountCity').val(res.PlaceName);
$('#inputAccountRegion').val(res.StateCode);
}
});
});
} else {
// Replace the existing zipcode input with no autocomplete.
$('#divAccountZipcode').html('<input type="text" class="form-control form-control-user {{ ValidationFieldClass $.validationErrors "Zipcode" }}" id="inputAccountZipcode" name="Zipcode" value="{{ $.form.Zipcode }}" placeholder="Zipcode" required>');
// Replace the existing region select with a text input.
$('#divAccountRegion').html('<input type="text" class="form-control form-control-user {{ ValidationFieldClass $.validationErrors "Region" }}" id="inputAccountRegion" name="Region" value="{{ $.form.Region }}" placeholder="Region" required>');
}
// Init the form defaults based on the current settings.
if (!selectInit) {
hideDuplicateValidationFieldErrors();
selectInit = true
}
}).change();
var selectedDatetimeFormat = false;
$('#selectDatetimeFormat > option').each(function() {
var curValue = $('#inputDatetimeFormat').val();
if (this.text == curValue || this.value == curValue) {
$(this).attr('selected','selected');
selectedDatetimeFormat = true;
$('#selectDatetimeFormat').show();
$('#inputDatetimeFormat').hide();
}
});
if (!selectedDatetimeFormat) {
$('#selectDatetimeFormat').val('custom');
$('#selectDatetimeFormat').show();
$('#inputDatetimeFormat').show();
}
$('#selectDatetimeFormat').on('change', function() {
if ($(this).val() == 'custom') {
$('#inputDatetimeFormat').show();
} else {
$('#inputDatetimeFormat').hide();
$('#inputDatetimeFormat').val($(this).val());
}
})
var selectedDateFormat = false;
$('#selectDateFormat > option').each(function() {
var curValue = $('#inputDateFormat').val();
if (this.text == curValue || this.value == curValue) {
$(this).attr('selected','selected');
selectedDateFormat = true;
$('#selectDateFormat').show();
$('#inputDateFormat').hide();
}
});
if (!selectedDateFormat) {
$('#selectDateFormat').val('custom');
$('#selectDateFormat').show();
$('#inputDateFormat').show();
}
$('#selectDateFormat').on('change', function() {
if ($(this).val() == 'custom') {
$('#inputDateFormat').show();
} else {
$('#inputDateFormat').hide();
$('#inputDateFormat').val($(this).val());
}
})
var selectedTimeFormat = false;
$('#selectTimeFormat > option').each(function() {
var curValue = $('#inputTimeFormat').val();
if (this.text == curValue || this.value == curValue) {
$(this).attr('selected','selected');
selectedTimeFormat = true;
$('#selectTimeFormat').show();
$('#inputTimeFormat').hide();
}
});
if (!selectedTimeFormat) {
$('#selectTimeFormat').val('custom');
$('#selectTimeFormat').show();
$('#inputTimeFormat').show();
}
$('#selectTimeFormat').on('change', function() {
if ($(this).val() == 'custom') {
$('#inputTimeFormat').show();
} else {
$('#inputTimeFormat').hide();
$('#inputTimeFormat').val($(this).val());
}
})
});
</script>
{{end}}

View File

@ -0,0 +1,53 @@
{{define "title"}}Account Settings{{end}}
{{define "style"}}
{{end}}
{{define "content"}}
<div class="row">
<div class="col-auto">
<a href="/account/update" class="btn btn-outline-success"><i class="fal fa-edit"></i>Edit Details</a>
</div>
</div>
<div class="spacer-30"></div>
<div class="row">
<div class="col-md-6">
<p>
<small>Name</small><br/>
<b>{{ .account.Name }}</b>
</p>
{{ if .account.City }}
<p>
<small>Address</small><br/>
{{if .account.Address1 }}
<b>{{ .account.Address1 }}{{ if .account.Address2 }},{{ .account.Address2 }}{{ end }}</b>
<br/>
{{end}}
<b>{{ .account.City }}, {{ .account.Region }}, {{ .account.Zipcode }}</b>
</p>
{{end}}
<p>
<small>Timezone</small><br/>
<b>{{.account.Timezone }}</b>
</p>
</div>
<div class="col-md-6">
<p>
<small>Status</small><br/>
<b>
{{ if eq .account.Status.Value "active" }}
<span class="text-green"><i class="fas fa-circle"></i>{{ .account.Status.Title }}</span>
{{else}}
<span class="text-orange"><i class="far fa-circle"></i>{{.account.Status.Title }}</span>
{{end}}
</b>
</p>
<p>
<small>ID</small><br/>
<b>{{ .account.ID }}</b>
</p>
</div>
</div>
{{end}}
{{define "js"}}
{{end}}

View File

@ -0,0 +1,46 @@
{{define "title"}}Account{{end}}
{{define "style"}}
{{end}}
{{define "content"}}
<div class="row">
<div class="col-md-6">
<div class="card">
<div class="card-header card-header-white">
<div class="row">
<div class="col">
<h4 class="card-title">Account Details</h4>
</div>
</div>
</div>
<div class="card-body">
<div class="row">
<div class="col">
<p>
<small>Name</small><br/>
<b>{{ .account.Name }}</b>
</p>
{{ if .account.Address1 }}
<p>
<small>Address</small><br/>
<b>{{ .account.Address1 }}{{ if .account.Address2 }},{{ .account.Address2 }}{{ end }}</b>
<br/>
<b>{{ .account.City }}, {{ .account.Region }}, {{ .account.Zipcode }}</b>
</p>
{{end}}
<p>
<small>Timezone</small><br/>
<b>{{.account.Timezone }}</b>
</p>
</div>
</div>
</div>
</div>
</div>
</div>
{{end}}
{{define "js"}}
{{end}}

View File

@ -0,0 +1,83 @@
{{define "title"}}Update Profile{{end}}
{{define "style"}}
{{end}}
{{define "content"}}
<form class="user" method="post" novalidate>
<div class="row">
<div class="col-md-6">
<div class="form-group">
<label for="inputFirstName">First Name</label>
<input type="text" class="form-control {{ ValidationFieldClass $.validationErrors "FirstName" }}" placeholder="enter first name" name="FirstName" value="{{ .form.FirstName }}" required>
{{template "invalid-feedback" dict "validationDefaults" $.userValidationDefaults "validationErrors" $.validationErrors "fieldName" "FirstName" }}
</div>
<div class="form-group">
<label for="inputLastName">Last Name</label>
<input type="text" class="form-control {{ ValidationFieldClass $.validationErrors "LastName" }}" placeholder="enter last name" name="LastName" value="{{ .form.LastName }}" required>
{{template "invalid-feedback" dict "validationDefaults" $.userValidationDefaults "validationErrors" $.validationErrors "fieldName" "LastName" }}
</div>
<div class="form-group">
<label for="inputEmail">Email</label>
<input type="text" class="form-control {{ ValidationFieldClass $.validationErrors "Email" }}" placeholder="enter email" name="Email" value="{{ .form.Email }}" required>
{{template "invalid-feedback" dict "validationDefaults" $.userValidationDefaults "validationErrors" $.validationErrors "fieldName" "Email" }}
</div>
<div class="form-group">
<label for="inputTimezone">Timezone</label>
<select class="form-control {{ ValidationFieldClass $.validationErrors "Timezone" }}" name="Timezone">
<option value="">Not set</option>
{{ range $idx, $t := .timezones }}
<option value="{{ $t }}" {{ if CmpString $t $.form.Timezone }}selected="selected"{{ end }}>{{ $t }}</option>
{{ end }}
</select>
{{template "invalid-feedback" dict "validationDefaults" $.validationDefaults "validationErrors" $.validationErrors "fieldName" "Timezone" }}
</div>
</div>
</div>
<div class="row">
<div class="col-md-6">
<h4 class="card-title">Change Password</h4>
<p><small><b>Optional</b>. You can change your password by specifying a new one below. Otherwise leave the fields empty.</small></p>
<div class="form-group">
<label for="inputPassword">Password</label>
<input type="password" class="form-control" id="inputPassword" placeholder="" name="Password" value="">
<span class="help-block "><small><a a href="javascript:void(0)" id="btnGeneratePassword"><i class="fal fa-random"></i>Generate random password </a></small></span>
{{template "invalid-feedback" dict "validationDefaults" $.passwordValidationDefaults "validationErrors" $.validationErrors "fieldName" "Password" }}
</div>
<div class="form-group">
<label for="inputPasswordConfirm">Confirm Password</label>
<input type="password" class="form-control" id="inputPasswordConfirm" placeholder="" name="PasswordConfirm" value="">
{{template "invalid-feedback" dict "validationDefaults" $.passwordValidationDefaults "validationErrors" $.validationErrors "fieldName" "PasswordConfirm" }}
</div>
</div>
</div>
<div class="spacer-30"></div>
<div class="row">
<div class="col">
<input id="btnSubmit" type="submit" name="action" value="Save" class="btn btn-primary"/>
</div>
</div>
</form>
{{end}}
{{define "js"}}
<script>
function randomPassword(length) {
var chars = "abcdefghijklmnopqrstuvwxyz!@#&*()-+<>ABCDEFGHIJKLMNOP1234567890";
var pass = "";
for (var x = 0; x < length; x++) {
var i = Math.floor(Math.random() * chars.length);
pass += chars.charAt(i);
}
return pass;
}
$(document).ready(function(){
$("#btnGeneratePassword").on("click", function() {
pwd = randomPassword(12);
$("#inputPassword").attr('type', 'text').val(pwd)
$("#inputPasswordConfirm").attr('type', 'text').val(pwd)
return false;
});
});
</script>
{{end}}

View File

@ -0,0 +1,85 @@
{{define "title"}}Profile{{end}}
{{define "style"}}
{{end}}
{{define "content"}}
<div class="row">
<div class="col">
<div class="row">
<div class="col-auto">
<img src="{{ .user.Gravatar.Medium }}" alt="gravatar image" class="rounded">
</div>
<div class="col">
<h4>Name</h4>
<p class="font-14">
{{ .user.Name }}
</p>
</div>
</div>
<div class="spacer-10"></div>
<p class="font-10"><a href="https://gravatar.com" target="_blank">Update Avatar</a></p>
</div>
<div class="col-auto">
<a href="/user/update" class="btn btn-outline-success"><i class="fal fa-edit"></i>Edit Details</a>
</div>
</div>
<div class="spacer-30"></div>
<div class="row">
<div class="col-md-6">
<p>
<small>Name</small><br/>
<b>{{ .user.Name }}</b>
</p>
<p>
<small>Email</small><br/>
<b>{{ .user.Email }}</b>
</p>
{{if .user.Timezone }}
<p>
<small>Timezone</small><br/>
<b>{{.user.Timezone }}</b>
</p>
{{end}}
<div class="spacer-15"></div>
</div>
<div class="col-md-6">
<p>
<small>Role</small><br/>
{{ if .userAccount }}
<b>
{{ range $r := .userAccount.Roles }}
{{ if eq $r "admin" }}
<span class="text-pink-dark"><i class="far fa-user-astronaut"></i>{{ $r }}</span>
{{else}}
<span class="text-purple-dark"><i class="fal fa-user"></i>{{ $r }}</span>
{{end}}
{{ end }}
</b>
{{ end }}
</p>
<p>
<small>Status</small><br/>
{{ if .userAccount }}
<b>
{{ if eq .userAccount.Status.Value "active" }}
<span class="text-green"><i class="fas fa-circle"></i>{{ .userAccount.Status.Title }}</span>
{{ else if eq .userAccount.Status.Value "invited" }}
<span class="text-blue"><i class="fas fa-unicorn"></i>{{ .userAccount.Status.Title }}</span>
{{else}}
<span class="text-orange"><i class="far fa-circle"></i>{{.userAccount.Status.Title }}</span>
{{end}}
</b>
{{ end }}
</p>
<p>
<small>ID</small><br/>
<b>{{ .user.ID }}</b>
</p>
</div>
</div>
{{end}}
{{define "js"}}
{{end}}

View File

@ -10,7 +10,7 @@
{{ if HasAuth $._Ctx }} {{ if HasAuth $._Ctx }}
<!-- Topbar Search --> <!-- Topbar Search -->
<form class="d-none d-sm-inline-block form-inline mr-auto ml-md-3 my-2 my-md-0 mw-100 navbar-search"> <!--- form class="d-none d-sm-inline-block form-inline mr-auto ml-md-3 my-2 my-md-0 mw-100 navbar-search">
<div class="input-group"> <div class="input-group">
<input type="text" class="form-control bg-light border-0 small" placeholder="Search for..." aria-label="Search" aria-describedby="basic-addon2"> <input type="text" class="form-control bg-light border-0 small" placeholder="Search for..." aria-label="Search" aria-describedby="basic-addon2">
<div class="input-group-append"> <div class="input-group-append">
@ -19,7 +19,7 @@
</button> </button>
</div> </div>
</div> </div>
</form> </form -->
<!-- Topbar Navbar --> <!-- Topbar Navbar -->
<ul class="navbar-nav ml-auto"> <ul class="navbar-nav ml-auto">
@ -161,7 +161,7 @@
<img class="img-profile rounded-circle" src="{{ $user.Gravatar.Medium }}"> <img class="img-profile rounded-circle" src="{{ $user.Gravatar.Medium }}">
{{ else }} {{ else }}
<span class="mr-2 d-none d-lg-inline text-gray-600 small">Space Cadet</span> <span class="mr-2 d-none d-lg-inline text-gray-600 small">Space Cadet</span>
<img class="img-profile rounded-circle" src="src="{{ SiteAssetUrl "/assets/images/user-default.jpg"}}"> <img class="img-profile rounded-circle" src="{{ SiteAssetUrl "/assets/images/user-default.jpg" }}">
{{ end }} {{ end }}
</a> </a>
@ -173,7 +173,7 @@
</a> </a>
{{ if HasRole $._Ctx "admin" }} {{ if HasRole $._Ctx "admin" }}
<a class="dropdown-item" href="/admin/account"> <a class="dropdown-item" href="/account">
<i class="fas fa-cogs fa-sm fa-fw mr-2 text-gray-400"></i> <i class="fas fa-cogs fa-sm fa-fw mr-2 text-gray-400"></i>
Account Settings Account Settings
</a> </a>
@ -187,7 +187,7 @@
Invite User Invite User
</a> </a>
{{ else }} {{ else }}
<a class="dropdown-item" href="/account"> <a class="dropdown-item" href="/user/account">
<i class="fas fa-cogs fa-sm fa-fw mr-2 text-gray-400"></i> <i class="fas fa-cogs fa-sm fa-fw mr-2 text-gray-400"></i>
Account Account
</a> </a>

View File

@ -3,10 +3,10 @@ package account
import ( import (
"context" "context"
"database/sql" "database/sql"
"geeks-accelerator/oss/saas-starter-kit/internal/platform/web/webcontext"
"time" "time"
"geeks-accelerator/oss/saas-starter-kit/internal/platform/auth" "geeks-accelerator/oss/saas-starter-kit/internal/platform/auth"
"geeks-accelerator/oss/saas-starter-kit/internal/platform/web/webcontext"
"github.com/huandu/go-sqlbuilder" "github.com/huandu/go-sqlbuilder"
"github.com/jmoiron/sqlx" "github.com/jmoiron/sqlx"
"github.com/pborman/uuid" "github.com/pborman/uuid"
@ -19,6 +19,8 @@ const (
accountTableName = "accounts" accountTableName = "accounts"
// The database table for User Account // The database table for User Account
userAccountTableName = "users_accounts" userAccountTableName = "users_accounts"
// The database table for AccountPreference
accountPreferenceTableName = "account_preferences"
) )
var ( var (
@ -29,24 +31,6 @@ var (
ErrForbidden = errors.New("Attempted action is not allowed") ErrForbidden = errors.New("Attempted action is not allowed")
) )
// accountMapColumns is the list of columns needed for mapRowsToAccount
var accountMapColumns = "id,name,address1,address2,city,region,country,zipcode,status,timezone,signup_user_id,billing_user_id,created_at,updated_at,archived_at"
// mapRowsToAccount takes the SQL rows and maps it to the Account struct
// with the columns defined by accountMapColumns
func mapRowsToAccount(rows *sql.Rows) (*Account, error) {
var (
a Account
err error
)
err = rows.Scan(&a.ID, &a.Name, &a.Address1, &a.Address2, &a.City, &a.Region, &a.Country, &a.Zipcode, &a.Status, &a.Timezone, &a.SignupUserID, &a.BillingUserID, &a.CreatedAt, &a.UpdatedAt, &a.ArchivedAt)
if err != nil {
return nil, errors.WithStack(err)
}
return &a, nil
}
// CanReadAccount determines if claims has the authority to access the specified account ID. // CanReadAccount determines if claims has the authority to access the specified account ID.
func CanReadAccount(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, accountID string) error { func CanReadAccount(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, accountID string) error {
// If the request has claims from a specific account, ensure that the claims // If the request has claims from a specific account, ensure that the claims
@ -152,7 +136,10 @@ func applyClaimsSelect(ctx context.Context, claims auth.Claims, query *sqlbuilde
return nil return nil
} }
// selectQuery constructs a base select query for Account // accountMapColumns is the list of columns needed for find.
var accountMapColumns = "id,name,address1,address2,city,region,country,zipcode,status,timezone,signup_user_id,billing_user_id,created_at,updated_at,archived_at"
// selectQuery constructs a base select query for Account.
func selectQuery() *sqlbuilder.SelectBuilder { func selectQuery() *sqlbuilder.SelectBuilder {
query := sqlbuilder.NewSelectBuilder() query := sqlbuilder.NewSelectBuilder()
query.Select(accountMapColumns) query.Select(accountMapColumns)
@ -160,11 +147,12 @@ func selectQuery() *sqlbuilder.SelectBuilder {
return query return query
} }
// findRequestQuery generates the select query for the given find request. // Find gets all the accounts from the database based on the request params.
// TODO: Need to figure out why can't parse the args when appending the where // TODO: Need to figure out why can't parse the args when appending the where
// to the query. // to the query.
func findRequestQuery(req AccountFindRequest) (*sqlbuilder.SelectBuilder, []interface{}) { func Find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req AccountFindRequest) ([]*Account, error) {
query := selectQuery() query := selectQuery()
if req.Where != nil { if req.Where != nil {
query.Where(query.And(*req.Where)) query.Where(query.And(*req.Where))
} }
@ -178,13 +166,7 @@ func findRequestQuery(req AccountFindRequest) (*sqlbuilder.SelectBuilder, []inte
query.Offset(int(*req.Offset)) query.Offset(int(*req.Offset))
} }
return query, req.Args return find(ctx, claims, dbConn, query, req.Args, req.IncludeArchived)
}
// Find gets all the accounts from the database based on the request params.
func Find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req AccountFindRequest) ([]*Account, error) {
query, args := findRequestQuery(req)
return find(ctx, claims, dbConn, query, args, req.IncludedArchived)
} }
// find internal method for getting all the accounts from the database using a select query. // find internal method for getting all the accounts from the database using a select query.
@ -219,12 +201,15 @@ func find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, query *sqlbu
// iterate over each row // iterate over each row
resp := []*Account{} resp := []*Account{}
for rows.Next() { for rows.Next() {
u, err := mapRowsToAccount(rows) var (
a Account
err error
)
err = rows.Scan(&a.ID, &a.Name, &a.Address1, &a.Address2, &a.City, &a.Region, &a.Country, &a.Zipcode, &a.Status, &a.Timezone, &a.SignupUserID, &a.BillingUserID, &a.CreatedAt, &a.UpdatedAt, &a.ArchivedAt)
if err != nil { if err != nil {
err = errors.Wrapf(err, "query - %s", query.String()) err = errors.Wrapf(err, "query - %s", query.String())
return nil, err
} }
resp = append(resp, u) resp = append(resp, &a)
} }
return resp, nil return resp, nil
@ -336,20 +321,35 @@ func Create(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req Accoun
return &a, nil return &a, nil
} }
// ReadByID gets the specified user by ID from the database.
func ReadByID(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, id string) (*Account, error) {
return Read(ctx, claims, dbConn, AccountReadRequest{
ID: id,
IncludeArchived: false,
})
}
// Read gets the specified account from the database. // Read gets the specified account from the database.
func Read(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, id string, includedArchived bool) (*Account, error) { func Read(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req AccountReadRequest) (*Account, error) {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.account.Read") span, ctx := tracer.StartSpanFromContext(ctx, "internal.account.Read")
defer span.Finish() defer span.Finish()
// Filter base select query by ID // Validate the request.
query := selectQuery() v := webcontext.Validator()
query.Where(query.Equal("id", id)) err := v.Struct(req)
if err != nil {
res, err := find(ctx, claims, dbConn, query, []interface{}{}, includedArchived)
if res == nil || len(res) == 0 {
err = errors.WithMessagef(ErrNotFound, "account %s not found", id)
return nil, err return nil, err
} else if err != nil { }
// Filter base select query by ID
query := sqlbuilder.NewSelectBuilder()
query.Where(query.Equal("id", req.ID))
res, err := find(ctx, claims, dbConn, query, []interface{}{}, req.IncludeArchived)
if err != nil {
return nil, err
} else if res == nil || len(res) == 0 {
err = errors.WithMessagef(ErrNotFound, "account %s not found", req.ID)
return nil, err return nil, err
} }
u := res[0] u := res[0]
@ -471,14 +471,6 @@ func Update(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req Accoun
return nil return nil
} }
// Archive soft deleted the account by ID from the database.
func ArchiveById(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, accountID string, now time.Time) error {
req := AccountArchiveRequest{
ID: accountID,
}
return Archive(ctx, claims, dbConn, req, now)
}
// Archive soft deleted the account from the database. // Archive soft deleted the account from the database.
func Archive(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req AccountArchiveRequest, now time.Time) error { func Archive(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req AccountArchiveRequest, now time.Time) error {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.account.Archive") span, ctx := tracer.StartSpanFromContext(ctx, "internal.account.Archive")
@ -552,17 +544,10 @@ func Archive(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req Accou
} }
// Delete removes an account from the database. // Delete removes an account from the database.
func Delete(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, accountID string) error { func Delete(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req AccountDeleteRequest) error {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.account.Delete") span, ctx := tracer.StartSpanFromContext(ctx, "internal.account.Delete")
defer span.Finish() defer span.Finish()
// Defines the struct to apply validation
req := struct {
ID string `json:"id" validate:"required,uuid"`
}{
ID: accountID,
}
// Validate the request. // Validate the request.
v := webcontext.Validator() v := webcontext.Validator()
err := v.Struct(req) err := v.Struct(req)
@ -605,6 +590,29 @@ func Delete(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, accountID
} }
} }
// Delete all the associated account preferences.
// Required to execute first to avoid foreign key constraints.
{
// Build the delete SQL statement.
query := sqlbuilder.NewDeleteBuilder()
query.DeleteFrom(accountPreferenceTableName)
query.Where(query.And(
query.Equal("account_id", req.ID),
))
// Execute the query with the provided context.
sql, args := query.Build()
sql = dbConn.Rebind(sql)
_, err = tx.ExecContext(ctx, sql, args...)
if err != nil {
tx.Rollback()
err = errors.Wrapf(err, "query - %s", query.String())
err = errors.WithMessagef(err, "delete preferences for account %s failed", req.ID)
return err
}
}
// Build the delete SQL statement. // Build the delete SQL statement.
query := sqlbuilder.NewDeleteBuilder() query := sqlbuilder.NewDeleteBuilder()
query.DeleteFrom(accountTableName) query.DeleteFrom(accountTableName)

View File

@ -0,0 +1,426 @@
package account_preference
import (
"context"
"time"
"geeks-accelerator/oss/saas-starter-kit/internal/account"
"geeks-accelerator/oss/saas-starter-kit/internal/platform/auth"
"geeks-accelerator/oss/saas-starter-kit/internal/platform/web/webcontext"
"github.com/huandu/go-sqlbuilder"
"github.com/jmoiron/sqlx"
"github.com/pborman/uuid"
"github.com/pkg/errors"
"gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer"
"gopkg.in/go-playground/validator.v9"
)
const (
// The database table for AccountPreference
accountPreferenceTableName = "account_preferences"
// The database table for User Account
userAccountTableName = "users_accounts"
)
var (
// ErrNotFound abstracts the mgo not found error.
ErrNotFound = errors.New("Entity not found")
)
// The list of columns needed for find
var accountPreferenceMapColumns = "account_id,name,value,created_at,updated_at,archived_at"
// applyClaimsSelect applies a sub-query to the provided query to enforce ACL based on
// the claims provided.
// 1. All role types can access their user ID
// 2. Any user with the same account ID
// 3. No claims, request is internal, no ACL applied
func applyClaimsSelect(ctx context.Context, claims auth.Claims, query *sqlbuilder.SelectBuilder) error {
// Claims are empty, don't apply any ACL
if claims.Audience == "" && claims.Subject == "" {
return nil
}
// Build select statement for users_accounts table
subQuery := sqlbuilder.NewSelectBuilder().Select("account_id").From(userAccountTableName)
var or []string
if claims.Audience != "" {
or = append(or, subQuery.Equal("account_id", claims.Audience))
}
if claims.Subject != "" {
or = append(or, subQuery.Equal("user_id", claims.Subject))
}
// Append sub query
if len(or) > 0 {
subQuery.Where(subQuery.Or(or...))
query.Where(query.In("account_id", subQuery))
}
return nil
}
// Find gets all the account preferences from the database based on the request params.
// TODO: Need to figure out why can't parse the args when appending the where to the query.
func Find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req AccountPreferenceFindRequest) ([]*AccountPreference, error) {
query := sqlbuilder.NewSelectBuilder()
if req.Where != nil {
query.Where(query.And(*req.Where))
}
if len(req.Order) > 0 {
query.OrderBy(req.Order...)
}
if req.Limit != nil {
query.Limit(int(*req.Limit))
}
if req.Offset != nil {
query.Offset(int(*req.Offset))
}
return find(ctx, claims, dbConn, query, req.Args, req.IncludeArchived)
}
// FindByAccountID gets the specified account preferences for an account from the database.
func FindByAccountID(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req AccountPreferenceFindByAccountIDRequest) ([]*AccountPreference, error) {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.account_preference.FindByAccountID")
defer span.Finish()
// Validate the request.
err := Validator().StructCtx(ctx, req)
if err != nil {
return nil, err
}
// Filter base select query by ID
query := sqlbuilder.NewSelectBuilder()
query.Where(query.Equal("account_id", req.AccountID))
if len(req.Order) > 0 {
query.OrderBy(req.Order...)
}
if req.Limit != nil {
query.Limit(int(*req.Limit))
}
if req.Offset != nil {
query.Offset(int(*req.Offset))
}
return find(ctx, claims, dbConn, query, []interface{}{}, req.IncludeArchived)
}
// find internal method for getting all the account preferences from the database using a select query.
func find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, query *sqlbuilder.SelectBuilder, args []interface{}, includedArchived bool) ([]*AccountPreference, error) {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.account_preference.Find")
defer span.Finish()
query.Select(accountPreferenceMapColumns)
query.From(accountPreferenceTableName)
if !includedArchived {
query.Where(query.IsNull("archived_at"))
}
// Check to see if a sub query needs to be applied for the claims
err := applyClaimsSelect(ctx, claims, query)
if err != nil {
return nil, err
}
queryStr, queryArgs := query.Build()
queryStr = dbConn.Rebind(queryStr)
args = append(args, queryArgs...)
// fetch all places from the db
rows, err := dbConn.QueryContext(ctx, queryStr, args...)
if err != nil {
err = errors.Wrapf(err, "query - %s", query.String())
err = errors.WithMessage(err, "find account preferences failed")
return nil, err
}
// iterate over each row
resp := []*AccountPreference{}
for rows.Next() {
var (
a AccountPreference
err error
)
err = rows.Scan(&a.AccountID, &a.Name, &a.Value, &a.CreatedAt, &a.UpdatedAt, &a.ArchivedAt)
if err != nil {
err = errors.Wrapf(err, "query - %s", query.String())
return nil, err
}
resp = append(resp, &a)
}
return resp, nil
}
// Read gets the specified account preference from the database.
func Read(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req AccountPreferenceReadRequest) (*AccountPreference, error) {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.account_preference.Read")
defer span.Finish()
// Validate the request.
err := Validator().StructCtx(ctx, req)
if err != nil {
return nil, err
}
// Filter base select query by ID
query := sqlbuilder.NewSelectBuilder()
query.Where(query.And(
query.Equal("account_id", req.AccountID)),
query.Equal("name", req.Name))
res, err := find(ctx, claims, dbConn, query, []interface{}{}, req.IncludeArchived)
if err != nil {
return nil, err
} else if res == nil || len(res) == 0 {
err = errors.WithMessagef(ErrNotFound, "account preference %s for account %s not found", req.Name, req.AccountID)
return nil, err
}
u := res[0]
return u, nil
}
type ctxKeyPreferenceName int
const KeyPreferenceName ctxKeyPreferenceName = 1
// Validator registers a custom validation function for tag preference_value.
func Validator() *validator.Validate {
v := webcontext.Validator()
fctx := func(ctx context.Context, fl validator.FieldLevel) bool {
if fl.Field().String() == "invalid" {
return false
}
name, ok := ctx.Value(KeyPreferenceName).(AccountPreferenceName)
if !ok {
return false
}
val := fl.Field().String()
switch name {
case AccountPreference_Datetime_Format:
loc, _ := time.LoadLocation("MST")
tv, _ := time.Parse(time.RFC3339, "2006-01-02T15:04:05Z")
tv = tv.In(loc)
pv, err := time.Parse(val, tv.Format(val))
if err != nil {
return false
}
if pv.Format(val) != tv.Format(val) || pv.Format("2006-01-02") != tv.Format("2006-01-02") || pv.IsZero() {
return false
}
return true
case AccountPreference_Date_Format:
loc, _ := time.LoadLocation("MST")
tv, _ := time.Parse(time.RFC3339, "2006-01-02T15:04:05Z")
tv = tv.In(loc)
pv, err := time.Parse(val, tv.Format(val))
if err != nil {
return false
}
if pv.Format(val) != tv.Format(val) || pv.UTC().Format("2006-01-02") != tv.UTC().Format("2006-01-02") || pv.IsZero() {
return false
}
return true
case AccountPreference_Time_Format:
//loc, _ := time.LoadLocation("MST")
tv, _ := time.Parse(time.RFC3339, "2006-01-02T15:04:05Z")
//tv = tv.In(loc)
pv, err := time.Parse(val, tv.Format(val))
if err != nil {
return false
}
if pv.Format(val) != tv.Format(val) || pv.UTC().Format("15:04") != tv.UTC().Format("15:04") || pv.IsZero() {
return false
}
return true
}
return false
}
v.RegisterValidationCtx("preference_value", fctx)
return v
}
// Set inserts a new account preference or updates an existing on.
func Set(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req AccountPreferenceSetRequest, now time.Time) error {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.account_preference.Set")
defer span.Finish()
ctx = context.WithValue(ctx, KeyPreferenceName, req.Name)
// Validate the request.
err := Validator().StructCtx(ctx, req)
if err != nil {
return err
}
// Ensure the claims can modify the account specified in the request.
err = account.CanModifyAccount(ctx, claims, dbConn, req.AccountID)
if err != nil {
return err
}
// If now empty set it to the current time.
if now.IsZero() {
now = time.Now()
}
// Always store the time as UTC.
now = now.UTC()
// Postgres truncates times to milliseconds when storing. We and do the same
// here so the value we return is consistent with what we store.
now = now.Truncate(time.Millisecond)
// Build the insert SQL statement.
query := sqlbuilder.NewInsertBuilder()
query.InsertInto(accountPreferenceTableName)
query.Cols("account_id", "name", "value", "created_at", "updated_at")
query.Values(req.AccountID, req.Name, req.Value, now, now)
// Execute the query with the provided context.
sql, args := query.Build()
sql = dbConn.Rebind(sql)
sql = sql + " ON CONFLICT ON CONSTRAINT account_preferences_pkey DO UPDATE set value = EXCLUDED.value "
_, err = dbConn.ExecContext(ctx, sql, args...)
if err != nil {
err = errors.Wrapf(err, "query - %s", query.String())
err = errors.WithMessage(err, "set account preference failed")
return err
}
return nil
}
// Archive soft deleted the account preference from the database.
func Archive(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req AccountPreferenceArchiveRequest, now time.Time) error {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.account_preference.Archive")
defer span.Finish()
// Validate the request.
v := webcontext.Validator()
err := v.Struct(req)
if err != nil {
return err
}
// Ensure the claims can modify the account specified in the request.
err = account.CanModifyAccount(ctx, claims, dbConn, req.AccountID)
if err != nil {
return err
}
// If now empty set it to the current time.
if now.IsZero() {
now = time.Now()
}
// Always store the time as UTC.
now = now.UTC()
// Postgres truncates times to milliseconds when storing. We and do the same
// here so the value we return is consistent with what we store.
now = now.Truncate(time.Millisecond)
// Build the update SQL statement.
query := sqlbuilder.NewUpdateBuilder()
query.Update(accountPreferenceTableName)
query.Set(
query.Assign("archived_at", now),
)
query.Where(query.Equal("account_id", req.AccountID))
// Execute the query with the provided context.
sql, args := query.Build()
sql = dbConn.Rebind(sql)
_, err = dbConn.ExecContext(ctx, sql, args...)
if err != nil {
err = errors.Wrapf(err, "query - %s", query.String())
err = errors.WithMessagef(err, "archive account preference %s for account %s failed", req.Name, req.AccountID)
return err
}
return nil
}
// Delete removes an account preference from the database.
func Delete(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req AccountPreferenceDeleteRequest) error {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.account_preference.Delete")
defer span.Finish()
// Validate the request.
v := webcontext.Validator()
err := v.Struct(req)
if err != nil {
return err
}
// Ensure the claims can modify the account specified in the request.
err = account.CanModifyAccount(ctx, claims, dbConn, req.AccountID)
if err != nil {
return err
}
// Start a new transaction to handle rollbacks on error.
tx, err := dbConn.Begin()
if err != nil {
return errors.WithStack(err)
}
// Build the delete SQL statement.
query := sqlbuilder.NewDeleteBuilder()
query.DeleteFrom(accountPreferenceTableName)
query.Where(query.Equal("account_id", req.AccountID))
// Execute the query with the provided context.
sql, args := query.Build()
sql = dbConn.Rebind(sql)
_, err = tx.ExecContext(ctx, sql, args...)
if err != nil {
tx.Rollback()
err = errors.Wrapf(err, "query - %s", query.String())
err = errors.WithMessagef(err, "delete account preference %s for account %s failed", req.Name, req.AccountID)
return err
}
err = tx.Commit()
if err != nil {
return errors.WithStack(err)
}
return nil
}
// MockAccountPreference returns a fake AccountPreference for testing.
func MockAccountPreference(ctx context.Context, dbConn *sqlx.DB, now time.Time) error {
req := AccountPreferenceSetRequest{
AccountID: uuid.NewRandom().String(),
Name: AccountPreference_Datetime_Format,
Value: AccountPreference_Datetime_Format_Default,
}
return Set(ctx, auth.Claims{}, dbConn, req, now)
}

View File

@ -0,0 +1,505 @@
package account_preference
import (
"geeks-accelerator/oss/saas-starter-kit/internal/account"
"math/rand"
"os"
"strings"
"testing"
"time"
"geeks-accelerator/oss/saas-starter-kit/internal/platform/auth"
"geeks-accelerator/oss/saas-starter-kit/internal/platform/tests"
"geeks-accelerator/oss/saas-starter-kit/internal/user_account"
"github.com/dgrijalva/jwt-go"
"github.com/google/go-cmp/cmp"
"github.com/pborman/uuid"
"github.com/pkg/errors"
)
var test *tests.Test
// TestMain is the entry point for testing.
func TestMain(m *testing.M) {
os.Exit(testMain(m))
}
func testMain(m *testing.M) int {
test = tests.New()
defer test.TearDown()
return m.Run()
}
// TestSetValidation ensures all the validation tags work on Set.
func TestSetValidation(t *testing.T) {
invalidName := AccountPreferenceName("xxxxxx")
var prefTests = []struct {
name string
req AccountPreferenceSetRequest
error error
}{
{"Required Fields",
AccountPreferenceSetRequest{},
errors.New("Key: 'AccountPreferenceSetRequest.{{account_id}}' Error:Field validation for '{{account_id}}' failed on the 'required' tag\n" +
"Key: 'AccountPreferenceSetRequest.{{name}}' Error:Field validation for '{{name}}' failed on the 'required' tag\n" +
"Key: 'AccountPreferenceSetRequest.{{value}}' Error:Field validation for '{{value}}' failed on the 'required' tag"),
},
{"Valid Name",
AccountPreferenceSetRequest{
AccountID: uuid.NewRandom().String(),
Name: invalidName,
Value: uuid.NewRandom().String(),
},
errors.New("Key: 'AccountPreferenceSetRequest.{{name}}' Error:Field validation for '{{name}}' failed on the 'oneof' tag\n" +
"Key: 'AccountPreferenceSetRequest.{{value}}' Error:Field validation for '{{value}}' failed on the 'preference_value' tag"),
},
}
now := time.Date(2018, time.October, 1, 0, 0, 0, 0, time.UTC)
t.Log("Given the need ensure all validation tags are working for account preference set.")
{
for i, tt := range prefTests {
t.Logf("\tTest: %d\tWhen running test: %s", i, tt.name)
{
ctx := tests.Context()
err := Set(ctx, auth.Claims{}, test.MasterDB, tt.req, now)
if err != tt.error {
// TODO: need a better way to handle validation errors as they are
// of type interface validator.ValidationErrorsTranslations
var errStr string
if err != nil {
errStr = strings.Replace(err.Error(), "{{", "", -1)
errStr = strings.Replace(errStr, "}}", "", -1)
}
var expectStr string
if tt.error != nil {
expectStr = strings.Replace(tt.error.Error(), "{{", "", -1)
expectStr = strings.Replace(expectStr, "}}", "", -1)
}
if errStr != expectStr {
t.Logf("\t\tGot : %+v", errStr)
t.Logf("\t\tWant: %+v", expectStr)
t.Fatalf("\t%s\tSet failed.", tests.Failed)
}
}
// If there was an error that was expected, then don't go any further
if tt.error != nil {
t.Logf("\t%s\tSet ok.", tests.Success)
continue
}
t.Logf("\t%s\tSet ok.", tests.Success)
}
}
}
}
// TestCrud validates the full set of CRUD operations for account preferences and ensures ACLs are correctly applied
// by claims.
func TestCrud(t *testing.T) {
defer tests.Recover(t)
now := time.Date(2018, time.October, 1, 0, 0, 0, 0, time.UTC)
// Create a test user and account.
usrAcc, err := user_account.MockUserAccount(tests.Context(), test.MasterDB, now, user_account.UserAccountRole_Admin)
if err != nil {
t.Log("Got :", err)
t.Fatalf("%s\tCreate account failed.", tests.Failed)
}
type prefTest struct {
name string
claims func(string, string) auth.Claims
set AccountPreferenceSetRequest
writeErr error
findErr error
}
var prefTests []prefTest
// Internal request, should bypass ACL.
prefTests = append(prefTests, prefTest{"EmptyClaims",
func(accountID, userId string) auth.Claims {
return auth.Claims{}
},
AccountPreferenceSetRequest{
AccountID: usrAcc.AccountID,
Name: AccountPreference_Datetime_Format,
Value: AccountPreference_Datetime_Format_Default,
},
nil,
nil,
})
// Role of account but claim account does not match update account so forbidden.
prefTests = append(prefTests, prefTest{"RoleAccountPreferenceDiffAccountPreference",
func(accountID, userId string) auth.Claims {
return auth.Claims{
Roles: []string{auth.RoleAdmin},
StandardClaims: jwt.StandardClaims{
Audience: uuid.NewRandom().String(),
Subject: userId,
},
}
},
AccountPreferenceSetRequest{
AccountID: usrAcc.AccountID,
Name: AccountPreference_Datetime_Format,
Value: AccountPreference_Datetime_Format_Default,
},
account.ErrForbidden,
ErrNotFound,
})
// Role of account AND claim account matches update account so OK.
prefTests = append(prefTests, prefTest{"RoleAccountPreferenceSameAccountPreference",
func(accountID, userId string) auth.Claims {
return auth.Claims{
Roles: []string{auth.RoleAdmin},
StandardClaims: jwt.StandardClaims{
Audience: accountID,
Subject: userId,
},
}
},
AccountPreferenceSetRequest{
AccountID: usrAcc.AccountID,
Name: AccountPreference_Date_Format,
Value: AccountPreference_Date_Format_Default,
},
nil,
nil,
})
// Role of admin but claim account does not match update account so forbidden.
prefTests = append(prefTests, prefTest{"RoleAdminDiffAccountPreference",
func(accountID, userID string) auth.Claims {
return auth.Claims{
Roles: []string{auth.RoleAdmin},
StandardClaims: jwt.StandardClaims{
Audience: uuid.NewRandom().String(),
Subject: uuid.NewRandom().String(),
},
}
},
AccountPreferenceSetRequest{
AccountID: usrAcc.AccountID,
Name: AccountPreference_Time_Format,
Value: AccountPreference_Time_Format_Default,
},
account.ErrForbidden,
ErrNotFound,
})
// Role of admin and claim account matches update account so ok.
prefTests = append(prefTests, prefTest{"RoleAdminSameAccountPreference",
func(accountID, userId string) auth.Claims {
return auth.Claims{
Roles: []string{auth.RoleAdmin},
StandardClaims: jwt.StandardClaims{
Audience: uuid.NewRandom().String(),
Subject: userId,
},
}
},
AccountPreferenceSetRequest{
AccountID: usrAcc.AccountID,
Name: AccountPreference_Time_Format,
Value: AccountPreference_Time_Format_Default,
},
account.ErrForbidden,
ErrNotFound,
})
t.Log("Given the need to ensure claims are applied as ACL for set account preference.")
{
for i, tt := range prefTests {
t.Logf("\tTest: %d\tWhen running test: %s", i, tt.name)
{
ctx := tests.Context()
err := Set(ctx, tt.claims(usrAcc.AccountID, usrAcc.UserID), test.MasterDB, tt.set, now)
if err != nil && errors.Cause(err) != tt.writeErr {
t.Logf("\t\tGot : %+v", err)
t.Logf("\t\tWant: %+v", tt.writeErr)
t.Fatalf("\t%s\tFind failed.", tests.Failed)
}
// If user doesn't have access to set, create one anyways to test the other endpoints.
if tt.writeErr != nil {
err := Set(ctx, auth.Claims{}, test.MasterDB, tt.set, now)
if err != nil {
t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tCreate failed.", tests.Failed)
}
}
// Find the account and make sure the set where made.
readRes, err := Read(ctx, tt.claims(usrAcc.AccountID, usrAcc.UserID), test.MasterDB, AccountPreferenceReadRequest{
AccountID: tt.set.AccountID,
Name: tt.set.Name,
})
if err != nil && errors.Cause(err) != tt.findErr {
t.Logf("\t\tGot : %+v", err)
t.Logf("\t\tWant: %+v", tt.findErr)
t.Fatalf("\t%s\tFind failed.", tests.Failed)
} else if tt.findErr == nil {
findExpected := &AccountPreference{
AccountID: tt.set.AccountID,
Name: tt.set.Name,
Value: tt.set.Value,
CreatedAt: readRes.CreatedAt,
UpdatedAt: readRes.UpdatedAt,
}
if diff := cmp.Diff(readRes, findExpected); diff != "" {
t.Fatalf("\t%s\tExpected find result to match update. Diff:\n%s", tests.Failed, diff)
}
t.Logf("\t%s\tRead ok.", tests.Success)
}
// Archive (soft-delete) the account.
err = Archive(ctx, tt.claims(usrAcc.AccountID, usrAcc.UserID), test.MasterDB, AccountPreferenceArchiveRequest{
AccountID: tt.set.AccountID,
Name: tt.set.Name,
}, now)
if err != nil && errors.Cause(err) != tt.writeErr {
t.Logf("\t\tGot : %+v", err)
t.Logf("\t\tWant: %+v", tt.writeErr)
t.Fatalf("\t%s\tArchive failed.", tests.Failed)
} else if tt.findErr == nil {
// Trying to find the archived account with the includeArchived false should result in not found.
_, err = Read(ctx, tt.claims(usrAcc.AccountID, usrAcc.UserID), test.MasterDB, AccountPreferenceReadRequest{
AccountID: tt.set.AccountID,
Name: tt.set.Name,
})
if err != nil && errors.Cause(err) != ErrNotFound {
t.Logf("\t\tGot : %+v", err)
t.Logf("\t\tWant: %+v", ErrNotFound)
t.Fatalf("\t%s\tArchive Read failed.", tests.Failed)
}
// Trying to find the archived account with the includeArchived true should result no error.
_, err = Read(ctx, tt.claims(usrAcc.AccountID, usrAcc.UserID), test.MasterDB, AccountPreferenceReadRequest{
AccountID: tt.set.AccountID,
Name: tt.set.Name,
IncludeArchived: true,
})
if err != nil {
t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tArchive Read failed.", tests.Failed)
}
}
t.Logf("\t%s\tArchive ok.", tests.Success)
// Delete (hard-delete) the account.
err = Delete(ctx, tt.claims(usrAcc.AccountID, usrAcc.UserID), test.MasterDB, AccountPreferenceDeleteRequest{
AccountID: tt.set.AccountID,
Name: tt.set.Name,
})
if err != nil && errors.Cause(err) != tt.writeErr {
t.Logf("\t\tGot : %+v", err)
t.Logf("\t\tWant: %+v", tt.writeErr)
t.Fatalf("\t%s\tDelete failed.", tests.Failed)
} else if tt.writeErr == nil {
// Trying to find the deleted account with the includeArchived true should result in not found.
_, err = Read(ctx, tt.claims(usrAcc.AccountID, usrAcc.UserID), test.MasterDB, AccountPreferenceReadRequest{
AccountID: tt.set.AccountID,
Name: tt.set.Name,
IncludeArchived: true,
})
if errors.Cause(err) != ErrNotFound {
t.Logf("\t\tGot : %+v", err)
t.Logf("\t\tWant: %+v", ErrNotFound)
t.Fatalf("\t%s\tDelete Read failed.", tests.Failed)
}
}
t.Logf("\t%s\tDelete ok.", tests.Success)
}
}
}
}
// TestFind validates all the request params are correctly parsed into a select query.
func TestFind(t *testing.T) {
now := time.Now().Add(time.Hour * -1).UTC()
// Create a test user and account.
usrAcc, err := user_account.MockUserAccount(tests.Context(), test.MasterDB, now, user_account.UserAccountRole_Admin)
if err != nil {
t.Log("Got :", err)
t.Fatalf("%s\tCreate account failed.", tests.Failed)
}
startTime := now.Truncate(time.Millisecond)
var endTime time.Time
reqs := []AccountPreferenceSetRequest{
{
AccountID: usrAcc.AccountID,
Name: AccountPreference_Datetime_Format,
Value: AccountPreference_Datetime_Format_Default,
},
{
AccountID: usrAcc.AccountID,
Name: AccountPreference_Date_Format,
Value: AccountPreference_Date_Format_Default,
},
{
AccountID: usrAcc.AccountID,
Name: AccountPreference_Time_Format,
Value: AccountPreference_Time_Format_Default,
},
}
var prefs []*AccountPreference
for idx, req := range reqs {
err = Set(tests.Context(), auth.Claims{}, test.MasterDB, req, now.Add(time.Second*time.Duration(idx)))
if err != nil {
t.Logf("\t\tGot : %+v", err)
t.Logf("\t\tRequest : %+v", req)
t.Fatalf("\t%s\tSet failed.", tests.Failed)
}
pref, err := Read(tests.Context(), auth.Claims{}, test.MasterDB, AccountPreferenceReadRequest{
AccountID: req.AccountID,
Name: req.Name,
})
if err != nil {
t.Logf("\t\tGot : %+v", err)
t.Logf("\t\tRequest : %+v", req)
t.Fatalf("\t%s\tSet failed.", tests.Failed)
}
prefs = append(prefs, pref)
endTime = pref.CreatedAt
}
type accountTest struct {
name string
req AccountPreferenceFindRequest
expected []*AccountPreference
error error
}
var prefTests []accountTest
createdFilter := "created_at BETWEEN ? AND ?"
// Test sort accounts.
prefTests = append(prefTests, accountTest{"Find all order by created_at asc",
AccountPreferenceFindRequest{
Where: &createdFilter,
Args: []interface{}{startTime, endTime},
Order: []string{"created_at"},
},
prefs,
nil,
})
// Test reverse sorted accounts.
var expected []*AccountPreference
for i := len(prefs) - 1; i >= 0; i-- {
expected = append(expected, prefs[i])
}
prefTests = append(prefTests, accountTest{"Find all order by created_at desc",
AccountPreferenceFindRequest{
Where: &createdFilter,
Args: []interface{}{startTime, endTime},
Order: []string{"created_at desc"},
},
expected,
nil,
})
// Test limit.
var limit uint = 2
prefTests = append(prefTests, accountTest{"Find limit",
AccountPreferenceFindRequest{
Where: &createdFilter,
Args: []interface{}{startTime, endTime},
Order: []string{"created_at"},
Limit: &limit,
},
prefs[0:2],
nil,
})
// Test offset.
var offset uint = 1
prefTests = append(prefTests, accountTest{"Find limit, offset",
AccountPreferenceFindRequest{
Where: &createdFilter,
Args: []interface{}{startTime, endTime},
Order: []string{"created_at"},
Limit: &limit,
Offset: &offset,
},
prefs[1:3],
nil,
})
// Test where filter.
whereParts := []string{}
whereArgs := []interface{}{startTime, endTime}
expected = []*AccountPreference{}
for i := 0; i < len(prefs); i++ {
if rand.Intn(100) < 50 {
continue
}
u := *prefs[i]
whereParts = append(whereParts, "name = ?")
whereArgs = append(whereArgs, u.Name)
expected = append(expected, &u)
}
where := createdFilter + " AND (" + strings.Join(whereParts, " OR ") + ")"
prefTests = append(prefTests, accountTest{"Find where",
AccountPreferenceFindRequest{
Where: &where,
Args: whereArgs,
Order: []string{"created_at"},
},
expected,
nil,
})
t.Log("Given the need to ensure find account preferences returns the expected results.")
{
for i, tt := range prefTests {
t.Logf("\tTest: %d\tWhen running test: %s", i, tt.name)
{
ctx := tests.Context()
res, err := Find(ctx, auth.Claims{}, test.MasterDB, tt.req)
if errors.Cause(err) != tt.error {
t.Logf("\t\tGot : %+v", err)
t.Logf("\t\tWant: %+v", tt.error)
t.Fatalf("\t%s\tFind failed.", tests.Failed)
} else if diff := cmp.Diff(res, tt.expected); diff != "" {
t.Logf("\t\tGot: %d items", len(res))
t.Logf("\t\tWant: %d items", len(tt.expected))
for _, u := range res {
t.Logf("\t\tGot: %s ID", u.Name)
}
for _, u := range tt.expected {
t.Logf("\t\tExpected: %s ID", u.Name)
}
t.Fatalf("\t%s\tExpected find result to match expected. Diff:\n%s", tests.Failed, diff)
}
t.Logf("\t%s\tFind ok.", tests.Success)
}
}
}
}

View File

@ -0,0 +1,150 @@
package account_preference
import (
"context"
"github.com/pkg/errors"
"time"
"database/sql/driver"
"geeks-accelerator/oss/saas-starter-kit/internal/platform/web"
"github.com/lib/pq"
"gopkg.in/go-playground/validator.v9"
)
// AccountPreference represents an account setting.
type AccountPreference struct {
AccountID string `json:"account_id" validate:"required,uuid" example:"c4653bf9-5978-48b7-89c5-95704aebb7e2"`
Name AccountPreferenceName `json:"name" validate:"required,oneof=datetime_format date_format time_format" swaggertype:"string" enums:"datetime_format,date_format,time_format" example:"datetime_format"`
Value string `json:"value" validate:"required,preference_value" example:"2006-01-02 at 3:04PM MST"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
ArchivedAt *pq.NullTime `json:"archived_at,omitempty"`
}
// AccountPreferenceResponse represents an account setting that is returned for display.
type AccountPreferenceResponse struct {
AccountID string `json:"account_id" example:"c4653bf9-5978-48b7-89c5-95704aebb7e2"`
Name web.EnumResponse `json:"name" example:"datetime_format"`
Value string `json:"value" example:"2006-01-02 at 3:04PM MST"`
CreatedAt web.TimeResponse `json:"created_at"` // CreatedAt contains multiple format options for display.
UpdatedAt web.TimeResponse `json:"updated_at"` // UpdatedAt contains multiple format options for display.
ArchivedAt *web.TimeResponse `json:"archived_at,omitempty"` // ArchivedAt contains multiple format options for display.
}
// Response transforms AccountPreference and AccountPreferenceResponse that is used for display.
// Additional filtering by context values or translations could be applied.
func (m *AccountPreference) Response(ctx context.Context) *AccountPreferenceResponse {
if m == nil {
return nil
}
r := &AccountPreferenceResponse{
AccountID: m.AccountID,
Name: web.NewEnumResponse(ctx, m.Name, AccountPreferenceName_Values),
Value: m.Value,
CreatedAt: web.NewTimeResponse(ctx, m.CreatedAt),
UpdatedAt: web.NewTimeResponse(ctx, m.UpdatedAt),
}
if m.ArchivedAt != nil && !m.ArchivedAt.Time.IsZero() {
at := web.NewTimeResponse(ctx, m.ArchivedAt.Time)
r.ArchivedAt = &at
}
return r
}
// AccountPreferenceReadRequest contains information needed to read an Account Preference.
type AccountPreferenceReadRequest struct {
AccountID string `json:"account_id" validate:"required,uuid" example:"c4653bf9-5978-48b7-89c5-95704aebb7e2"`
Name AccountPreferenceName `json:"name" validate:"required,oneof=datetime_format date_format time_format" swaggertype:"string" enums:"datetime_format,date_format,time_format" example:"datetime_format"`
IncludeArchived bool `json:"include-archived" example:"false"`
}
// AccountPreferenceSetRequest contains information needed to create a new Account Preference.
type AccountPreferenceSetRequest struct {
AccountID string `json:"account_id" validate:"required,uuid" example:"c4653bf9-5978-48b7-89c5-95704aebb7e2"`
Name AccountPreferenceName `json:"name" validate:"required,oneof=datetime_format date_format time_format" swaggertype:"string" enums:"datetime_format,date_format,time_format" example:"datetime_format"`
Value string `json:"value" validate:"required,preference_value" example:"2006-01-02 at 3:04PM MST"`
}
// AccountPreferenceArchiveRequest defines the information needed to archive an account preference.
// This will archive (soft-delete) the existing database entry.
type AccountPreferenceArchiveRequest struct {
AccountID string `json:"account_id" validate:"required,uuid" example:"c4653bf9-5978-48b7-89c5-95704aebb7e2"`
Name AccountPreferenceName `json:"name" validate:"required,oneof=datetime_format date_format time_format" swaggertype:"string" enums:"datetime_format,date_format,time_format" example:"datetime_format"`
}
// AccountPreferenceDeleteRequest defines the information needed to delete an account preference.
type AccountPreferenceDeleteRequest struct {
AccountID string `json:"account_id" validate:"required,uuid" example:"c4653bf9-5978-48b7-89c5-95704aebb7e2"`
Name AccountPreferenceName `json:"name" validate:"required,oneof=datetime_format date_format time_format" swaggertype:"string" enums:"datetime_format,date_format,time_format" example:"datetime_format"`
}
// AccountPreferenceFindRequest defines the possible options to search for accounts. By default
// archived accounts will be excluded from response.
type AccountPreferenceFindRequest struct {
Where *string `json:"where" example:"name = ?"`
Args []interface{} `json:"args" swaggertype:"array,string" example:"Company Name,active"`
Order []string `json:"order" example:"created_at desc"`
Limit *uint `json:"limit" example:"10"`
Offset *uint `json:"offset" example:"20"`
IncludeArchived bool `json:"include-archived" example:"false"`
}
// AccountPreferenceFindByAccountIDRequest defines the possible options to search for accounts. By default
// archived account preferences will be excluded from response.
type AccountPreferenceFindByAccountIDRequest struct {
AccountID string `json:"id" validate:"required,uuid" example:"c4653bf9-5978-48b7-89c5-95704aebb7e2"`
Order []string `json:"order" example:"created_at desc"`
Limit *uint `json:"limit" example:"10"`
Offset *uint `json:"offset" example:"20"`
IncludeArchived bool `json:"include-archived" example:"false"`
}
// AccountPreferenceName represents the name of an account preference.
type AccountPreferenceName string
// Account Preference Datetime Format
var (
AccountPreference_Datetime_Format AccountPreferenceName = "datetime_format"
AccountPreference_Date_Format AccountPreferenceName = "date_format"
AccountPreference_Time_Format AccountPreferenceName = "time_format"
AccountPreference_Datetime_Format_Default = "2006-01-02 at 3:04PM MST"
AccountPreference_Date_Format_Default = "2006-01-02"
AccountPreference_Time_Format_Default = "3:04PM MST"
)
// AccountPreferenceName_Values provides list of valid AccountPreferenceName values.
var AccountPreferenceName_Values = []AccountPreferenceName{
AccountPreference_Datetime_Format,
AccountPreference_Date_Format,
AccountPreference_Time_Format,
}
// Scan supports reading the AccountPreferenceName value from the database.
func (s *AccountPreferenceName) Scan(value interface{}) error {
asBytes, ok := value.(string)
if !ok {
return errors.New("Scan source is not []byte")
}
*s = AccountPreferenceName(string(asBytes))
return nil
}
// Value converts the AccountPreferenceName value to be stored in the database.
func (s AccountPreferenceName) Value() (driver.Value, error) {
v := validator.New()
errs := v.Var(s, "required,oneof=datetime_format date_format time_format")
if errs != nil {
return nil, errs
}
return string(s), nil
}
// String converts the AccountPreferenceName value to a string.
func (s AccountPreferenceName) String() string {
return string(s)
}

View File

@ -30,39 +30,6 @@ func testMain(m *testing.M) int {
return m.Run() return m.Run()
} }
// TestFindRequestQuery validates findRequestQuery
func TestFindRequestQuery(t *testing.T) {
where := "first_name = ? or address1 = ?"
var (
limit uint = 12
offset uint = 34
)
req := AccountFindRequest{
Where: &where,
Args: []interface{}{
"lee",
"103 East Main St.",
},
Order: []string{
"id asc",
"created_at desc",
},
Limit: &limit,
Offset: &offset,
}
expected := "SELECT " + accountMapColumns + " FROM " + accountTableName + " WHERE (first_name = ? or address1 = ?) ORDER BY id asc, created_at desc LIMIT 12 OFFSET 34"
res, args := findRequestQuery(req)
if diff := cmp.Diff(res.String(), expected); diff != "" {
t.Fatalf("\t%s\tExpected result query to match. Diff:\n%s", tests.Failed, diff)
}
if diff := cmp.Diff(args, req.Args); diff != "" {
t.Fatalf("\t%s\tExpected result query to match. Diff:\n%s", tests.Failed, diff)
}
}
// TestApplyClaimsSelect validates applyClaimsSelect // TestApplyClaimsSelect validates applyClaimsSelect
func TestApplyClaimsSelect(t *testing.T) { func TestApplyClaimsSelect(t *testing.T) {
var claimTests = []struct { var claimTests = []struct {
@ -786,7 +753,7 @@ func TestCrud(t *testing.T) {
t.Logf("\t%s\tUpdate ok.", tests.Success) t.Logf("\t%s\tUpdate ok.", tests.Success)
// Find the account and make sure the updates where made. // Find the account and make sure the updates where made.
findRes, err := Read(ctx, tt.claims(account, userId), test.MasterDB, account.ID, false) findRes, err := ReadByID(ctx, tt.claims(account, userId), test.MasterDB, account.ID)
if err != nil && errors.Cause(err) != tt.findErr { if err != nil && errors.Cause(err) != tt.findErr {
t.Logf("\t\tGot : %+v", err) t.Logf("\t\tGot : %+v", err)
t.Logf("\t\tWant: %+v", tt.findErr) t.Logf("\t\tWant: %+v", tt.findErr)
@ -800,14 +767,14 @@ func TestCrud(t *testing.T) {
} }
// Archive (soft-delete) the account. // Archive (soft-delete) the account.
err = ArchiveById(ctx, tt.claims(account, userId), test.MasterDB, account.ID, now) err = Archive(ctx, tt.claims(account, userId), test.MasterDB, AccountArchiveRequest{ID: account.ID}, now)
if err != nil && errors.Cause(err) != tt.updateErr { if err != nil && errors.Cause(err) != tt.updateErr {
t.Logf("\t\tGot : %+v", err) t.Logf("\t\tGot : %+v", err)
t.Logf("\t\tWant: %+v", tt.updateErr) t.Logf("\t\tWant: %+v", tt.updateErr)
t.Fatalf("\t%s\tArchive failed.", tests.Failed) t.Fatalf("\t%s\tArchive failed.", tests.Failed)
} else if tt.updateErr == nil { } else if tt.updateErr == nil {
// Trying to find the archived account with the includeArchived false should result in not found. // Trying to find the archived account with the includeArchived false should result in not found.
_, err = Read(ctx, tt.claims(account, userId), test.MasterDB, account.ID, false) _, err = ReadByID(ctx, tt.claims(account, userId), test.MasterDB, account.ID)
if err != nil && errors.Cause(err) != ErrNotFound { if err != nil && errors.Cause(err) != ErrNotFound {
t.Logf("\t\tGot : %+v", err) t.Logf("\t\tGot : %+v", err)
t.Logf("\t\tWant: %+v", ErrNotFound) t.Logf("\t\tWant: %+v", ErrNotFound)
@ -815,7 +782,8 @@ func TestCrud(t *testing.T) {
} }
// Trying to find the archived account with the includeArchived true should result no error. // Trying to find the archived account with the includeArchived true should result no error.
_, err = Read(ctx, tt.claims(account, userId), test.MasterDB, account.ID, true) _, err = Read(ctx, tt.claims(account, userId), test.MasterDB,
AccountReadRequest{ID: account.ID, IncludeArchived: true})
if err != nil { if err != nil {
t.Log("\t\tGot :", err) t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tArchive Read failed.", tests.Failed) t.Fatalf("\t%s\tArchive Read failed.", tests.Failed)
@ -824,14 +792,14 @@ func TestCrud(t *testing.T) {
t.Logf("\t%s\tArchive ok.", tests.Success) t.Logf("\t%s\tArchive ok.", tests.Success)
// Delete (hard-delete) the account. // Delete (hard-delete) the account.
err = Delete(ctx, tt.claims(account, userId), test.MasterDB, account.ID) err = Delete(ctx, tt.claims(account, userId), test.MasterDB, AccountDeleteRequest{ID: account.ID})
if err != nil && errors.Cause(err) != tt.updateErr { if err != nil && errors.Cause(err) != tt.updateErr {
t.Logf("\t\tGot : %+v", err) t.Logf("\t\tGot : %+v", err)
t.Logf("\t\tWant: %+v", tt.updateErr) t.Logf("\t\tWant: %+v", tt.updateErr)
t.Fatalf("\t%s\tUpdate failed.", tests.Failed) t.Fatalf("\t%s\tUpdate failed.", tests.Failed)
} else if tt.updateErr == nil { } else if tt.updateErr == nil {
// Trying to find the deleted account with the includeArchived true should result in not found. // Trying to find the deleted account with the includeArchived true should result in not found.
_, err = Read(ctx, tt.claims(account, userId), test.MasterDB, account.ID, true) _, err = ReadByID(ctx, tt.claims(account, userId), test.MasterDB, account.ID)
if errors.Cause(err) != ErrNotFound { if errors.Cause(err) != ErrNotFound {
t.Logf("\t\tGot : %+v", err) t.Logf("\t\tGot : %+v", err)
t.Logf("\t\tWant: %+v", ErrNotFound) t.Logf("\t\tWant: %+v", ErrNotFound)

View File

@ -4,6 +4,7 @@ import (
"context" "context"
"database/sql" "database/sql"
"database/sql/driver" "database/sql/driver"
"encoding/json"
"geeks-accelerator/oss/saas-starter-kit/internal/platform/web" "geeks-accelerator/oss/saas-starter-kit/internal/platform/web"
"time" "time"
@ -87,6 +88,17 @@ func (m *Account) Response(ctx context.Context) *AccountResponse {
return r return r
} }
func (m *AccountResponse) UnmarshalBinary(data []byte) error {
if data == nil || len(data) == 0 {
return nil
}
return json.Unmarshal(data, m)
}
func (m *AccountResponse) MarshalBinary() ([]byte, error) {
return json.Marshal(m)
}
// AccountCreateRequest contains information needed to create a new Account. // AccountCreateRequest contains information needed to create a new Account.
type AccountCreateRequest struct { type AccountCreateRequest struct {
Name string `json:"name" validate:"required,unique" example:"Company Name"` Name string `json:"name" validate:"required,unique" example:"Company Name"`
@ -102,6 +114,12 @@ type AccountCreateRequest struct {
BillingUserID *string `json:"billing_user_id,omitempty" validate:"omitempty,uuid" swaggertype:"string" example:"d69bdef7-173f-4d29-b52c-3edc60baf6a2"` BillingUserID *string `json:"billing_user_id,omitempty" validate:"omitempty,uuid" swaggertype:"string" example:"d69bdef7-173f-4d29-b52c-3edc60baf6a2"`
} }
// AccountReadRequest defines the information needed to read an account.
type AccountReadRequest struct {
ID string `json:"id" validate:"required,uuid" example:"c4653bf9-5978-48b7-89c5-95704aebb7e2"`
IncludeArchived bool `json:"include-archived" example:"false"`
}
// AccountUpdateRequest defines what information may be provided to modify an existing // AccountUpdateRequest defines what information may be provided to modify an existing
// Account. All fields are optional so clients can send just the fields they want // Account. All fields are optional so clients can send just the fields they want
// changed. It uses pointer fields so we can differentiate between a field that // changed. It uses pointer fields so we can differentiate between a field that
@ -129,15 +147,20 @@ type AccountArchiveRequest struct {
ID string `json:"id" validate:"required,uuid" example:"c4653bf9-5978-48b7-89c5-95704aebb7e2"` ID string `json:"id" validate:"required,uuid" example:"c4653bf9-5978-48b7-89c5-95704aebb7e2"`
} }
// AccountDeleteRequest defines the information needed to delete a user.
type AccountDeleteRequest struct {
ID string `json:"id" validate:"required,uuid" example:"d69bdef7-173f-4d29-b52c-3edc60baf6a2"`
}
// AccountFindRequest defines the possible options to search for accounts. By default // AccountFindRequest defines the possible options to search for accounts. By default
// archived accounts will be excluded from response. // archived accounts will be excluded from response.
type AccountFindRequest struct { type AccountFindRequest struct {
Where *string `json:"where" example:"name = ? and status = ?"` Where *string `json:"where" example:"name = ? and status = ?"`
Args []interface{} `json:"args" swaggertype:"array,string" example:"Company Name,active"` Args []interface{} `json:"args" swaggertype:"array,string" example:"Company Name,active"`
Order []string `json:"order" example:"created_at desc"` Order []string `json:"order" example:"created_at desc"`
Limit *uint `json:"limit" example:"10"` Limit *uint `json:"limit" example:"10"`
Offset *uint `json:"offset" example:"20"` Offset *uint `json:"offset" example:"20"`
IncludedArchived bool `json:"included-archived" example:"false"` IncludeArchived bool `json:"include-archived" example:"false"`
} }
// AccountStatus represents the status of an account. // AccountStatus represents the status of an account.

View File

@ -63,3 +63,27 @@ func FindCountryTimezones(ctx context.Context, dbConn *sqlx.DB, orderBy, where s
return resp, nil return resp, nil
} }
func ListTimezones(ctx context.Context, dbConn *sqlx.DB) ([]string, error) {
res, err := FindCountryTimezones(ctx, dbConn, "timezone_id", "")
if err != nil {
return nil, err
}
resp := []string{}
for _, ct := range res {
var exists bool
for _, t := range resp {
if ct.TimezoneId == t {
exists = true
break
}
}
if !exists {
resp = append(resp, ct.TimezoneId)
}
}
return resp, nil
}

View File

@ -6,6 +6,7 @@ import (
"net/http" "net/http"
"geeks-accelerator/oss/saas-starter-kit/internal/platform/web" "geeks-accelerator/oss/saas-starter-kit/internal/platform/web"
"geeks-accelerator/oss/saas-starter-kit/internal/platform/web/webcontext"
"geeks-accelerator/oss/saas-starter-kit/internal/platform/web/weberror" "geeks-accelerator/oss/saas-starter-kit/internal/platform/web/weberror"
"gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer" "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer"
) )
@ -13,7 +14,7 @@ import (
// Errors handles errors coming out of the call chain. It detects normal // Errors handles errors coming out of the call chain. It detects normal
// application errors which are used to respond to the client in a uniform way. // application errors which are used to respond to the client in a uniform way.
// Unexpected errors (status >= 500) are logged. // Unexpected errors (status >= 500) are logged.
func Errors(log *log.Logger) web.Middleware { func Errors(log *log.Logger, renderer web.Renderer) web.Middleware {
// This is the actual middleware function to be executed. // This is the actual middleware function to be executed.
f := func(before web.Handler) web.Handler { f := func(before web.Handler) web.Handler {
@ -23,26 +24,35 @@ func Errors(log *log.Logger) web.Middleware {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.mid.Errors") span, ctx := tracer.StartSpanFromContext(ctx, "internal.mid.Errors")
defer span.Finish() defer span.Finish()
if err := before(ctx, w, r, params); err != nil { if er := before(ctx, w, r, params); er != nil {
// Log the error. // Log the error.
log.Printf("%d : ERROR : %+v", span.Context().TraceID(), err) log.Printf("%d : ERROR : %+v", span.Context().TraceID(), er)
// Respond to the error. // Respond to the error.
if web.RequestIsJson(r) { if web.RequestIsJson(r) {
if err := web.RespondJsonError(ctx, w, err); err != nil { if err := web.RespondJsonError(ctx, w, er); err != nil {
return err
}
} else if renderer != nil {
v, err := webcontext.ContextValues(ctx)
if err != nil {
return err
}
if err := renderer.Error(ctx, w, r, v.StatusCode, er); err != nil {
return err return err
} }
} else { } else {
if err := web.RespondError(ctx, w, err); err != nil { if err := web.RespondError(ctx, w, er); err != nil {
return err return err
} }
} }
// If we receive the shutdown err we need to return it // If we receive the shutdown err we need to return it
// back to the base handler to shutdown the service. // back to the base handler to shutdown the service.
if ok := weberror.IsShutdown(err); ok { if ok := weberror.IsShutdown(er); ok {
return err return er
} }
} }

View File

@ -137,3 +137,67 @@ func (a *Authenticator) ParseClaims(tknStr string) (Claims, error) {
return claims, nil return claims, nil
} }
// mockTokenGenerator is used for testing that Authenticate calls its provided
// token generator in a specific way.
type MockTokenGenerator struct {
// Private key generated by GenerateToken that is need for ParseClaims
key *rsa.PrivateKey
// algorithm is the method used to generate the private key.
algorithm string
}
// GenerateToken implements the TokenGenerator interface. It returns a "token"
// that includes some information about the claims it was passed.
func (g *MockTokenGenerator) GenerateToken(claims Claims) (string, error) {
privateKey, err := KeyGen()
if err != nil {
return "", err
}
g.key, err = jwt.ParseRSAPrivateKeyFromPEM(privateKey)
if err != nil {
return "", err
}
g.algorithm = "RS256"
method := jwt.GetSigningMethod(g.algorithm)
tkn := jwt.NewWithClaims(method, claims)
tkn.Header["kid"] = "1"
str, err := tkn.SignedString(g.key)
if err != nil {
return "", err
}
return str, nil
}
// ParseClaims recreates the Claims that were used to generate a token. It
// verifies that the token was signed using our key.
func (g *MockTokenGenerator) ParseClaims(tknStr string) (Claims, error) {
parser := jwt.Parser{
ValidMethods: []string{g.algorithm},
}
if g.key == nil {
return Claims{}, errors.New("Private key is empty.")
}
f := func(t *jwt.Token) (interface{}, error) {
return g.key.Public().(*rsa.PublicKey), nil
}
var claims Claims
tkn, err := parser.ParseWithClaims(tknStr, &claims, f)
if err != nil {
return Claims{}, errors.Wrap(err, "parsing token")
}
if !tkn.Valid {
return Claims{}, errors.New("Invalid token")
}
return claims, nil
}

View File

@ -23,20 +23,29 @@ const Key ctxKey = 1
// Claims represents the authorization claims transmitted via a JWT. // Claims represents the authorization claims transmitted via a JWT.
type Claims struct { type Claims struct {
AccountIds []string `json:"accounts"` AccountIds []string `json:"accounts"`
Roles []string `json:"roles"` Roles []string `json:"roles"`
Timezone string `json:"timezone"` Preferences ClaimPreferences `json:"prefs"`
tz *time.Location
jwt.StandardClaims jwt.StandardClaims
} }
// ClaimPreferences defines preferences for the user.
type ClaimPreferences struct {
Timezone string `json:"timezone"`
DatetimeFormat string `json:"pref_datetime_format"`
DateFormat string `json:"pref_date_format"`
TimeFormat string `json:"pref_time_format"`
tz *time.Location
}
// NewClaims constructs a Claims value for the identified user. The Claims // NewClaims constructs a Claims value for the identified user. The Claims
// expire within a specified duration of the provided time. Additional fields // expire within a specified duration of the provided time. Additional fields
// of the Claims can be set after calling NewClaims is desired. // of the Claims can be set after calling NewClaims is desired.
func NewClaims(userId, accountId string, accountIds []string, roles []string, userTimezone *time.Location, now time.Time, expires time.Duration) Claims { func NewClaims(userId, accountId string, accountIds []string, roles []string, prefs ClaimPreferences, now time.Time, expires time.Duration) Claims {
c := Claims{ c := Claims{
AccountIds: accountIds, AccountIds: accountIds,
Roles: roles, Roles: roles,
Preferences: prefs,
StandardClaims: jwt.StandardClaims{ StandardClaims: jwt.StandardClaims{
Subject: userId, Subject: userId,
Audience: accountId, Audience: accountId,
@ -45,11 +54,22 @@ func NewClaims(userId, accountId string, accountIds []string, roles []string, us
}, },
} }
if userTimezone != nil { return c
c.Timezone = userTimezone.String() }
// NewClaimPreferences constructs ClaimPreferences for the user/account.
func NewClaimPreferences(timezone *time.Location, datetimeFormat, dateFormat, timeFormat string) ClaimPreferences {
p := ClaimPreferences{
DatetimeFormat: datetimeFormat,
DateFormat: dateFormat,
TimeFormat: timeFormat,
} }
return c if timezone != nil {
p.Timezone = timezone.String()
}
return p
} }
// Valid is called during the parsing of a token. // Valid is called during the parsing of a token.
@ -88,13 +108,18 @@ func (c Claims) HasRole(roles ...string) bool {
} }
// TimeLocation returns the timezone used to format datetimes for the user. // TimeLocation returns the timezone used to format datetimes for the user.
func (c Claims) TimeLocation() *time.Location { func (c ClaimPreferences) TimeLocation() *time.Location {
if c.tz == nil && c.Timezone != "" { if c.tz == nil && c.Timezone != "" {
c.tz, _ = time.LoadLocation(c.Timezone) c.tz, _ = time.LoadLocation(c.Timezone)
} }
return c.tz return c.tz
} }
// TimeLocation returns the timezone used to format datetimes for the user.
func (c Claims) TimeLocation() *time.Location {
return c.Preferences.TimeLocation()
}
// ClaimsFromContext loads the claims from context. // ClaimsFromContext loads the claims from context.
func ClaimsFromContext(ctx context.Context) (Claims, error) { func ClaimsFromContext(ctx context.Context) (Claims, error) {
claims, ok := ctx.Value(Key).(Claims) claims, ok := ctx.Value(Key).(Claims)

View File

@ -1,16 +0,0 @@
package session
import (
"geeks-accelerator/oss/saas-starter-kit/internal/platform/auth"
)
// ctxKey represents the type of value for the context key.
type ctxKey int
// Key is used to store/retrieve a Claims value from a context.Context.
const Key ctxKey = 1
// Session represents a user with authentication.
type Session struct {
Claims auth.Claims `json:"claims"`
}

View File

@ -130,6 +130,7 @@ func Context() context.Context {
TraceID: uint64(time.Now().UnixNano()), TraceID: uint64(time.Now().UnixNano()),
Now: time.Now(), Now: time.Now(),
RequestIP: "68.69.35.104", RequestIP: "68.69.35.104",
Env: "dev",
} }
return context.WithValue(context.Background(), webcontext.KeyValues, &values) return context.WithValue(context.Background(), webcontext.KeyValues, &values)

View File

@ -12,6 +12,7 @@ import (
const DatetimeFormatLocal = "Mon Jan _2 3:04PM" const DatetimeFormatLocal = "Mon Jan _2 3:04PM"
const DateFormatLocal = "Mon Jan _2" const DateFormatLocal = "Mon Jan _2"
const TimeFormatLocal = time.Kitchen
// TimeResponse is a response friendly format for displaying the value of a time. // TimeResponse is a response friendly format for displaying the value of a time.
type TimeResponse struct { type TimeResponse struct {
@ -23,6 +24,7 @@ type TimeResponse struct {
RFC1123 string `json:"rfc1123" example:"Tue, 25 Jun 2019 03:00:53 AKDT"` RFC1123 string `json:"rfc1123" example:"Tue, 25 Jun 2019 03:00:53 AKDT"`
Local string `json:"local" example:"Tue Jun 25 3:00AM"` Local string `json:"local" example:"Tue Jun 25 3:00AM"`
LocalDate string `json:"local_date" example:"Tue Jun 25"` LocalDate string `json:"local_date" example:"Tue Jun 25"`
LocalTime string `json:"local_time" example:"3:00AM"`
NowTime string `json:"now_time" example:"5 hours ago"` NowTime string `json:"now_time" example:"5 hours ago"`
NowRelTime string `json:"now_rel_time" example:"15 hours from now"` NowRelTime string `json:"now_rel_time" example:"15 hours from now"`
Timezone string `json:"timezone" example:"America/Anchorage"` Timezone string `json:"timezone" example:"America/Anchorage"`
@ -39,6 +41,21 @@ func NewTimeResponse(ctx context.Context, t time.Time) TimeResponse {
t = t.In(claims.TimeLocation()) t = t.In(claims.TimeLocation())
} }
var formatDatetime = DatetimeFormatLocal
if claims.Preferences.DatetimeFormat != "" {
formatDatetime = claims.Preferences.DatetimeFormat
}
var formatDate = DatetimeFormatLocal
if claims.Preferences.DateFormat != "" {
formatDate = claims.Preferences.DateFormat
}
var formatTime = DatetimeFormatLocal
if claims.Preferences.DatetimeFormat != "" {
formatTime = claims.Preferences.TimeFormat
}
tr := TimeResponse{ tr := TimeResponse{
Value: t, Value: t,
ValueUTC: t.UTC(), ValueUTC: t.UTC(),
@ -46,8 +63,9 @@ func NewTimeResponse(ctx context.Context, t time.Time) TimeResponse {
Time: t.Format("15:04:05"), Time: t.Format("15:04:05"),
Kitchen: t.Format(time.Kitchen), Kitchen: t.Format(time.Kitchen),
RFC1123: t.Format(time.RFC1123), RFC1123: t.Format(time.RFC1123),
Local: t.Format(DatetimeFormatLocal), Local: t.Format(formatDatetime),
LocalDate: t.Format(DateFormatLocal), LocalDate: t.Format(formatDate),
LocalTime: t.Format(formatTime),
NowTime: humanize.Time(t.UTC()), NowTime: humanize.Time(t.UTC()),
NowRelTime: humanize.RelTime(time.Now().UTC(), t.UTC(), "ago", "from now"), NowRelTime: humanize.RelTime(time.Now().UTC(), t.UTC(), "ago", "from now"),
} }
@ -100,15 +118,15 @@ func EnumValueTitle(v string) string {
} }
type GravatarResponse struct { type GravatarResponse struct {
Small string `json:"small" example:"https://www.gravatar.com/avatar/xy7.jpg?s=30"` Small string `json:"small" example:"https://www.gravatar.com/avatar/xy7.jpg?s=30"`
Medium string `json:"medium" example:"https://www.gravatar.com/avatar/xy7.jpg?s=80"` Medium string `json:"medium" example:"https://www.gravatar.com/avatar/xy7.jpg?s=80"`
} }
func NewGravatarResponse(ctx context.Context, email string) GravatarResponse { func NewGravatarResponse(ctx context.Context, email string) GravatarResponse {
u := fmt.Sprintf("https://www.gravatar.com/avatar/%x.jpg?s=", md5.Sum([]byte(strings.ToLower(email)))) u := fmt.Sprintf("https://www.gravatar.com/avatar/%x.jpg?s=", md5.Sum([]byte(strings.ToLower(email))))
return GravatarResponse{ return GravatarResponse{
Small: u+"30", Small: u + "30",
Medium: u+"80", Medium: u + "80",
} }
} }

View File

@ -179,19 +179,28 @@ func RenderError(ctx context.Context, w http.ResponseWriter, r *http.Request, er
return err return err
} }
// If the error was of the type *Error, the handler has webErr, ok := er.(*weberror.Error)
// a specific status code and error to return. if !ok {
webErr := weberror.NewError(ctx, er, v.StatusCode).(*weberror.Error).Response(ctx, true) if v.StatusCode == 0 {
v.StatusCode = webErr.StatusCode v.StatusCode = http.StatusInternalServerError
}
// If the error was of the type *Error, the handler has
// a specific status code and error to return.
webErr = weberror.NewError(ctx, er, v.StatusCode).(*weberror.Error)
}
v.StatusCode = webErr.Status
resp := webErr.Response(ctx, true)
data := map[string]interface{}{ data := map[string]interface{}{
"StatusCode": webErr.StatusCode, "StatusCode": resp.StatusCode,
"Error": webErr.Error, "Error": resp.Error,
"Details": webErr.Details, "Details": resp.Details,
"Fields": webErr.Fields, "Fields": resp.Fields,
} }
return renderer.Render(ctx, w, r, templateLayoutName, templateContentName, contentType, webErr.StatusCode, data) return renderer.Render(ctx, w, r, templateLayoutName, templateContentName, contentType, webErr.Status, data)
} }
// Static registers a new route with path prefix to serve static files from the // Static registers a new route with path prefix to serve static files from the

View File

@ -2,6 +2,7 @@ package template_renderer
import ( import (
"context" "context"
"encoding/json"
"fmt" "fmt"
"html/template" "html/template"
"math" "math"
@ -123,6 +124,18 @@ func NewTemplate(templateFuncs template.FuncMap) *Template {
} }
return claims.HasRole(roles...) return claims.HasRole(roles...)
}, },
"CmpString": func(str1 string, str2Ptr *string) bool {
var str2 string
if str2Ptr != nil {
str2 = *str2Ptr
}
if str1 == str2 {
return true
}
return false
},
"dict": func(values ...interface{}) (map[string]interface{}, error) { "dict": func(values ...interface{}) (map[string]interface{}, error) {
if len(values) == 0 { if len(values) == 0 {
return nil, errors.New("invalid dict call") return nil, errors.New("invalid dict call")
@ -307,7 +320,7 @@ func (r *TemplateRenderer) Render(ctx context.Context, w http.ResponseWriter, re
// Specific new data map for render to allow values to be overwritten on a request // Specific new data map for render to allow values to be overwritten on a request
// basis. // basis.
// append the global key/pairs // append the global key/pairs
renderData := make(map[string]interface{}, len(r.globalViewData)) renderData := make(map[string]interface{}, len(r.globalViewData))
for k, v := range r.globalViewData { for k, v := range r.globalViewData {
renderData[k] = v renderData[k] = v
} }
@ -356,7 +369,20 @@ func (r *TemplateRenderer) Render(ctx context.Context, w http.ResponseWriter, re
sess := webcontext.ContextSession(ctx) sess := webcontext.ContextSession(ctx)
if sess != nil { if sess != nil {
// Load any flash messages and append to response data to be included in the rendered template. // Load any flash messages and append to response data to be included in the rendered template.
if flashes := sess.Flashes(); len(flashes) > 0 { if msgs := sess.Flashes(); len(msgs) > 0 {
var flashes []webcontext.FlashMsgResponse
for _, mv := range msgs {
dat, ok := mv.([]byte)
if !ok {
continue
}
var msg webcontext.FlashMsgResponse
if err := json.Unmarshal(dat, &msg); err != nil {
continue
}
flashes = append(flashes, msg)
}
renderData["flashes"] = flashes renderData["flashes"] = flashes
} }

View File

@ -76,6 +76,7 @@ func (a *App) Handle(verb, path string, handler Handler, mw ...Middleware) {
// Call the wrapped handler functions. // Call the wrapped handler functions.
err := handler(ctx, w, r, params) err := handler(ctx, w, r, params)
if err != nil { if err != nil {
// If we have specifically handled the error, then no need // If we have specifically handled the error, then no need
// to initiate a shutdown. // to initiate a shutdown.
if webErr, ok := err.(*weberror.Error); ok { if webErr, ok := err.(*weberror.Error); ok {

View File

@ -3,6 +3,7 @@ package webcontext
import ( import (
"context" "context"
"encoding/gob" "encoding/gob"
"encoding/json"
"html/template" "html/template"
) )
@ -23,18 +24,26 @@ type FlashMsg struct {
Details string `json:"details"` Details string `json:"details"`
} }
func (r FlashMsg) Response(ctx context.Context) map[string]interface{} { type FlashMsgResponse struct {
Type FlashType `json:"type"`
Title template.HTML `json:"title"`
Text template.HTML `json:"text"`
Items []template.HTML `json:"items"`
Details template.HTML `json:"details"`
}
func (r FlashMsg) Response(ctx context.Context) FlashMsgResponse {
var items []template.HTML var items []template.HTML
for _, i := range r.Items { for _, i := range r.Items {
items = append(items, template.HTML(i)) items = append(items, template.HTML(i))
} }
return map[string]interface{}{ return FlashMsgResponse{
"Type": r.Type, Type: r.Type,
"Title": r.Title, Title: template.HTML(r.Title),
"Text": template.HTML(r.Text), Text: template.HTML(r.Text),
"Items": items, Items: items,
"Details": template.HTML(r.Details), Details: template.HTML(r.Details),
} }
} }
@ -46,7 +55,8 @@ func init() {
// adds the message to the session. The renderer should save the session before writing the response // adds the message to the session. The renderer should save the session before writing the response
// to the client or save be directly invoked. // to the client or save be directly invoked.
func SessionAddFlash(ctx context.Context, msg FlashMsg) { func SessionAddFlash(ctx context.Context, msg FlashMsg) {
ContextSession(ctx).AddFlash(msg.Response(ctx)) dat, _ := json.Marshal(msg.Response(ctx))
ContextSession(ctx).AddFlash(dat)
} }
// SessionFlashSuccess add a message with type Success. // SessionFlashSuccess add a message with type Success.

View File

@ -2,7 +2,6 @@ package webcontext
import ( import (
"context" "context"
"github.com/gorilla/sessions" "github.com/gorilla/sessions"
) )
@ -12,14 +11,23 @@ type ctxKeySession int
// KeySession is used to store/retrieve a Session from a context.Context. // KeySession is used to store/retrieve a Session from a context.Context.
const KeySession ctxKeySession = 1 const KeySession ctxKeySession = 1
// KeyAccessToken is used to store the access token for the user in their session. // Session keys used to store values.
const KeyAccessToken = "AccessToken" const (
SessionKeyAccessToken = iota
//SessionKeyPreferenceDatetimeFormat
//SessionKeyPreferenceDateFormat
//SessionKeyPreferenceTimeFormat
//SessionKeyTimezone
)
// KeyUser is used to store the user in the session. func init() {
const KeyUser = "User" //gob.Register(&Session{})
}
// KeyAccount is used to store the account in the session. // Session represents a user with authentication.
const KeyAccount = "Account" type Session struct {
*sessions.Session
}
// ContextWithSession appends a universal translator to a context. // ContextWithSession appends a universal translator to a context.
func ContextWithSession(ctx context.Context, session *sessions.Session) context.Context { func ContextWithSession(ctx context.Context, session *sessions.Session) context.Context {
@ -27,69 +35,83 @@ func ContextWithSession(ctx context.Context, session *sessions.Session) context.
} }
// ContextSession returns the session from a context. // ContextSession returns the session from a context.
func ContextSession(ctx context.Context) *sessions.Session { func ContextSession(ctx context.Context) *Session {
return ctx.Value(KeySession).(*sessions.Session) if s, ok := ctx.Value(KeySession).(*Session); ok {
return s
}
return nil
} }
func ContextAccessToken(ctx context.Context) (string, bool) { func ContextAccessToken(ctx context.Context) (string, bool) {
session := ContextSession(ctx) return ContextSession(ctx).AccessToken()
return SessionAccessToken(session)
} }
func SessionAccessToken(session *sessions.Session) (string, bool) { func (sess *Session) AccessToken() (string, bool) {
if sv, ok := session.Values[KeyAccessToken].(string); ok { if sess == nil {
return "", false
}
if sv, ok := sess.Values[SessionKeyAccessToken].(string); ok {
return sv, true return sv, true
} }
return "", false return "", false
} }
func SessionUser(session *sessions.Session) ( interface{}, bool) { /*
if sv, ok := session.Values[KeyUser]; ok && sv != nil { func(sess *Session) PreferenceDatetimeFormat() (string, bool) {
if sess == nil {
return "", false
}
if sv, ok := sess.Values[SessionKeyPreferenceDatetimeFormat].(string); ok {
return sv, true return sv, true
} }
return "", false
}
func(sess *Session) PreferenceDateFormat() (string, bool) {
if sess == nil {
return "", false
}
if sv, ok := sess.Values[SessionKeyPreferenceDateFormat].(string); ok {
return sv, true
}
return "", false
}
func(sess *Session) PreferenceTimeFormat() (string, bool) {
if sess == nil {
return "", false
}
if sv, ok := sess.Values[SessionKeyPreferenceTimeFormat].(string); ok {
return sv, true
}
return "", false
}
func(sess *Session) Timezone() (*time.Location, bool) {
if sess != nil {
if sv, ok := sess.Values[SessionKeyTimezone].(*time.Location); ok {
return sv, true
}
}
return nil, false return nil, false
} }
*/
func SessionAccount(session *sessions.Session) (interface{}, bool) { func SessionInit(session *Session, accessToken string) *Session {
if sv, ok := session.Values[KeyAccount]; ok && sv != nil {
return sv, true
}
return nil, false session.Values[SessionKeyAccessToken] = accessToken
} //session.Values[SessionKeyPreferenceDatetimeFormat] = datetimeFormat
//session.Values[SessionKeyPreferenceDateFormat] = dateFormat
func SessionInit(session *sessions.Session, accessToken string, usr interface{}, acc interface{}) *sessions.Session { //session.Values[SessionKeyPreferenceTimeFormat] = timeFormat
//session.Values[SessionKeyTimezone] = timezone
if accessToken != "" {
session.Values[KeyAccessToken] = accessToken
} else {
delete(session.Values, KeyAccessToken)
}
if usr != nil {
session.Values[KeyUser] = usr
} else {
delete(session.Values, KeyUser)
}
if acc != nil {
session.Values[KeyAccount] = acc
} else {
delete(session.Values, KeyAccount)
}
return session return session
} }
func SessionDestroy(session *sessions.Session) *sessions.Session { func SessionDestroy(session *Session) *Session {
delete(session.Values, KeyAccessToken) delete(session.Values, SessionKeyAccessToken)
delete(session.Values, KeyUser)
delete(session.Values, KeyAccount)
return session return session
} }

View File

@ -83,6 +83,8 @@ func (err *Error) Error() string {
func (er *Error) Response(ctx context.Context, htmlEntities bool) ErrorResponse { func (er *Error) Response(ctx context.Context, htmlEntities bool) ErrorResponse {
var r ErrorResponse var r ErrorResponse
r.StatusCode = er.Status
if er.Message != "" { if er.Message != "" {
r.Error = er.Message r.Error = er.Message
} else { } else {

View File

@ -63,6 +63,12 @@ type ProjectCreateRequest struct {
Status *ProjectStatus `json:"status,omitempty" validate:"omitempty,oneof=active disabled" enums:"active,disabled" swaggertype:"string" example:"active"` Status *ProjectStatus `json:"status,omitempty" validate:"omitempty,oneof=active disabled" enums:"active,disabled" swaggertype:"string" example:"active"`
} }
// ProjectReadRequest defines the information needed to read a project.
type ProjectReadRequest struct {
ID string `json:"id" validate:"required,uuid" example:"985f1746-1d9f-459f-a2d9-fc53ece5ae86"`
IncludeArchived bool `json:"include-archived" example:"false"`
}
// ProjectUpdateRequest defines what information may be provided to modify an existing // ProjectUpdateRequest defines what information may be provided to modify an existing
// Project. All fields are optional so clients can send just the fields they want // Project. All fields are optional so clients can send just the fields they want
// changed. It uses pointer fields so we can differentiate between a field that // changed. It uses pointer fields so we can differentiate between a field that
@ -79,15 +85,20 @@ type ProjectArchiveRequest struct {
ID string `json:"id" validate:"required,uuid" example:"985f1746-1d9f-459f-a2d9-fc53ece5ae86"` ID string `json:"id" validate:"required,uuid" example:"985f1746-1d9f-459f-a2d9-fc53ece5ae86"`
} }
// ProjectDeleteRequest defines the information needed to delete a project.
type ProjectDeleteRequest struct {
ID string `json:"id" validate:"required,uuid" example:"985f1746-1d9f-459f-a2d9-fc53ece5ae86"`
}
// ProjectFindRequest defines the possible options to search for projects. By default // ProjectFindRequest defines the possible options to search for projects. By default
// archived project will be excluded from response. // archived project will be excluded from response.
type ProjectFindRequest struct { type ProjectFindRequest struct {
Where *string `json:"where" example:"name = ? and status = ?"` Where *string `json:"where" example:"name = ? and status = ?"`
Args []interface{} `json:"args" swaggertype:"array,string" example:"Moon Launch,active"` Args []interface{} `json:"args" swaggertype:"array,string" example:"Moon Launch,active"`
Order []string `json:"order" example:"created_at desc"` Order []string `json:"order" example:"created_at desc"`
Limit *uint `json:"limit" example:"10"` Limit *uint `json:"limit" example:"10"`
Offset *uint `json:"offset" example:"20"` Offset *uint `json:"offset" example:"20"`
IncludedArchived bool `json:"included-archived" example:"false"` IncludeArchived bool `json:"include-archived" example:"false"`
} }
// ProjectStatus represents the status of project. // ProjectStatus represents the status of project.

View File

@ -26,25 +26,6 @@ var (
ErrForbidden = errors.New("Attempted action is not allowed") ErrForbidden = errors.New("Attempted action is not allowed")
) )
// projectMapColumns is the list of columns needed for mapRowsToProject
var projectMapColumns = "id,account_id,name,status,created_at,updated_at,archived_at"
// mapRowsToProject takes the SQL rows and maps it to the Project struct
// with the columns defined by projectMapColumns
func mapRowsToProject(rows *sql.Rows) (*Project, error) {
var (
m Project
err error
)
err = rows.Scan(&m.ID, &m.AccountID, &m.Name, &m.Status, &m.CreatedAt, &m.UpdatedAt, &m.ArchivedAt)
if err != nil {
return nil, errors.WithStack(err)
}
return &m, nil
}
// CanReadProject determines if claims has the authority to access the specified project by id. // CanReadProject determines if claims has the authority to access the specified project by id.
func CanReadProject(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, id string) error { func CanReadProject(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, id string) error {
@ -106,7 +87,10 @@ func applyClaimsSelect(ctx context.Context, claims auth.Claims, query *sqlbuilde
return nil return nil
} }
// selectQuery constructs a base select query for Project // projectMapColumns is the list of columns needed for find.
var projectMapColumns = "id,account_id,name,status,created_at,updated_at,archived_at"
// selectQuery constructs a base select query for Project.
func selectQuery() *sqlbuilder.SelectBuilder { func selectQuery() *sqlbuilder.SelectBuilder {
query := sqlbuilder.NewSelectBuilder() query := sqlbuilder.NewSelectBuilder()
query.Select(projectMapColumns) query.Select(projectMapColumns)
@ -119,6 +103,7 @@ func selectQuery() *sqlbuilder.SelectBuilder {
// to the query. // to the query.
func findRequestQuery(req ProjectFindRequest) (*sqlbuilder.SelectBuilder, []interface{}) { func findRequestQuery(req ProjectFindRequest) (*sqlbuilder.SelectBuilder, []interface{}) {
query := selectQuery() query := selectQuery()
if req.Where != nil { if req.Where != nil {
query.Where(query.And(*req.Where)) query.Where(query.And(*req.Where))
} }
@ -141,13 +126,14 @@ func findRequestQuery(req ProjectFindRequest) (*sqlbuilder.SelectBuilder, []inte
// Find gets all the projects from the database based on the request params. // Find gets all the projects from the database based on the request params.
func Find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req ProjectFindRequest) ([]*Project, error) { func Find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req ProjectFindRequest) ([]*Project, error) {
query, args := findRequestQuery(req) query, args := findRequestQuery(req)
return find(ctx, claims, dbConn, query, args, req.IncludedArchived) return find(ctx, claims, dbConn, query, args, req.IncludeArchived)
} }
// find internal method for getting all the projects from the database using a select query. // find internal method for getting all the projects from the database using a select query.
func find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, query *sqlbuilder.SelectBuilder, args []interface{}, includedArchived bool) ([]*Project, error) { func find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, query *sqlbuilder.SelectBuilder, args []interface{}, includedArchived bool) ([]*Project, error) {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.project.Find") span, ctx := tracer.StartSpanFromContext(ctx, "internal.project.Find")
defer span.Finish() defer span.Finish()
query.Select(projectMapColumns) query.Select(projectMapColumns)
query.From(projectTableName) query.From(projectTableName)
if !includedArchived { if !includedArchived {
@ -174,32 +160,51 @@ func find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, query *sqlbu
// Iterate over each row. // Iterate over each row.
resp := []*Project{} resp := []*Project{}
for rows.Next() { for rows.Next() {
u, err := mapRowsToProject(rows) var (
m Project
err error
)
err = rows.Scan(&m.ID, &m.AccountID, &m.Name, &m.Status, &m.CreatedAt, &m.UpdatedAt, &m.ArchivedAt)
if err != nil { if err != nil {
err = errors.Wrapf(err, "query - %s", query.String()) err = errors.Wrapf(err, "query - %s", query.String())
return nil, err return nil, err
} }
resp = append(resp, u) resp = append(resp, &m)
} }
return resp, nil return resp, nil
} }
// ReadByID gets the specified project by ID from the database.
func ReadByID(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, id string) (*Project, error) {
return Read(ctx, claims, dbConn, ProjectReadRequest{
ID: id,
IncludeArchived: false,
})
}
// Read gets the specified project from the database. // Read gets the specified project from the database.
func Read(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, id string, includedArchived bool) (*Project, error) { func Read(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req ProjectReadRequest) (*Project, error) {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.project.Read") span, ctx := tracer.StartSpanFromContext(ctx, "internal.project.Read")
defer span.Finish() defer span.Finish()
// Filter base select query by id // Validate the request.
query := selectQuery() v := webcontext.Validator()
query.Where(query.Equal("id", id)) err := v.Struct(req)
if err != nil {
res, err := find(ctx, claims, dbConn, query, []interface{}{}, includedArchived)
if res == nil || len(res) == 0 {
err = errors.WithMessagef(ErrNotFound, "project %s not found", id)
return nil, err return nil, err
} else if err != nil { }
// Filter base select query by id
query := sqlbuilder.NewSelectBuilder()
query.Where(query.Equal("id", req.ID))
res, err := find(ctx, claims, dbConn, query, []interface{}{}, req.IncludeArchived)
if err != nil {
return nil, err
} else if res == nil || len(res) == 0 {
err = errors.WithMessagef(ErrNotFound, "project %s not found", req.ID)
return nil, err return nil, err
} }
@ -358,14 +363,6 @@ func Update(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req Projec
return nil return nil
} }
// Archive soft deleted the project by ID from the database.
func ArchiveById(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, id string, now time.Time) error {
req := ProjectArchiveRequest{
ID: id,
}
return Archive(ctx, claims, dbConn, req, now)
}
// Archive soft deleted the project from the database. // Archive soft deleted the project from the database.
func Archive(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req ProjectArchiveRequest, now time.Time) error { func Archive(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req ProjectArchiveRequest, now time.Time) error {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.project.Archive") span, ctx := tracer.StartSpanFromContext(ctx, "internal.project.Archive")
@ -416,17 +413,10 @@ func Archive(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req Proje
} }
// Delete removes an project from the database. // Delete removes an project from the database.
func Delete(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, id string) error { func Delete(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req ProjectDeleteRequest) error {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.project.Delete") span, ctx := tracer.StartSpanFromContext(ctx, "internal.project.Delete")
defer span.Finish() defer span.Finish()
// Defines the struct to apply validation
req := struct {
ID string `json:"id" validate:"required,uuid"`
}{
ID: id,
}
// Validate the request. // Validate the request.
v := webcontext.Validator() v := webcontext.Validator()
err := v.Struct(req) err := v.Struct(req)

View File

@ -561,5 +561,29 @@ func migrationList(db *sqlx.DB, log *log.Logger, isUnittest bool) []*sqlxmigrate
return nil return nil
}, },
}, },
// Create new table account_preferences.
{
ID: "20190801-01",
Migrate: func(tx *sql.Tx) error {
q := `CREATE TABLE IF NOT EXISTS account_preferences (
account_id char(36) NOT NULL REFERENCES accounts(id) ON DELETE NO ACTION,
name varchar(200) NOT NULL DEFAULT '',
value varchar(200) NOT NULL DEFAULT '',
created_at TIMESTAMP WITH TIME ZONE NOT NULL,
updated_at TIMESTAMP WITH TIME ZONE DEFAULT NULL,
archived_at TIMESTAMP WITH TIME ZONE DEFAULT NULL,
CONSTRAINT account_preferences_pkey UNIQUE (account_id,name)
)`
if _, err := tx.Exec(q); err != nil {
return errors.WithMessagef(err, "Query failed %s", q)
}
return nil
},
Rollback: func(tx *sql.Tx) error {
return nil
},
},
} }
} }

View File

@ -1,148 +0,0 @@
package user
import (
"encoding/json"
"testing"
"time"
"geeks-accelerator/oss/saas-starter-kit/internal/platform/auth"
"geeks-accelerator/oss/saas-starter-kit/internal/platform/tests"
"github.com/google/go-cmp/cmp"
"github.com/pborman/uuid"
"github.com/pkg/errors"
)
// TestAuthenticate validates the behavior around authenticating users.
func TestAuthenticate(t *testing.T) {
defer tests.Recover(t)
t.Log("Given the need to authenticate users")
{
t.Log("\tWhen handling a single User.")
{
ctx := tests.Context()
tknGen := &MockTokenGenerator{}
// Auth tokens are valid for an our and is verified against current time.
// Issue the token one hour ago.
now := time.Now().Add(time.Hour * -1)
// Try to authenticate an invalid user.
_, err := Authenticate(ctx, test.MasterDB, tknGen, "doesnotexist@gmail.com", "xy7", time.Hour, now)
if errors.Cause(err) != ErrAuthenticationFailure {
t.Logf("\t\tGot : %+v", err)
t.Logf("\t\tWant: %+v", ErrAuthenticationFailure)
t.Fatalf("\t%s\tAuthenticate non existant user failed.", tests.Failed)
}
t.Logf("\t%s\tAuthenticate non existant user ok.", tests.Success)
// Create a new user for testing.
initPass := uuid.NewRandom().String()
user, err := Create(ctx, auth.Claims{}, test.MasterDB, UserCreateRequest{
FirstName: "Lee",
LastName: "Brown",
Email: uuid.NewRandom().String() + "@geeksinthewoods.com",
Password: initPass,
PasswordConfirm: initPass,
}, now)
if err != nil {
t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tCreate user failed.", tests.Failed)
}
t.Logf("\t%s\tCreate user ok.", tests.Success)
// Create a new random account.
account1Id := uuid.NewRandom().String()
err = mockAccount(account1Id, user.CreatedAt)
if err != nil {
t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tCreate account failed.", tests.Failed)
}
// Associate new account with user user. This defined role should be the claims.
account1Role := auth.RoleAdmin
err = mockUserAccount(user.ID, account1Id, user.CreatedAt, account1Role)
if err != nil {
t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tCreate user account failed.", tests.Failed)
}
// Create a second new random account. Need to ensure
account2Id := uuid.NewRandom().String()
err = mockAccount(account2Id, user.CreatedAt)
if err != nil {
t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tCreate account failed.", tests.Failed)
}
// Associate second new account with user user. Need to ensure that now
// is always greater than the first user_account entry created so it will
// be returned consistently back in the same order, last.
account2Role := auth.RoleUser
err = mockUserAccount(user.ID, account2Id, user.CreatedAt.Add(time.Second), account2Role)
if err != nil {
t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tCreate user account failed.", tests.Failed)
}
// Add 30 minutes to now to simulate time passing.
now = now.Add(time.Minute * 30)
// Try to authenticate valid user with invalid password.
_, err = Authenticate(ctx, test.MasterDB, tknGen, user.Email, "xy7", time.Hour, now)
if errors.Cause(err) != ErrAuthenticationFailure {
t.Logf("\t\tGot : %+v", err)
t.Logf("\t\tWant: %+v", ErrAuthenticationFailure)
t.Fatalf("\t%s\tAuthenticate user w/invalid password failed.", tests.Failed)
}
t.Logf("\t%s\tAuthenticate user w/invalid password ok.", tests.Success)
// Verify that the user can be authenticated with the created user.
tkn1, err := Authenticate(ctx, test.MasterDB, tknGen, user.Email, initPass, time.Hour, now)
if err != nil {
t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tAuthenticate user failed.", tests.Failed)
}
t.Logf("\t%s\tAuthenticate user ok.", tests.Success)
// Ensure the token string was correctly generated.
claims1, err := tknGen.ParseClaims(tkn1.AccessToken)
if err != nil {
t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tParse claims from token failed.", tests.Failed)
}
// Hack for Unhandled Exception in go-cmp@v0.3.0/cmp/options.go:229
resClaims, _ := json.Marshal(claims1)
expectClaims, _ := json.Marshal(tkn1.claims)
if diff := cmp.Diff(string(resClaims), string(expectClaims)); diff != "" {
t.Fatalf("\t%s\tExpected parsed claims to match from token. Diff:\n%s", tests.Failed, diff)
}
t.Logf("\t%s\tAuthenticate parse claims from token ok.", tests.Success)
// Try switching to a second account using the first set of claims.
tkn2, err := SwitchAccount(ctx, test.MasterDB, tknGen, claims1, account2Id, time.Hour, now)
if err != nil {
t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tSwitchAccount user failed.", tests.Failed)
}
t.Logf("\t%s\tSwitchAccount user ok.", tests.Success)
// Ensure the token string was correctly generated.
claims2, err := tknGen.ParseClaims(tkn2.AccessToken)
if err != nil {
t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tParse claims from token failed.", tests.Failed)
}
// Hack for Unhandled Exception in go-cmp@v0.3.0/cmp/options.go:229
resClaims, _ = json.Marshal(claims2)
expectClaims, _ = json.Marshal(tkn2.claims)
if diff := cmp.Diff(string(resClaims), string(expectClaims)); diff != "" {
t.Fatalf("\t%s\tExpected parsed claims to match from token. Diff:\n%s", tests.Failed, diff)
}
t.Logf("\t%s\tSwitchAccount parse claims from token ok.", tests.Success)
}
}
}

View File

@ -3,7 +3,7 @@ package user
import ( import (
"context" "context"
"database/sql" "database/sql"
"geeks-accelerator/oss/saas-starter-kit/internal/platform/auth" "encoding/json"
"geeks-accelerator/oss/saas-starter-kit/internal/platform/web" "geeks-accelerator/oss/saas-starter-kit/internal/platform/web"
"time" "time"
@ -27,14 +27,15 @@ type User struct {
// UserResponse represents someone with access to our system that is returned for display. // UserResponse represents someone with access to our system that is returned for display.
type UserResponse struct { type UserResponse struct {
ID string `json:"id" example:"d69bdef7-173f-4d29-b52c-3edc60baf6a2"` ID string `json:"id" example:"d69bdef7-173f-4d29-b52c-3edc60baf6a2"`
FirstName string `json:"first_name" example:"Gabi"` Name string `json:"name" example:"Gabi"`
LastName string `json:"last_name" example:"May"` FirstName string `json:"first_name" example:"Gabi"`
Email string `json:"email" example:"gabi@geeksinthewoods.com"` LastName string `json:"last_name" example:"May"`
Timezone string `json:"timezone" example:"America/Anchorage"` Email string `json:"email" example:"gabi@geeksinthewoods.com"`
CreatedAt web.TimeResponse `json:"created_at"` // CreatedAt contains multiple format options for display. Timezone string `json:"timezone" example:"America/Anchorage"`
UpdatedAt web.TimeResponse `json:"updated_at"` // UpdatedAt contains multiple format options for display. CreatedAt web.TimeResponse `json:"created_at"` // CreatedAt contains multiple format options for display.
ArchivedAt *web.TimeResponse `json:"archived_at,omitempty"` // ArchivedAt contains multiple format options for display. UpdatedAt web.TimeResponse `json:"updated_at"` // UpdatedAt contains multiple format options for display.
ArchivedAt *web.TimeResponse `json:"archived_at,omitempty"` // ArchivedAt contains multiple format options for display.
Gravatar web.GravatarResponse `json:"gravatar"` Gravatar web.GravatarResponse `json:"gravatar"`
} }
@ -47,13 +48,14 @@ func (m *User) Response(ctx context.Context) *UserResponse {
r := &UserResponse{ r := &UserResponse{
ID: m.ID, ID: m.ID,
Name: m.FirstName + " " + m.LastName,
FirstName: m.FirstName, FirstName: m.FirstName,
LastName: m.LastName, LastName: m.LastName,
Email: m.Email, Email: m.Email,
Timezone: m.Timezone, Timezone: m.Timezone,
CreatedAt: web.NewTimeResponse(ctx, m.CreatedAt), CreatedAt: web.NewTimeResponse(ctx, m.CreatedAt),
UpdatedAt: web.NewTimeResponse(ctx, m.UpdatedAt), UpdatedAt: web.NewTimeResponse(ctx, m.UpdatedAt),
Gravatar: web.NewGravatarResponse(ctx, m.Email), Gravatar: web.NewGravatarResponse(ctx, m.Email),
} }
if m.ArchivedAt != nil && !m.ArchivedAt.Time.IsZero() { if m.ArchivedAt != nil && !m.ArchivedAt.Time.IsZero() {
@ -64,6 +66,18 @@ func (m *User) Response(ctx context.Context) *UserResponse {
return r return r
} }
func (m *UserResponse) UnmarshalBinary(data []byte) error {
if data == nil || len(data) == 0 {
return nil
}
// convert data to yours, let's assume its json data
return json.Unmarshal(data, m)
}
func (m *UserResponse) MarshalBinary() ([]byte, error) {
return json.Marshal(m)
}
// UserCreateRequest contains information needed to create a new User. // UserCreateRequest contains information needed to create a new User.
type UserCreateRequest struct { type UserCreateRequest struct {
FirstName string `json:"first_name" validate:"required" example:"Gabi"` FirstName string `json:"first_name" validate:"required" example:"Gabi"`
@ -79,6 +93,12 @@ type UserCreateInviteRequest struct {
Email string `json:"email" validate:"required,email,unique" example:"gabi@geeksinthewoods.com"` Email string `json:"email" validate:"required,email,unique" example:"gabi@geeksinthewoods.com"`
} }
// UserReadRequest defines the information needed to read an user.
type UserReadRequest struct {
ID string `json:"id" validate:"required,uuid" example:"d69bdef7-173f-4d29-b52c-3edc60baf6a2"`
IncludeArchived bool `json:"include-archived" example:"false"`
}
// UserUpdateRequest defines what information may be provided to modify an existing // UserUpdateRequest defines what information may be provided to modify an existing
// User. All fields are optional so clients can send just the fields they want // User. All fields are optional so clients can send just the fields they want
// changed. It uses pointer fields so we can differentiate between a field that // changed. It uses pointer fields so we can differentiate between a field that
@ -106,20 +126,25 @@ type UserArchiveRequest struct {
ID string `json:"id" validate:"required,uuid" example:"d69bdef7-173f-4d29-b52c-3edc60baf6a2"` ID string `json:"id" validate:"required,uuid" example:"d69bdef7-173f-4d29-b52c-3edc60baf6a2"`
} }
// UserUnarchiveRequest defines the information needed to unarchive an user. // UserRestoreRequest defines the information needed to restore an user.
type UserUnarchiveRequest struct { type UserRestoreRequest struct {
ID string `json:"id" validate:"required,uuid" example:"d69bdef7-173f-4d29-b52c-3edc60baf6a2"`
}
// UserDeleteRequest defines the information needed to delete a user.
type UserDeleteRequest struct {
ID string `json:"id" validate:"required,uuid" example:"d69bdef7-173f-4d29-b52c-3edc60baf6a2"` ID string `json:"id" validate:"required,uuid" example:"d69bdef7-173f-4d29-b52c-3edc60baf6a2"`
} }
// UserFindRequest defines the possible options to search for users. By default // UserFindRequest defines the possible options to search for users. By default
// archived users will be excluded from response. // archived users will be excluded from response.
type UserFindRequest struct { type UserFindRequest struct {
Where *string `json:"where" example:"name = ? and email = ?"` Where *string `json:"where" example:"name = ? and email = ?"`
Args []interface{} `json:"args" swaggertype:"array,string" example:"Company Name,gabi.may@geeksinthewoods.com"` Args []interface{} `json:"args" swaggertype:"array,string" example:"Company Name,gabi.may@geeksinthewoods.com"`
Order []string `json:"order" example:"created_at desc"` Order []string `json:"order" example:"created_at desc"`
Limit *uint `json:"limit" example:"10"` Limit *uint `json:"limit" example:"10"`
Offset *uint `json:"offset" example:"20"` Offset *uint `json:"offset" example:"20"`
IncludedArchived bool `json:"included-archived" example:"false"` IncludeArchived bool `json:"include-archived" example:"false"`
} }
// UserResetPasswordRequest defines the fields need to reset a user password. // UserResetPasswordRequest defines the fields need to reset a user password.
@ -142,32 +167,3 @@ type UserResetConfirmRequest struct {
Password string `json:"password" validate:"required" example:"SecretString"` Password string `json:"password" validate:"required" example:"SecretString"`
PasswordConfirm string `json:"password_confirm" validate:"required,eqfield=Password" example:"SecretString"` PasswordConfirm string `json:"password_confirm" validate:"required,eqfield=Password" example:"SecretString"`
} }
// AuthenticateRequest defines what information is required to authenticate a user.
type AuthenticateRequest struct {
Email string `json:"email" validate:"required,email" example:"gabi.may@geeksinthewoods.com"`
Password string `json:"password" validate:"required" example:"NeverTellSecret"`
}
// Token is the payload we deliver to users when they authenticate.
type Token struct {
// AccessToken is the token that authorizes and authenticates
// the requests.
AccessToken string `json:"access_token"`
// TokenType is the type of token.
// The Type method returns either this or "Bearer", the default.
TokenType string `json:"token_type,omitempty"`
// Expiry is the optional expiration time of the access token.
//
// If zero, TokenSource implementations will reuse the same
// token forever and RefreshToken or equivalent
// mechanisms for that TokenSource will not be used.
Expiry time.Time `json:"expiry,omitempty"`
TTL time.Duration `json:"ttl,omitempty"`
// contains filtered or unexported fields
claims auth.Claims `json:"-"`
// UserId is the ID of the user authenticated.
UserID string `json:"user_id" example:"d69bdef7-173f-4d29-b52c-3edc60baf6a2"`
// AccountID is the ID of the account for the user authenticated.
AccountID string `json:"account_id"example:"c4653bf9-5978-48b7-89c5-95704aebb7e2"`
}

View File

@ -35,10 +35,6 @@ var (
// ErrForbidden occurs when a user tries to do something that is forbidden to them according to our access control policies. // ErrForbidden occurs when a user tries to do something that is forbidden to them according to our access control policies.
ErrForbidden = errors.New("Attempted action is not allowed") ErrForbidden = errors.New("Attempted action is not allowed")
// ErrAuthenticationFailure occurs when a user attempts to authenticate but
// anything goes wrong.
ErrAuthenticationFailure = errors.New("Authentication failed")
// ErrResetExpired occurs when the the reset hash exceeds the expiration. // ErrResetExpired occurs when the the reset hash exceeds the expiration.
ErrResetExpired = errors.New("Reset expired") ErrResetExpired = errors.New("Reset expired")
) )
@ -208,7 +204,7 @@ func findRequestQuery(req UserFindRequest) (*sqlbuilder.SelectBuilder, []interfa
// Find gets all the users from the database based on the request params. // Find gets all the users from the database based on the request params.
func Find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserFindRequest) ([]*User, error) { func Find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserFindRequest) ([]*User, error) {
query, args := findRequestQuery(req) query, args := findRequestQuery(req)
return find(ctx, claims, dbConn, query, args, req.IncludedArchived) return find(ctx, claims, dbConn, query, args, req.IncludeArchived)
} }
// find internal method for getting all the users from the database using a select query. // find internal method for getting all the users from the database using a select query.
@ -432,20 +428,56 @@ func CreateInvite(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req
return &u, nil return &u, nil
} }
// ReadByID gets the specified user by ID from the database.
func ReadByID(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, id string) (*User, error) {
return Read(ctx, claims, dbConn, UserReadRequest{
ID: id,
IncludeArchived: false,
})
}
// Read gets the specified user from the database. // Read gets the specified user from the database.
func Read(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, id string, includedArchived bool) (*User, error) { func Read(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserReadRequest) (*User, error) {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.Read") span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.Read")
defer span.Finish() defer span.Finish()
// Validate the request.
v := webcontext.Validator()
err := v.Struct(req)
if err != nil {
return nil, err
}
// Filter base select query by ID
query := selectQuery()
query.Where(query.Equal("id", req.ID))
res, err := find(ctx, claims, dbConn, query, []interface{}{}, req.IncludeArchived)
if err != nil {
return nil, err
} else if res == nil || len(res) == 0 {
err = errors.WithMessagef(ErrNotFound, "user %s not found", req.ID)
return nil, err
}
u := res[0]
return u, nil
}
// ReadByEmail gets the specified user from the database.
func ReadByEmail(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, email string, includedArchived bool) (*User, error) {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.ReadByEmail")
defer span.Finish()
// Filter base select query by ID // Filter base select query by ID
query := selectQuery() query := selectQuery()
query.Where(query.Equal("id", id)) query.Where(query.Equal("email", email))
res, err := find(ctx, claims, dbConn, query, []interface{}{}, includedArchived) res, err := find(ctx, claims, dbConn, query, []interface{}{}, includedArchived)
if err != nil { if err != nil {
return nil, err return nil, err
} else if res == nil || len(res) == 0 { } else if res == nil || len(res) == 0 {
err = errors.WithMessagef(ErrNotFound, "user %s not found", id) err = errors.WithMessagef(ErrNotFound, "user %s not found", email)
return nil, err return nil, err
} }
u := res[0] u := res[0]
@ -599,14 +631,6 @@ func UpdatePassword(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, re
return nil return nil
} }
// Archive soft deleted the user by ID from the database.
func ArchiveById(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, id string, now time.Time) error {
req := UserArchiveRequest{
ID: id,
}
return Archive(ctx, claims, dbConn, req, now)
}
// Archive soft deleted the user from the database. // Archive soft deleted the user from the database.
func Archive(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserArchiveRequest, now time.Time) error { func Archive(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserArchiveRequest, now time.Time) error {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.Archive") span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.Archive")
@ -679,9 +703,9 @@ func Archive(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserA
return nil return nil
} }
// Unarchive undeletes the user from the database. // Restore undeletes the user from the database.
func Unarchive(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserUnarchiveRequest, now time.Time) error { func Restore(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserRestoreRequest, now time.Time) error {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.Unarchive") span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.Restore")
defer span.Finish() defer span.Finish()
// Validate the request. // Validate the request.
@ -731,17 +755,10 @@ func Unarchive(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req Use
} }
// Delete removes a user from the database. // Delete removes a user from the database.
func Delete(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, userID string) error { func Delete(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserDeleteRequest) error {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.Delete") span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.Delete")
defer span.Finish() defer span.Finish()
// Defines the struct to apply validation
req := struct {
ID string `json:"id" validate:"required,uuid"`
}{
ID: userID,
}
// Validate the request. // Validate the request.
v := webcontext.Validator() v := webcontext.Validator()
err := v.Struct(req) err := v.Struct(req)
@ -1011,3 +1028,30 @@ func ResetConfirm(ctx context.Context, dbConn *sqlx.DB, req UserResetConfirmRequ
return u, nil return u, nil
} }
type MockUserResponse struct {
*User
Password string
}
// MockUser returns a fake User for testing.
func MockUser(ctx context.Context, dbConn *sqlx.DB, now time.Time) (*MockUserResponse, error) {
pass := uuid.NewRandom().String()
req := UserCreateRequest{
FirstName: "Lee",
LastName: "Brown",
Email: uuid.NewRandom().String() + "@geeksinthewoods.com",
Password: pass,
PasswordConfirm: pass,
}
u, err := Create(ctx, auth.Claims{}, dbConn, req, now)
if err != nil {
return nil, err
}
return &MockUserResponse{
User: u,
Password: pass,
}, nil
}

View File

@ -517,8 +517,6 @@ func TestUpdatePassword(t *testing.T) {
now := time.Date(2018, time.October, 1, 0, 0, 0, 0, time.UTC) now := time.Date(2018, time.October, 1, 0, 0, 0, 0, time.UTC)
tknGen := &MockTokenGenerator{}
// Create a new user for testing. // Create a new user for testing.
initPass := uuid.NewRandom().String() initPass := uuid.NewRandom().String()
user, err := Create(ctx, auth.Claims{}, test.MasterDB, UserCreateRequest{ user, err := Create(ctx, auth.Claims{}, test.MasterDB, UserCreateRequest{
@ -548,13 +546,6 @@ func TestUpdatePassword(t *testing.T) {
t.Fatalf("\t%s\tCreate user account failed.", tests.Failed) t.Fatalf("\t%s\tCreate user account failed.", tests.Failed)
} }
// Verify that the user can be authenticated with the created user.
_, err = Authenticate(ctx, test.MasterDB, tknGen, user.Email, initPass, time.Hour, now)
if err != nil {
t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tAuthenticate failed.", tests.Failed)
}
// Ensure validation is working by trying UpdatePassword with an empty request. // Ensure validation is working by trying UpdatePassword with an empty request.
expectedErr := errors.New("Key: 'UserUpdatePasswordRequest.id' Error:Field validation for 'id' failed on the 'required' tag\n" + expectedErr := errors.New("Key: 'UserUpdatePasswordRequest.id' Error:Field validation for 'id' failed on the 'required' tag\n" +
"Key: 'UserUpdatePasswordRequest.password' Error:Field validation for 'password' failed on the 'required' tag\n" + "Key: 'UserUpdatePasswordRequest.password' Error:Field validation for 'password' failed on the 'required' tag\n" +
@ -587,14 +578,6 @@ func TestUpdatePassword(t *testing.T) {
t.Fatalf("\t%s\tUpdate password failed.", tests.Failed) t.Fatalf("\t%s\tUpdate password failed.", tests.Failed)
} }
t.Logf("\t%s\tUpdatePassword ok.", tests.Success) t.Logf("\t%s\tUpdatePassword ok.", tests.Success)
// Verify that the user can be authenticated with the updated password.
_, err = Authenticate(ctx, test.MasterDB, tknGen, user.Email, newPass, time.Hour, now)
if err != nil {
t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tAuthenticate failed.", tests.Failed)
}
t.Logf("\t%s\tAuthenticate ok.", tests.Success)
} }
} }
@ -850,7 +833,7 @@ func TestCrud(t *testing.T) {
t.Logf("\t%s\tUpdate ok.", tests.Success) t.Logf("\t%s\tUpdate ok.", tests.Success)
// Find the user and make sure the updates where made. // Find the user and make sure the updates where made.
findRes, err := Read(ctx, tt.claims(user, accountId), test.MasterDB, user.ID, false) findRes, err := ReadByID(ctx, tt.claims(user, accountId), test.MasterDB, user.ID)
if err != nil && errors.Cause(err) != tt.findErr { if err != nil && errors.Cause(err) != tt.findErr {
t.Logf("\t\tGot : %+v", err) t.Logf("\t\tGot : %+v", err)
t.Logf("\t\tWant: %+v", tt.findErr) t.Logf("\t\tWant: %+v", tt.findErr)
@ -864,14 +847,14 @@ func TestCrud(t *testing.T) {
} }
// Archive (soft-delete) the user. // Archive (soft-delete) the user.
err = ArchiveById(ctx, tt.claims(user, accountId), test.MasterDB, user.ID, now) err = Archive(ctx, tt.claims(user, accountId), test.MasterDB, UserArchiveRequest{ID: user.ID}, now)
if err != nil && errors.Cause(err) != tt.updateErr { if err != nil && errors.Cause(err) != tt.updateErr {
t.Logf("\t\tGot : %+v", err) t.Logf("\t\tGot : %+v", err)
t.Logf("\t\tWant: %+v", tt.updateErr) t.Logf("\t\tWant: %+v", tt.updateErr)
t.Fatalf("\t%s\tArchive failed.", tests.Failed) t.Fatalf("\t%s\tArchive failed.", tests.Failed)
} else if tt.updateErr == nil { } else if tt.updateErr == nil {
// Trying to find the archived user with the includeArchived false should result in not found. // Trying to find the archived user with the includeArchived false should result in not found.
_, err = Read(ctx, tt.claims(user, accountId), test.MasterDB, user.ID, false) _, err = ReadByID(ctx, tt.claims(user, accountId), test.MasterDB, user.ID)
if err != nil && errors.Cause(err) != ErrNotFound { if err != nil && errors.Cause(err) != ErrNotFound {
t.Logf("\t\tGot : %+v", err) t.Logf("\t\tGot : %+v", err)
t.Logf("\t\tWant: %+v", ErrNotFound) t.Logf("\t\tWant: %+v", ErrNotFound)
@ -879,7 +862,8 @@ func TestCrud(t *testing.T) {
} }
// Trying to find the archived user with the includeArchived true should result no error. // Trying to find the archived user with the includeArchived true should result no error.
_, err = Read(ctx, tt.claims(user, accountId), test.MasterDB, user.ID, true) _, err = Read(ctx, tt.claims(user, accountId), test.MasterDB,
UserReadRequest{ID: user.ID, IncludeArchived: true})
if err != nil { if err != nil {
t.Log("\t\tGot :", err) t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tArchive Read failed.", tests.Failed) t.Fatalf("\t%s\tArchive Read failed.", tests.Failed)
@ -887,15 +871,15 @@ func TestCrud(t *testing.T) {
} }
t.Logf("\t%s\tArchive ok.", tests.Success) t.Logf("\t%s\tArchive ok.", tests.Success)
// Unarchive (un-delete) the user. // Restore (un-delete) the user.
err = Unarchive(ctx, tt.claims(user, accountId), test.MasterDB, UserUnarchiveRequest{ID: user.ID}, now) err = Restore(ctx, tt.claims(user, accountId), test.MasterDB, UserRestoreRequest{ID: user.ID}, now)
if err != nil && errors.Cause(err) != tt.updateErr { if err != nil && errors.Cause(err) != tt.updateErr {
t.Logf("\t\tGot : %+v", err) t.Logf("\t\tGot : %+v", err)
t.Logf("\t\tWant: %+v", tt.updateErr) t.Logf("\t\tWant: %+v", tt.updateErr)
t.Fatalf("\t%s\tUnarchive failed.", tests.Failed) t.Fatalf("\t%s\tUnarchive failed.", tests.Failed)
} else if tt.updateErr == nil { } else if tt.updateErr == nil {
// Trying to find the archived user with the includeArchived false should result no error. // Trying to find the archived user with the includeArchived false should result no error.
_, err = Read(ctx, tt.claims(user, accountId), test.MasterDB, user.ID, false) _, err = ReadByID(ctx, tt.claims(user, accountId), test.MasterDB, user.ID)
if err != nil { if err != nil {
t.Log("\t\tGot :", err) t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tUnarchive Read failed.", tests.Failed) t.Fatalf("\t%s\tUnarchive Read failed.", tests.Failed)
@ -904,14 +888,14 @@ func TestCrud(t *testing.T) {
t.Logf("\t%s\tUnarchive ok.", tests.Success) t.Logf("\t%s\tUnarchive ok.", tests.Success)
// Delete (hard-delete) the user. // Delete (hard-delete) the user.
err = Delete(ctx, tt.claims(user, accountId), test.MasterDB, user.ID) err = Delete(ctx, tt.claims(user, accountId), test.MasterDB, UserDeleteRequest{ID: user.ID})
if err != nil && errors.Cause(err) != tt.updateErr { if err != nil && errors.Cause(err) != tt.updateErr {
t.Logf("\t\tGot : %+v", err) t.Logf("\t\tGot : %+v", err)
t.Logf("\t\tWant: %+v", tt.updateErr) t.Logf("\t\tWant: %+v", tt.updateErr)
t.Fatalf("\t%s\tUpdate failed.", tests.Failed) t.Fatalf("\t%s\tUpdate failed.", tests.Failed)
} else if tt.updateErr == nil { } else if tt.updateErr == nil {
// Trying to find the deleted user with the includeArchived true should result in not found. // Trying to find the deleted user with the includeArchived true should result in not found.
_, err = Read(ctx, tt.claims(user, accountId), test.MasterDB, user.ID, true) _, err = ReadByID(ctx, tt.claims(user, accountId), test.MasterDB, user.ID)
if errors.Cause(err) != ErrNotFound { if errors.Cause(err) != ErrNotFound {
t.Logf("\t\tGot : %+v", err) t.Logf("\t\tGot : %+v", err)
t.Logf("\t\tWant: %+v", ErrNotFound) t.Logf("\t\tWant: %+v", ErrNotFound)
@ -1079,8 +1063,6 @@ func TestResetPassword(t *testing.T) {
now := time.Date(2018, time.October, 1, 0, 0, 0, 0, time.UTC) now := time.Date(2018, time.October, 1, 0, 0, 0, 0, time.UTC)
tknGen := &MockTokenGenerator{}
// Create a new user for testing. // Create a new user for testing.
initPass := uuid.NewRandom().String() initPass := uuid.NewRandom().String()
user, err := Create(ctx, auth.Claims{}, test.MasterDB, UserCreateRequest{ user, err := Create(ctx, auth.Claims{}, test.MasterDB, UserCreateRequest{
@ -1152,7 +1134,7 @@ func TestResetPassword(t *testing.T) {
t.Logf("\t%s\tResetPassword ok.", tests.Success) t.Logf("\t%s\tResetPassword ok.", tests.Success)
// Read the user to ensure the password_reset field was set. // Read the user to ensure the password_reset field was set.
user, err = Read(ctx, auth.Claims{}, test.MasterDB, user.ID, false) user, err = ReadByID(ctx, auth.Claims{}, test.MasterDB, user.ID)
if err != nil { if err != nil {
t.Log("\t\tGot :", err) t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tRead failed.", tests.Failed) t.Fatalf("\t%s\tRead failed.", tests.Failed)
@ -1215,14 +1197,6 @@ func TestResetPassword(t *testing.T) {
} }
t.Logf("\t%s\tResetConfirm ok.", tests.Success) t.Logf("\t%s\tResetConfirm ok.", tests.Success)
// Verify that the user can be authenticated with the updated password.
_, err = Authenticate(ctx, test.MasterDB, tknGen, user.Email, newPass, time.Hour, now)
if err != nil {
t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tAuthenticate failed.", tests.Failed)
}
t.Logf("\t%s\tAuthenticate ok.", tests.Success)
// Ensure the reset hash does not work after its used. // Ensure the reset hash does not work after its used.
{ {
newPass := uuid.NewRandom().String() newPass := uuid.NewRandom().String()

View File

@ -27,9 +27,9 @@ var (
ErrInviteUserPasswordSet = errors.New("User password set") ErrInviteUserPasswordSet = errors.New("User password set")
) )
// InviteUsers sends emails to the users inviting them to join an account. // SendUserInvites sends emails to the users inviting them to join an account.
func InviteUsers(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, resetUrl func(string) string, notify notify.Email, req InviteUsersRequest, secretKey string, now time.Time) ([]string, error) { func SendUserInvites(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, resetUrl func(string) string, notify notify.Email, req SendUserInvitesRequest, secretKey string, now time.Time) ([]string, error) {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user_account.invite.InviteUsers") span, ctx := tracer.StartSpanFromContext(ctx, "internal.user_account.invite.SendUserInvites")
defer span.Finish() defer span.Finish()
v := webcontext.Validator() v := webcontext.Validator()
@ -131,12 +131,12 @@ func InviteUsers(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, reset
req.TTL = time.Minute * 90 req.TTL = time.Minute * 90
} }
fromUser, err := user.Read(ctx, claims, dbConn, req.UserID, false) fromUser, err := user.ReadByID(ctx, claims, dbConn, req.UserID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
account, err := account.Read(ctx, claims, dbConn, req.AccountID, false) account, err := account.ReadByID(ctx, claims, dbConn, req.AccountID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -190,9 +190,9 @@ func InviteUsers(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, reset
return inviteHashes, nil return inviteHashes, nil
} }
// InviteAccept updates the password for a user using the provided reset password ID. // AcceptInvite updates the user using the provided invite hash.
func InviteAccept(ctx context.Context, dbConn *sqlx.DB, req InviteAcceptRequest, secretKey string, now time.Time) error { func AcceptInvite(ctx context.Context, dbConn *sqlx.DB, req AcceptInviteRequest, secretKey string, now time.Time) error {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user_account.invite.InviteAccept") span, ctx := tracer.StartSpanFromContext(ctx, "internal.user_account.invite.AcceptInvite")
defer span.Finish() defer span.Finish()
v := webcontext.Validator() v := webcontext.Validator()
@ -232,13 +232,14 @@ func InviteAccept(ctx context.Context, dbConn *sqlx.DB, req InviteAcceptRequest,
return err return err
} }
u, err := user.Read(ctx, auth.Claims{}, dbConn, hash.UserID, true) u, err := user.Read(ctx, auth.Claims{}, dbConn,
user.UserReadRequest{ID: hash.UserID, IncludeArchived: true})
if err != nil { if err != nil {
return err return err
} }
if u.ArchivedAt != nil && !u.ArchivedAt.Time.IsZero() { if u.ArchivedAt != nil && !u.ArchivedAt.Time.IsZero() {
err = user.Unarchive(ctx, auth.Claims{}, dbConn, user.UserUnarchiveRequest{ID: hash.UserID}, now) err = user.Restore(ctx, auth.Claims{}, dbConn, user.UserRestoreRequest{ID: hash.UserID}, now)
if err != nil { if err != nil {
return err return err
} }

View File

@ -13,7 +13,6 @@ import (
"geeks-accelerator/oss/saas-starter-kit/internal/user" "geeks-accelerator/oss/saas-starter-kit/internal/user"
"geeks-accelerator/oss/saas-starter-kit/internal/user_account" "geeks-accelerator/oss/saas-starter-kit/internal/user_account"
"github.com/dgrijalva/jwt-go" "github.com/dgrijalva/jwt-go"
"github.com/huandu/go-sqlbuilder"
"github.com/pborman/uuid" "github.com/pborman/uuid"
"github.com/pkg/errors" "github.com/pkg/errors"
) )
@ -31,8 +30,8 @@ func testMain(m *testing.M) int {
return m.Run() return m.Run()
} }
// TestInviteUsers validates that invite users works. // TestSendUserInvites validates that invite users works.
func TestInviteUsers(t *testing.T) { func TestSendUserInvites(t *testing.T) {
t.Log("Given the need ensure a user an invite users to their account.") t.Log("Given the need ensure a user an invite users to their account.")
{ {
@ -101,11 +100,11 @@ func TestInviteUsers(t *testing.T) {
// Ensure validation is working by trying ResetPassword with an empty request. // Ensure validation is working by trying ResetPassword with an empty request.
{ {
expectedErr := errors.New("Key: 'InviteUsersRequest.account_id' Error:Field validation for 'account_id' failed on the 'required' tag\n" + expectedErr := errors.New("Key: 'SendUserInvitesRequest.account_id' Error:Field validation for 'account_id' failed on the 'required' tag\n" +
"Key: 'InviteUsersRequest.user_id' Error:Field validation for 'user_id' failed on the 'required' tag\n" + "Key: 'SendUserInvitesRequest.user_id' Error:Field validation for 'user_id' failed on the 'required' tag\n" +
"Key: 'InviteUsersRequest.emails' Error:Field validation for 'emails' failed on the 'required' tag\n" + "Key: 'SendUserInvitesRequest.emails' Error:Field validation for 'emails' failed on the 'required' tag\n" +
"Key: 'InviteUsersRequest.roles' Error:Field validation for 'roles' failed on the 'required' tag") "Key: 'SendUserInvitesRequest.roles' Error:Field validation for 'roles' failed on the 'required' tag")
_, err = InviteUsers(ctx, claims, test.MasterDB, resetUrl, notify, InviteUsersRequest{}, secretKey, now) _, err = SendUserInvites(ctx, claims, test.MasterDB, resetUrl, notify, SendUserInvitesRequest{}, secretKey, now)
if err == nil { if err == nil {
t.Logf("\t\tWant: %+v", expectedErr) t.Logf("\t\tWant: %+v", expectedErr)
t.Fatalf("\t%s\tInviteUsers failed.", tests.Failed) t.Fatalf("\t%s\tInviteUsers failed.", tests.Failed)
@ -129,7 +128,7 @@ func TestInviteUsers(t *testing.T) {
} }
// Make the reset password request. // Make the reset password request.
inviteHashes, err := InviteUsers(ctx, claims, test.MasterDB, resetUrl, notify, InviteUsersRequest{ inviteHashes, err := SendUserInvites(ctx, claims, test.MasterDB, resetUrl, notify, SendUserInvitesRequest{
UserID: u.ID, UserID: u.ID,
AccountID: a.ID, AccountID: a.ID,
Emails: inviteEmails, Emails: inviteEmails,
@ -148,12 +147,12 @@ func TestInviteUsers(t *testing.T) {
// Ensure validation is working by trying ResetConfirm with an empty request. // Ensure validation is working by trying ResetConfirm with an empty request.
{ {
expectedErr := errors.New("Key: 'InviteAcceptRequest.invite_hash' Error:Field validation for 'invite_hash' failed on the 'required' tag\n" + expectedErr := errors.New("Key: 'AcceptInviteRequest.invite_hash' Error:Field validation for 'invite_hash' failed on the 'required' tag\n" +
"Key: 'InviteAcceptRequest.first_name' Error:Field validation for 'first_name' failed on the 'required' tag\n" + "Key: 'AcceptInviteRequest.first_name' Error:Field validation for 'first_name' failed on the 'required' tag\n" +
"Key: 'InviteAcceptRequest.last_name' Error:Field validation for 'last_name' failed on the 'required' tag\n" + "Key: 'AcceptInviteRequest.last_name' Error:Field validation for 'last_name' failed on the 'required' tag\n" +
"Key: 'InviteAcceptRequest.password' Error:Field validation for 'password' failed on the 'required' tag\n" + "Key: 'AcceptInviteRequest.password' Error:Field validation for 'password' failed on the 'required' tag\n" +
"Key: 'InviteAcceptRequest.password_confirm' Error:Field validation for 'password_confirm' failed on the 'required' tag") "Key: 'AcceptInviteRequest.password_confirm' Error:Field validation for 'password_confirm' failed on the 'required' tag")
err = InviteAccept(ctx, test.MasterDB, InviteAcceptRequest{}, secretKey, now) err = AcceptInvite(ctx, test.MasterDB, AcceptInviteRequest{}, secretKey, now)
if err == nil { if err == nil {
t.Logf("\t\tWant: %+v", expectedErr) t.Logf("\t\tWant: %+v", expectedErr)
t.Fatalf("\t%s\tResetConfirm failed.", tests.Failed) t.Fatalf("\t%s\tResetConfirm failed.", tests.Failed)
@ -173,7 +172,7 @@ func TestInviteUsers(t *testing.T) {
// Ensure the TTL is enforced. // Ensure the TTL is enforced.
{ {
newPass := uuid.NewRandom().String() newPass := uuid.NewRandom().String()
err = InviteAccept(ctx, test.MasterDB, InviteAcceptRequest{ err = AcceptInvite(ctx, test.MasterDB, AcceptInviteRequest{
InviteHash: inviteHashes[0], InviteHash: inviteHashes[0],
FirstName: "Foo", FirstName: "Foo",
LastName: "Bar", LastName: "Bar",
@ -191,7 +190,7 @@ func TestInviteUsers(t *testing.T) {
// Assuming we have received the email and clicked the link, we now can ensure accept works. // Assuming we have received the email and clicked the link, we now can ensure accept works.
for _, inviteHash := range inviteHashes { for _, inviteHash := range inviteHashes {
newPass := uuid.NewRandom().String() newPass := uuid.NewRandom().String()
err = InviteAccept(ctx, test.MasterDB, InviteAcceptRequest{ err = AcceptInvite(ctx, test.MasterDB, AcceptInviteRequest{
InviteHash: inviteHash, InviteHash: inviteHash,
FirstName: "Foo", FirstName: "Foo",
LastName: "Bar", LastName: "Bar",
@ -208,7 +207,7 @@ func TestInviteUsers(t *testing.T) {
// Ensure the reset hash does not work after its used. // Ensure the reset hash does not work after its used.
{ {
newPass := uuid.NewRandom().String() newPass := uuid.NewRandom().String()
err = InviteAccept(ctx, test.MasterDB, InviteAcceptRequest{ err = AcceptInvite(ctx, test.MasterDB, AcceptInviteRequest{
InviteHash: inviteHashes[0], InviteHash: inviteHashes[0],
FirstName: "Foo", FirstName: "Foo",
LastName: "Bar", LastName: "Bar",
@ -224,43 +223,3 @@ func TestInviteUsers(t *testing.T) {
} }
} }
} }
func mockAccount(accountId string, now time.Time) error {
// Build the insert SQL statement.
query := sqlbuilder.NewInsertBuilder()
query.InsertInto("accounts")
query.Cols("id", "name", "created_at", "updated_at")
query.Values(accountId, uuid.NewRandom().String(), now, now)
// Execute the query with the provided context.
sql, args := query.Build()
sql = test.MasterDB.Rebind(sql)
_, err := test.MasterDB.ExecContext(tests.Context(), sql, args...)
if err != nil {
err = errors.Wrapf(err, "query - %s", query.String())
return err
}
return nil
}
func mockUser(userId string, now time.Time) error {
// Build the insert SQL statement.
query := sqlbuilder.NewInsertBuilder()
query.InsertInto("users")
query.Cols("id", "email", "password_hash", "password_salt", "created_at", "updated_at")
query.Values(userId, uuid.NewRandom().String(), "-", "-", now, now)
// Execute the query with the provided context.
sql, args := query.Build()
sql = test.MasterDB.Rebind(sql)
_, err := test.MasterDB.ExecContext(tests.Context(), sql, args...)
if err != nil {
err = errors.Wrapf(err, "query - %s", query.String())
return err
}
return nil
}

View File

@ -6,8 +6,8 @@ import (
"geeks-accelerator/oss/saas-starter-kit/internal/user_account" "geeks-accelerator/oss/saas-starter-kit/internal/user_account"
) )
// InviteUsersRequest defines the data needed to make an invite request. // SendUserInvitesRequest defines the data needed to make an invite request.
type InviteUsersRequest struct { type SendUserInvitesRequest struct {
AccountID string `json:"account_id" validate:"required,uuid" example:"c4653bf9-5978-48b7-89c5-95704aebb7e2"` AccountID string `json:"account_id" validate:"required,uuid" example:"c4653bf9-5978-48b7-89c5-95704aebb7e2"`
UserID string `json:"user_id" validate:"required,uuid" example:"c4653bf9-5978-48b7-89c5-95704aebb7e2"` UserID string `json:"user_id" validate:"required,uuid" example:"c4653bf9-5978-48b7-89c5-95704aebb7e2"`
Emails []string `json:"emails" validate:"required,dive,email"` Emails []string `json:"emails" validate:"required,dive,email"`
@ -23,8 +23,8 @@ type InviteHash struct {
RequestIP string `json:"request_ip" validate:"required,ip" example:"69.56.104.36"` RequestIP string `json:"request_ip" validate:"required,ip" example:"69.56.104.36"`
} }
// InviteAcceptRequest defines the fields need to complete an invite request. // AcceptInviteRequest defines the fields need to complete an invite request.
type InviteAcceptRequest struct { type AcceptInviteRequest struct {
InviteHash string `json:"invite_hash" validate:"required" example:"d69bdef7-173f-4d29-b52c-3edc60baf6a2"` InviteHash string `json:"invite_hash" validate:"required" example:"d69bdef7-173f-4d29-b52c-3edc60baf6a2"`
FirstName string `json:"first_name" validate:"required" example:"Gabi"` FirstName string `json:"first_name" validate:"required" example:"Gabi"`
LastName string `json:"last_name" validate:"required" example:"May"` LastName string `json:"last_name" validate:"required" example:"May"`

View File

@ -19,7 +19,7 @@ import (
// application. The status will allow users to be managed on by account with users // application. The status will allow users to be managed on by account with users
// being global to the application. // being global to the application.
type UserAccount struct { type UserAccount struct {
ID string `json:"id" validate:"required,uuid" example:"72938896-a998-4258-a17b-6418dcdb80e3"` //ID string `json:"id" validate:"required,uuid" example:"72938896-a998-4258-a17b-6418dcdb80e3"`
UserID string `json:"user_id" validate:"required,uuid" example:"d69bdef7-173f-4d29-b52c-3edc60baf6a2"` UserID string `json:"user_id" validate:"required,uuid" example:"d69bdef7-173f-4d29-b52c-3edc60baf6a2"`
AccountID string `json:"account_id" validate:"required,uuid" example:"c4653bf9-5978-48b7-89c5-95704aebb7e2"` AccountID string `json:"account_id" validate:"required,uuid" example:"c4653bf9-5978-48b7-89c5-95704aebb7e2"`
Roles UserAccountRoles `json:"roles" validate:"required,dive,oneof=admin user" enums:"admin,user" swaggertype:"array,string" example:"admin"` Roles UserAccountRoles `json:"roles" validate:"required,dive,oneof=admin user" enums:"admin,user" swaggertype:"array,string" example:"admin"`
@ -31,7 +31,7 @@ type UserAccount struct {
// UserAccountResponse defines the one to many relationship of an user to an account that is returned for display. // UserAccountResponse defines the one to many relationship of an user to an account that is returned for display.
type UserAccountResponse struct { type UserAccountResponse struct {
ID string `json:"id" example:"d69bdef7-173f-4d29-b52c-3edc60baf6a2"` //ID string `json:"id" example:"d69bdef7-173f-4d29-b52c-3edc60baf6a2"`
UserID string `json:"user_id" example:"d69bdef7-173f-4d29-b52c-3edc60baf6a2"` UserID string `json:"user_id" example:"d69bdef7-173f-4d29-b52c-3edc60baf6a2"`
AccountID string `json:"account_id" example:"c4653bf9-5978-48b7-89c5-95704aebb7e2"` AccountID string `json:"account_id" example:"c4653bf9-5978-48b7-89c5-95704aebb7e2"`
Roles UserAccountRoles `json:"roles" validate:"required,dive,oneof=admin user" enums:"admin,user" swaggertype:"array,string" example:"admin"` Roles UserAccountRoles `json:"roles" validate:"required,dive,oneof=admin user" enums:"admin,user" swaggertype:"array,string" example:"admin"`
@ -49,7 +49,7 @@ func (m *UserAccount) Response(ctx context.Context) *UserAccountResponse {
} }
r := &UserAccountResponse{ r := &UserAccountResponse{
ID: m.ID, //ID: m.ID,
UserID: m.UserID, UserID: m.UserID,
AccountID: m.AccountID, AccountID: m.AccountID,
Roles: m.Roles, Roles: m.Roles,
@ -77,6 +77,13 @@ type UserAccountCreateRequest struct {
Status *UserAccountStatus `json:"status,omitempty" validate:"omitempty,oneof=active invited disabled" enums:"active,invited,disabled" swaggertype:"string" example:"active"` Status *UserAccountStatus `json:"status,omitempty" validate:"omitempty,oneof=active invited disabled" enums:"active,invited,disabled" swaggertype:"string" example:"active"`
} }
// UserAccountReadRequest defines the information needed to read a user account.
type UserAccountReadRequest struct {
UserID string `json:"user_id" validate:"required,uuid" example:"d69bdef7-173f-4d29-b52c-3edc60baf6a2"`
AccountID string `json:"account_id" validate:"required,uuid" example:"c4653bf9-5978-48b7-89c5-95704aebb7e2"`
IncludeArchived bool `json:"include-archived" example:"false"`
}
// UserAccountUpdateRequest defines the information needed to update the roles or the // UserAccountUpdateRequest defines the information needed to update the roles or the
// status for an existing user account. // status for an existing user account.
type UserAccountUpdateRequest struct { type UserAccountUpdateRequest struct {
@ -104,12 +111,12 @@ type UserAccountDeleteRequest struct {
// UserAccountFindRequest defines the possible options to search for users accounts. // UserAccountFindRequest defines the possible options to search for users accounts.
// By default archived user accounts will be excluded from response. // By default archived user accounts will be excluded from response.
type UserAccountFindRequest struct { type UserAccountFindRequest struct {
Where *string `json:"where" example:"user_id = ? and account_id = ?"` Where *string `json:"where" example:"user_id = ? and account_id = ?"`
Args []interface{} `json:"args" swaggertype:"array,string" example:"d69bdef7-173f-4d29-b52c-3edc60baf6a2,c4653bf9-5978-48b7-89c5-95704aebb7e2"` Args []interface{} `json:"args" swaggertype:"array,string" example:"d69bdef7-173f-4d29-b52c-3edc60baf6a2,c4653bf9-5978-48b7-89c5-95704aebb7e2"`
Order []string `json:"order" example:"created_at desc"` Order []string `json:"order" example:"created_at desc"`
Limit *uint `json:"limit" example:"10"` Limit *uint `json:"limit" example:"10"`
Offset *uint `json:"offset" example:"20"` Offset *uint `json:"offset" example:"20"`
IncludedArchived bool `json:"included-archived" example:"false"` IncludeArchived bool `json:"include-archived" example:"false"`
} }
// UserAccountStatus represents the status of a user for an account. // UserAccountStatus represents the status of a user for an account.

View File

@ -3,6 +3,7 @@ package user_account
import ( import (
"context" "context"
"database/sql" "database/sql"
"geeks-accelerator/oss/saas-starter-kit/internal/user"
"time" "time"
"geeks-accelerator/oss/saas-starter-kit/internal/account" "geeks-accelerator/oss/saas-starter-kit/internal/account"
@ -30,7 +31,7 @@ const userAccountTableName = "users_accounts"
const userTableName = "users" const userTableName = "users"
// The list of columns needed for mapRowsToUserAccount // The list of columns needed for mapRowsToUserAccount
var userAccountMapColumns = "id,user_id,account_id,roles,status,created_at,updated_at,archived_at" var userAccountMapColumns = "user_id,account_id,roles,status,created_at,updated_at,archived_at"
// mapRowsToUserAccount takes the SQL rows and maps it to the UserAccount struct // mapRowsToUserAccount takes the SQL rows and maps it to the UserAccount struct
// with the columns defined by userAccountMapColumns // with the columns defined by userAccountMapColumns
@ -39,7 +40,7 @@ func mapRowsToUserAccount(rows *sql.Rows) (*UserAccount, error) {
ua UserAccount ua UserAccount
err error err error
) )
err = rows.Scan(&ua.ID, &ua.UserID, &ua.AccountID, &ua.Roles, &ua.Status, &ua.CreatedAt, &ua.UpdatedAt, &ua.ArchivedAt) err = rows.Scan(&ua.UserID, &ua.AccountID, &ua.Roles, &ua.Status, &ua.CreatedAt, &ua.UpdatedAt, &ua.ArchivedAt)
if err != nil { if err != nil {
return nil, errors.WithStack(err) return nil, errors.WithStack(err)
} }
@ -132,7 +133,7 @@ func findRequestQuery(req UserAccountFindRequest) (*sqlbuilder.SelectBuilder, []
// Find gets all the user accounts from the database based on the request params. // Find gets all the user accounts from the database based on the request params.
func Find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserAccountFindRequest) ([]*UserAccount, error) { func Find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserAccountFindRequest) ([]*UserAccount, error) {
query, args := findRequestQuery(req) query, args := findRequestQuery(req)
return find(ctx, claims, dbConn, query, args, req.IncludedArchived) return find(ctx, claims, dbConn, query, args, req.IncludeArchived)
} }
// Find gets all the user accounts from the database based on the select query // Find gets all the user accounts from the database based on the select query
@ -260,8 +261,10 @@ func Create(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserAc
ua.UpdatedAt = now ua.UpdatedAt = now
ua.ArchivedAt = nil ua.ArchivedAt = nil
} else { } else {
uaID := uuid.NewRandom().String()
ua = UserAccount{ ua = UserAccount{
ID: uuid.NewRandom().String(), //ID: uaID,
UserID: req.UserID, UserID: req.UserID,
AccountID: req.AccountID, AccountID: req.AccountID,
Roles: req.Roles, Roles: req.Roles,
@ -278,7 +281,7 @@ func Create(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserAc
query := sqlbuilder.NewInsertBuilder() query := sqlbuilder.NewInsertBuilder()
query.InsertInto(userAccountTableName) query.InsertInto(userAccountTableName)
query.Cols("id", "user_id", "account_id", "roles", "status", "created_at", "updated_at") query.Cols("id", "user_id", "account_id", "roles", "status", "created_at", "updated_at")
query.Values(ua.ID, ua.UserID, ua.AccountID, ua.Roles, ua.Status.String(), ua.CreatedAt, ua.UpdatedAt) query.Values(uaID, ua.UserID, ua.AccountID, ua.Roles, ua.Status.String(), ua.CreatedAt, ua.UpdatedAt)
// Execute the query with the provided context. // Execute the query with the provided context.
sql, args := query.Build() sql, args := query.Build()
@ -295,19 +298,28 @@ func Create(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserAc
} }
// Read gets the specified user account from the database. // Read gets the specified user account from the database.
func Read(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, id string, includedArchived bool) (*UserAccount, error) { func Read(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserAccountReadRequest) (*UserAccount, error) {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user_account.Read") span, ctx := tracer.StartSpanFromContext(ctx, "internal.user_account.Read")
defer span.Finish() defer span.Finish()
// Validate the request.
v := webcontext.Validator()
err := v.Struct(req)
if err != nil {
return nil, err
}
// Filter base select query by ID // Filter base select query by ID
query := selectQuery() query := selectQuery()
query.Where(query.Equal("id", id)) query.Where(query.And(
query.Equal("user_id", req.UserID),
query.Equal("account_id", req.AccountID)))
res, err := find(ctx, claims, dbConn, query, []interface{}{}, includedArchived) res, err := find(ctx, claims, dbConn, query, []interface{}{}, req.IncludeArchived)
if res == nil || len(res) == 0 { if err != nil {
err = errors.WithMessagef(ErrNotFound, "user account %s not found", id)
return nil, err return nil, err
} else if err != nil { } else if res == nil || len(res) == 0 {
err = errors.WithMessagef(ErrNotFound, "entry for user %s account %s not found", req.UserID, req.AccountID)
return nil, err return nil, err
} }
u := res[0] u := res[0]
@ -478,3 +490,41 @@ func Delete(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserAc
return nil return nil
} }
type MockUserAccountResponse struct {
*UserAccount
User *user.MockUserResponse
Account *account.Account
}
// MockUserAccount returns a fake UserAccount for testing.
func MockUserAccount(ctx context.Context, dbConn *sqlx.DB, now time.Time, roles ...UserAccountRole) (*MockUserAccountResponse, error) {
usr, err := user.MockUser(ctx, dbConn, now)
if err != nil {
return nil, err
}
acc, err := account.MockAccount(ctx, dbConn, now)
if err != nil {
return nil, err
}
status := UserAccountStatus_Active
req := UserAccountCreateRequest{
UserID: usr.ID,
AccountID: acc.ID,
Status: &status,
Roles: roles,
}
ua, err := Create(ctx, auth.Claims{}, dbConn, req, now)
if err != nil {
return nil, err
}
return &MockUserAccountResponse{
UserAccount: ua,
User: usr,
Account: acc,
}, nil
}

View File

@ -193,7 +193,7 @@ func TestCreateValidation(t *testing.T) {
Status: UserAccountStatus_Active, Status: UserAccountStatus_Active,
// Copy this fields from the result. // Copy this fields from the result.
ID: res.ID, //ID: res.ID,
CreatedAt: res.CreatedAt, CreatedAt: res.CreatedAt,
UpdatedAt: res.UpdatedAt, UpdatedAt: res.UpdatedAt,
//ArchivedAt: nil, //ArchivedAt: nil,
@ -326,7 +326,8 @@ func TestCreateExistingEntry(t *testing.T) {
} }
// Find the archived user account // Find the archived user account
arcRes, err := Read(tests.Context(), auth.Claims{}, test.MasterDB, ua2.ID, true) arcRes, err := Read(tests.Context(), auth.Claims{}, test.MasterDB,
UserAccountReadRequest{UserID: req1.UserID, AccountID: req1.AccountID, IncludeArchived: true})
if err != nil || arcRes == nil { if err != nil || arcRes == nil {
t.Log("\t\tGot :", err) t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tFind user account failed.", tests.Failed) t.Fatalf("\t%s\tFind user account failed.", tests.Failed)
@ -349,7 +350,8 @@ func TestCreateExistingEntry(t *testing.T) {
} }
// Ensure the user account has archived_at empty // Ensure the user account has archived_at empty
findRes, err := Read(tests.Context(), auth.Claims{}, test.MasterDB, ua3.ID, false) findRes, err := Read(tests.Context(), auth.Claims{}, test.MasterDB,
UserAccountReadRequest{UserID: req1.UserID, AccountID: req1.AccountID})
if err != nil || arcRes == nil { if err != nil || arcRes == nil {
t.Log("\t\tGot :", err) t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tFind user account failed.", tests.Failed) t.Fatalf("\t%s\tFind user account failed.", tests.Failed)
@ -609,7 +611,7 @@ func TestCrud(t *testing.T) {
} else if tt.findErr == nil { } else if tt.findErr == nil {
expected := []*UserAccount{ expected := []*UserAccount{
&UserAccount{ &UserAccount{
ID: ua.ID, //ID: ua.ID,
UserID: ua.UserID, UserID: ua.UserID,
AccountID: ua.AccountID, AccountID: ua.AccountID,
Roles: ua.Roles, Roles: ua.Roles,
@ -651,7 +653,7 @@ func TestCrud(t *testing.T) {
expected := []*UserAccount{ expected := []*UserAccount{
&UserAccount{ &UserAccount{
ID: ua.ID, //ID: ua.ID,
UserID: ua.UserID, UserID: ua.UserID,
AccountID: ua.AccountID, AccountID: ua.AccountID,
Roles: *updateReq.Roles, Roles: *updateReq.Roles,
@ -806,8 +808,9 @@ func TestFind(t *testing.T) {
} }
ua := *userAccounts[i] ua := *userAccounts[i]
whereParts = append(whereParts, "id = ?") whereParts = append(whereParts, "(user_id = ? and account_id = ?)")
whereArgs = append(whereArgs, ua.ID) whereArgs = append(whereArgs, ua.UserID)
whereArgs = append(whereArgs, ua.AccountID)
expected = append(expected, &ua) expected = append(expected, &ua)
} }
where := createdFilter + " AND (" + strings.Join(whereParts, " OR ") + ")" where := createdFilter + " AND (" + strings.Join(whereParts, " OR ") + ")"

View File

@ -1,15 +1,15 @@
package user package user_auth
import ( import (
"context" "context"
"crypto/rsa"
"database/sql" "database/sql"
"strings" "strings"
"time" "time"
"geeks-accelerator/oss/saas-starter-kit/internal/account/account_preference"
"geeks-accelerator/oss/saas-starter-kit/internal/platform/auth" "geeks-accelerator/oss/saas-starter-kit/internal/platform/auth"
"geeks-accelerator/oss/saas-starter-kit/internal/platform/web/webcontext" "geeks-accelerator/oss/saas-starter-kit/internal/platform/web/webcontext"
"github.com/dgrijalva/jwt-go" "geeks-accelerator/oss/saas-starter-kit/internal/user"
"github.com/huandu/go-sqlbuilder" "github.com/huandu/go-sqlbuilder"
"github.com/jmoiron/sqlx" "github.com/jmoiron/sqlx"
"github.com/lib/pq" "github.com/lib/pq"
@ -18,35 +18,37 @@ import (
"gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer" "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer"
) )
// TokenGenerator is the behavior we need in our Authenticate to generate tokens for var (
// authenticated users. // ErrAuthenticationFailure occurs when a user attempts to authenticate but
type TokenGenerator interface { // anything goes wrong.
GenerateToken(auth.Claims) (string, error) ErrAuthenticationFailure = errors.New("Authentication failed")
ParseClaims(string) (auth.Claims, error) )
}
const (
// The database table for User
userTableName = "users"
// The database table for Account
accountTableName = "accounts"
// The database table for User Account
userAccountTableName = "users_accounts"
)
// Authenticate finds a user by their email and verifies their password. On success // Authenticate finds a user by their email and verifies their password. On success
// it returns a Token that can be used to authenticate access to the application in // it returns a Token that can be used to authenticate access to the application in
// the future. // the future.
func Authenticate(ctx context.Context, dbConn *sqlx.DB, tknGen TokenGenerator, email, password string, expires time.Duration, now time.Time, scopes ...string) (Token, error) { func Authenticate(ctx context.Context, dbConn *sqlx.DB, tknGen TokenGenerator, email, password string, expires time.Duration, now time.Time, scopes ...string) (Token, error) {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.Authenticate") span, ctx := tracer.StartSpanFromContext(ctx, "internal.user_auth.Authenticate")
defer span.Finish() defer span.Finish()
// Generate sql query to select user by email address. u, err := user.ReadByEmail(ctx, auth.Claims{}, dbConn, email, false)
query := sqlbuilder.NewSelectBuilder()
query.Where(query.Equal("email", email))
// Run the find, use empty claims to bypass ACLs since this in an internal request
// and the current user is not authenticated at this point. If the email is
// invalid, return the same error as when an invalid password is supplied.
res, err := find(ctx, auth.Claims{}, dbConn, query, []interface{}{}, false)
if err != nil { if err != nil {
return Token{}, err if errors.Cause(err) == user.ErrNotFound {
} else if res == nil || len(res) == 0 { err = errors.WithStack(ErrAuthenticationFailure)
err = errors.WithStack(ErrAuthenticationFailure) return Token{}, err
return Token{}, err } else {
return Token{}, err
}
} }
u := res[0]
// Append the salt from the user record to the supplied password. // Append the salt from the user record to the supplied password.
saltedPassword := password + u.PasswordSalt saltedPassword := password + u.PasswordSalt
@ -67,7 +69,7 @@ func Authenticate(ctx context.Context, dbConn *sqlx.DB, tknGen TokenGenerator, e
// it returns a Token that can be used to authenticate access to the application in // it returns a Token that can be used to authenticate access to the application in
// the future. // the future.
func SwitchAccount(ctx context.Context, dbConn *sqlx.DB, tknGen TokenGenerator, claims auth.Claims, accountID string, expires time.Duration, now time.Time, scopes ...string) (Token, error) { func SwitchAccount(ctx context.Context, dbConn *sqlx.DB, tknGen TokenGenerator, claims auth.Claims, accountID string, expires time.Duration, now time.Time, scopes ...string) (Token, error) {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.SwitchAccount") span, ctx := tracer.StartSpanFromContext(ctx, "internal.user_auth.SwitchAccount")
defer span.Finish() defer span.Finish()
// Defines struct to apply validation for the supplied claims and account ID. // Defines struct to apply validation for the supplied claims and account ID.
@ -221,61 +223,102 @@ func generateToken(ctx context.Context, dbConn *sqlx.DB, tknGen TokenGenerator,
// Allow the scope to be defined for the claims. This enables testing via the API when a user has the role of admin // Allow the scope to be defined for the claims. This enables testing via the API when a user has the role of admin
// and would like to limit their role to user. // and would like to limit their role to user.
var roles []string var roles []string
if len(scopes) > 0 && scopes[0] != "" { {
// Parse scopes, handle when one value has a list of scopes if len(scopes) > 0 && scopes[0] != "" {
// separated by a space. // Parse scopes, handle when one value has a list of scopes
var scopeList []string // separated by a space.
for _, vs := range scopes { var scopeList []string
for _, v := range strings.Split(vs, " ") { for _, vs := range scopes {
v = strings.TrimSpace(v) for _, v := range strings.Split(vs, " ") {
if v == "" { v = strings.TrimSpace(v)
continue if v == "" {
} continue
scopeList = append(scopeList, v) }
} scopeList = append(scopeList, v)
}
for _, s := range scopeList {
var scopeValid bool
for _, r := range account.Roles {
if r == s || (s == auth.RoleUser && r == auth.RoleAdmin) {
scopeValid = true
break
} }
} }
if scopeValid { for _, s := range scopeList {
roles = append(roles, s) var scopeValid bool
} else { for _, r := range account.Roles {
err := errors.Errorf("invalid scope '%s'", s) if r == s || (s == auth.RoleUser && r == auth.RoleAdmin) {
return Token{}, err scopeValid = true
break
}
}
if scopeValid {
roles = append(roles, s)
} else {
err := errors.Errorf("invalid scope '%s'", s)
return Token{}, err
}
}
} else {
roles = account.Roles
}
if len(roles) == 0 {
err := errors.New("no roles defined for user")
return Token{}, err
}
}
var claimPref auth.ClaimPreferences
{
// Set the timezone if one is specifically set on the user.
var tz *time.Location
if account.UserTimezone.Valid && account.UserTimezone.String != "" {
tz, _ = time.LoadLocation(account.UserTimezone.String)
}
// If user timezone failed to parse or none is set, check the timezone set on the account.
if tz == nil && account.AccountTimezone.Valid && account.AccountTimezone.String != "" {
tz, _ = time.LoadLocation(account.AccountTimezone.String)
}
prefs, err := account_preference.FindByAccountID(ctx, auth.Claims{}, dbConn, account_preference.AccountPreferenceFindByAccountIDRequest{
AccountID: accountID,
})
if err != nil {
return Token{}, err
}
var (
preferenceDatetimeFormat string
preferenceDateFormat string
preferenceTimeFormat string
)
for _, pref := range prefs {
switch pref.Name {
case account_preference.AccountPreference_Datetime_Format:
preferenceDatetimeFormat = pref.Value
case account_preference.AccountPreference_Date_Format:
preferenceDateFormat = pref.Value
case account_preference.AccountPreference_Time_Format:
preferenceTimeFormat = pref.Value
} }
} }
} else {
roles = account.Roles
}
if len(roles) == 0 { if preferenceDatetimeFormat == "" {
err := errors.New("no roles defined for user") preferenceDatetimeFormat = account_preference.AccountPreference_Datetime_Format_Default
return Token{}, err }
} if preferenceDateFormat == "" {
preferenceDateFormat = account_preference.AccountPreference_Date_Format_Default
}
if preferenceTimeFormat == "" {
preferenceTimeFormat = account_preference.AccountPreference_Time_Format_Default
}
// Set the timezone if one is specifically set on the user. claimPref = auth.NewClaimPreferences(tz, preferenceDatetimeFormat, preferenceDateFormat, preferenceTimeFormat)
var tz *time.Location
if account.UserTimezone.Valid && account.UserTimezone.String != "" {
tz, _ = time.LoadLocation(account.UserTimezone.String)
}
// If user timezone failed to parse or none is set, check the timezone set on the account.
if tz == nil && account.AccountTimezone.Valid && account.AccountTimezone.String != "" {
tz, _ = time.LoadLocation(account.AccountTimezone.String)
} }
// JWT claims requires both an audience and a subject. For this application: // JWT claims requires both an audience and a subject. For this application:
// Subject: The ID of the user authenticated. // Subject: The ID of the user authenticated.
// Audience: The ID of the account the user is accessing. A list of account IDs // Audience: The ID of the account the user is accessing. A list of account IDs
// will also be included to support the user switching between them. // will also be included to support the user switching between them.
claims = auth.NewClaims(userID, accountID, accountIds, roles, tz, now, expires) claims = auth.NewClaims(userID, accountID, accountIds, roles, claimPref, now, expires)
// Generate a token for the user with the defined claims. // Generate a token for the user with the defined claims.
tknStr, err := tknGen.GenerateToken(claims) tknStr, err := tknGen.GenerateToken(claims)
@ -287,8 +330,8 @@ func generateToken(ctx context.Context, dbConn *sqlx.DB, tknGen TokenGenerator,
AccessToken: tknStr, AccessToken: tknStr,
TokenType: "Bearer", TokenType: "Bearer",
claims: claims, claims: claims,
UserID: claims.Subject, UserID: claims.Subject,
AccountID: claims.Audience, AccountID: claims.Audience,
} }
if expires.Seconds() > 0 { if expires.Seconds() > 0 {
@ -298,72 +341,3 @@ func generateToken(ctx context.Context, dbConn *sqlx.DB, tknGen TokenGenerator,
return tkn, nil return tkn, nil
} }
// AuthorizationHeader returns the header authorization value.
func (t Token) AuthorizationHeader() string {
return "Bearer " + t.AccessToken
}
// mockTokenGenerator is used for testing that Authenticate calls its provided
// token generator in a specific way.
type MockTokenGenerator struct {
// Private key generated by GenerateToken that is need for ParseClaims
key *rsa.PrivateKey
// algorithm is the method used to generate the private key.
algorithm string
}
// GenerateToken implements the TokenGenerator interface. It returns a "token"
// that includes some information about the claims it was passed.
func (g *MockTokenGenerator) GenerateToken(claims auth.Claims) (string, error) {
privateKey, err := auth.KeyGen()
if err != nil {
return "", err
}
g.key, err = jwt.ParseRSAPrivateKeyFromPEM(privateKey)
if err != nil {
return "", err
}
g.algorithm = "RS256"
method := jwt.GetSigningMethod(g.algorithm)
tkn := jwt.NewWithClaims(method, claims)
tkn.Header["kid"] = "1"
str, err := tkn.SignedString(g.key)
if err != nil {
return "", err
}
return str, nil
}
// ParseClaims recreates the Claims that were used to generate a token. It
// verifies that the token was signed using our key.
func (g *MockTokenGenerator) ParseClaims(tknStr string) (auth.Claims, error) {
parser := jwt.Parser{
ValidMethods: []string{g.algorithm},
}
if g.key == nil {
return auth.Claims{}, errors.New("Private key is empty.")
}
f := func(t *jwt.Token) (interface{}, error) {
return g.key.Public().(*rsa.PublicKey), nil
}
var claims auth.Claims
tkn, err := parser.ParseWithClaims(tknStr, &claims, f)
if err != nil {
return auth.Claims{}, errors.Wrap(err, "parsing token")
}
if !tkn.Valid {
return auth.Claims{}, errors.New("Invalid token")
}
return claims, nil
}

View File

@ -0,0 +1,258 @@
package user_auth
import (
"encoding/json"
"os"
"testing"
"time"
"geeks-accelerator/oss/saas-starter-kit/internal/account"
"geeks-accelerator/oss/saas-starter-kit/internal/platform/auth"
"geeks-accelerator/oss/saas-starter-kit/internal/platform/notify"
"geeks-accelerator/oss/saas-starter-kit/internal/platform/tests"
"geeks-accelerator/oss/saas-starter-kit/internal/user"
"geeks-accelerator/oss/saas-starter-kit/internal/user_account"
"github.com/google/go-cmp/cmp"
"github.com/pborman/uuid"
"github.com/pkg/errors"
)
var test *tests.Test
// TestMain is the entry point for testing.
func TestMain(m *testing.M) {
os.Exit(testMain(m))
}
func testMain(m *testing.M) int {
test = tests.New()
defer test.TearDown()
return m.Run()
}
// TestAuthenticate validates the behavior around authenticating users.
func TestAuthenticate(t *testing.T) {
defer tests.Recover(t)
t.Log("Given the need to authenticate users")
{
t.Log("\tWhen handling a single User.")
{
ctx := tests.Context()
tknGen := &auth.MockTokenGenerator{}
// Auth tokens are valid for an our and is verified against current time.
// Issue the token one hour ago.
now := time.Now().Add(time.Hour * -1)
// Try to authenticate an invalid user.
_, err := Authenticate(ctx, test.MasterDB, tknGen, "doesnotexist@gmail.com", "xy7", time.Hour, now)
if errors.Cause(err) != ErrAuthenticationFailure {
t.Logf("\t\tGot : %+v", err)
t.Logf("\t\tWant: %+v", ErrAuthenticationFailure)
t.Fatalf("\t%s\tAuthenticate non existant user failed.", tests.Failed)
}
t.Logf("\t%s\tAuthenticate non existant user ok.", tests.Success)
// Create a new user for testing.
usrAcc, err := user_account.MockUserAccount(ctx, test.MasterDB, now, user_account.UserAccountRole_User)
if err != nil {
t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tCreate user account failed.", tests.Failed)
}
t.Logf("\t%s\tCreate user account ok.", tests.Success)
acc2, err := account.MockAccount(ctx, test.MasterDB, now)
if err != nil {
t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tCreate second account failed.", tests.Failed)
}
t.Logf("\t%s\tCreate second account ok.", tests.Success)
// Associate second new account with user user. Need to ensure that now
// is always greater than the first user_account entry created so it will
// be returned consistently back in the same order, last.
account2Role := auth.RoleUser
_, err = user_account.Create(ctx, auth.Claims{}, test.MasterDB, user_account.UserAccountCreateRequest{
UserID: usrAcc.UserID,
AccountID: acc2.ID,
Roles: []user_account.UserAccountRole{user_account.UserAccountRole(account2Role)},
}, now)
// Add 30 minutes to now to simulate time passing.
now = now.Add(time.Minute * 30)
// Try to authenticate valid user with invalid password.
_, err = Authenticate(ctx, test.MasterDB, tknGen, usrAcc.User.Email, "xy7", time.Hour, now)
if errors.Cause(err) != ErrAuthenticationFailure {
t.Logf("\t\tGot : %+v", err)
t.Logf("\t\tWant: %+v", ErrAuthenticationFailure)
t.Fatalf("\t%s\tAuthenticate user w/invalid password failed.", tests.Failed)
}
t.Logf("\t%s\tAuthenticate user w/invalid password ok.", tests.Success)
// Verify that the user can be authenticated with the created user.
tkn1, err := Authenticate(ctx, test.MasterDB, tknGen, usrAcc.User.Email, usrAcc.User.Password, time.Hour, now)
if err != nil {
t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tAuthenticate user failed.", tests.Failed)
}
t.Logf("\t%s\tAuthenticate user ok.", tests.Success)
// Ensure the token string was correctly generated.
claims1, err := tknGen.ParseClaims(tkn1.AccessToken)
if err != nil {
t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tParse claims from token failed.", tests.Failed)
}
// Hack for Unhandled Exception in go-cmp@v0.3.0/cmp/options.go:229
resClaims, _ := json.Marshal(claims1)
expectClaims, _ := json.Marshal(tkn1.claims)
if diff := cmp.Diff(string(resClaims), string(expectClaims)); diff != "" {
t.Fatalf("\t%s\tExpected parsed claims to match from token. Diff:\n%s", tests.Failed, diff)
}
t.Logf("\t%s\tAuthenticate parse claims from token ok.", tests.Success)
// Try switching to a second account using the first set of claims.
tkn2, err := SwitchAccount(ctx, test.MasterDB, tknGen, claims1, acc2.ID, time.Hour, now)
if err != nil {
t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tSwitchAccount user failed.", tests.Failed)
}
t.Logf("\t%s\tSwitchAccount user ok.", tests.Success)
// Ensure the token string was correctly generated.
claims2, err := tknGen.ParseClaims(tkn2.AccessToken)
if err != nil {
t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tParse claims from token failed.", tests.Failed)
}
// Hack for Unhandled Exception in go-cmp@v0.3.0/cmp/options.go:229
resClaims, _ = json.Marshal(claims2)
expectClaims, _ = json.Marshal(tkn2.claims)
if diff := cmp.Diff(string(resClaims), string(expectClaims)); diff != "" {
t.Fatalf("\t%s\tExpected parsed claims to match from token. Diff:\n%s", tests.Failed, diff)
}
t.Logf("\t%s\tSwitchAccount parse claims from token ok.", tests.Success)
}
}
}
// TestUserUpdatePassword validates update user password works.
func TestUserUpdatePassword(t *testing.T) {
t.Log("Given the need ensure a user password can be updated.")
{
ctx := tests.Context()
now := time.Date(2018, time.October, 1, 0, 0, 0, 0, time.UTC)
tknGen := &auth.MockTokenGenerator{}
// Create a new user for testing.
usrAcc, err := user_account.MockUserAccount(ctx, test.MasterDB, now, user_account.UserAccountRole_User)
if err != nil {
t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tCreate user account failed.", tests.Failed)
}
t.Logf("\t%s\tCreate user account ok.", tests.Success)
// Verify that the user can be authenticated with the created user.
_, err = Authenticate(ctx, test.MasterDB, tknGen, usrAcc.User.Email, usrAcc.User.Password, time.Hour, now)
if err != nil {
t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tAuthenticate failed.", tests.Failed)
}
// Update the users password.
newPass := uuid.NewRandom().String()
err = user.UpdatePassword(ctx, auth.Claims{}, test.MasterDB, user.UserUpdatePasswordRequest{
ID: usrAcc.UserID,
Password: newPass,
PasswordConfirm: newPass,
}, now)
if err != nil {
t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tUpdate password failed.", tests.Failed)
}
t.Logf("\t%s\tUpdatePassword ok.", tests.Success)
// Verify that the user can be authenticated with the updated password.
_, err = Authenticate(ctx, test.MasterDB, tknGen, usrAcc.User.Email, newPass, time.Hour, now)
if err != nil {
t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tAuthenticate failed.", tests.Failed)
}
t.Logf("\t%s\tAuthenticate ok.", tests.Success)
}
}
// TestUserResetPassword validates that reset password for a user works.
func TestUserResetPassword(t *testing.T) {
t.Log("Given the need ensure a user can reset their password.")
{
ctx := tests.Context()
now := time.Date(2018, time.October, 1, 0, 0, 0, 0, time.UTC)
tknGen := &auth.MockTokenGenerator{}
// Create a new user for testing.
usrAcc, err := user_account.MockUserAccount(ctx, test.MasterDB, now, user_account.UserAccountRole_User)
if err != nil {
t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tCreate user account failed.", tests.Failed)
}
t.Logf("\t%s\tCreate user account ok.", tests.Success)
// Mock the methods needed to make a password reset.
resetUrl := func(string) string {
return ""
}
notify := &notify.MockEmail{}
secretKey := "6368616e676520746869732070617373"
ttl := time.Hour
// Make the reset password request.
resetHash, err := user.ResetPassword(ctx, test.MasterDB, resetUrl, notify, user.UserResetPasswordRequest{
Email: usrAcc.User.Email,
TTL: ttl,
}, secretKey, now)
if err != nil {
t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tResetPassword failed.", tests.Failed)
}
t.Logf("\t%s\tResetPassword ok.", tests.Success)
// Assuming we have received the email and clicked the link, we now can ensure confirm works.
newPass := uuid.NewRandom().String()
reset, err := user.ResetConfirm(ctx, test.MasterDB, user.UserResetConfirmRequest{
ResetHash: resetHash,
Password: newPass,
PasswordConfirm: newPass,
}, secretKey, now)
if err != nil {
t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tResetConfirm failed.", tests.Failed)
} else if reset.ID != usrAcc.User.ID {
t.Logf("\t\tGot : %+v", reset.ID)
t.Logf("\t\tWant: %+v", usrAcc.User.ID)
t.Fatalf("\t%s\tResetConfirm failed.", tests.Failed)
}
t.Logf("\t%s\tResetConfirm ok.", tests.Success)
// Verify that the user can be authenticated with the updated password.
_, err = Authenticate(ctx, test.MasterDB, tknGen, usrAcc.User.Email, newPass, time.Hour, now)
if err != nil {
t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tAuthenticate failed.", tests.Failed)
}
t.Logf("\t%s\tAuthenticate ok.", tests.Success)
}
}

View File

@ -0,0 +1,48 @@
package user_auth
import (
"time"
"geeks-accelerator/oss/saas-starter-kit/internal/platform/auth"
)
// AuthenticateRequest defines what information is required to authenticate a user.
type AuthenticateRequest struct {
Email string `json:"email" validate:"required,email" example:"gabi.may@geeksinthewoods.com"`
Password string `json:"password" validate:"required" example:"NeverTellSecret"`
}
// Token is the payload we deliver to users when they authenticate.
type Token struct {
// AccessToken is the token that authorizes and authenticates
// the requests.
AccessToken string `json:"access_token"`
// TokenType is the type of token.
// The Type method returns either this or "Bearer", the default.
TokenType string `json:"token_type,omitempty"`
// Expiry is the optional expiration time of the access token.
//
// If zero, TokenSource implementations will reuse the same
// token forever and RefreshToken or equivalent
// mechanisms for that TokenSource will not be used.
Expiry time.Time `json:"expiry,omitempty"`
TTL time.Duration `json:"ttl,omitempty"`
// contains filtered or unexported fields
claims auth.Claims `json:"-"`
// UserId is the ID of the user authenticated.
UserID string `json:"user_id" example:"d69bdef7-173f-4d29-b52c-3edc60baf6a2"`
// AccountID is the ID of the account for the user authenticated.
AccountID string `json:"account_id"example:"c4653bf9-5978-48b7-89c5-95704aebb7e2"`
}
// AuthorizationHeader returns the header authorization value.
func (t Token) AuthorizationHeader() string {
return "Bearer " + t.AccessToken
}
// TokenGenerator is the behavior we need in our Authenticate to generate tokens for
// authenticated users.
type TokenGenerator interface {
GenerateToken(auth.Claims) (string, error)
ParseClaims(string) (auth.Claims, error)
}