From 48ae19bd6ab62abf78eff39d34871caae6c33ede Mon Sep 17 00:00:00 2001 From: Lee Brown Date: Wed, 26 Jun 2019 20:21:00 -0800 Subject: [PATCH] checkpoint --- .../cmd/web-api/handlers/account.go | 47 +- .../cmd/web-api/handlers/project.go | 118 +- .../cmd/web-api/handlers/routes.go | 15 +- .../cmd/web-api/handlers/signup.go | 20 +- example-project/cmd/web-api/handlers/user.go | 201 ++-- .../cmd/web-api/handlers/user_account.go | 368 ++++++ .../cmd/web-api/tests/account_test.go | 339 ++++++ .../cmd/web-api/tests/signup_test.go | 185 +++ .../cmd/web-api/tests/tests_test.go | 165 ++- .../cmd/web-api/tests/user_test.go | 1000 ++++++++--------- example-project/internal/account/account.go | 24 +- example-project/internal/mid/auth.go | 67 +- example-project/internal/mid/errors.go | 10 +- .../internal/platform/web/errors.go | 37 + .../internal/platform/web/request.go | 79 +- .../internal/platform/web/response.go | 7 +- example-project/internal/project/project.go | 26 +- example-project/internal/signup/models.go | 30 +- example-project/internal/signup/signup.go | 14 +- .../internal/signup/signup_test.go | 4 +- example-project/internal/user/auth.go | 5 +- example-project/internal/user/user.go | 25 +- .../internal/user_account/user_account.go | 25 +- 23 files changed, 1952 insertions(+), 859 deletions(-) create mode 100644 example-project/cmd/web-api/handlers/user_account.go create mode 100644 example-project/cmd/web-api/tests/account_test.go create mode 100644 example-project/cmd/web-api/tests/signup_test.go diff --git a/example-project/cmd/web-api/handlers/account.go b/example-project/cmd/web-api/handlers/account.go index f985dd3..9f9bda6 100644 --- a/example-project/cmd/web-api/handlers/account.go +++ b/example-project/cmd/web-api/handlers/account.go @@ -30,7 +30,6 @@ type Account struct { // @Param id path string true "Account ID" // @Success 200 {object} account.AccountResponse // @Failure 400 {object} web.ErrorResponse -// @Failure 403 {object} web.ErrorResponse // @Failure 404 {object} web.ErrorResponse // @Failure 500 {object} web.ErrorResponse // @Router /accounts/{id} [get] @@ -40,24 +39,23 @@ func (a *Account) Read(ctx context.Context, w http.ResponseWriter, r *http.Reque return errors.New("claims missing from context") } + // Handle included-archived query value if set. var includeArchived bool - if qv := r.URL.Query().Get("include-archived"); qv != "" { - var err error - includeArchived, err = strconv.ParseBool(qv) + if v := r.URL.Query().Get("included-archived"); v != "" { + b, err := strconv.ParseBool(v) if err != nil { - return errors.Wrapf(err, "Invalid value for include-archived : %s", qv) + err = errors.WithMessagef(err, "unable to parse %s as boolean for included-archived param", v) + return web.RespondJsonError(ctx, w, web.NewRequestError(err, http.StatusBadRequest)) } + includeArchived = b } res, err := account.Read(ctx, claims, a.MasterDB, params["id"], includeArchived) if err != nil { - switch err { - case account.ErrInvalidID: - return web.NewRequestError(err, http.StatusBadRequest) + cause := errors.Cause(err) + switch cause { case account.ErrNotFound: - return web.NewRequestError(err, http.StatusNotFound) - case account.ErrForbidden: - return web.NewRequestError(err, http.StatusForbidden) + return web.RespondJsonError(ctx, w, web.NewRequestError(err, http.StatusNotFound)) default: return errors.Wrapf(err, "ID: %s", params["id"]) } @@ -74,10 +72,9 @@ func (a *Account) Read(ctx context.Context, w http.ResponseWriter, r *http.Reque // @Produce json // @Security OAuth2Password // @Param data body account.AccountUpdateRequest true "Update fields" -// @Success 201 +// @Success 204 // @Failure 400 {object} web.ErrorResponse // @Failure 403 {object} web.ErrorResponse -// @Failure 404 {object} web.ErrorResponse // @Failure 500 {object} web.ErrorResponse // @Router /accounts [patch] func (a *Account) Update(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { @@ -93,31 +90,25 @@ func (a *Account) Update(ctx context.Context, w http.ResponseWriter, r *http.Req var req account.AccountUpdateRequest if err := web.Decode(r, &req); err != nil { - err = errors.WithStack(err) - - _, ok := err.(validator.ValidationErrors) - if ok { - return web.NewRequestError(err, http.StatusBadRequest) + if _, ok := errors.Cause(err).(*web.Error); !ok { + err = web.NewRequestError(err, http.StatusBadRequest) } - return err + return web.RespondJsonError(ctx, w, err) } err := account.Update(ctx, claims, a.MasterDB, req, v.Now) if err != nil { - switch err { - case account.ErrInvalidID: - return web.NewRequestError(err, http.StatusBadRequest) - case account.ErrNotFound: - return web.NewRequestError(err, http.StatusNotFound) + cause := errors.Cause(err) + switch cause { case account.ErrForbidden: - return web.NewRequestError(err, http.StatusForbidden) + return web.RespondJsonError(ctx, w, web.NewRequestError(err, http.StatusForbidden)) default: - _, ok := err.(validator.ValidationErrors) + _, ok := cause.(validator.ValidationErrors) if ok { - return web.NewRequestError(err, http.StatusBadRequest) + return web.RespondJsonError(ctx, w, web.NewRequestError(err, http.StatusBadRequest)) } - return errors.Wrapf(err, "Id: %s Account: %+v", params["id"], &req) + return errors.Wrapf(err, "Id: %s Account: %+v", req.ID, &req) } } diff --git a/example-project/cmd/web-api/handlers/project.go b/example-project/cmd/web-api/handlers/project.go index 7abfa8c..bdd66ea 100644 --- a/example-project/cmd/web-api/handlers/project.go +++ b/example-project/cmd/web-api/handlers/project.go @@ -50,7 +50,7 @@ func (p *Project) Find(ctx context.Context, w http.ResponseWriter, r *http.Reque if v := r.URL.Query().Get("where"); v != "" { where, args, err := web.ExtractWhereArgs(v) if err != nil { - return web.NewRequestError(err, http.StatusBadRequest) + return web.RespondJsonError(ctx, w, web.NewRequestError(err, http.StatusBadRequest)) } req.Where = &where req.Args = args @@ -71,7 +71,7 @@ func (p *Project) Find(ctx context.Context, w http.ResponseWriter, r *http.Reque l, err := strconv.Atoi(v) if err != nil { err = errors.WithMessagef(err, "unable to parse %s as int for limit param", v) - return web.NewRequestError(err, http.StatusBadRequest) + return web.RespondJsonError(ctx, w, web.NewRequestError(err, http.StatusBadRequest)) } ul := uint(l) req.Limit = &ul @@ -82,22 +82,29 @@ func (p *Project) Find(ctx context.Context, w http.ResponseWriter, r *http.Reque l, err := strconv.Atoi(v) if err != nil { err = errors.WithMessagef(err, "unable to parse %s as int for offset param", v) - return web.NewRequestError(err, http.StatusBadRequest) + return web.RespondJsonError(ctx, w, web.NewRequestError(err, http.StatusBadRequest)) } ul := uint(l) req.Limit = &ul } - // Handle order query value if set. + // Handle include-archive query value if set. if v := r.URL.Query().Get("included-archived"); v != "" { b, err := strconv.ParseBool(v) if err != nil { err = errors.WithMessagef(err, "unable to parse %s as boolean for included-archived param", v) - return web.NewRequestError(err, http.StatusBadRequest) + return web.RespondJsonError(ctx, w, web.NewRequestError(err, http.StatusBadRequest)) } req.IncludedArchived = b } + //if err := web.Decode(r, &req); err != nil { + // if _, ok := errors.Cause(err).(*web.Error); !ok { + // err = web.NewRequestError(err, http.StatusBadRequest) + // } + // return web.RespondJsonError(ctx, w, err) + //} + res, err := project.Find(ctx, claims, p.MasterDB, req) if err != nil { return err @@ -121,7 +128,6 @@ func (p *Project) Find(ctx context.Context, w http.ResponseWriter, r *http.Reque // @Param id path string true "Project ID" // @Success 200 {object} project.ProjectResponse // @Failure 400 {object} web.ErrorResponse -// @Failure 403 {object} web.ErrorResponse // @Failure 404 {object} web.ErrorResponse // @Failure 500 {object} web.ErrorResponse // @Router /projects/{id} [get] @@ -131,24 +137,23 @@ func (p *Project) Read(ctx context.Context, w http.ResponseWriter, r *http.Reque return errors.New("claims missing from context") } + // Handle included-archived query value if set. var includeArchived bool - if qv := r.URL.Query().Get("include-archived"); qv != "" { - var err error - includeArchived, err = strconv.ParseBool(qv) + if v := r.URL.Query().Get("included-archived"); v != "" { + b, err := strconv.ParseBool(v) if err != nil { - return errors.Wrapf(err, "Invalid value for include-archived : %s", qv) + err = errors.WithMessagef(err, "unable to parse %s as boolean for included-archived param", v) + return web.RespondJsonError(ctx, w, web.NewRequestError(err, http.StatusBadRequest)) } + includeArchived = b } res, err := project.Read(ctx, claims, p.MasterDB, params["id"], includeArchived) if err != nil { - switch err { - case project.ErrInvalidID: - return web.NewRequestError(err, http.StatusBadRequest) + cause := errors.Cause(err) + switch cause { case project.ErrNotFound: - return web.NewRequestError(err, http.StatusNotFound) - case project.ErrForbidden: - return web.NewRequestError(err, http.StatusForbidden) + return web.RespondJsonError(ctx, w, web.NewRequestError(err, http.StatusNotFound)) default: return errors.Wrapf(err, "ID: %s", params["id"]) } @@ -165,7 +170,7 @@ func (p *Project) Read(ctx context.Context, w http.ResponseWriter, r *http.Reque // @Produce json // @Security OAuth2Password // @Param data body project.ProjectCreateRequest true "Project details" -// @Success 200 {object} project.ProjectResponse +// @Success 201 {object} project.ProjectResponse // @Failure 400 {object} web.ErrorResponse // @Failure 403 {object} web.ErrorResponse // @Failure 404 {object} web.ErrorResponse @@ -184,24 +189,22 @@ func (p *Project) Create(ctx context.Context, w http.ResponseWriter, r *http.Req var req project.ProjectCreateRequest if err := web.Decode(r, &req); err != nil { - err = errors.WithStack(err) - - _, ok := err.(validator.ValidationErrors) - if ok { - return web.NewRequestError(err, http.StatusBadRequest) + if _, ok := errors.Cause(err).(*web.Error); !ok { + err = web.NewRequestError(err, http.StatusBadRequest) } - return err + return web.RespondJsonError(ctx, w, err) } res, err := project.Create(ctx, claims, p.MasterDB, req, v.Now) if err != nil { - switch err { + cause := errors.Cause(err) + switch cause { case project.ErrForbidden: - return web.NewRequestError(err, http.StatusForbidden) + return web.RespondJsonError(ctx, w, web.NewRequestError(err, http.StatusForbidden)) default: - _, ok := err.(validator.ValidationErrors) + _, ok := cause.(validator.ValidationErrors) if ok { - return web.NewRequestError(err, http.StatusBadRequest) + return web.RespondJsonError(ctx, w, web.NewRequestError(err, http.StatusBadRequest)) } return errors.Wrapf(err, "Project: %+v", &req) } @@ -218,10 +221,9 @@ func (p *Project) Create(ctx context.Context, w http.ResponseWriter, r *http.Req // @Produce json // @Security OAuth2Password // @Param data body project.ProjectUpdateRequest true "Update fields" -// @Success 201 +// @Success 204 // @Failure 400 {object} web.ErrorResponse // @Failure 403 {object} web.ErrorResponse -// @Failure 404 {object} web.ErrorResponse // @Failure 500 {object} web.ErrorResponse // @Router /projects [patch] func (p *Project) Update(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { @@ -237,29 +239,24 @@ func (p *Project) Update(ctx context.Context, w http.ResponseWriter, r *http.Req var req project.ProjectUpdateRequest if err := web.Decode(r, &req); err != nil { - err = errors.WithStack(err) - - _, ok := err.(validator.ValidationErrors) - if ok { - return web.NewRequestError(err, http.StatusBadRequest) + if _, ok := errors.Cause(err).(*web.Error); !ok { + err = web.NewRequestError(err, http.StatusBadRequest) } - return err + return web.RespondJsonError(ctx, w, err) } err := project.Update(ctx, claims, p.MasterDB, req, v.Now) if err != nil { - switch err { - case project.ErrInvalidID: - return web.NewRequestError(err, http.StatusBadRequest) - case project.ErrNotFound: - return web.NewRequestError(err, http.StatusNotFound) + cause := errors.Cause(err) + switch cause { case project.ErrForbidden: - return web.NewRequestError(err, http.StatusForbidden) + return web.RespondJsonError(ctx, w, web.NewRequestError(err, http.StatusForbidden)) default: - _, ok := err.(validator.ValidationErrors) + _, ok := cause.(validator.ValidationErrors) if ok { - return web.NewRequestError(err, http.StatusBadRequest) + return web.RespondJsonError(ctx, w, web.NewRequestError(err, http.StatusBadRequest)) } + return errors.Wrapf(err, "ID: %s Update: %+v", req.ID, req) } } @@ -275,7 +272,7 @@ func (p *Project) Update(ctx context.Context, w http.ResponseWriter, r *http.Req // @Produce json // @Security OAuth2Password // @Param data body project.ProjectArchiveRequest true "Update fields" -// @Success 201 +// @Success 204 // @Failure 400 {object} web.ErrorResponse // @Failure 403 {object} web.ErrorResponse // @Failure 404 {object} web.ErrorResponse @@ -294,28 +291,24 @@ func (p *Project) Archive(ctx context.Context, w http.ResponseWriter, r *http.Re var req project.ProjectArchiveRequest if err := web.Decode(r, &req); err != nil { - err = errors.WithStack(err) - - _, ok := err.(validator.ValidationErrors) - if ok { - return web.NewRequestError(err, http.StatusBadRequest) + if _, ok := errors.Cause(err).(*web.Error); !ok { + err = web.NewRequestError(err, http.StatusBadRequest) } - return err + return web.RespondJsonError(ctx, w, err) } err := project.Archive(ctx, claims, p.MasterDB, req, v.Now) if err != nil { - switch err { - case project.ErrInvalidID: - return web.NewRequestError(err, http.StatusBadRequest) + cause := errors.Cause(err) + switch cause { case project.ErrNotFound: - return web.NewRequestError(err, http.StatusNotFound) + return web.RespondJsonError(ctx, w, web.NewRequestError(err, http.StatusNotFound)) case project.ErrForbidden: - return web.NewRequestError(err, http.StatusForbidden) + return web.RespondJsonError(ctx, w, web.NewRequestError(err, http.StatusForbidden)) default: - _, ok := err.(validator.ValidationErrors) + _, ok := cause.(validator.ValidationErrors) if ok { - return web.NewRequestError(err, http.StatusBadRequest) + return web.RespondJsonError(ctx, w, web.NewRequestError(err, http.StatusBadRequest)) } return errors.Wrapf(err, "Id: %s", req.ID) @@ -333,7 +326,7 @@ func (p *Project) Archive(ctx context.Context, w http.ResponseWriter, r *http.Re // @Produce json // @Security OAuth2Password // @Param id path string true "Project ID" -// @Success 201 +// @Success 204 // @Failure 400 {object} web.ErrorResponse // @Failure 403 {object} web.ErrorResponse // @Failure 404 {object} web.ErrorResponse @@ -347,13 +340,12 @@ func (p *Project) Delete(ctx context.Context, w http.ResponseWriter, r *http.Req err := project.Delete(ctx, claims, p.MasterDB, params["id"]) if err != nil { - switch err { - case project.ErrInvalidID: - return web.NewRequestError(err, http.StatusBadRequest) + cause := errors.Cause(err) + switch cause { case project.ErrNotFound: - return web.NewRequestError(err, http.StatusNotFound) + return web.RespondJsonError(ctx, w, web.NewRequestError(err, http.StatusNotFound)) case project.ErrForbidden: - return web.NewRequestError(err, http.StatusForbidden) + return web.RespondJsonError(ctx, w, web.NewRequestError(err, http.StatusForbidden)) default: return errors.Wrapf(err, "Id: %s", params["id"]) } diff --git a/example-project/cmd/web-api/handlers/routes.go b/example-project/cmd/web-api/handlers/routes.go index 38d18a5..a9119da 100644 --- a/example-project/cmd/web-api/handlers/routes.go +++ b/example-project/cmd/web-api/handlers/routes.go @@ -30,12 +30,11 @@ func API(shutdown chan os.Signal, log *log.Logger, masterDB *sqlx.DB, redis *red MasterDB: masterDB, TokenGenerator: authenticator, } - app.Handle("GET", "/v1/users", u.Find, mid.Authenticate(authenticator)) app.Handle("POST", "/v1/users", u.Create, mid.Authenticate(authenticator), mid.HasRole(auth.RoleAdmin)) app.Handle("GET", "/v1/users/:id", u.Read, mid.Authenticate(authenticator)) app.Handle("PATCH", "/v1/users", u.Update, mid.Authenticate(authenticator)) - app.Handle("PATCH", "/v1/users/password", u.UpdatePassword, mid.Authenticate(authenticator), mid.HasRole(auth.RoleAdmin)) + app.Handle("PATCH", "/v1/users/password", u.UpdatePassword, mid.Authenticate(authenticator)) app.Handle("PATCH", "/v1/users/archive", u.Archive, mid.Authenticate(authenticator), mid.HasRole(auth.RoleAdmin)) app.Handle("DELETE", "/v1/users/:id", u.Delete, mid.Authenticate(authenticator), mid.HasRole(auth.RoleAdmin)) app.Handle("PATCH", "/v1/users/switch-account/:account_id", u.SwitchAccount, mid.Authenticate(authenticator)) @@ -43,6 +42,18 @@ func API(shutdown chan os.Signal, log *log.Logger, masterDB *sqlx.DB, redis *red // This route is not authenticated app.Handle("POST", "/v1/oauth/token", u.Token) + + // Register user account management endpoints. + ua := UserAccount{ + MasterDB: masterDB, + } + app.Handle("GET", "/v1/user_accounts", ua.Find, mid.Authenticate(authenticator)) + app.Handle("POST", "/v1/user_accounts", ua.Create, mid.Authenticate(authenticator), mid.HasRole(auth.RoleAdmin)) + app.Handle("GET", "/v1/user_accounts/:id", ua.Read, mid.Authenticate(authenticator)) + app.Handle("PATCH", "/v1/user_accounts", ua.Update, mid.Authenticate(authenticator)) + app.Handle("PATCH", "/v1/user_accounts/archive", ua.Archive, mid.Authenticate(authenticator), mid.HasRole(auth.RoleAdmin)) + app.Handle("DELETE", "/v1/user_accounts", ua.Delete, mid.Authenticate(authenticator), mid.HasRole(auth.RoleAdmin)) + // Register account endpoints. a := Account{ MasterDB: masterDB, diff --git a/example-project/cmd/web-api/handlers/signup.go b/example-project/cmd/web-api/handlers/signup.go index 5187676..5719f73 100644 --- a/example-project/cmd/web-api/handlers/signup.go +++ b/example-project/cmd/web-api/handlers/signup.go @@ -27,9 +27,8 @@ type Signup struct { // @Accept json // @Produce json // @Param data body signup.SignupRequest true "Signup details" -// @Success 200 {object} signup.SignupResponse +// @Success 201 {object} signup.SignupResponse // @Failure 400 {object} web.ErrorResponse -// @Failure 403 {object} web.ErrorResponse // @Failure 500 {object} web.ErrorResponse // @Router /signup [post] func (c *Signup) Signup(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { @@ -43,29 +42,26 @@ func (c *Signup) Signup(ctx context.Context, w http.ResponseWriter, r *http.Requ var req signup.SignupRequest if err := web.Decode(r, &req); err != nil { - err = errors.WithStack(err) - - _, ok := err.(validator.ValidationErrors) - if ok { - return web.NewRequestError(err, http.StatusBadRequest) + if _, ok := errors.Cause(err).(*web.Error); !ok { + err = web.NewRequestError(err, http.StatusBadRequest) } - return err + return web.RespondJsonError(ctx, w, err) } res, err := signup.Signup(ctx, claims, c.MasterDB, req, v.Now) if err != nil { - switch err { + switch errors.Cause(err) { case account.ErrForbidden: - return web.NewRequestError(err, http.StatusForbidden) + return web.RespondJsonError(ctx, w, web.NewRequestError(err, http.StatusForbidden)) default: _, ok := err.(validator.ValidationErrors) if ok { - return web.NewRequestError(err, http.StatusBadRequest) + return web.RespondJsonError(ctx, w, web.NewRequestError(err, http.StatusBadRequest)) } return errors.Wrapf(err, "Signup: %+v", &req) } } - return web.RespondJson(ctx, w, res, http.StatusCreated) + return web.RespondJson(ctx, w, res.Response(ctx), http.StatusCreated) } diff --git a/example-project/cmd/web-api/handlers/user.go b/example-project/cmd/web-api/handlers/user.go index 25eb145..7ca8e86 100644 --- a/example-project/cmd/web-api/handlers/user.go +++ b/example-project/cmd/web-api/handlers/user.go @@ -2,7 +2,6 @@ package handlers import ( "context" - "geeks-accelerator/oss/saas-starter-kit/example-project/internal/user_account" "net/http" "strconv" "strings" @@ -56,7 +55,7 @@ func (u *User) Find(ctx context.Context, w http.ResponseWriter, r *http.Request, if v := r.URL.Query().Get("where"); v != "" { where, args, err := web.ExtractWhereArgs(v) if err != nil { - return web.NewRequestError(err, http.StatusBadRequest) + return web.RespondJsonError(ctx, w, web.NewRequestError(err, http.StatusBadRequest)) } req.Where = &where req.Args = args @@ -77,7 +76,7 @@ func (u *User) Find(ctx context.Context, w http.ResponseWriter, r *http.Request, l, err := strconv.Atoi(v) if err != nil { err = errors.WithMessagef(err, "unable to parse %s as int for limit param", v) - return web.NewRequestError(err, http.StatusBadRequest) + return web.RespondJsonError(ctx, w, web.NewRequestError(err, http.StatusBadRequest)) } ul := uint(l) req.Limit = &ul @@ -88,31 +87,28 @@ func (u *User) Find(ctx context.Context, w http.ResponseWriter, r *http.Request, l, err := strconv.Atoi(v) if err != nil { err = errors.WithMessagef(err, "unable to parse %s as int for offset param", v) - return web.NewRequestError(err, http.StatusBadRequest) + return web.RespondJsonError(ctx, w, web.NewRequestError(err, http.StatusBadRequest)) } ul := uint(l) req.Limit = &ul } - // Handle order query value if set. + // Handle included-archived query value if set. if v := r.URL.Query().Get("included-archived"); v != "" { b, err := strconv.ParseBool(v) if err != nil { err = errors.WithMessagef(err, "unable to parse %s as boolean for included-archived param", v) - return web.NewRequestError(err, http.StatusBadRequest) + return web.RespondJsonError(ctx, w, web.NewRequestError(err, http.StatusBadRequest)) } req.IncludedArchived = b } - if err := web.Decode(r, &req); err != nil { - err = errors.WithStack(err) - - _, ok := err.(validator.ValidationErrors) - if ok { - return web.NewRequestError(err, http.StatusBadRequest) - } - return err - } + //if err := web.Decode(r, &req); err != nil { + // if _, ok := errors.Cause(err).(*web.Error); !ok { + // err = web.NewRequestError(err, http.StatusBadRequest) + // } + // return web.RespondJsonError(ctx, w, err) + //} res, err := user.Find(ctx, claims, u.MasterDB, req) if err != nil { @@ -137,7 +133,6 @@ func (u *User) Find(ctx context.Context, w http.ResponseWriter, r *http.Request, // @Param id path string true "User ID" // @Success 200 {object} user.UserResponse // @Failure 400 {object} web.ErrorResponse -// @Failure 403 {object} web.ErrorResponse // @Failure 404 {object} web.ErrorResponse // @Failure 500 {object} web.ErrorResponse // @Router /users/{id} [get] @@ -147,30 +142,24 @@ func (u *User) Read(ctx context.Context, w http.ResponseWriter, r *http.Request, return errors.New("claims missing from context") } + // Handle included-archived query value if set. var includeArchived bool - if qv := r.URL.Query().Get("include-archived"); qv != "" { - var err error - includeArchived, err = strconv.ParseBool(qv) + if v := r.URL.Query().Get("included-archived"); v != "" { + b, err := strconv.ParseBool(v) if err != nil { - return errors.Wrapf(err, "Invalid value for include-archived : %s", qv) + err = errors.WithMessagef(err, "unable to parse %s as boolean for included-archived param", v) + return web.RespondJsonError(ctx, w, web.NewRequestError(err, http.StatusBadRequest)) } + includeArchived = b } res, err := user.Read(ctx, claims, u.MasterDB, params["id"], includeArchived) if err != nil { - switch err { - case user.ErrInvalidID: - return web.NewRequestError(err, http.StatusBadRequest) + cause := errors.Cause(err) + switch cause { case user.ErrNotFound: - return web.NewRequestError(err, http.StatusNotFound) - case user.ErrForbidden: - return web.NewRequestError(err, http.StatusForbidden) + return web.RespondJsonError(ctx, w, web.NewRequestError(err, http.StatusNotFound)) default: - _, ok := err.(validator.ValidationErrors) - if ok { - return web.NewRequestError(err, http.StatusBadRequest) - } - return errors.Wrapf(err, "ID: %s", params["id"]) } } @@ -186,10 +175,9 @@ func (u *User) Read(ctx context.Context, w http.ResponseWriter, r *http.Request, // @Produce json // @Security OAuth2Password // @Param data body user.UserCreateRequest true "User details" -// @Success 200 {object} user.UserResponse +// @Success 201 {object} user.UserResponse // @Failure 400 {object} web.ErrorResponse // @Failure 403 {object} web.ErrorResponse -// @Failure 404 {object} web.ErrorResponse // @Failure 500 {object} web.ErrorResponse // @Router /users [post] func (u *User) Create(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { @@ -205,53 +193,28 @@ func (u *User) Create(ctx context.Context, w http.ResponseWriter, r *http.Reques var req user.UserCreateRequest if err := web.Decode(r, &req); err != nil { - err = errors.WithStack(err) - - _, ok := err.(validator.ValidationErrors) - if ok { - return web.NewRequestError(err, http.StatusBadRequest) + if _, ok := errors.Cause(err).(*web.Error); !ok { + err = web.NewRequestError(err, http.StatusBadRequest) } - return err + return web.RespondJsonError(ctx, w, err) } res, err := user.Create(ctx, claims, u.MasterDB, req, v.Now) if err != nil { - switch err { + cause := errors.Cause(err) + switch cause { case user.ErrForbidden: - return web.NewRequestError(err, http.StatusForbidden) + return web.RespondJsonError(ctx, w, web.NewRequestError(err, http.StatusForbidden)) default: - _, ok := err.(validator.ValidationErrors) + _, ok := cause.(validator.ValidationErrors) if ok { - return web.NewRequestError(err, http.StatusBadRequest) + return web.RespondJsonError(ctx, w, web.NewRequestError(err, http.StatusBadRequest)) } return errors.Wrapf(err, "User: %+v", &req) } } - if claims.Audience != "" { - uaReq := user_account.UserAccountCreateRequest{ - UserID: resp.User.ID, - AccountID: resp.Account.ID, - Roles: []user_account.UserAccountRole{user_account.UserAccountRole_Admin}, - //Status: Use default value - } - _, err = user_account.Create(ctx, claims, u.MasterDB, uaReq, v.Now) - if err != nil { - switch err { - case user.ErrForbidden: - return web.NewRequestError(err, http.StatusForbidden) - default: - _, ok := err.(validator.ValidationErrors) - if ok { - return web.NewRequestError(err, http.StatusBadRequest) - } - - return errors.Wrapf(err, "User account: %+v", &req) - } - } - } - return web.RespondJson(ctx, w, res.Response(ctx), http.StatusCreated) } @@ -263,10 +226,9 @@ func (u *User) Create(ctx context.Context, w http.ResponseWriter, r *http.Reques // @Produce json // @Security OAuth2Password // @Param data body user.UserUpdateRequest true "Update fields" -// @Success 201 +// @Success 204 // @Failure 400 {object} web.ErrorResponse // @Failure 403 {object} web.ErrorResponse -// @Failure 404 {object} web.ErrorResponse // @Failure 500 {object} web.ErrorResponse // @Router /users [patch] func (u *User) Update(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { @@ -282,31 +244,25 @@ func (u *User) Update(ctx context.Context, w http.ResponseWriter, r *http.Reques var req user.UserUpdateRequest if err := web.Decode(r, &req); err != nil { - err = errors.WithStack(err) - - _, ok := err.(validator.ValidationErrors) - if ok { - return web.NewRequestError(err, http.StatusBadRequest) + if _, ok := errors.Cause(err).(*web.Error); !ok { + err = web.NewRequestError(err, http.StatusBadRequest) } - return err + return web.RespondJsonError(ctx, w, err) } err := user.Update(ctx, claims, u.MasterDB, req, v.Now) if err != nil { - switch err { - case user.ErrInvalidID: - return web.NewRequestError(err, http.StatusBadRequest) - case user.ErrNotFound: - return web.NewRequestError(err, http.StatusNotFound) + cause := errors.Cause(err) + switch cause { case user.ErrForbidden: - return web.NewRequestError(err, http.StatusForbidden) + return web.RespondJsonError(ctx, w, web.NewRequestError(err, http.StatusForbidden)) default: - _, ok := err.(validator.ValidationErrors) + _, ok := cause.(validator.ValidationErrors) if ok { - return web.NewRequestError(err, http.StatusBadRequest) + return web.RespondJsonError(ctx, w, web.NewRequestError(err, http.StatusBadRequest)) } - return errors.Wrapf(err, "Id: %s User: %+v", req.ID, &req) + return errors.Wrapf(err, "Id: %s User: %+v", req.ID, &req) } } @@ -321,10 +277,9 @@ func (u *User) Update(ctx context.Context, w http.ResponseWriter, r *http.Reques // @Produce json // @Security OAuth2Password // @Param data body user.UserUpdatePasswordRequest true "Update fields" -// @Success 201 +// @Success 204 // @Failure 400 {object} web.ErrorResponse // @Failure 403 {object} web.ErrorResponse -// @Failure 404 {object} web.ErrorResponse // @Failure 500 {object} web.ErrorResponse // @Router /users/password [patch] func (u *User) UpdatePassword(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { @@ -340,28 +295,24 @@ func (u *User) UpdatePassword(ctx context.Context, w http.ResponseWriter, r *htt var req user.UserUpdatePasswordRequest if err := web.Decode(r, &req); err != nil { - err = errors.WithStack(err) - - _, ok := err.(validator.ValidationErrors) - if ok { - return web.NewRequestError(err, http.StatusBadRequest) + if _, ok := errors.Cause(err).(*web.Error); !ok { + err = web.NewRequestError(err, http.StatusBadRequest) } - return err + return web.RespondJsonError(ctx, w, err) } err := user.UpdatePassword(ctx, claims, u.MasterDB, req, v.Now) if err != nil { - switch err { - case user.ErrInvalidID: - return web.NewRequestError(err, http.StatusBadRequest) + cause := errors.Cause(err) + switch cause { case user.ErrNotFound: - return web.NewRequestError(err, http.StatusNotFound) + return web.RespondJsonError(ctx, w, web.NewRequestError(err, http.StatusNotFound)) case user.ErrForbidden: - return web.NewRequestError(err, http.StatusForbidden) + return web.RespondJsonError(ctx, w, web.NewRequestError(err, http.StatusForbidden)) default: - _, ok := err.(validator.ValidationErrors) + _, ok := cause.(validator.ValidationErrors) if ok { - return web.NewRequestError(err, http.StatusBadRequest) + return web.RespondJsonError(ctx, w, web.NewRequestError(err, http.StatusBadRequest)) } return errors.Wrapf(err, "Id: %s User: %+v", req.ID, &req) @@ -379,7 +330,7 @@ func (u *User) UpdatePassword(ctx context.Context, w http.ResponseWriter, r *htt // @Produce json // @Security OAuth2Password // @Param data body user.UserArchiveRequest true "Update fields" -// @Success 201 +// @Success 204 // @Failure 400 {object} web.ErrorResponse // @Failure 403 {object} web.ErrorResponse // @Failure 404 {object} web.ErrorResponse @@ -398,28 +349,24 @@ func (u *User) Archive(ctx context.Context, w http.ResponseWriter, r *http.Reque var req user.UserArchiveRequest if err := web.Decode(r, &req); err != nil { - err = errors.WithStack(err) - - _, ok := err.(validator.ValidationErrors) - if ok { - return web.NewRequestError(err, http.StatusBadRequest) + if _, ok := errors.Cause(err).(*web.Error); !ok { + err = web.NewRequestError(err, http.StatusBadRequest) } - return err + return web.RespondJsonError(ctx, w, err) } err := user.Archive(ctx, claims, u.MasterDB, req, v.Now) if err != nil { - switch err { - case user.ErrInvalidID: - return web.NewRequestError(err, http.StatusBadRequest) + cause := errors.Cause(err) + switch cause { case user.ErrNotFound: - return web.NewRequestError(err, http.StatusNotFound) + return web.RespondJsonError(ctx, w, web.NewRequestError(err, http.StatusNotFound)) case user.ErrForbidden: - return web.NewRequestError(err, http.StatusForbidden) + return web.RespondJsonError(ctx, w, web.NewRequestError(err, http.StatusForbidden)) default: - _, ok := err.(validator.ValidationErrors) + _, ok := cause.(validator.ValidationErrors) if ok { - return web.NewRequestError(err, http.StatusBadRequest) + return web.RespondJsonError(ctx, w, web.NewRequestError(err, http.StatusBadRequest)) } return errors.Wrapf(err, "Id: %s", req.ID) @@ -437,7 +384,7 @@ func (u *User) Archive(ctx context.Context, w http.ResponseWriter, r *http.Reque // @Produce json // @Security OAuth2Password // @Param id path string true "User ID" -// @Success 201 +// @Success 204 // @Failure 400 {object} web.ErrorResponse // @Failure 403 {object} web.ErrorResponse // @Failure 404 {object} web.ErrorResponse @@ -451,15 +398,14 @@ func (u *User) Delete(ctx context.Context, w http.ResponseWriter, r *http.Reques err := user.Delete(ctx, claims, u.MasterDB, params["id"]) if err != nil { - switch err { - case user.ErrInvalidID: - return web.NewRequestError(err, http.StatusBadRequest) + cause := errors.Cause(err) + switch cause { case user.ErrNotFound: - return web.NewRequestError(err, http.StatusNotFound) + return web.RespondJsonError(ctx, w, web.NewRequestError(err, http.StatusNotFound)) case user.ErrForbidden: - return web.NewRequestError(err, http.StatusForbidden) + return web.RespondJsonError(ctx, w, web.NewRequestError(err, http.StatusForbidden)) default: - return errors.Wrapf(err, "Id: %s", params["id"]) + return errors.Wrapf(cause, "Id: %s", params["id"]) } } @@ -493,13 +439,14 @@ func (u *User) SwitchAccount(ctx context.Context, w http.ResponseWriter, r *http tkn, err := user.SwitchAccount(ctx, u.MasterDB, u.TokenGenerator, claims, params["account_id"], sessionTtl, v.Now) if err != nil { - switch err { + cause := errors.Cause(err) + switch cause { case user.ErrAuthenticationFailure: - return web.NewRequestError(err, http.StatusUnauthorized) + return web.RespondJsonError(ctx, w, web.NewRequestError(err, http.StatusUnauthorized)) default: - _, ok := err.(validator.ValidationErrors) + _, ok := cause.(validator.ValidationErrors) if ok { - return web.NewRequestError(err, http.StatusBadRequest) + return web.RespondJsonError(ctx, w, web.NewRequestError(err, http.StatusBadRequest)) } return errors.Wrap(err, "switch account") @@ -521,7 +468,6 @@ func (u *User) SwitchAccount(ctx context.Context, w http.ResponseWriter, r *http // @Header 200 {string} Token "qwerty" // @Failure 400 {object} web.Error // @Failure 403 {object} web.Error -// @Failure 404 {object} web.Error // @Router /oauth/token [post] func (u *User) Token(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { v, ok := ctx.Value(web.KeyValues).(*web.Values) @@ -532,7 +478,7 @@ func (u *User) Token(ctx context.Context, w http.ResponseWriter, r *http.Request email, pass, ok := r.BasicAuth() if !ok { err := errors.New("must provide email and password in Basic auth") - return web.NewRequestError(err, http.StatusUnauthorized) + return web.RespondJsonError(ctx, w, web.NewRequestError(err, http.StatusUnauthorized)) } // Optional to include scope. @@ -540,9 +486,10 @@ func (u *User) Token(ctx context.Context, w http.ResponseWriter, r *http.Request tkn, err := user.Authenticate(ctx, u.MasterDB, u.TokenGenerator, email, pass, sessionTtl, v.Now, scope) if err != nil { - switch err { + cause := errors.Cause(err) + switch cause { case user.ErrAuthenticationFailure: - return web.NewRequestError(err, http.StatusUnauthorized) + return web.RespondJsonError(ctx, w, web.NewRequestError(err, http.StatusUnauthorized)) default: return errors.Wrap(err, "authenticating") } diff --git a/example-project/cmd/web-api/handlers/user_account.go b/example-project/cmd/web-api/handlers/user_account.go new file mode 100644 index 0000000..3fb182b --- /dev/null +++ b/example-project/cmd/web-api/handlers/user_account.go @@ -0,0 +1,368 @@ +package handlers + +import ( + "context" + "geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/auth" + "geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/web" + "geeks-accelerator/oss/saas-starter-kit/example-project/internal/user_account" + "github.com/jmoiron/sqlx" + "github.com/pkg/errors" + "gopkg.in/go-playground/validator.v9" + "net/http" + "strconv" + "strings" +) + +// UserAccount represents the UserAccount API method handler set. +type UserAccount struct { + MasterDB *sqlx.DB + + // ADD OTHER STATE LIKE THE LOGGER AND CONFIG HERE. +} + +// Find godoc +// @Summary List user accounts +// @Description Find returns the existing user accounts in the system. +// @Tags user_account +// @Accept json +// @Produce json +// @Security OAuth2Password +// @Param where query string false "Filter string, example: account_id = 'c4653bf9-5978-48b7-89c5-95704aebb7e2'" +// @Param order query string false "Order columns separated by comma, example: created_at desc" +// @Param limit query integer false "Limit, example: 10" +// @Param offset query integer false "Offset, example: 20" +// @Param included-archived query boolean false "Included Archived, example: false" +// @Success 200 {array} user_account.UserAccountResponse +// @Failure 400 {object} web.ErrorResponse +// @Failure 403 {object} web.ErrorResponse +// @Failure 500 {object} web.ErrorResponse +// @Router /user_accounts [get] +func (u *UserAccount) Find(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { + claims, ok := ctx.Value(auth.Key).(auth.Claims) + if !ok { + return errors.New("claims missing from context") + } + + var req user_account.UserAccountFindRequest + + // Handle where query value if set. + if v := r.URL.Query().Get("where"); v != "" { + where, args, err := web.ExtractWhereArgs(v) + if err != nil { + return web.RespondJsonError(ctx, w, web.NewRequestError(err, http.StatusBadRequest)) + } + req.Where = &where + req.Args = args + } + + // Handle order query value if set. + if v := r.URL.Query().Get("order"); v != "" { + for _, o := range strings.Split(v, ",") { + o = strings.TrimSpace(o) + if o != "" { + req.Order = append(req.Order, o) + } + } + } + + // Handle limit query value if set. + if v := r.URL.Query().Get("limit"); v != "" { + l, err := strconv.Atoi(v) + if err != nil { + err = errors.WithMessagef(err, "unable to parse %s as int for limit param", v) + return web.RespondJsonError(ctx, w, web.NewRequestError(err, http.StatusBadRequest)) + } + ul := uint(l) + req.Limit = &ul + } + + // Handle offset query value if set. + if v := r.URL.Query().Get("offset"); v != "" { + l, err := strconv.Atoi(v) + if err != nil { + err = errors.WithMessagef(err, "unable to parse %s as int for offset param", v) + return web.RespondJsonError(ctx, w, web.NewRequestError(err, http.StatusBadRequest)) + } + ul := uint(l) + req.Limit = &ul + } + + // Handle order query value if set. + if v := r.URL.Query().Get("included-archived"); v != "" { + b, err := strconv.ParseBool(v) + if err != nil { + err = errors.WithMessagef(err, "unable to parse %s as boolean for included-archived param", v) + return web.RespondJsonError(ctx, w, web.NewRequestError(err, http.StatusBadRequest)) + } + req.IncludedArchived = b + } + + //if err := web.Decode(r, &req); err != nil { + // if _, ok := errors.Cause(err).(*web.Error); !ok { + // err = web.NewRequestError(err, http.StatusBadRequest) + // } + // return web.RespondJsonError(ctx, w, err) + //} + + res, err := user_account.Find(ctx, claims, u.MasterDB, req) + if err != nil { + return err + } + + var resp []*user_account.UserAccountResponse + for _, m := range res { + resp = append(resp, m.Response(ctx)) + } + + return web.RespondJson(ctx, w, resp, http.StatusOK) +} + +// Read godoc +// @Summary Get user account by ID +// @Description Read returns the specified user account from the system. +// @Tags user_account +// @Accept json +// @Produce json +// @Security OAuth2Password +// @Param id path string true "UserAccount ID" +// @Success 200 {object} user_account.UserAccountResponse +// @Failure 400 {object} web.ErrorResponse +// @Failure 404 {object} web.ErrorResponse +// @Failure 500 {object} web.ErrorResponse +// @Router /user_accounts/{id} [get] +func (u *UserAccount) Read(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { + claims, ok := ctx.Value(auth.Key).(auth.Claims) + if !ok { + return errors.New("claims missing from context") + } + + // Handle included-archived query value if set. + var includeArchived bool + if v := r.URL.Query().Get("included-archived"); v != "" { + b, err := strconv.ParseBool(v) + if err != nil { + err = errors.WithMessagef(err, "unable to parse %s as boolean for included-archived param", v) + return web.RespondJsonError(ctx, w, web.NewRequestError(err, http.StatusBadRequest)) + } + includeArchived = b + } + + res, err := user_account.Read(ctx, claims, u.MasterDB, params["id"], includeArchived) + if err != nil { + cause := errors.Cause(err) + switch cause { + case user_account.ErrNotFound: + return web.RespondJsonError(ctx, w, web.NewRequestError(err, http.StatusNotFound)) + default: + return errors.Wrapf(err, "ID: %s", params["id"]) + } + } + + return web.RespondJson(ctx, w, res.Response(ctx), http.StatusOK) +} + +// Create godoc +// @Summary Create new user account. +// @Description Create inserts a new user account into the system. +// @Tags user_account +// @Accept json +// @Produce json +// @Security OAuth2Password +// @Param data body user_account.UserAccountCreateRequest true "User Account details" +// @Success 201 {object} user_account.UserAccountResponse +// @Failure 400 {object} web.ErrorResponse +// @Failure 403 {object} web.ErrorResponse +// @Failure 404 {object} web.ErrorResponse +// @Failure 500 {object} web.ErrorResponse +// @Router /user_accounts [post] +func (u *UserAccount) Create(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { + v, ok := ctx.Value(web.KeyValues).(*web.Values) + if !ok { + return web.NewShutdownError("web value missing from context") + } + + claims, ok := ctx.Value(auth.Key).(auth.Claims) + if !ok { + return errors.New("claims missing from context") + } + + var req user_account.UserAccountCreateRequest + if err := web.Decode(r, &req); err != nil { + if _, ok := errors.Cause(err).(*web.Error); !ok { + err = web.NewRequestError(err, http.StatusBadRequest) + } + return web.RespondJsonError(ctx, w, err) + } + + res, err := user_account.Create(ctx, claims, u.MasterDB, req, v.Now) + if err != nil { + cause := errors.Cause(err) + switch cause { + case user_account.ErrForbidden: + return web.RespondJsonError(ctx, w, web.NewRequestError(err, http.StatusForbidden)) + default: + _, ok := cause.(validator.ValidationErrors) + if ok { + return web.RespondJsonError(ctx, w, web.NewRequestError(err, http.StatusBadRequest)) + } + + return errors.Wrapf(err, "User Account: %+v", &req) + } + } + + return web.RespondJson(ctx, w, res.Response(ctx), http.StatusCreated) +} + +// Read godoc +// @Summary Update user account by user ID and account ID +// @Description Update updates the specified user account in the system. +// @Tags user +// @Accept json +// @Produce json +// @Security OAuth2Password +// @Param data body user_account.UserAccountUpdateRequest true "Update fields" +// @Success 204 +// @Failure 400 {object} web.ErrorResponse +// @Failure 403 {object} web.ErrorResponse +// @Failure 500 {object} web.ErrorResponse +// @Router /user_accounts [patch] +func (u *UserAccount) Update(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { + v, ok := ctx.Value(web.KeyValues).(*web.Values) + if !ok { + return web.NewShutdownError("web value missing from context") + } + + claims, ok := ctx.Value(auth.Key).(auth.Claims) + if !ok { + return errors.New("claims missing from context") + } + + var req user_account.UserAccountUpdateRequest + if err := web.Decode(r, &req); err != nil { + if _, ok := errors.Cause(err).(*web.Error); !ok { + err = web.NewRequestError(err, http.StatusBadRequest) + } + return web.RespondJsonError(ctx, w, err) + } + + err := user_account.Update(ctx, claims, u.MasterDB, req, v.Now) + if err != nil { + cause := errors.Cause(err) + switch cause { + case user_account.ErrForbidden: + return web.RespondJsonError(ctx, w, web.NewRequestError(err, http.StatusForbidden)) + default: + _, ok := cause.(validator.ValidationErrors) + if ok { + return web.RespondJsonError(ctx, w, web.NewRequestError(err, http.StatusBadRequest)) + } + + return errors.Wrapf(err, "UserID: %s AccountID: %s User Account: %+v", req.UserID, req.AccountID, &req) + } + } + + return web.RespondJson(ctx, w, nil, http.StatusNoContent) +} + +// Read godoc +// @Summary Archive user account by user ID and account ID +// @Description Archive soft-deletes the specified user account from the system. +// @Tags user +// @Accept json +// @Produce json +// @Security OAuth2Password +// @Param data body user_account.UserAccountArchiveRequest true "Update fields" +// @Success 204 +// @Failure 400 {object} web.ErrorResponse +// @Failure 403 {object} web.ErrorResponse +// @Failure 404 {object} web.ErrorResponse +// @Failure 500 {object} web.ErrorResponse +// @Router /user_accounts/archive [patch] +func (u *UserAccount) Archive(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { + v, ok := ctx.Value(web.KeyValues).(*web.Values) + if !ok { + return web.NewShutdownError("web value missing from context") + } + + claims, ok := ctx.Value(auth.Key).(auth.Claims) + if !ok { + return errors.New("claims missing from context") + } + + var req user_account.UserAccountArchiveRequest + if err := web.Decode(r, &req); err != nil { + if _, ok := errors.Cause(err).(*web.Error); !ok { + err = web.NewRequestError(err, http.StatusBadRequest) + } + return web.RespondJsonError(ctx, w, err) + } + + err := user_account.Archive(ctx, claims, u.MasterDB, req, v.Now) + if err != nil { + cause := errors.Cause(err) + switch cause { + case user_account.ErrNotFound: + return web.RespondJsonError(ctx, w, web.NewRequestError(err, http.StatusNotFound)) + case user_account.ErrForbidden: + return web.RespondJsonError(ctx, w, web.NewRequestError(err, http.StatusForbidden)) + default: + _, ok := cause.(validator.ValidationErrors) + if ok { + return web.RespondJsonError(ctx, w, web.NewRequestError(err, http.StatusBadRequest)) + } + + return errors.Wrapf(err, "UserID: %s AccountID: %s User Account: %+v", req.UserID, req.AccountID, &req) + } + } + + return web.RespondJson(ctx, w, nil, http.StatusNoContent) +} + +// Delete godoc +// @Summary Delete user account by user ID and account ID +// @Description Delete removes the specified user account from the system. +// @Tags user +// @Accept json +// @Produce json +// @Security OAuth2Password +// @Param id path string true "UserAccount ID" +// @Success 204 +// @Failure 400 {object} web.ErrorResponse +// @Failure 403 {object} web.ErrorResponse +// @Failure 404 {object} web.ErrorResponse +// @Failure 500 {object} web.ErrorResponse +// @Router /user_accounts [delete] +func (u *UserAccount) Delete(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { + claims, ok := ctx.Value(auth.Key).(auth.Claims) + if !ok { + return errors.New("claims missing from context") + } + + var req user_account.UserAccountDeleteRequest + if err := web.Decode(r, &req); err != nil { + if _, ok := errors.Cause(err).(*web.Error); !ok { + err = web.NewRequestError(err, http.StatusBadRequest) + } + return web.RespondJsonError(ctx, w, err) + } + + err := user_account.Delete(ctx, claims, u.MasterDB, req) + if err != nil { + cause := errors.Cause(err) + switch cause { + case user_account.ErrNotFound: + return web.RespondJsonError(ctx, w, web.NewRequestError(err, http.StatusNotFound)) + case user_account.ErrForbidden: + return web.RespondJsonError(ctx, w, web.NewRequestError(err, http.StatusForbidden)) + default: + _, ok := cause.(validator.ValidationErrors) + if ok { + return web.RespondJsonError(ctx, w, web.NewRequestError(err, http.StatusBadRequest)) + } + + return errors.Wrapf(err, "UserID: %s AccountID: %s User Account: %+v", req.UserID, req.AccountID, &req) + } + } + + return web.RespondJson(ctx, w, nil, http.StatusNoContent) +} diff --git a/example-project/cmd/web-api/tests/account_test.go b/example-project/cmd/web-api/tests/account_test.go new file mode 100644 index 0000000..7209902 --- /dev/null +++ b/example-project/cmd/web-api/tests/account_test.go @@ -0,0 +1,339 @@ +package tests + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "strconv" + "testing" + "time" + + "geeks-accelerator/oss/saas-starter-kit/example-project/internal/mid" + "geeks-accelerator/oss/saas-starter-kit/example-project/internal/account" + "geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/auth" + "geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/tests" + "geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/web" + "github.com/google/go-cmp/cmp" + "github.com/pborman/uuid" +) + +func mockAccount() *account.Account { + req := account.AccountCreateRequest{ + Name: uuid.NewRandom().String(), + Address1: "103 East Main St", + Address2: "Unit 546", + City: "Valdez", + Region: "AK", + Country: "USA", + Zipcode: "99686", + } + + a, err := account.Create(tests.Context(), auth.Claims{}, test.MasterDB, req, time.Now().UTC().AddDate(-1, -1, -1)) + if err != nil { + panic(err) + } + return a +} + +// TestAccount is the entry point for the account endpoints. +func TestAccount(t *testing.T) { + defer tests.Recover(t) + + t.Run("getAccount", getAccount) + t.Run("patchAccount", patchAccount) +} + +// getAccount validates get account by ID endpoint. +func getAccount(t *testing.T) { + + var rtests []requestTest + + forbiddenAccount := mockAccount() + + // Both roles should be able to read the account. + for rn, tr := range roleTests { + acc := tr.SignupResult.Account + + // Test 200. + rtests = append(rtests, requestTest{ + fmt.Sprintf("Role %s 200", rn), + http.MethodGet, + fmt.Sprintf("/v1/accounts/%s", acc.ID), + nil, + tr.Token, + tr.Claims, + http.StatusOK, + nil, + func(treq requestTest, body []byte) bool { + var actual account.AccountResponse + if err := json.Unmarshal(body, &actual); err != nil { + t.Logf("\t\tGot error : %+v", err) + return false + } + + // Add claims to the context so they can be retrieved later. + ctx := context.WithValue(tests.Context(), auth.Key, tr.Claims) + + expectedMap := map[string]interface{}{ + "updated_at": web.NewTimeResponse(ctx, acc.UpdatedAt), + "id": acc.ID, + "address2": acc.Address2, + "region": acc.Region, + "zipcode": acc.Zipcode, + "timezone": acc.Timezone, + "created_at": web.NewTimeResponse(ctx, acc.CreatedAt), + "country": acc.Country, + "billing_user_id": &acc.BillingUserID, + "name": acc.Name, + "address1": acc.Address1, + "city": acc.City, + "status": map[string]interface{}{ + "value": "active", + "title": "Active", + "options": []map[string]interface{}{{"selected": false, "title": "[Active Pending Disabled]", "value": "[active pending disabled]"}}, + }, + "signup_user_id": &acc.SignupUserID, + } + expectedJson, err := json.Marshal(expectedMap) + if err != nil { + t.Logf("\t\tGot error : %+v", err) + return false + } + + var expected account.AccountResponse + if err := json.Unmarshal([]byte(expectedJson), &expected); err != nil { + t.Logf("\t\tGot error : %+v", err) + printResultMap(ctx, body) + return false + } + + if diff := cmp.Diff(actual, expected); diff != "" { + actualJSON, err := json.MarshalIndent(actual, "", " ") + if err != nil { + t.Logf("\t\tGot error : %+v", err) + return false + } + t.Logf("\t\tGot : %s\n", actualJSON) + + expectedJSON, err := json.MarshalIndent(expected, "", " ") + if err != nil { + t.Logf("\t\tGot error : %+v", err) + return false + } + t.Logf("\t\tExpected : %s\n", expectedJSON) + + t.Logf("\t\tDiff : %s\n", diff) + + if len(expectedMap) == 0 { + printResultMap(ctx, body) + } + + return false + } + + return true + }, + }) + + // Test 404. + invalidID := uuid.NewRandom().String() + rtests = append(rtests, requestTest{ + fmt.Sprintf("Role %s 404 w/invalid ID", rn), + http.MethodGet, + fmt.Sprintf("/v1/accounts/%s", invalidID), + nil, + tr.Token, + tr.Claims, + http.StatusNotFound, + web.ErrorResponse{ + Error: fmt.Sprintf("account %s not found: Entity not found", invalidID), + }, + func(treq requestTest, body []byte) bool { + return true + }, + }) + + // Test 404 - Account exists but not allowed. + rtests = append(rtests, requestTest{ + fmt.Sprintf("Role %s 404 w/random account ID", rn), + http.MethodGet, + fmt.Sprintf("/v1/accounts/%s", forbiddenAccount.ID), + nil, + tr.Token, + tr.Claims, + http.StatusNotFound, + web.ErrorResponse{ + Error: fmt.Sprintf("account %s not found: Entity not found", forbiddenAccount.ID), + }, + func(treq requestTest, body []byte) bool { + return true + }, + }) + } + + runRequestTests(t, rtests) +} + +// patchAccount validates update account by ID endpoint. +func patchAccount(t *testing.T) { + + var rtests []requestTest + + // Test update an account + // Admin role: 204 + // User role 403 + for rn, tr := range roleTests { + var expectedStatus int + var expectedErr interface{} + + // Test 204. + if rn == auth.RoleAdmin { + expectedStatus = http.StatusNoContent + } else { + expectedStatus = http.StatusForbidden + expectedErr = web.ErrorResponse{ + Error: mid.ErrForbidden.Error(), + } + } + + newName := rn + uuid.NewRandom().String() + strconv.Itoa(len(rtests)) + rtests = append(rtests, requestTest{ + fmt.Sprintf("Role %s %d", rn, expectedStatus), + http.MethodPatch, + "/v1/accounts", + account.AccountUpdateRequest{ + ID: tr.SignupResult.Account.ID, + Name: &newName, + }, + tr.Token, + tr.Claims, + expectedStatus, + expectedErr, + func(treq requestTest, body []byte) bool { + return true + }, + }) + } + + // Test update an account with invalid data. + // Admin role: 400 + // User role 400 + for rn, tr := range roleTests { + var expectedStatus int + var expectedErr interface{} + + if rn == auth.RoleAdmin { + expectedStatus = http.StatusBadRequest + expectedErr = web.ErrorResponse{ + Error: "field validation error", + Fields: []web.FieldError{ + {Field: "status", Error: "Key: 'AccountUpdateRequest.status' Error:Field validation for 'status' failed on the 'oneof' tag"}, + }, + } + } else { + expectedStatus = http.StatusForbidden + expectedErr = web.ErrorResponse{ + Error: mid.ErrForbidden.Error(), + } + } + + invalidStatus := account.AccountStatus("invalid status") + rtests = append(rtests, requestTest{ + fmt.Sprintf("Role %s %d w/invalid data", rn, expectedStatus), + http.MethodPatch, + "/v1/accounts", + account.AccountUpdateRequest{ + ID: tr.SignupResult.User.ID, + Status: &invalidStatus, + }, + tr.Token, + tr.Claims, + expectedStatus, + expectedErr, + func(treq requestTest, body []byte) bool { + return true + }, + }) + } + + // Test update an account for with an invalid ID. + // Admin role: 403 + // User role 403 + for rn, tr := range roleTests { + var expectedStatus int + var expectedErr interface{} + + // Test 403. + if rn == auth.RoleAdmin { + expectedStatus = http.StatusForbidden + expectedErr = web.ErrorResponse{ + Error: account.ErrForbidden.Error(), + } + } else { + expectedStatus = http.StatusForbidden + expectedErr = web.ErrorResponse{ + Error: mid.ErrForbidden.Error(), + } + } + newName := rn + uuid.NewRandom().String() + strconv.Itoa(len(rtests)) + invalidID := uuid.NewRandom().String() + rtests = append(rtests, requestTest{ + fmt.Sprintf("Role %s %d w/invalid ID", rn, expectedStatus), + http.MethodPatch, + "/v1/accounts", + account.AccountUpdateRequest{ + ID: invalidID, + Name: &newName, + }, + tr.Token, + tr.Claims, + expectedStatus, + expectedErr, + func(treq requestTest, body []byte) bool { + return true + }, + }) + } + + // Test update an account for with random account ID. + // Admin role: 403 + // User role 403 + forbiddenAccount := mockAccount() + for rn, tr := range roleTests { + var expectedStatus int + var expectedErr interface{} + + // Test 403. + if rn == auth.RoleAdmin { + expectedStatus = http.StatusForbidden + expectedErr = web.ErrorResponse{ + Error: account.ErrForbidden.Error(), + } + } else { + expectedStatus = http.StatusForbidden + expectedErr = web.ErrorResponse{ + Error: mid.ErrForbidden.Error(), + } + } + newName := rn+uuid.NewRandom().String()+strconv.Itoa(len(rtests)) + rtests = append(rtests, requestTest{ + fmt.Sprintf("Role %s %d w/random account ID", rn, expectedStatus), + http.MethodPatch, + "/v1/accounts", + account.AccountUpdateRequest{ + ID: forbiddenAccount.ID, + Name: &newName, + }, + tr.Token, + tr.Claims, + expectedStatus, + expectedErr, + func(treq requestTest, body []byte) bool { + return true + }, + }) + } + + runRequestTests(t, rtests) +} diff --git a/example-project/cmd/web-api/tests/signup_test.go b/example-project/cmd/web-api/tests/signup_test.go new file mode 100644 index 0000000..1977057 --- /dev/null +++ b/example-project/cmd/web-api/tests/signup_test.go @@ -0,0 +1,185 @@ +package tests + +import ( + "encoding/json" + "net/http" + "testing" + + "geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/auth" + "geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/tests" + "geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/web" + "geeks-accelerator/oss/saas-starter-kit/example-project/internal/signup" + "geeks-accelerator/oss/saas-starter-kit/example-project/internal/user" + "github.com/google/go-cmp/cmp" + "github.com/pborman/uuid" +) + +func mockSignupRequest() signup.SignupRequest { + return signup.SignupRequest{ + Account: signup.SignupAccount{ + Name: uuid.NewRandom().String(), + Address1: "103 East Main St", + Address2: "Unit 546", + City: "Valdez", + Region: "AK", + Country: "USA", + Zipcode: "99686", + }, + User: signup.SignupUser{ + Name: "Lee Brown", + Email: uuid.NewRandom().String() + "@geeksinthewoods.com", + Password: "akTechFr0n!ier", + PasswordConfirm: "akTechFr0n!ier", + }, + } +} + +// TestSignup is the entry point for the signup +func TestSignup(t *testing.T) { + defer tests.Recover(t) + + t.Run("postSigup", postSigup) +} + +// postSigup validates the signup endpoint. +func postSigup(t *testing.T) { + + var rtests []requestTest + + // Test 201. + // Signup does not require auth, so empty token and claims should result in success. + req1 := mockSignupRequest() + rtests = append(rtests, requestTest{ + "No Authorization Valid", + http.MethodPost, + "/v1/signup", + req1, + user.Token{}, + auth.Claims{}, + http.StatusCreated, + nil, + func(treq requestTest, body []byte) bool { + var actual signup.SignupResponse + if err := json.Unmarshal(body, &actual); err != nil { + t.Logf("\t\tGot error : %+v", err) + return false + } + + ctx := tests.Context() + + req := treq.request.(signup.SignupRequest ) + + expectedMap := map[string]interface{}{ + "user": map[string]interface{}{ + "id": actual.User.ID, + "name": req.User.Name, + "email": req.User.Email, + "timezone": actual.User.Timezone, + "created_at": web.NewTimeResponse(ctx, actual.User.CreatedAt.Value), + "updated_at": web.NewTimeResponse(ctx, actual.User.UpdatedAt.Value), + }, + "account": map[string]interface{}{ + "updated_at": web.NewTimeResponse(ctx, actual.Account.UpdatedAt.Value), + "id": actual.Account.ID, + "address2": req.Account.Address2, + "region": req.Account.Region, + "zipcode": req.Account.Zipcode, + "timezone": actual.Account.Timezone, + "created_at": web.NewTimeResponse(ctx, actual.Account.CreatedAt.Value), + "country": req.Account.Country, + "billing_user_id": &actual.Account.BillingUserID, + "name": req.Account.Name, + "address1": req.Account.Address1, + "city": req.Account.City, + "status": map[string]interface{}{ + "value": "active", + "title": "Active", + "options": []map[string]interface{}{{"selected":false,"title":"[Active Pending Disabled]","value":"[active pending disabled]"}}, + }, + "signup_user_id": &actual.Account.SignupUserID, + }, + } + expectedJson, err := json.Marshal(expectedMap) + if err != nil { + t.Logf("\t\tGot error : %+v", err) + return false + } + + var expected signup.SignupResponse + if err := json.Unmarshal([]byte(expectedJson), &expected); err != nil { + t.Logf("\t\tGot error : %+v", err) + printResultMap(ctx, body) + return false + } + + if diff := cmp.Diff(actual, expected); diff != "" { + actualJSON, err := json.MarshalIndent(actual, "", " ") + if err != nil { + t.Logf("\t\tGot error : %+v", err) + return false + } + t.Logf("\t\tGot : %s\n", actualJSON) + + expectedJSON, err := json.MarshalIndent(expected, "", " ") + if err != nil { + t.Logf("\t\tGot error : %+v", err) + return false + } + t.Logf("\t\tExpected : %s\n", expectedJSON) + + t.Logf("\t\tDiff : %s\n", diff) + + if len(expectedMap) == 0 { + printResultMap(ctx, body) + } + + return false + } + + return true + }, + }) + + // Test 404 w/empty request. + rtests = append(rtests, requestTest{ + "Empty request", + http.MethodPost, + "/v1/signup", + nil, + user.Token{}, + auth.Claims{}, + http.StatusBadRequest, + web.ErrorResponse{ + Error: "decode request body failed: EOF", + }, + func(req requestTest, body []byte) bool { + return true + }, + }) + + // Test 404 w/validation errors. + invalidReq := mockSignupRequest() + invalidReq.User.Email = "" + invalidReq.Account.Name = "" + rtests = append(rtests, requestTest{ + "Invalid request", + http.MethodPost, + "/v1/signup", + invalidReq, + user.Token{}, + auth.Claims{}, + http.StatusBadRequest, + web.ErrorResponse{ + Error: "field validation error", + Fields: []web.FieldError{ + {Field: "name", Error: "Key: 'SignupRequest.account.name' Error:Field validation for 'name' failed on the 'required' tag"}, + {Field: "email", Error: "Key: 'SignupRequest.user.email' Error:Field validation for 'email' failed on the 'required' tag"}, + }, + }, + func(req requestTest, body []byte) bool { + return true + }, + }) + + runRequestTests(t, rtests) +} diff --git a/example-project/cmd/web-api/tests/tests_test.go b/example-project/cmd/web-api/tests/tests_test.go index a83b96b..fd136b6 100644 --- a/example-project/cmd/web-api/tests/tests_test.go +++ b/example-project/cmd/web-api/tests/tests_test.go @@ -1,14 +1,25 @@ package tests import ( - "geeks-accelerator/oss/saas-starter-kit/example-project/internal/account" - "geeks-accelerator/oss/saas-starter-kit/example-project/internal/signup" - "github.com/pborman/uuid" + "bytes" + "context" + "encoding/json" + "fmt" + "geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/web" + "github.com/google/go-cmp/cmp" + "io" "net/http" + "net/http/httptest" "os" + "strings" "testing" "time" + "geeks-accelerator/oss/saas-starter-kit/example-project/internal/account" + "geeks-accelerator/oss/saas-starter-kit/example-project/internal/signup" + "geeks-accelerator/oss/saas-starter-kit/example-project/internal/user_account" + "github.com/iancoleman/strcase" + "github.com/pborman/uuid" "geeks-accelerator/oss/saas-starter-kit/example-project/cmd/web-api/handlers" "geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/auth" "geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/tests" @@ -23,11 +34,23 @@ type roleTest struct { Token user.Token Claims auth.Claims SignupRequest *signup.SignupRequest - SignupResponse *signup.SignupResponse + SignupResult *signup.SignupResult User *user.User Account *account.Account } +type requestTest struct { + name string + method string + url string + request interface{} + token user.Token + claims auth.Claims + statusCode int + error interface{} + expected func(req requestTest, result []byte) bool +} + var roleTests map[string]roleTest func init() { @@ -92,7 +115,7 @@ func testMain(m *testing.M) int { Token: adminTkn, Claims: adminClaims, SignupRequest: &signupReq, - SignupResponse: signup, + SignupResult: signup, User: signup.User, Account: signup.Account, } @@ -109,6 +132,16 @@ func testMain(m *testing.M) int { panic(err) } + _, err = user_account.Create(tests.Context(), adminClaims, test.MasterDB, user_account.UserAccountCreateRequest{ + UserID: usr.ID, + AccountID: signup.Account.ID, + Roles: []user_account.UserAccountRole{user_account.UserAccountRole_User}, + // Status: use default value + }, now) + if err != nil { + panic(err) + } + userTkn, err := user.Authenticate(tests.Context(), test.MasterDB, authenticator, usr.Email, userReq.Password, expires, now) if err != nil { panic(err) @@ -123,10 +156,130 @@ func testMain(m *testing.M) int { Token: userTkn, Claims: userClaims, SignupRequest: &signupReq, - SignupResponse: signup, + SignupResult: signup, Account: signup.Account, User: usr, } return m.Run() } + + +// runRequestTests helper function for testing endpoints. +func runRequestTests(t *testing.T, rtests []requestTest ) { + + for i, tt := range rtests { + t.Logf("\tTest: %d\tWhen running test: %s", i, tt.name) + { + var req []byte + var rr io.Reader + if tt.request != nil { + var ok bool + req, ok = tt.request.([]byte) + if !ok { + var err error + req, err = json.Marshal(tt.request) + if err != nil { + t.Logf("\t\tGot err : %+v", err) + t.Fatalf("\t%s\tEncode request failed.", tests.Failed) + } + } + rr = bytes.NewReader(req) + } + + r := httptest.NewRequest(tt.method, tt.url , rr) + w := httptest.NewRecorder() + + r.Header.Set("Content-Type", web.MIMEApplicationJSONCharsetUTF8) + if tt.token.AccessToken != "" { + r.Header.Set("Authorization", tt.token.AuthorizationHeader()) + } + + a.ServeHTTP(w, r) + + if w.Code != tt.statusCode { + t.Logf("\t\tRequest : %s\n", string(req)) + t.Logf("\t\tBody : %s\n", w.Body.String()) + t.Fatalf("\t%s\tShould receive a status code of %d for the response : %v", tests.Failed, tt.statusCode, w.Code) + } + t.Logf("\t%s\tReceived valid status code of %d.", tests.Success, w.Code) + + if tt.error != nil { + + + + var actual web.ErrorResponse + if err := json.Unmarshal(w.Body.Bytes(), &actual); err != nil { + t.Logf("\t\tBody : %s\n", w.Body.String()) + t.Logf("\t\tGot error : %+v", err) + t.Fatalf("\t%s\tShould get the expected error.", tests.Failed) + } + + if diff := cmp.Diff(actual, tt.error); diff != "" { + t.Logf("\t\tDiff : %s\n", diff) + t.Fatalf("\t%s\tShould get the expected error.", tests.Failed) + } + } + + if ok := tt.expected(tt, w.Body.Bytes()); !ok { + t.Fatalf("\t%s\tShould get the expected result.", tests.Failed) + } + t.Logf("\t%s\tReceived expected result.", tests.Success) + } + } +} + + +func printResultMap(ctx context.Context, result []byte) { + var m map[string]interface{} + if err := json.Unmarshal(result, &m); err != nil { + panic(err) + } + + fmt.Println(`map[string]interface{}{`) + printResultMapKeys(ctx, m, 1) + fmt.Println(`}`) +} + +func printResultMapKeys(ctx context.Context, m map[string]interface{}, depth int) { + var isEnum bool + if m["value"] != nil && m["title"] != nil && m["options"] != nil { + isEnum = true + } + + for k, kv := range m { + fn := strcase.ToCamel(k) + + switch k { + case "created_at", "updated_at", "archived_at": + pv := fmt.Sprintf("web.NewTimeResponse(ctx, actual.%s)", fn) + fmt.Printf("%s\"%s\": %s,\n", strings.Repeat("\t", depth), k, pv) + continue + } + + if sm, ok := kv.([]map[string]interface{}); ok { + fmt.Printf("%s\"%s\": []map[string]interface{}{\n", strings.Repeat("\t", depth), k) + + for _, smv := range sm { + printResultMapKeys(ctx, smv, depth +1) + } + + fmt.Printf("%s},\n", strings.Repeat("\t", depth)) + } else if sm, ok := kv.(map[string]interface{}); ok { + fmt.Printf("%s\"%s\": map[string]interface{}{\n", strings.Repeat("\t", depth), k) + printResultMapKeys(ctx, sm, depth +1) + fmt.Printf("%s},\n", strings.Repeat("\t", depth)) + } else { + var pv string + if isEnum { + jv, _ := json.Marshal(kv) + pv = string(jv) + } else { + pv = fmt.Sprintf("req.%s", fn) + } + + fmt.Printf("%s\"%s\": %s,\n", strings.Repeat("\t", depth), k, pv) + } + } +} + diff --git a/example-project/cmd/web-api/tests/user_test.go b/example-project/cmd/web-api/tests/user_test.go index 07116d1..7d007be 100644 --- a/example-project/cmd/web-api/tests/user_test.go +++ b/example-project/cmd/web-api/tests/user_test.go @@ -1,578 +1,562 @@ package tests -/* import ( - "bytes" + "context" "encoding/json" + "fmt" + "geeks-accelerator/oss/saas-starter-kit/example-project/internal/mid" "net/http" - "net/http/httptest" - "strings" + "strconv" "testing" + "time" + "geeks-accelerator/oss/saas-starter-kit/example-project/internal/user" "geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/auth" "geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/tests" "geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/web" - "geeks-accelerator/oss/saas-starter-kit/example-project/internal/user" "github.com/google/go-cmp/cmp" - "github.com/google/go-cmp/cmp/cmpopts" - "gopkg.in/mgo.v2/bson" + "github.com/pborman/uuid" ) -// TestUsers is the entry point for testing user management functions. -func TestUsers(t *testing.T) { +func mockUser() *user.User { + req := user.UserCreateRequest{ + Name: "Lee Brown", + Email: uuid.NewRandom().String() + "@geeksinthewoods.com", + Password: "akTechFr0n!ier", + PasswordConfirm: "akTechFr0n!ier", + } + + a, err := user.Create(tests.Context(), auth.Claims{}, test.MasterDB, req, time.Now().UTC().AddDate(-1, -1, -1)) + if err != nil { + panic(err) + } + return a +} + +// TestUser is the entry point for the user endpoints. +func TestUser(t *testing.T) { defer tests.Recover(t) - t.Run("getToken401", getToken401) - t.Run("getToken200", getToken200) - t.Run("postUser400", postUser400) - t.Run("postUser401", postUser401) - t.Run("postUser403", postUser403) - t.Run("getUser400", getUser400) - t.Run("getUser403", getUser403) - t.Run("getUser404", getUser404) - t.Run("deleteUser404", deleteUser404) - t.Run("putUser404", putUser404) - t.Run("crudUsers", crudUser) + t.Run("getUser", getUser) + t.Run("createUser", createUser) + t.Run("patchUser", patchUser) + t.Run("patchUserPassword", patchUserPassword) } -// getToken401 ensures an unknown user can't generate a token. -func getToken401(t *testing.T) { - r := httptest.NewRequest("GET", "/v1/users/token", nil) - w := httptest.NewRecorder() +// getUser validates get user by ID endpoint. +func getUser(t *testing.T) { - r.SetBasicAuth("unknown@example.com", "some-password") + var rtests []requestTest - a.ServeHTTP(w, r) + forbiddenUser := mockUser() - t.Log("Given the need to deny tokens to unknown users.") - { - t.Log("\tTest 0:\tWhen fetching a token with an unrecognized email.") - { - if w.Code != http.StatusUnauthorized { - t.Fatalf("\t%s\tShould receive a status code of 401 for the response : %v", tests.Failed, w.Code) + // Both roles should be able to read the user. + for rn, tr := range roleTests { + usr := tr.SignupResult.User + + // Test 200. + rtests = append(rtests, requestTest{ + fmt.Sprintf("Role %s 200", rn), + http.MethodGet, + fmt.Sprintf("/v1/users/%s", usr.ID), + nil, + tr.Token, + tr.Claims, + http.StatusOK, + nil, + func(treq requestTest, body []byte) bool { + var actual user.UserResponse + if err := json.Unmarshal(body, &actual); err != nil { + t.Logf("\t\tGot error : %+v", err) + return false + } + + // Add claims to the context so they can be retrieved later. + ctx := context.WithValue(tests.Context(), auth.Key, tr.Claims) + + expectedMap := map[string]interface{}{ + "updated_at": web.NewTimeResponse(ctx, usr.UpdatedAt), + "id": usr.ID, + "email": usr.Email, + "timezone": usr.Timezone, + "created_at": web.NewTimeResponse(ctx, usr.CreatedAt), + "name": usr.Name, + } + expectedJson, err := json.Marshal(expectedMap) + if err != nil { + t.Logf("\t\tGot error : %+v", err) + return false + } + + var expected user.UserResponse + if err := json.Unmarshal([]byte(expectedJson), &expected); err != nil { + t.Logf("\t\tGot error : %+v", err) + printResultMap(ctx, body) + return false + } + + if diff := cmp.Diff(actual, expected); diff != "" { + actualJSON, err := json.MarshalIndent(actual, "", " ") + if err != nil { + t.Logf("\t\tGot error : %+v", err) + return false + } + t.Logf("\t\tGot : %s\n", actualJSON) + + expectedJSON, err := json.MarshalIndent(expected, "", " ") + if err != nil { + t.Logf("\t\tGot error : %+v", err) + return false + } + t.Logf("\t\tExpected : %s\n", expectedJSON) + + t.Logf("\t\tDiff : %s\n", diff) + + if len(expectedMap) == 0 { + printResultMap(ctx, body) + } + + return false + } + + return true + }, + }) + + // Test 404. + invalidID := uuid.NewRandom().String() + rtests = append(rtests, requestTest{ + fmt.Sprintf("Role %s 404 w/invalid ID", rn), + http.MethodGet, + fmt.Sprintf("/v1/users/%s", invalidID), + nil, + tr.Token, + tr.Claims, + http.StatusNotFound, + web.ErrorResponse{ + Error: fmt.Sprintf("user %s not found: Entity not found", invalidID), + }, + func(treq requestTest, body []byte) bool { + return true + }, + }) + + // Test 404 - User exists but not allowed. + rtests = append(rtests, requestTest{ + fmt.Sprintf("Role %s 404 w/random user ID", rn), + http.MethodGet, + fmt.Sprintf("/v1/users/%s", forbiddenUser.ID), + nil, + tr.Token, + tr.Claims, + http.StatusNotFound, + web.ErrorResponse{ + Error: fmt.Sprintf("user %s not found: Entity not found", forbiddenUser.ID), + }, + func(treq requestTest, body []byte) bool { + return true + }, + }) + } + + runRequestTests(t, rtests) +} + +// createUser validates create user endpoint. +func createUser(t *testing.T) { + + var rtests []requestTest + + // Test create user. + // Admin role: 201 + // User role 403 + for rn, tr := range roleTests { + var expectedStatus int + var expectedErr interface{} + + // Test 201. + if rn == auth.RoleAdmin { + expectedStatus = http.StatusCreated + } else { + expectedStatus = http.StatusForbidden + expectedErr = web.ErrorResponse{ + Error: mid.ErrForbidden.Error(), } - t.Logf("\t%s\tShould receive a status code of 401 for the response.", tests.Success) } - } -} -// getToken200 -func getToken200(t *testing.T) { + rtests = append(rtests, requestTest{ + fmt.Sprintf("Role %s %d", rn, expectedStatus), + http.MethodPost, + "/v1/users", + user.UserCreateRequest{ + Name: "Lee Brown", + Email: uuid.NewRandom().String() + rn + strconv.Itoa(len(rtests))+ "@geeksinthewoods.com", + Password: "akTechFr0n!ier", + PasswordConfirm: "akTechFr0n!ier", + }, + tr.Token, + tr.Claims, + expectedStatus, + expectedErr, + func(treq requestTest, body []byte) bool { + if treq.error != nil { + return true + } - r := httptest.NewRequest("GET", "/v1/users/token", nil) - w := httptest.NewRecorder() + var actual user.UserResponse + if err := json.Unmarshal(body, &actual); err != nil { + t.Logf("\t\tGot error : %+v", err) + return false + } - r.SetBasicAuth("admin@ardanlabs.com", "gophers") + // Add claims to the context so they can be retrieved later. + ctx := context.WithValue(tests.Context(), auth.Key, tr.Claims) - a.ServeHTTP(w, r) + req := treq.request.(user.UserCreateRequest) - t.Log("Given the need to issues tokens to known users.") - { - t.Log("\tTest 0:\tWhen fetching a token with valid credentials.") - { - if w.Code != http.StatusOK { - t.Fatalf("\t%s\tShould receive a status code of 200 for the response : %v", tests.Failed, w.Code) - } - t.Logf("\t%s\tShould receive a status code of 200 for the response.", tests.Success) + expectedMap := map[string]interface{}{ + "updated_at": web.NewTimeResponse(ctx, actual.UpdatedAt.Value), + "id": actual.ID, + "email": req.Email, + "timezone": actual.Timezone, + "created_at": web.NewTimeResponse(ctx, actual.CreatedAt.Value), + "name": req.Name, + } + expectedJson, err := json.Marshal(expectedMap) + if err != nil { + t.Logf("\t\tGot error : %+v", err) + return false + } - var got user.Token - if err := json.NewDecoder(w.Body).Decode(&got); err != nil { - t.Fatalf("\t%s\tShould be able to unmarshal the response : %v", tests.Failed, err) - } - t.Logf("\t%s\tShould be able to unmarshal the response.", tests.Success) + var expected user.UserResponse + if err := json.Unmarshal([]byte(expectedJson), &expected); err != nil { + t.Logf("\t\tGot error : %+v", err) + printResultMap(ctx, body) + return false + } - // TODO(jlw) Should we ensure the token is valid? - } - } -} + if diff := cmp.Diff(actual, expected); diff != "" { + actualJSON, err := json.MarshalIndent(actual, "", " ") + if err != nil { + t.Logf("\t\tGot error : %+v", err) + return false + } + t.Logf("\t\tGot : %s\n", actualJSON) -// postUser400 validates a user can't be created with the endpoint -// unless a valid user document is submitted. -func postUser400(t *testing.T) { - body, err := json.Marshal(&user.NewUser{}) - if err != nil { - t.Fatal(err) + expectedJSON, err := json.MarshalIndent(expected, "", " ") + if err != nil { + t.Logf("\t\tGot error : %+v", err) + return false + } + t.Logf("\t\tExpected : %s\n", expectedJSON) + + t.Logf("\t\tDiff : %s\n", diff) + + if len(expectedMap) == 0 { + printResultMap(ctx, body) + } + + return false + } + + return true + }, + }) } - r := httptest.NewRequest("POST", "/v1/users", bytes.NewBuffer(body)) - w := httptest.NewRecorder() + // Test update a user with invalid data. + // Admin role: 400 + // User role 403 + for rn, tr := range roleTests { + var expectedStatus int + var expectedErr interface{} - r.Header.Set("Authorization", adminAuthorization) - - a.ServeHTTP(w, r) - - t.Log("Given the need to validate a new user can't be created with an invalid document.") - { - t.Log("\tTest 0:\tWhen using an incomplete user value.") - { - if w.Code != http.StatusBadRequest { - t.Fatalf("\t%s\tShould receive a status code of 400 for the response : %v", tests.Failed, w.Code) - } - t.Logf("\t%s\tShould receive a status code of 400 for the response.", tests.Success) - - // Inspect the response. - var got web.ErrorResponse - if err := json.NewDecoder(w.Body).Decode(&got); err != nil { - t.Fatalf("\t%s\tShould be able to unmarshal the response to an error type : %v", tests.Failed, err) - } - t.Logf("\t%s\tShould be able to unmarshal the response to an error type.", tests.Success) - - // Define what we want to see. - want := web.ErrorResponse{ + // Test 201. + if rn == auth.RoleAdmin { + expectedStatus = http.StatusBadRequest + expectedErr = web.ErrorResponse{ Error: "field validation error", Fields: []web.FieldError{ - {Field: "name", Error: "name is a required field"}, - {Field: "email", Error: "email is a required field"}, - {Field: "roles", Error: "roles is a required field"}, - {Field: "password", Error: "password is a required field"}, + {Field: "email", Error: "Key: 'UserCreateRequest.email' Error:Field validation for 'email' failed on the 'email' tag"}, }, } - - // We can't rely on the order of the field errors so they have to be - // sorted. Tell the cmp package how to sort them. - sorter := cmpopts.SortSlices(func(a, b web.FieldError) bool { - return a.Field < b.Field - }) - - if diff := cmp.Diff(want, got, sorter); diff != "" { - t.Fatalf("\t%s\tShould get the expected result. Diff:\n%s", tests.Failed, diff) + } else { + expectedStatus = http.StatusForbidden + expectedErr = web.ErrorResponse{ + Error: mid.ErrForbidden.Error(), } - t.Logf("\t%s\tShould get the expected result.", tests.Success) - } - } -} - -// postUser401 validates a user can't be created unless the calling user is -// authenticated. -func postUser401(t *testing.T) { - body, err := json.Marshal(&user.User{}) - if err != nil { - t.Fatal(err) - } - - r := httptest.NewRequest("POST", "/v1/users", bytes.NewBuffer(body)) - w := httptest.NewRecorder() - - r.Header.Set("Authorization", userAuthorization) - - a.ServeHTTP(w, r) - - t.Log("Given the need to validate a new user can't be created with an invalid document.") - { - t.Log("\tTest 0:\tWhen using an incomplete user value.") - { - if w.Code != http.StatusForbidden { - t.Fatalf("\t%s\tShould receive a status code of 403 for the response : %v", tests.Failed, w.Code) - } - t.Logf("\t%s\tShould receive a status code of 403 for the response.", tests.Success) - } - } -} - -// postUser403 validates a user can't be created unless the calling user is -// an admin user. Regular users can't do this. -func postUser403(t *testing.T) { - body, err := json.Marshal(&user.User{}) - if err != nil { - t.Fatal(err) - } - - r := httptest.NewRequest("POST", "/v1/users", bytes.NewBuffer(body)) - w := httptest.NewRecorder() - - // Not setting the Authorization header - - a.ServeHTTP(w, r) - - t.Log("Given the need to validate a new user can't be created with an invalid document.") - { - t.Log("\tTest 0:\tWhen using an incomplete user value.") - { - if w.Code != http.StatusUnauthorized { - t.Fatalf("\t%s\tShould receive a status code of 401 for the response : %v", tests.Failed, w.Code) - } - t.Logf("\t%s\tShould receive a status code of 401 for the response.", tests.Success) - } - } -} - -// getUser400 validates a user request for a malformed userid. -func getUser400(t *testing.T) { - id := "12345" - - r := httptest.NewRequest("GET", "/v1/users/"+id, nil) - w := httptest.NewRecorder() - - r.Header.Set("Authorization", adminAuthorization) - - a.ServeHTTP(w, r) - - t.Log("Given the need to validate getting a user with a malformed userid.") - { - t.Logf("\tTest 0:\tWhen using the new user %s.", id) - { - if w.Code != http.StatusBadRequest { - t.Fatalf("\t%s\tShould receive a status code of 400 for the response : %v", tests.Failed, w.Code) - } - t.Logf("\t%s\tShould receive a status code of 400 for the response.", tests.Success) - - recv := w.Body.String() - resp := `{"error":"ID is not in its proper form"}` - if resp != recv { - t.Log("Got :", recv) - t.Log("Want:", resp) - t.Fatalf("\t%s\tShould get the expected result.", tests.Failed) - } - t.Logf("\t%s\tShould get the expected result.", tests.Success) - } - } -} - -// getUser403 validates a regular user can't fetch anyone but themselves -func getUser403(t *testing.T) { - t.Log("Given the need to validate regular users can't fetch other users.") - { - t.Logf("\tTest 0:\tWhen fetching the admin user as a regular user.") - { - r := httptest.NewRequest("GET", "/v1/users/"+adminID, nil) - w := httptest.NewRecorder() - - r.Header.Set("Authorization", userAuthorization) - - a.ServeHTTP(w, r) - - if w.Code != http.StatusForbidden { - t.Fatalf("\t%s\tShould receive a status code of 403 for the response : %v", tests.Failed, w.Code) - } - t.Logf("\t%s\tShould receive a status code of 403 for the response.", tests.Success) - - recv := w.Body.String() - resp := `{"error":"Attempted action is not allowed"}` - if resp != recv { - t.Log("Got :", recv) - t.Log("Want:", resp) - t.Fatalf("\t%s\tShould get the expected result.", tests.Failed) - } - t.Logf("\t%s\tShould get the expected result.", tests.Success) } - t.Logf("\tTest 1:\tWhen fetching the user as a themselves.") - { + rtests = append(rtests, requestTest{ + fmt.Sprintf("Role %s %d w/invalid data", rn, expectedStatus), + http.MethodPost, + "/v1/users", + user.UserCreateRequest{ + Name: "Lee Brown", + Email: "invalid email address", + Password: "akTechFr0n!ier", + PasswordConfirm: "akTechFr0n!ier", + }, + tr.Token, + tr.Claims, + expectedStatus, + expectedErr, + func(treq requestTest, body []byte) bool { + return true + }, + }) + } - r := httptest.NewRequest("GET", "/v1/users/"+userID, nil) - w := httptest.NewRecorder() + runRequestTests(t, rtests) +} - r.Header.Set("Authorization", userAuthorization) +// patchUser validates update user by ID endpoint. +func patchUser(t *testing.T) { - a.ServeHTTP(w, r) - if w.Code != http.StatusOK { - t.Fatalf("\t%s\tShould receive a status code of 200 for the response : %v", tests.Failed, w.Code) - } - t.Logf("\t%s\tShould receive a status code of 200 for the response.", tests.Success) + var rtests []requestTest + + // Test update a user + // Admin role: 204 + // User role 204 - user ID matches claims so OK + for rn, tr := range roleTests { + expectedStatus := http.StatusNoContent + newName := rn + uuid.NewRandom().String() + strconv.Itoa(len(rtests)) + rtests = append(rtests, requestTest{ + fmt.Sprintf("Role %s %d", rn, expectedStatus), + http.MethodPatch, + "/v1/users", + user.UserUpdateRequest{ + ID: tr.SignupResult.User.ID, + Name: &newName, + }, + tr.Token, + tr.Claims, + expectedStatus, + nil, + func(treq requestTest, body []byte) bool { + return true + }, + }) + } + + // Test update a user with invalid data. + // Admin role: 400 + // User role 400 + for rn, tr := range roleTests { + expectedStatus := http.StatusBadRequest + expectedErr := web.ErrorResponse{ + Error: "field validation error", + Fields: []web.FieldError{ + {Field: "email", Error: "Key: 'UserUpdateRequest.email' Error:Field validation for 'email' failed on the 'email' tag"}, + }, } + + invalidEmail := "invalid email address" + rtests = append(rtests, requestTest{ + fmt.Sprintf("Role %s %d w/invalid data", rn, expectedStatus), + http.MethodPatch, + "/v1/users", + user.UserUpdateRequest{ + ID: tr.SignupResult.User.ID, + Email: &invalidEmail, + }, + tr.Token, + tr.Claims, + expectedStatus, + expectedErr, + func(treq requestTest, body []byte) bool { + return true + }, + }) } -} -// getUser404 validates a user request for a user that does not exist with the endpoint. -func getUser404(t *testing.T) { - id := bson.NewObjectId().Hex() + // Test update a user for with an invalid ID. + // Admin role: 403 + // User role 403 + for rn, tr := range roleTests { - r := httptest.NewRequest("GET", "/v1/users/"+id, nil) - w := httptest.NewRecorder() - - r.Header.Set("Authorization", adminAuthorization) - - a.ServeHTTP(w, r) - - t.Log("Given the need to validate getting a user with an unknown id.") - { - t.Logf("\tTest 0:\tWhen using the new user %s.", id) - { - if w.Code != http.StatusNotFound { - t.Fatalf("\t%s\tShould receive a status code of 404 for the response : %v", tests.Failed, w.Code) - } - t.Logf("\t%s\tShould receive a status code of 404 for the response.", tests.Success) - - recv := w.Body.String() - resp := "Entity not found" - if !strings.Contains(recv, resp) { - t.Log("Got :", recv) - t.Log("Want:", resp) - t.Fatalf("\t%s\tShould get the expected result.", tests.Failed) - } - t.Logf("\t%s\tShould get the expected result.", tests.Success) + expectedStatus := http.StatusForbidden + expectedErr := web.ErrorResponse{ + Error: user.ErrForbidden.Error(), } + + newName := rn + uuid.NewRandom().String() + strconv.Itoa(len(rtests)) + invalidID := uuid.NewRandom().String() + rtests = append(rtests, requestTest{ + fmt.Sprintf("Role %s %d w/invalid ID", rn, expectedStatus), + http.MethodPatch, + "/v1/users", + user.UserUpdateRequest{ + ID: invalidID, + Name: &newName, + }, + tr.Token, + tr.Claims, + expectedStatus, + expectedErr, + func(treq requestTest, body []byte) bool { + return true + }, + }) } -} -// deleteUser404 validates deleting a user that does not exist. -func deleteUser404(t *testing.T) { - id := bson.NewObjectId().Hex() + // Test update a user for with random user ID. + // Admin role: 403 + // User role 403 + forbiddenUser := mockUser() + for rn, tr := range roleTests { - r := httptest.NewRequest("DELETE", "/v1/users/"+id, nil) - w := httptest.NewRecorder() - - r.Header.Set("Authorization", adminAuthorization) - - a.ServeHTTP(w, r) - - t.Log("Given the need to validate deleting a user that does not exist.") - { - t.Logf("\tTest 0:\tWhen using the new user %s.", id) - { - if w.Code != http.StatusNotFound { - t.Fatalf("\t%s\tShould receive a status code of 404 for the response : %v", tests.Failed, w.Code) - } - t.Logf("\t%s\tShould receive a status code of 404 for the response.", tests.Success) - - recv := w.Body.String() - resp := "Entity not found" - if !strings.Contains(recv, resp) { - t.Log("Got :", recv) - t.Log("Want:", resp) - t.Fatalf("\t%s\tShould get the expected result.", tests.Failed) - } - t.Logf("\t%s\tShould get the expected result.", tests.Success) + expectedStatus := http.StatusForbidden + expectedErr := web.ErrorResponse{ + Error: user.ErrForbidden.Error(), } + + newName := rn+uuid.NewRandom().String()+strconv.Itoa(len(rtests)) + rtests = append(rtests, requestTest{ + fmt.Sprintf("Role %s %d w/random user ID", rn, expectedStatus), + http.MethodPatch, + "/v1/users", + user.UserUpdateRequest{ + ID: forbiddenUser.ID, + Name: &newName, + }, + tr.Token, + tr.Claims, + expectedStatus, + expectedErr, + func(treq requestTest, body []byte) bool { + return true + }, + }) } + + runRequestTests(t, rtests) } -// putUser404 validates updating a user that does not exist. -func putUser404(t *testing.T) { - u := user.UpdateUser{ - Name: tests.StringPointer("Doesn't Exist"), +// patchUserPassword validates update user password by ID endpoint. +func patchUserPassword(t *testing.T) { + + var rtests []requestTest + + // Test update a user + // Admin role: 204 + // User role 204 - user ID matches claims so OK + for rn, tr := range roleTests { + expectedStatus := http.StatusNoContent + newPass := uuid.NewRandom().String() + rtests = append(rtests, requestTest{ + fmt.Sprintf("Role %s %d", rn, expectedStatus), + http.MethodPatch, + "/v1/users/password", + user.UserUpdatePasswordRequest{ + ID: tr.SignupResult.User.ID, + Password: newPass, + PasswordConfirm: newPass, + }, + tr.Token, + tr.Claims, + expectedStatus, + nil, + func(treq requestTest, body []byte) bool { + return true + }, + }) } - id := bson.NewObjectId().Hex() - - body, err := json.Marshal(&u) - if err != nil { - t.Fatal(err) - } - - r := httptest.NewRequest("PUT", "/v1/users/"+id, bytes.NewBuffer(body)) - w := httptest.NewRecorder() - - r.Header.Set("Authorization", adminAuthorization) - - a.ServeHTTP(w, r) - - t.Log("Given the need to validate updating a user that does not exist.") - { - t.Logf("\tTest 0:\tWhen using the new user %s.", id) - { - if w.Code != http.StatusNotFound { - t.Fatalf("\t%s\tShould receive a status code of 404 for the response : %v", tests.Failed, w.Code) - } - t.Logf("\t%s\tShould receive a status code of 404 for the response.", tests.Success) - - recv := w.Body.String() - resp := "Entity not found" - if !strings.Contains(recv, resp) { - t.Log("Got :", recv) - t.Log("Want:", resp) - t.Fatalf("\t%s\tShould get the expected result.", tests.Failed) - } - t.Logf("\t%s\tShould get the expected result.", tests.Success) + // Test update a user password with invalid data. + // Admin role: 400 + // User role 400 + for rn, tr := range roleTests { + expectedStatus := http.StatusBadRequest + expectedErr := web.ErrorResponse{ + Error: "field validation error", + Fields: []web.FieldError{ + {Field: "password_confirm", Error: "Key: 'UserUpdatePasswordRequest.password_confirm' Error:Field validation for 'password_confirm' failed on the 'eqfield' tag"}, + }, } - } -} -// crudUser performs a complete test of CRUD against the api. -func crudUser(t *testing.T) { - nu := postUser201(t) - defer deleteUser204(t, nu.ID.Hex()) - - getUser200(t, nu.ID.Hex()) - putUser204(t, nu.ID.Hex()) - putUser403(t, nu.ID.Hex()) -} - -// postUser201 validates a user can be created with the endpoint. -func postUser201(t *testing.T) user.User { - nu := user.NewUser{ - Name: "Bill Kennedy", - Email: "bill@ardanlabs.com", - Roles: []string{auth.RoleAdmin}, - Password: "gophers", - PasswordConfirm: "gophers", + newPass := uuid.NewRandom().String() + rtests = append(rtests, requestTest{ + fmt.Sprintf("Role %s %d w/invalid data", rn, expectedStatus), + http.MethodPatch, + "/v1/users/password", + user.UserUpdatePasswordRequest{ + ID: tr.SignupResult.User.ID, + Password: newPass, + PasswordConfirm: "different", + }, + tr.Token, + tr.Claims, + expectedStatus, + expectedErr, + func(treq requestTest, body []byte) bool { + return true + }, + }) } - body, err := json.Marshal(&nu) - if err != nil { - t.Fatal(err) - } + // Test update a user password for with an invalid ID. + // Admin role: 403 + // User role 403 + for rn, tr := range roleTests { - r := httptest.NewRequest("POST", "/v1/users", bytes.NewBuffer(body)) - w := httptest.NewRecorder() - - r.Header.Set("Authorization", adminAuthorization) - - a.ServeHTTP(w, r) - - // u is the value we will return. - var u user.User - - t.Log("Given the need to create a new user with the users endpoint.") - { - t.Log("\tTest 0:\tWhen using the declared user value.") - { - if w.Code != http.StatusCreated { - t.Fatalf("\t%s\tShould receive a status code of 201 for the response : %v", tests.Failed, w.Code) - } - t.Logf("\t%s\tShould receive a status code of 201 for the response.", tests.Success) - - if err := json.NewDecoder(w.Body).Decode(&u); err != nil { - t.Fatalf("\t%s\tShould be able to unmarshal the response : %v", tests.Failed, err) - } - - // Define what we wanted to receive. We will just trust the generated - // fields like ID and Dates so we copy u. - want := u - want.Name = "Bill Kennedy" - want.Email = "bill@ardanlabs.com" - want.Roles = []string{auth.RoleAdmin} - - if diff := cmp.Diff(want, u); diff != "" { - t.Fatalf("\t%s\tShould get the expected result. Diff:\n%s", tests.Failed, diff) - } - t.Logf("\t%s\tShould get the expected result.", tests.Success) + expectedStatus := http.StatusForbidden + expectedErr := web.ErrorResponse{ + Error: user.ErrForbidden.Error(), } + + newPass := uuid.NewRandom().String() + invalidID := uuid.NewRandom().String() + rtests = append(rtests, requestTest{ + fmt.Sprintf("Role %s %d w/invalid ID", rn, expectedStatus), + http.MethodPatch, + "/v1/users/password", + user.UserUpdatePasswordRequest{ + ID: invalidID, + Password: newPass, + PasswordConfirm: newPass, + }, + tr.Token, + tr.Claims, + expectedStatus, + expectedErr, + func(treq requestTest, body []byte) bool { + return true + }, + }) } - return u -} + // Test update a user password for with random user ID. + // Admin role: 403 + // User role 403 + forbiddenUser := mockUser() + for rn, tr := range roleTests { -// deleteUser200 validates deleting a user that does exist. -func deleteUser204(t *testing.T, id string) { - r := httptest.NewRequest("DELETE", "/v1/users/"+id, nil) - w := httptest.NewRecorder() - - r.Header.Set("Authorization", adminAuthorization) - - a.ServeHTTP(w, r) - - t.Log("Given the need to validate deleting a user that does exist.") - { - t.Logf("\tTest 0:\tWhen using the new user %s.", id) - { - if w.Code != http.StatusNoContent { - t.Fatalf("\t%s\tShould receive a status code of 204 for the response : %v", tests.Failed, w.Code) - } - t.Logf("\t%s\tShould receive a status code of 204 for the response.", tests.Success) + expectedStatus := http.StatusForbidden + expectedErr := web.ErrorResponse{ + Error: user.ErrForbidden.Error(), } + + newPass := uuid.NewRandom().String() + rtests = append(rtests, requestTest{ + fmt.Sprintf("Role %s %d w/random user ID", rn, expectedStatus), + http.MethodPatch, + "/v1/users/password", + user.UserUpdatePasswordRequest{ + ID: forbiddenUser.ID, + Password: newPass, + PasswordConfirm: newPass, + }, + tr.Token, + tr.Claims, + expectedStatus, + expectedErr, + func(treq requestTest, body []byte) bool { + return true + }, + }) } + + runRequestTests(t, rtests) } -// getUser200 validates a user request for an existing userid. -func getUser200(t *testing.T, id string) { - r := httptest.NewRequest("GET", "/v1/users/"+id, nil) - w := httptest.NewRecorder() - - r.Header.Set("Authorization", adminAuthorization) - - a.ServeHTTP(w, r) - - t.Log("Given the need to validate getting a user that exsits.") - { - t.Logf("\tTest 0:\tWhen using the new user %s.", id) - { - if w.Code != http.StatusOK { - t.Fatalf("\t%s\tShould receive a status code of 200 for the response : %v", tests.Failed, w.Code) - } - t.Logf("\t%s\tShould receive a status code of 200 for the response.", tests.Success) - - var u user.User - if err := json.NewDecoder(w.Body).Decode(&u); err != nil { - t.Fatalf("\t%s\tShould be able to unmarshal the response : %v", tests.Failed, err) - } - - // Define what we wanted to receive. We will just trust the generated - // fields like Dates so we copy p. - want := u - want.ID = bson.ObjectIdHex(id) - want.Name = "Bill Kennedy" - want.Email = "bill@ardanlabs.com" - want.Roles = []string{auth.RoleAdmin} - - if diff := cmp.Diff(want, u); diff != "" { - t.Fatalf("\t%s\tShould get the expected result. Diff:\n%s", tests.Failed, diff) - } - t.Logf("\t%s\tShould get the expected result.", tests.Success) - } - } -} - -// putUser204 validates updating a user that does exist. -func putUser204(t *testing.T, id string) { - body := `{"name": "Jacob Walker"}` - - r := httptest.NewRequest("PUT", "/v1/users/"+id, strings.NewReader(body)) - w := httptest.NewRecorder() - - r.Header.Set("Authorization", adminAuthorization) - - a.ServeHTTP(w, r) - - t.Log("Given the need to update a user with the users endpoint.") - { - t.Log("\tTest 0:\tWhen using the modified user value.") - { - if w.Code != http.StatusNoContent { - t.Fatalf("\t%s\tShould receive a status code of 204 for the response : %v", tests.Failed, w.Code) - } - t.Logf("\t%s\tShould receive a status code of 204 for the response.", tests.Success) - - r = httptest.NewRequest("GET", "/v1/users/"+id, nil) - w = httptest.NewRecorder() - - r.Header.Set("Authorization", adminAuthorization) - - a.ServeHTTP(w, r) - - if w.Code != http.StatusOK { - t.Fatalf("\t%s\tShould receive a status code of 200 for the retrieve : %v", tests.Failed, w.Code) - } - t.Logf("\t%s\tShould receive a status code of 200 for the retrieve.", tests.Success) - - var ru user.User - if err := json.NewDecoder(w.Body).Decode(&ru); err != nil { - t.Fatalf("\t%s\tShould be able to unmarshal the response : %v", tests.Failed, err) - } - - if ru.Name != "Jacob Walker" { - t.Fatalf("\t%s\tShould see an updated Name : got %q want %q", tests.Failed, ru.Name, "Jacob Walker") - } - t.Logf("\t%s\tShould see an updated Name.", tests.Success) - - if ru.Email != "bill@ardanlabs.com" { - t.Fatalf("\t%s\tShould not affect other fields like Email : got %q want %q", tests.Failed, ru.Email, "bill@ardanlabs.com") - } - t.Logf("\t%s\tShould not affect other fields like Email.", tests.Success) - } - } -} - -// putUser403 validates that a user can't modify users unless they are an admin. -func putUser403(t *testing.T, id string) { - body := `{"name": "Anna Walker"}` - - r := httptest.NewRequest("PUT", "/v1/users/"+id, strings.NewReader(body)) - w := httptest.NewRecorder() - - r.Header.Set("Authorization", userAuthorization) - - a.ServeHTTP(w, r) - - t.Log("Given the need to update a user with the users endpoint.") - { - t.Log("\tTest 0:\tWhen a non-admin user makes a request") - { - if w.Code != http.StatusForbidden { - t.Fatalf("\t%s\tShould receive a status code of 403 for the response : %v", tests.Failed, w.Code) - } - t.Logf("\t%s\tShould receive a status code of 403 for the response.", tests.Success) - } - } -} -*/ diff --git a/example-project/internal/account/account.go b/example-project/internal/account/account.go index a1bc83c..6f8b510 100644 --- a/example-project/internal/account/account.go +++ b/example-project/internal/account/account.go @@ -3,6 +3,7 @@ package account import ( "context" "database/sql" + "geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/web" "time" "geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/auth" @@ -25,9 +26,6 @@ var ( // ErrNotFound abstracts the mgo not found error. ErrNotFound = errors.New("Entity not found") - // ErrInvalidID occurs when an ID is not in a valid form. - ErrInvalidID = errors.New("ID is not in its proper form") - // ErrForbidden occurs when a user tries to do something that is forbidden to them according to our access control policies. ErrForbidden = errors.New("Attempted action is not allowed") ) @@ -243,6 +241,7 @@ func UniqueName(ctx context.Context, dbConn *sqlx.DB, name, accountId string) (b var existingId string err := dbConn.QueryRowContext(ctx, queryStr, args...).Scan(&existingId) + if err != nil && err != sql.ErrNoRows { err = errors.Wrapf(err, "query - %s", query.String()) return false, err @@ -261,8 +260,6 @@ func Create(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req Accoun span, ctx := tracer.StartSpanFromContext(ctx, "internal.account.Create") defer span.Finish() - v := validator.New() - // Validation email address is unique in the database. uniq, err := UniqueName(ctx, dbConn, req.Name, "") if err != nil { @@ -274,6 +271,8 @@ func Create(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req Accoun } return uniq } + + v := web.NewValidator() v.RegisterValidation("unique", f) // Validate the request. @@ -352,11 +351,11 @@ func Read(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, id string, i query.Where(query.Equal("id", id)) res, err := find(ctx, claims, dbConn, query, []interface{}{}, includedArchived) - if err != nil { - return nil, err - } else if res == nil || len(res) == 0 { + if res == nil || len(res) == 0 { err = errors.WithMessagef(ErrNotFound, "account %s not found", id) return nil, err + } else if err != nil { + return nil, err } u := res[0] @@ -368,7 +367,7 @@ func Update(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req Accoun span, ctx := tracer.StartSpanFromContext(ctx, "internal.account.Update") defer span.Finish() - v := validator.New() + v := web.NewValidator() // Validation name is unique in the database. if req.Name != nil { @@ -380,6 +379,7 @@ func Update(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req Accoun if fl.Field().String() == "invalid" { return false } + return uniq } v.RegisterValidation("unique", f) @@ -495,7 +495,8 @@ func Archive(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req Accou defer span.Finish() // Validate the request. - err := validator.New().Struct(req) + v := web.NewValidator() + err := v.Struct(req) if err != nil { return err } @@ -573,7 +574,8 @@ func Delete(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, accountID } // Validate the request. - err := validator.New().Struct(req) + v := web.NewValidator() + err := v.Struct(req) if err != nil { return err } diff --git a/example-project/internal/mid/auth.go b/example-project/internal/mid/auth.go index 1b2d215..538d59a 100644 --- a/example-project/internal/mid/auth.go +++ b/example-project/internal/mid/auth.go @@ -29,25 +29,36 @@ func Authenticate(authenticator *auth.Authenticator) web.Middleware { span, ctx := tracer.StartSpanFromContext(ctx, "internal.mid.Authenticate") defer span.Finish() - authHdr := r.Header.Get("Authorization") - if authHdr == "" { - err := errors.New("missing Authorization header") - return web.NewRequestError(err, http.StatusUnauthorized) + m := func() error { + authHdr := r.Header.Get("Authorization") + if authHdr == "" { + err := errors.New("missing Authorization header") + return web.NewRequestError(err, http.StatusUnauthorized) + } + + tknStr, err := parseAuthHeader(authHdr) + if err != nil { + return web.NewRequestError(err, http.StatusUnauthorized) + } + + claims, err := authenticator.ParseClaims(tknStr) + if err != nil { + return web.NewRequestError(err, http.StatusUnauthorized) + } + + // Add claims to the context so they can be retrieved later. + ctx = context.WithValue(ctx, auth.Key, claims) + + return nil } - tknStr, err := parseAuthHeader(authHdr) - if err != nil { - return web.NewRequestError(err, http.StatusUnauthorized) + if err := m(); err != nil { + if web.RequestIsJson(r) { + return web.RespondJsonError(ctx, w, err) + } + return err } - claims, err := authenticator.ParseClaims(tknStr) - if err != nil { - return web.NewRequestError(err, http.StatusUnauthorized) - } - - // Add claims to the context so they can be retrieved later. - ctx = context.WithValue(ctx, auth.Key, claims) - return after(ctx, w, r, params) } @@ -68,14 +79,25 @@ func HasRole(roles ...string) web.Middleware { span, ctx := tracer.StartSpanFromContext(ctx, "internal.mid.HasRole") defer span.Finish() - claims, ok := ctx.Value(auth.Key).(auth.Claims) - if !ok { - // TODO(jlw) should this be a web.Shutdown? - return errors.New("claims missing from context: HasRole called without/before Authenticate") + m := func() error { + claims, ok := ctx.Value(auth.Key).(auth.Claims) + if !ok { + // TODO(jlw) should this be a web.Shutdown? + return errors.New("claims missing from context: HasRole called without/before Authenticate") + } + + if !claims.HasRole(roles...) { + return ErrForbidden + } + + return nil } - if !claims.HasRole(roles...) { - return ErrForbidden + if err := m(); err != nil { + if web.RequestIsJson(r) { + return web.RespondJsonError(ctx, w, err) + } + return err } return after(ctx, w, r, params) @@ -97,3 +119,6 @@ func parseAuthHeader(bearerStr string) (string, error) { return split[1], nil } + + + diff --git a/example-project/internal/mid/errors.go b/example-project/internal/mid/errors.go index d572b35..ed7cd18 100644 --- a/example-project/internal/mid/errors.go +++ b/example-project/internal/mid/errors.go @@ -28,8 +28,14 @@ func Errors(log *log.Logger) web.Middleware { log.Printf("%d : ERROR : %+v", span.Context().TraceID(), err) // Respond to the error. - if err := web.RespondError(ctx, w, err); err != nil { - return err + if web.RequestIsJson(r) { + if err := web.RespondJsonError(ctx, w, err); err != nil { + return err + } + } else { + if err := web.RespondError(ctx, w, err); err != nil { + return err + } } // If we receive the shutdown err we need to return it diff --git a/example-project/internal/platform/web/errors.go b/example-project/internal/platform/web/errors.go index 08044c6..23ffa51 100644 --- a/example-project/internal/platform/web/errors.go +++ b/example-project/internal/platform/web/errors.go @@ -2,6 +2,8 @@ package web import ( "github.com/pkg/errors" + "gopkg.in/go-playground/validator.v9" + "net/http" ) // FieldError is used to indicate an error with a specific request field. @@ -27,6 +29,12 @@ type Error struct { // NewRequestError wraps a provided error with an HTTP status code. This // function should be used when handlers encounter expected errors. func NewRequestError(err error, status int) error { + + // if its a validation error then + if verr, ok := NewValidationError(err); ok { + return verr + } + return &Error{err, status, nil} } @@ -52,6 +60,35 @@ func NewShutdownError(message string) error { return &shutdown{message} } +// NewValidationError checks the error for validation errors and formats the correct response. +func NewValidationError(err error) (error, bool) { + + // Use a type assertion to get the real error value. + verrors, ok := errors.Cause(err).(validator.ValidationErrors) + if !ok { + return err, false + } + + // lang controls the language of the error messages. You could look at the + // Accept-Language header if you intend to support multiple languages. + lang, _ := translator.GetTranslator("en") + + var fields []FieldError + for _, verror := range verrors { + field := FieldError{ + Field: verror.Field(), + Error: verror.Translate(lang), + } + fields = append(fields, field) + } + + return &Error{ + Err: errors.New("field validation error"), + Status: http.StatusBadRequest, + Fields: fields, + }, true +} + // IsShutdown checks to see if the shutdown error is contained // in the specified error value. func IsShutdown(err error) bool { diff --git a/example-project/internal/platform/web/request.go b/example-project/internal/platform/web/request.go index 633ca07..41fb91d 100644 --- a/example-project/internal/platform/web/request.go +++ b/example-project/internal/platform/web/request.go @@ -36,7 +36,21 @@ func init() { en_translations.RegisterDefaultTranslations(validate, lang) // Use JSON tag names for errors instead of Go struct names. - validate.RegisterTagNameFunc(func(fld reflect.StructField) string { + validate = NewValidator() + + // Empty method that can be overwritten in business logic packages to prevent web.Decode from failing. + f := func(fl validator.FieldLevel) bool { + return true + } + validate.RegisterValidation("unique", f) +} + +// NewValidator inits a new validator with custom settings. +func NewValidator() *validator.Validate { + var v = validator.New() + + // Use JSON tag names for errors instead of Go struct names. + v.RegisterTagNameFunc(func(fld reflect.StructField) string { name := strings.SplitN(fld.Tag.Get("json"), ",", 2)[0] if name == "-" { return "" @@ -44,10 +58,7 @@ func init() { return name }) - f := func(fl validator.FieldLevel) bool { - return true - } - validate.RegisterValidation("unique", f) + return v } // Decode reads the body of an HTTP request looking for a JSON document. The @@ -56,45 +67,24 @@ func init() { // If the provided value is a struct then it is checked for validation tags. func Decode(r *http.Request, val interface{}) error { - if r.Method == http.MethodPost || r.Method == http.MethodPut || r.Method == http.MethodDelete { + if r.Method == http.MethodPost || r.Method == http.MethodPut || r.Method == http.MethodPatch || r.Method == http.MethodDelete { decoder := json.NewDecoder(r.Body) decoder.DisallowUnknownFields() if err := decoder.Decode(val); err != nil { + err = errors.Wrap(err, "decode request body failed") return NewRequestError(err, http.StatusBadRequest) } } else { decoder := schema.NewDecoder() if err := decoder.Decode(val, r.URL.Query()); err != nil { + err = errors.Wrap(err, "decode request query failed") return NewRequestError(err, http.StatusBadRequest) } } if err := validate.Struct(val); err != nil { - - // Use a type assertion to get the real error value. - verrors, ok := err.(validator.ValidationErrors) - if !ok { - return err - } - - // lang controls the language of the error messages. You could look at the - // Accept-Language header if you intend to support multiple languages. - lang, _ := translator.GetTranslator("en") - - var fields []FieldError - for _, verror := range verrors { - field := FieldError{ - Field: verror.Field(), - Error: verror.Translate(lang), - } - fields = append(fields, field) - } - - return &Error{ - Err: errors.New("field validation error"), - Status: http.StatusBadRequest, - Fields: fields, - } + verr, _ := NewValidationError(err) + return verr } return nil @@ -139,3 +129,30 @@ func ExtractWhereArgs(where string) (string, []interface{}, error) { return where, vals, nil } + + +func RequestIsJson(r *http.Request) bool { + if r == nil { + return false + } + if v := r.Header.Get("Content-type"); v != "" { + for _, hv := range strings.Split(v, ";") { + if strings.ToLower(hv) == "application/json" { + return true + } + } + } + + if v := r.URL.Query().Get("ResponseFormat"); v != "" { + if strings.ToLower(v) == "json" { + return true + } + } + + if strings.HasSuffix(r.URL.Path, ".json") { + return true + } + + return false +} + diff --git a/example-project/internal/platform/web/response.go b/example-project/internal/platform/web/response.go index 77862cc..8964c42 100644 --- a/example-project/internal/platform/web/response.go +++ b/example-project/internal/platform/web/response.go @@ -32,7 +32,12 @@ func RespondJsonError(ctx context.Context, w http.ResponseWriter, err error) err // If the error was of the type *Error, the handler has // a specific status code and error to return. - if webErr, ok := errors.Cause(err).(*Error); ok { + webErr, ok := errors.Cause(err).(*Error) + if !ok { + webErr, ok = err.(*Error) + } + + if ok { er := ErrorResponse{ Error: webErr.Err.Error(), Fields: webErr.Fields, diff --git a/example-project/internal/project/project.go b/example-project/internal/project/project.go index c918b21..22603de 100644 --- a/example-project/internal/project/project.go +++ b/example-project/internal/project/project.go @@ -4,12 +4,12 @@ import ( "context" "database/sql" "geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/auth" + "geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/web" "github.com/huandu/go-sqlbuilder" "github.com/jmoiron/sqlx" "github.com/pborman/uuid" "github.com/pkg/errors" "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer" - "gopkg.in/go-playground/validator.v9" "time" ) @@ -21,10 +21,9 @@ const ( var ( // ErrNotFound abstracts the postgres not found error. ErrNotFound = errors.New("Entity not found") + // ErrForbidden occurs when a user tries to do something that is forbidden to them according to our access control policies. ErrForbidden = errors.New("Attempted action is not allowed") - // ErrInvalidID occurs when an ID is not in a valid form. - ErrInvalidID = errors.New("ID is not in its proper form") ) // projectMapColumns is the list of columns needed for mapRowsToProject @@ -193,14 +192,16 @@ func find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, query *sqlbu func Read(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, id string, includedArchived bool) (*Project, error) { span, ctx := tracer.StartSpanFromContext(ctx, "internal.project.Read") defer span.Finish() + // Filter base select query by id query := selectQuery() query.Where(query.Equal("id", id)) + res, err := find(ctx, claims, dbConn, query, []interface{}{}, includedArchived) - if err != nil { + if res == nil || len(res) == 0 { + err = errors.WithMessagef(ErrNotFound, "account %s not found", id) return nil, err - } else if res == nil || len(res) == 0 { - err = errors.WithMessagef(ErrNotFound, "project %s not found", id) + } else if err != nil { return nil, err } @@ -231,8 +232,8 @@ func Create(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req Projec } - v := validator.New() // Validate the request. + v := web.NewValidator() err := v.Struct(req) if err != nil { return nil, err @@ -301,8 +302,9 @@ func Create(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req Projec func Update(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req ProjectUpdateRequest, now time.Time) error { span, ctx := tracer.StartSpanFromContext(ctx, "internal.project.Update") defer span.Finish() - v := validator.New() + // Validate the request. + v := web.NewValidator() err := v.Struct(req) if err != nil { return err @@ -372,7 +374,8 @@ func Archive(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req Proje defer span.Finish() // Validate the request. - err := validator.New().Struct(req) + v := web.NewValidator() + err := v.Struct(req) if err != nil { return err } @@ -418,12 +421,15 @@ func Archive(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req Proje func Delete(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, id string) error { span, ctx := tracer.StartSpanFromContext(ctx, "internal.project.Delete") defer span.Finish() + // Defines the struct to apply validation req := struct { ID string `validate:"required,uuid"` }{} + // Validate the request. - err := validator.New().Struct(req) + v := web.NewValidator() + err := v.Struct(req) if err != nil { return err } diff --git a/example-project/internal/signup/models.go b/example-project/internal/signup/models.go index 27ba402..1315425 100644 --- a/example-project/internal/signup/models.go +++ b/example-project/internal/signup/models.go @@ -1,6 +1,7 @@ package signup import ( + "context" "geeks-accelerator/oss/saas-starter-kit/example-project/internal/account" "geeks-accelerator/oss/saas-starter-kit/example-project/internal/user" ) @@ -31,8 +32,33 @@ type SignupUser struct { PasswordConfirm string `json:"password_confirm" validate:"eqfield=Password" example:"SecretString"` } -// SignupResponse response signup with created account and user. -type SignupResponse struct { +// SignupResult response signup with created account and user. +type SignupResult struct { Account *account.Account `json:"account"` User *user.User `json:"user"` } + +// SignupResponse represents the user and account created for signup that is returned for display. +type SignupResponse struct { + Account *account.AccountResponse `json:"account"` + User *user.UserResponse `json:"user"` +} + +// Response transforms SignupResult to SignupResponse that is used for display. +// Additional filtering by context values or translations could be applied. +func (m *SignupResult) Response(ctx context.Context) *SignupResponse { + if m == nil { + return nil + } + + r := &SignupResponse{} + if m.Account != nil { + r.Account = m.Account.Response(ctx) + } + if m.User != nil { + r.User = m.User.Response(ctx) + } + + return r +} + diff --git a/example-project/internal/signup/signup.go b/example-project/internal/signup/signup.go index 963ce9b..d9ada05 100644 --- a/example-project/internal/signup/signup.go +++ b/example-project/internal/signup/signup.go @@ -2,6 +2,7 @@ package signup import ( "context" + "geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/web" "time" "geeks-accelerator/oss/saas-starter-kit/example-project/internal/account" @@ -15,12 +16,10 @@ import ( // Signup performs the steps needed to create a new account, new user and then associate // both records with a new user_account entry. -func Signup(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req SignupRequest, now time.Time) (*SignupResponse, error) { +func Signup(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req SignupRequest, now time.Time) (*SignupResult, error) { span, ctx := tracer.StartSpanFromContext(ctx, "internal.signup.Signup") defer span.Finish() - v := validator.New() - // Validate the user email address is unique in the database. uniqEmail, err := user.UniqueEmail(ctx, dbConn, req.User.Email, "") if err != nil { @@ -33,6 +32,7 @@ func Signup(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req Signup return nil, err } + f := func(fl validator.FieldLevel) bool { if fl.Field().String() == "invalid" { return false @@ -40,14 +40,16 @@ func Signup(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req Signup var uniq bool switch fl.FieldName() { - case "Name": + case "Name", "name": uniq = uniqName - case "Email": + case "Email", "email": uniq = uniqEmail } return uniq } + + v := web.NewValidator() v.RegisterValidation("unique", f) // Validate the request. @@ -56,7 +58,7 @@ func Signup(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req Signup return nil, err } - var resp SignupResponse + var resp SignupResult // UserCreateRequest contains information needed to create a new User. userReq := user.UserCreateRequest{ diff --git a/example-project/internal/signup/signup_test.go b/example-project/internal/signup/signup_test.go index 56c878f..a975520 100644 --- a/example-project/internal/signup/signup_test.go +++ b/example-project/internal/signup/signup_test.go @@ -32,12 +32,12 @@ func TestSignupValidation(t *testing.T) { var userTests = []struct { name string req SignupRequest - expected func(req SignupRequest, res *SignupResponse) *SignupResponse + expected func(req SignupRequest, res *SignupResult) *SignupResult error error }{ {"Required Fields", SignupRequest{}, - func(req SignupRequest, res *SignupResponse) *SignupResponse { + func(req SignupRequest, res *SignupResult) *SignupResult { return nil }, errors.New("Key: 'SignupRequest.Account.Name' Error:Field validation for 'Name' failed on the 'required' tag\n" + diff --git a/example-project/internal/user/auth.go b/example-project/internal/user/auth.go index 62e5fe2..a280973 100644 --- a/example-project/internal/user/auth.go +++ b/example-project/internal/user/auth.go @@ -4,6 +4,7 @@ import ( "context" "crypto/rsa" "database/sql" + "geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/web" "github.com/dgrijalva/jwt-go" "strings" "time" @@ -15,7 +16,6 @@ import ( "github.com/pkg/errors" "golang.org/x/crypto/bcrypt" "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer" - "gopkg.in/go-playground/validator.v9" ) // TokenGenerator is the behavior we need in our Authenticate to generate tokens for @@ -80,7 +80,8 @@ func SwitchAccount(ctx context.Context, dbConn *sqlx.DB, tknGen TokenGenerator, } // Validate the request. - err := validator.New().Struct(req) + v := web.NewValidator() + err := v.Struct(req) if err != nil { return Token{}, err } diff --git a/example-project/internal/user/user.go b/example-project/internal/user/user.go index 2ec3907..fae4329 100644 --- a/example-project/internal/user/user.go +++ b/example-project/internal/user/user.go @@ -3,6 +3,7 @@ package user import ( "context" "database/sql" + "geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/web" "time" "geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/auth" @@ -28,9 +29,6 @@ var ( // ErrNotFound abstracts the mgo not found error. ErrNotFound = errors.New("Entity not found") - // ErrInvalidID occurs when an ID is not in a valid form. - ErrInvalidID = errors.New("ID is not in its proper form") - // ErrForbidden occurs when a user tries to do something that is forbidden to them according to our access control policies. ErrForbidden = errors.New("Attempted action is not allowed") @@ -261,8 +259,6 @@ func Create(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserCr span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.Create") defer span.Finish() - v := validator.New() - // Validation email address is unique in the database. uniq, err := UniqueEmail(ctx, dbConn, req.Email, "") if err != nil { @@ -274,6 +270,8 @@ func Create(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserCr } return uniq } + + v := web.NewValidator() v.RegisterValidation("unique", f) // Validate the request. @@ -356,11 +354,11 @@ func Read(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, id string, i query.Where(query.Equal("id", id)) res, err := find(ctx, claims, dbConn, query, []interface{}{}, includedArchived) - if err != nil { - return nil, err - } else if res == nil || len(res) == 0 { + if res == nil || len(res) == 0 { err = errors.WithMessagef(ErrNotFound, "user %s not found", id) return nil, err + } else if err != nil { + return nil, err } u := res[0] @@ -372,7 +370,7 @@ func Update(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserUp span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.Update") defer span.Finish() - v := validator.New() + v := web.NewValidator() // Validation email address is unique in the database. if req.Email != nil { @@ -458,7 +456,8 @@ func UpdatePassword(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, re defer span.Finish() // Validate the request. - err := validator.New().Struct(req) + v := web.NewValidator() + err := v.Struct(req) if err != nil { return err } @@ -526,7 +525,8 @@ func Archive(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserA defer span.Finish() // Validate the request. - err := validator.New().Struct(req) + v := web.NewValidator() + err := v.Struct(req) if err != nil { return err } @@ -604,7 +604,8 @@ func Delete(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, userID str } // Validate the request. - err := validator.New().Struct(req) + v := web.NewValidator() + err := v.Struct(req) if err != nil { return err } diff --git a/example-project/internal/user_account/user_account.go b/example-project/internal/user_account/user_account.go index b6c4f48..be91a02 100644 --- a/example-project/internal/user_account/user_account.go +++ b/example-project/internal/user_account/user_account.go @@ -4,6 +4,7 @@ import ( "context" "database/sql" "geeks-accelerator/oss/saas-starter-kit/example-project/internal/account" + "geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/web" "time" "geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/auth" @@ -12,16 +13,12 @@ import ( "github.com/pborman/uuid" "github.com/pkg/errors" "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer" - "gopkg.in/go-playground/validator.v9" ) var ( // ErrNotFound abstracts the mgo not found error. ErrNotFound = errors.New("Entity not found") - // ErrInvalidID occurs when an ID is not in a valid form. - ErrInvalidID = errors.New("ID is not in its proper form") - // ErrForbidden occurs when a user tries to do something that is forbidden to them according to our access control policies. ErrForbidden = errors.New("Attempted action is not allowed") ) @@ -64,8 +61,6 @@ func mapAccountError(err error) error { switch errors.Cause(err) { case account.ErrNotFound: err = ErrNotFound - case account.ErrInvalidID: - err = ErrInvalidID case account.ErrForbidden: err = ErrForbidden } @@ -206,7 +201,8 @@ func Create(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserAc defer span.Finish() // Validate the request. - err := validator.New().Struct(req) + v := web.NewValidator() + err := v.Struct(req) if err != nil { return nil, err } @@ -303,10 +299,10 @@ func Read(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, id string, i query.Where(query.Equal("id", id)) res, err := find(ctx, claims, dbConn, query, []interface{}{}, includedArchived) - if err != nil { + if res == nil || len(res) == 0 { + err = errors.WithMessagef(ErrNotFound, "account %s not found", id) return nil, err - } else if res == nil || len(res) == 0 { - err = errors.WithMessagef(ErrNotFound, "user account %s not found", id) + } else if err != nil { return nil, err } u := res[0] @@ -320,7 +316,8 @@ func Update(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserAc defer span.Finish() // Validate the request. - err := validator.New().Struct(req) + v := web.NewValidator() + err := v.Struct(req) if err != nil { return err } @@ -392,7 +389,8 @@ func Archive(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserA defer span.Finish() // Validate the request. - err := validator.New().Struct(req) + v := web.NewValidator() + err := v.Struct(req) if err != nil { return err } @@ -443,7 +441,8 @@ func Delete(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserAc defer span.Finish() // Validate the request. - err := validator.New().Struct(req) + v := web.NewValidator() + err := v.Struct(req) if err != nil { return err }