From 4c25d50c761b5dbc08bc04d89301fcdf99f98f60 Mon Sep 17 00:00:00 2001 From: Lee Brown Date: Mon, 5 Aug 2019 17:12:28 -0800 Subject: [PATCH] fix where, auth use request arg --- cmd/web-api/handlers/project.go | 2 +- cmd/web-api/handlers/signup.go | 4 + cmd/web-api/handlers/user.go | 39 ++---- cmd/web-api/handlers/user_account.go | 2 +- cmd/web-api/tests/account_test.go | 20 ++- cmd/web-api/tests/project_test.go | 2 +- cmd/web-api/tests/signup_test.go | 15 ++- cmd/web-api/tests/tests_test.go | 10 +- cmd/web-api/tests/user_account_test.go | 9 +- cmd/web-api/tests/user_test.go | 19 ++- cmd/web-app/handlers/projects.go | 3 +- cmd/web-app/handlers/user.go | 71 ++++++---- cmd/web-app/handlers/users.go | 74 +++++++++-- .../content/user-reset-confirm.gohtml | 1 - .../templates/content/users-create.gohtml | 75 +++++------ .../templates/content/users-update.gohtml | 121 +++++++++--------- .../templates/content/users-view.gohtml | 10 +- internal/account/account.go | 4 +- .../account_preference/account_preference.go | 4 +- .../account_preference_test.go | 11 +- internal/account/account_preference/models.go | 2 +- internal/account/account_test.go | 13 +- internal/account/models.go | 2 +- internal/platform/web/request.go | 5 +- internal/project/models.go | 2 +- internal/project/project.go | 4 +- internal/project/project_test.go | 4 +- internal/user/models.go | 73 ++++++++++- internal/user/user.go | 58 ++------- internal/user/user_test.go | 23 ++-- internal/user_account/invite/invite.go | 9 +- internal/user_account/invite/invite_test.go | 31 ++++- internal/user_account/invite/models.go | 3 - internal/user_account/models.go | 56 ++++---- internal/user_account/user_account.go | 4 +- internal/user_account/user_account_test.go | 25 ++-- internal/user_auth/auth.go | 15 ++- internal/user_auth/auth_test.go | 49 +++++-- internal/user_auth/models.go | 5 +- 39 files changed, 532 insertions(+), 347 deletions(-) diff --git a/cmd/web-api/handlers/project.go b/cmd/web-api/handlers/project.go index 7b49b35..4a0e09f 100644 --- a/cmd/web-api/handlers/project.go +++ b/cmd/web-api/handlers/project.go @@ -55,7 +55,7 @@ func (p *Project) Find(ctx context.Context, w http.ResponseWriter, r *http.Reque if err != nil { return web.RespondJsonError(ctx, w, weberror.NewError(ctx, err, http.StatusBadRequest)) } - req.Where = &where + req.Where = where req.Args = args } diff --git a/cmd/web-api/handlers/signup.go b/cmd/web-api/handlers/signup.go index 49c3571..208ec43 100644 --- a/cmd/web-api/handlers/signup.go +++ b/cmd/web-api/handlers/signup.go @@ -42,6 +42,10 @@ func (c *Signup) Signup(ctx context.Context, w http.ResponseWriter, r *http.Requ // Claims are optional as authentication is not required ATM for this method. claims, _ := auth.ClaimsFromContext(ctx) + // Hack to allow custom validation to be handled by business logic package. + ctx = context.WithValue(ctx, signup.KeyTagUniqueEmail, true) + ctx = context.WithValue(ctx, signup.KeyTagUniqueName, true) + var req signup.SignupRequest if err := web.Decode(ctx, r, &req); err != nil { if _, ok := errors.Cause(err).(*weberror.Error); !ok { diff --git a/cmd/web-api/handlers/user.go b/cmd/web-api/handlers/user.go index 8457b1a..a695b9c 100644 --- a/cmd/web-api/handlers/user.go +++ b/cmd/web-api/handlers/user.go @@ -60,7 +60,7 @@ func (u *User) Find(ctx context.Context, w http.ResponseWriter, r *http.Request, if err != nil { return web.RespondJsonError(ctx, w, weberror.NewError(ctx, err, http.StatusBadRequest)) } - req.Where = &where + req.Where = where req.Args = args } @@ -442,7 +442,9 @@ func (u *User) SwitchAccount(ctx context.Context, w http.ResponseWriter, r *http return err } - tkn, err := user_auth.SwitchAccount(ctx, u.MasterDB, u.TokenGenerator, claims, params["account_id"], sessionTtl, v.Now) + tkn, err := user_auth.SwitchAccount(ctx, u.MasterDB, u.TokenGenerator, claims, user_auth.SwitchAccountRequest{ + AccountID: params["account_id"], + }, sessionTtl, v.Now) if err != nil { cause := errors.Cause(err) switch cause { @@ -486,10 +488,16 @@ func (u *User) Token(ctx context.Context, w http.ResponseWriter, r *http.Request return web.RespondJsonError(ctx, w, weberror.NewError(ctx, err, http.StatusUnauthorized)) } + accountID := r.URL.Query().Get("account_id") + // Optional to include scope. scope := r.URL.Query().Get("scope") - tkn, err := user_auth.Authenticate(ctx, u.MasterDB, u.TokenGenerator, email, pass, sessionTtl, v.Now, scope) + tkn, err := user_auth.Authenticate(ctx, u.MasterDB, u.TokenGenerator, user_auth.AuthenticateRequest{ + Email: email, + Password: pass, + AccountID: accountID, + }, sessionTtl, v.Now, scope) if err != nil { cause := errors.Cause(err) switch cause { @@ -505,30 +513,5 @@ func (u *User) Token(ctx context.Context, w http.ResponseWriter, r *http.Request } } - accountID := r.URL.Query().Get("account_id") - if accountID != "" && accountID != tkn.AccountID { - - claims, err := u.TokenGenerator.ParseClaims(tkn.AccessToken) - if err != nil { - return err - } - - tkn, err = user_auth.SwitchAccount(ctx, u.MasterDB, u.TokenGenerator, claims, accountID, sessionTtl, v.Now) - if err != nil { - cause := errors.Cause(err) - switch cause { - case user_auth.ErrAuthenticationFailure: - return web.RespondJsonError(ctx, w, weberror.NewError(ctx, err, http.StatusUnauthorized)) - default: - _, ok := cause.(validator.ValidationErrors) - if ok { - return web.RespondJsonError(ctx, w, weberror.NewError(ctx, err, http.StatusBadRequest)) - } - - return errors.Wrap(err, "switch account") - } - } - } - return web.RespondJson(ctx, w, tkn, http.StatusOK) } diff --git a/cmd/web-api/handlers/user_account.go b/cmd/web-api/handlers/user_account.go index 849a52a..091fc2c 100644 --- a/cmd/web-api/handlers/user_account.go +++ b/cmd/web-api/handlers/user_account.go @@ -55,7 +55,7 @@ func (u *UserAccount) Find(ctx context.Context, w http.ResponseWriter, r *http.R if err != nil { return web.RespondJsonError(ctx, w, weberror.NewError(ctx, err, http.StatusBadRequest)) } - req.Where = &where + req.Where = where req.Args = args } diff --git a/cmd/web-api/tests/account_test.go b/cmd/web-api/tests/account_test.go index ffde8ed..918aa6c 100644 --- a/cmd/web-api/tests/account_test.go +++ b/cmd/web-api/tests/account_test.go @@ -102,9 +102,13 @@ func TestAccountCRUDAdmin(t *testing.T) { "address1": tr.Account.Address1, "city": tr.Account.City, "status": map[string]interface{}{ - "value": "active", - "title": "Active", - "options": []map[string]interface{}{{"selected": false, "title": "[Active Pending Disabled]", "value": "[active pending disabled]"}}, + "value": "active", + "title": "Active", + "options": []map[string]interface{}{ + {"selected": true, "title": "Active", "value": "active"}, + {"selected": false, "title": "Pending", "value": "pending"}, + {"selected": false, "title": "Disabled", "value": "disabled"}, + }, }, "signup_user_id": &tr.Account.SignupUserID.String, } @@ -322,9 +326,13 @@ func TestAccountCRUDUser(t *testing.T) { "address1": tr.Account.Address1, "city": tr.Account.City, "status": map[string]interface{}{ - "value": "active", - "title": "Active", - "options": []map[string]interface{}{{"selected": false, "title": "[Active Pending Disabled]", "value": "[active pending disabled]"}}, + "value": "active", + "title": "Active", + "options": []map[string]interface{}{ + {"selected": true, "title": "Active", "value": "active"}, + {"selected": false, "title": "Pending", "value": "pending"}, + {"selected": false, "title": "Disabled", "value": "disabled"}, + }, }, "signup_user_id": &tr.Account.SignupUserID.String, } diff --git a/cmd/web-api/tests/project_test.go b/cmd/web-api/tests/project_test.go index a42c749..fbefebd 100644 --- a/cmd/web-api/tests/project_test.go +++ b/cmd/web-api/tests/project_test.go @@ -79,7 +79,7 @@ func TestProjectCRUDAdmin(t *testing.T) { "updated_at": web.NewTimeResponse(ctx, actual.UpdatedAt.Value), "id": actual.ID, "account_id": req.AccountID, - "status": web.NewEnumResponse(ctx, "active", project.ProjectStatus_Values), + "status": web.NewEnumResponse(ctx, "active", project.ProjectStatus_ValuesInterface()...), "created_at": web.NewTimeResponse(ctx, actual.CreatedAt.Value), "name": req.Name, } diff --git a/cmd/web-api/tests/signup_test.go b/cmd/web-api/tests/signup_test.go index 2456a03..686f2a7 100644 --- a/cmd/web-api/tests/signup_test.go +++ b/cmd/web-api/tests/signup_test.go @@ -56,7 +56,10 @@ func newMockSignup() mockSignup { } expires := time.Now().UTC().Sub(s.User.CreatedAt) + time.Hour - tkn, err := user_auth.Authenticate(tests.Context(), test.MasterDB, authenticator, req.User.Email, req.User.Password, expires, now) + tkn, err := user_auth.Authenticate(tests.Context(), test.MasterDB, authenticator, user_auth.AuthenticateRequest{ + Email: req.User.Email, + Password: req.User.Password, + }, expires, now) if err != nil { panic(err) } @@ -139,9 +142,13 @@ func TestSignup(t *testing.T) { "address1": req.Account.Address1, "city": req.Account.City, "status": map[string]interface{}{ - "value": "active", - "title": "Active", - "options": []map[string]interface{}{{"selected": false, "title": "[Active Pending Disabled]", "value": "[active pending disabled]"}}, + "value": "active", + "title": "Active", + "options": []map[string]interface{}{ + {"selected": true, "title": "Active", "value": "active"}, + {"selected": false, "title": "Pending", "value": "pending"}, + {"selected": false, "title": "Disabled", "value": "disabled"}, + }, }, "signup_user_id": &actual.Account.SignupUserID, }, diff --git a/cmd/web-api/tests/tests_test.go b/cmd/web-api/tests/tests_test.go index 0edec3c..83608f5 100644 --- a/cmd/web-api/tests/tests_test.go +++ b/cmd/web-api/tests/tests_test.go @@ -95,7 +95,10 @@ func testMain(m *testing.M) int { } expires := time.Now().UTC().Sub(signup1.User.CreatedAt) + time.Hour - adminTkn, err := user_auth.Authenticate(tests.Context(), test.MasterDB, authenticator, signupReq1.User.Email, signupReq1.User.Password, expires, now) + adminTkn, err := user_auth.Authenticate(tests.Context(), test.MasterDB, authenticator, user_auth.AuthenticateRequest{ + Email: signupReq1.User.Email, + Password: signupReq1.User.Password, + }, expires, now) if err != nil { panic(err) } @@ -146,7 +149,10 @@ func testMain(m *testing.M) int { panic(err) } - userTkn, err := user_auth.Authenticate(tests.Context(), test.MasterDB, authenticator, usr.Email, userReq.Password, expires, now) + userTkn, err := user_auth.Authenticate(tests.Context(), test.MasterDB, authenticator, user_auth.AuthenticateRequest{ + Email: usr.Email, + Password: userReq.Password, + }, expires, now) if err != nil { panic(err) } diff --git a/cmd/web-api/tests/user_account_test.go b/cmd/web-api/tests/user_account_test.go index 00f4ebb..e1c650c 100644 --- a/cmd/web-api/tests/user_account_test.go +++ b/cmd/web-api/tests/user_account_test.go @@ -89,13 +89,18 @@ func TestUserAccountCRUDAdmin(t *testing.T) { } created = actual + var roles []interface{} + for _, r := range req.Roles { + roles = append(roles, r) + } + expectedMap := map[string]interface{}{ "updated_at": web.NewTimeResponse(ctx, actual.UpdatedAt.Value), //"id": actual.ID, "account_id": req.AccountID, "user_id": req.UserID, - "status": web.NewEnumResponse(ctx, "active", user_account.UserAccountStatus_Values), - "roles": req.Roles, + "status": web.NewEnumResponse(ctx, "active", user_account.UserAccountStatus_ValuesInterface()...), + "roles": web.NewEnumMultiResponse(ctx, roles, user_account.UserAccountRole_ValuesInterface()...), "created_at": web.NewTimeResponse(ctx, actual.CreatedAt.Value), } diff --git a/cmd/web-api/tests/user_test.go b/cmd/web-api/tests/user_test.go index 19c3902..d2778e6 100644 --- a/cmd/web-api/tests/user_test.go +++ b/cmd/web-api/tests/user_test.go @@ -1419,7 +1419,7 @@ func TestUserToken(t *testing.T) { // Test user token with invalid email. { - expectedStatus := http.StatusUnauthorized + expectedStatus := http.StatusBadRequest rt := requestTest{ fmt.Sprintf("Token %d using invalid email", expectedStatus), @@ -1434,7 +1434,9 @@ func TestUserToken(t *testing.T) { t.Logf("\tTest: %s - %s %s", rt.name, rt.method, rt.url) r := httptest.NewRequest(rt.method, rt.url, nil) - r.SetBasicAuth("invalid email.com", "some random password") + + invalidEmail := "invalid email.com" + r.SetBasicAuth(invalidEmail, "some random password") w := httptest.NewRecorder() r.Header.Set("Content-Type", web.MIMEApplicationJSONCharsetUTF8) @@ -1456,8 +1458,17 @@ func TestUserToken(t *testing.T) { expected := weberror.ErrorResponse{ StatusCode: expectedStatus, - Error: http.StatusText(expectedStatus), - Details: user_auth.ErrAuthenticationFailure.Error(), + Error: "Field validation error", + Fields: []weberror.FieldError{ + { + Field: "email", + Value: invalidEmail, + Tag: "email", + Error: "email must be a valid email address", + Display: "email must be a valid email address", + }, + }, + Details: actual.Details, StackTrace: actual.StackTrace, } diff --git a/cmd/web-app/handlers/projects.go b/cmd/web-app/handlers/projects.go index 630465d..1d8ad82 100644 --- a/cmd/web-app/handlers/projects.go +++ b/cmd/web-app/handlers/projects.go @@ -110,9 +110,8 @@ func (h *Projects) Index(ctx context.Context, w http.ResponseWriter, r *http.Req } loadFunc := func(ctx context.Context, sorting string, fields []datatable.DisplayField) (resp [][]datatable.ColumnValue, err error) { - whereFilter := "account_id = ?" res, err := project.Find(ctx, claims, h.MasterDB, project.ProjectFindRequest{ - Where: &whereFilter, + Where: "account_id = ?", Args: []interface{}{claims.Audience}, Order: strings.Split(sorting, ","), }) diff --git a/cmd/web-app/handlers/user.go b/cmd/web-app/handlers/user.go index b8ff5e4..618586c 100644 --- a/cmd/web-app/handlers/user.go +++ b/cmd/web-app/handlers/user.go @@ -209,6 +209,8 @@ func (h *User) ResetPassword(ctx context.Context, w http.ResponseWriter, r *http // ResetConfirm handles changing a users password after they have clicked on the link emailed. func (h *User) ResetConfirm(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { + resetHash := params["hash"] + ctxValues, err := webcontext.ContextValues(ctx) if err != nil { return err @@ -217,31 +219,36 @@ func (h *User) ResetConfirm(ctx context.Context, w http.ResponseWriter, r *http. // req := new(user.UserResetConfirmRequest) data := make(map[string]interface{}) - f := func() error { + f := func() (bool, error) { if r.Method == http.MethodPost { err := r.ParseForm() if err != nil { - return err + return false, err } decoder := schema.NewDecoder() if err := decoder.Decode(req, r.PostForm); err != nil { - return err + return false, err } // Append the query param value to the request. - req.ResetHash = params["hash"] + req.ResetHash = resetHash u, err := user.ResetConfirm(ctx, h.MasterDB, *req, h.SecretKey, ctxValues.Now) if err != nil { switch errors.Cause(err) { + case user.ErrResetExpired: + webcontext.SessionFlashError(ctx, + "Reset Expired", + "The reset has expired.") + return false, nil default: if verr, ok := weberror.NewValidationError(ctx, err); ok { data["validationErrors"] = verr.(*weberror.Error) - return nil + return false, nil } else { - return err + return false, err } } } @@ -249,34 +256,51 @@ func (h *User) ResetConfirm(ctx context.Context, w http.ResponseWriter, r *http. // Authenticated the user. Probably should use the default session TTL from UserLogin. token, err := user_auth.Authenticate(ctx, h.MasterDB, h.Authenticator, u.Email, req.Password, time.Hour, ctxValues.Now) if err != nil { - switch errors.Cause(err) { - case account.ErrForbidden: - return web.RespondError(ctx, w, weberror.NewError(ctx, err, http.StatusForbidden)) - default: - if verr, ok := weberror.NewValidationError(ctx, err); ok { - data["validationErrors"] = verr.(*weberror.Error) - return nil - } else { - return err - } + if verr, ok := weberror.NewValidationError(ctx, err); ok { + data["validationErrors"] = verr.(*weberror.Error) + return false, nil + } else { + return false, err } } // Add the token to the users session. err = handleSessionToken(ctx, h.MasterDB, w, r, token) if err != nil { - return err + return false, err } // Redirect the user to the dashboard. http.Redirect(w, r, "/", http.StatusFound) + return true, nil } - return nil + _, err = user.ParseResetHash(ctx, h.SecretKey, resetHash, ctxValues.Now) + if err != nil { + switch errors.Cause(err) { + case user.ErrResetExpired: + webcontext.SessionFlashError(ctx, + "Reset Expired", + "The reset has expired.") + return false, nil + default: + if verr, ok := weberror.NewValidationError(ctx, err); ok { + data["validationErrors"] = verr.(*weberror.Error) + return false, nil + } else { + return false, err + } + } + } + + return false, nil } - if err := f(); err != nil { + end, err := f() + if err != nil { return web.RenderError(ctx, w, r, err, h.Renderer, TmplLayoutBase, TmplContentErrorGeneric, web.MIMETextHTMLCharsetUTF8) + } else if end { + return nil } data["form"] = req @@ -572,9 +596,8 @@ func (h *User) VirtualLogin(ctx context.Context, w http.ResponseWriter, r *http. return nil } - usrAccFilter := "account_id = ?" usrAccs, err := user_account.Find(ctx, claims, h.MasterDB, user_account.UserAccountFindRequest{ - Where: &usrAccFilter, + Where: "account_id = ?", Args: []interface{}{claims.Audience}, }) if err != nil { @@ -597,10 +620,10 @@ func (h *User) VirtualLogin(ctx context.Context, w http.ResponseWriter, r *http. userPhs = append(userPhs, "?") } - usrFilter := fmt.Sprintf("id IN (%s)", strings.Join(userPhs, ", ")) users, err := user.Find(ctx, claims, h.MasterDB, user.UserFindRequest{ - Where: &usrFilter, - Args: userIDs, + Where: fmt.Sprintf("id IN (%s)", + strings.Join(userPhs, ", ")), + Args: userIDs, }) if err != nil { return err diff --git a/cmd/web-app/handlers/users.go b/cmd/web-app/handlers/users.go index 9be188b..418a49f 100644 --- a/cmd/web-app/handlers/users.go +++ b/cmd/web-app/handlers/users.go @@ -52,12 +52,18 @@ func urlUsersUpdate(userID string) string { return fmt.Sprintf("/users/%s/update", userID) } -// UserLoginRequest extends the AuthenicateRequest with the RememberMe flag. +// UserCreateRequest extends the UserCreateRequest with a list of roles. type UserCreateRequest struct { user.UserCreateRequest Roles user_account.UserAccountRoles `json:"roles" validate:"required,dive,oneof=admin user" enums:"admin,user" swaggertype:"array,string" example:"admin"` } +// UserUpdateRequest extends the UserUpdateRequest with a list of roles. +type UserUpdateRequest struct { + user.UserUpdateRequest + Roles user_account.UserAccountRoles `json:"roles" validate:"required,dive,oneof=admin user" enums:"admin,user" swaggertype:"array,string" example:"admin"` +} + // Index handles listing all the users for the current account. func (h *Users) Index(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { @@ -198,8 +204,6 @@ func (h *Users) Create(ctx context.Context, w http.ResponseWriter, r *http.Reque } decoder := schema.NewDecoder() - decoder.IgnoreUnknownKeys(true) - if err := decoder.Decode(req, r.PostForm); err != nil { return false, err } @@ -279,11 +283,11 @@ func (h *Users) Create(ctx context.Context, w http.ResponseWriter, r *http.Reque return err } - var roleValues []interface{} - for _, v := range user_account.UserAccountRole_Values { - roleValues = append(roleValues, string(v)) + var selectedRoles []interface{} + for _, r := range req.Roles { + selectedRoles = append(selectedRoles, r.String()) } - data["roles"] = web.NewEnumResponse(ctx, nil, roleValues...) + data["roles"] = web.NewEnumMultiResponse(ctx, selectedRoles, user_account.UserAccountRole_ValuesInterface()...) data["form"] = req @@ -389,7 +393,7 @@ func (h *Users) Update(ctx context.Context, w http.ResponseWriter, r *http.Reque } // - req := new(user.UserUpdateRequest) + req := new(UserUpdateRequest) data := make(map[string]interface{}) f := func() (bool, error) { if r.Method == http.MethodPost { @@ -400,13 +404,27 @@ func (h *Users) Update(ctx context.Context, w http.ResponseWriter, r *http.Reque decoder := schema.NewDecoder() decoder.IgnoreUnknownKeys(true) - if err := decoder.Decode(req, r.PostForm); err != nil { return false, err } req.ID = userID - err = user.Update(ctx, claims, h.MasterDB, *req, ctxValues.Now) + // Bypass the uniq check on email here for the moment, it will be caught before the user_account is + // created by user.Create. + ctx = context.WithValue(ctx, webcontext.KeyTagUnique, true) + + // Validate the request. + err = webcontext.Validator().StructCtx(ctx, req) + if err != nil { + if verr, ok := weberror.NewValidationError(ctx, err); ok { + data["validationErrors"] = verr.(*weberror.Error) + return false, nil + } else { + return false, err + } + } + + err = user.Update(ctx, claims, h.MasterDB, req.UserUpdateRequest, ctxValues.Now) if err != nil { switch errors.Cause(err) { default: @@ -419,6 +437,25 @@ func (h *Users) Update(ctx context.Context, w http.ResponseWriter, r *http.Reque } } + if req.Roles != nil { + err = user_account.Update(ctx, claims, h.MasterDB, user_account.UserAccountUpdateRequest{ + UserID: userID, + AccountID: claims.Audience, + Roles: &req.Roles, + }, ctxValues.Now) + if err != nil { + switch errors.Cause(err) { + default: + if verr, ok := weberror.NewValidationError(ctx, err); ok { + data["validationErrors"] = verr.(*weberror.Error) + return false, nil + } else { + return false, err + } + } + } + } + if r.PostForm.Get("Password") != "" { pwdReq := new(user.UserUpdatePasswordRequest) @@ -469,11 +506,20 @@ func (h *Users) Update(ctx context.Context, w http.ResponseWriter, r *http.Reque return err } + usrAcc, err := user_account.Read(ctx, claims, h.MasterDB, user_account.UserAccountReadRequest{ + UserID: userID, + AccountID: claims.Audience, + }) + if err != nil { + return err + } + if req.ID == "" { req.FirstName = &usr.FirstName req.LastName = &usr.LastName req.Email = &usr.Email req.Timezone = usr.Timezone + req.Roles = usrAcc.Roles } data["user"] = usr.Response(ctx) @@ -483,9 +529,15 @@ func (h *Users) Update(ctx context.Context, w http.ResponseWriter, r *http.Reque return err } + var selectedRoles []interface{} + for _, r := range req.Roles { + selectedRoles = append(selectedRoles, r.String()) + } + data["roles"] = web.NewEnumMultiResponse(ctx, selectedRoles, user_account.UserAccountRole_ValuesInterface()...) + data["form"] = req - if verr, ok := weberror.NewValidationError(ctx, webcontext.Validator().Struct(user.UserUpdateRequest{})); ok { + if verr, ok := weberror.NewValidationError(ctx, webcontext.Validator().Struct(UserUpdateRequest{})); ok { data["userValidationDefaults"] = verr.(*weberror.Error) } diff --git a/cmd/web-app/templates/content/user-reset-confirm.gohtml b/cmd/web-app/templates/content/user-reset-confirm.gohtml index 3d869df..9ae5560 100644 --- a/cmd/web-app/templates/content/user-reset-confirm.gohtml +++ b/cmd/web-app/templates/content/user-reset-confirm.gohtml @@ -28,7 +28,6 @@ {{ template "validation-error" . }}
-
diff --git a/cmd/web-app/templates/content/users-create.gohtml b/cmd/web-app/templates/content/users-create.gohtml index 298866b..b132caa 100644 --- a/cmd/web-app/templates/content/users-create.gohtml +++ b/cmd/web-app/templates/content/users-create.gohtml @@ -9,90 +9,83 @@
-
-
- - {{template "invalid-feedback" dict "validationDefaults" $.validationDefaults "validationErrors" $.validationErrors "fieldName" "UserCreateRequest.FirstName" }} + + {{template "invalid-feedback" dict "fieldName" "UserCreateRequest.FirstName" "validationDefaults" $.validationDefaults "validationErrors" $.validationErrors }}
- - {{template "invalid-feedback" dict "validationDefaults" $.validationDefaults "validationErrors" $.validationErrors "fieldName" "UserCreateRequest.LastName" }} + + {{template "invalid-feedback" dict "fieldName" "UserCreateRequest.LastName" "validationDefaults" $.validationDefaults "validationErrors" $.validationErrors }}
- - {{template "invalid-feedback" dict "validationDefaults" $.validationDefaults "validationErrors" $.validationErrors "fieldName" "UserCreateRequest.Email" }} + + {{template "invalid-feedback" dict "fieldName" "UserCreateRequest.Email" "validationDefaults" $.validationDefaults "validationErrors" $.validationErrors }}
- {{ range $idx, $t := .timezones }} {{ end }} - {{template "invalid-feedback" dict "validationDefaults" $.validationDefaults "validationErrors" $.validationErrors "fieldName" "UserCreateRequest.Timezone" }} + {{template "invalid-feedback" dict "fieldName" "UserCreateRequest.Timezone" "validationDefaults" $.validationDefaults "validationErrors" $.validationErrors }}
- - Generate random password - {{template "invalid-feedback" dict "validationDefaults" $.validationDefaults "validationErrors" $.validationErrors "fieldName" "UserCreateRequest.Password" }} + + + + Generate random password + + {{template "invalid-feedback" dict "fieldName" "UserCreateRequest.Password" "validationDefaults" $.validationDefaults "validationErrors" $.validationErrors }}
- - {{template "invalid-feedback" dict "validationDefaults" $.validationDefaults "validationErrors" $.validationErrors "fieldName" "UserCreateRequest.PasswordConfirm" }} + + {{template "invalid-feedback" dict "fieldName" "UserCreateRequest.PasswordConfirm" "validationDefaults" $.validationDefaults "validationErrors" $.validationErrors }}
-
- Select at least one role. - - {{ range $r := .roles.Options }} - {{ $selectRole := false }} - {{ range $fr := $.form.Roles }} - {{ if eq $r.Value $fr }}{{ $selectRole = true }}{{ end }} - {{ end }} - + {{ range $r := .roles }}
- -
{{ end }} - - - - - {{template "invalid-feedback" dict "validationDefaults" $.validationDefaults "validationErrors" $.validationErrors "fieldName" "Roles" }} + {{template "invalid-feedback" dict "fieldName" "Roles" "validationDefaults" $.validationDefaults "validationErrors" $.validationErrors }}
-
-
- + Cancel
diff --git a/cmd/web-app/templates/content/users-update.gohtml b/cmd/web-app/templates/content/users-update.gohtml index b58ab35..e9d8faf 100644 --- a/cmd/web-app/templates/content/users-update.gohtml +++ b/cmd/web-app/templates/content/users-update.gohtml @@ -8,93 +8,99 @@

Update User

- -
-
-

User Details

-
-
-
- - - {{template "invalid-feedback" dict "validationDefaults" $.userValidationDefaults "validationErrors" $.validationErrors "fieldName" "FirstName" }} +
+
+ + + {{template "invalid-feedback" dict "fieldName" "UserUpdateRequest.FirstName" "validationDefaults" $.userValidationDefaults "validationErrors" $.validationErrors }} +
+
+ + + {{template "invalid-feedback" dict "fieldName" "UserUpdateRequest.LastName" "validationDefaults" $.userValidationDefaults "validationErrors" $.validationErrors }} +
+
+ + + {{template "invalid-feedback" dict "fieldName" "UserUpdateRequest.Email" "validationDefaults" $.userValidationDefaults "validationErrors" $.validationErrors }} +
+
+ + - Select at least one role. + {{ range $r := .roles }} +
+ + +
+ {{ end }} + {{template "invalid-feedback" dict "fieldName" "Roles" "validationDefaults" $.validationDefaults "validationErrors" $.validationErrors }} +
+
+ + + {{template "invalid-feedback" dict "fieldName" "UserUpdateRequest.Timezone" "validationDefaults" $.validationDefaults "validationErrors" $.validationErrors }} +
+
-
- - - {{template "invalid-feedback" dict "validationDefaults" $.userValidationDefaults "validationErrors" $.validationErrors "fieldName" "LastName" }} -
-
- - - {{template "invalid-feedback" dict "validationDefaults" $.userValidationDefaults "validationErrors" $.validationErrors "fieldName" "Email" }} -
-
- - - {{template "invalid-feedback" dict "validationDefaults" $.validationDefaults "validationErrors" $.validationErrors "fieldName" "Timezone" }} -
- -
-
-
- + Cancel
- -
-
- -
-
-

Change Password

Optional. You can change the users' password by specifying a new one below. Otherwise leave the fields empty.

-
-
-
-
- - - Generate random password - {{template "invalid-feedback" dict "validationDefaults" $.passwordValidationDefaults "validationErrors" $.validationErrors "fieldName" "Password" }} +
+
+ + + Generate random password + {{template "invalid-feedback" dict "validationDefaults" $.passwordValidationDefaults "validationErrors" $.validationErrors "fieldName" "Password" }} +
+
+ + + {{template "invalid-feedback" dict "validationDefaults" $.passwordValidationDefaults "validationErrors" $.validationErrors "fieldName" "PasswordConfirm" }} +
+
-
- - - {{template "invalid-feedback" dict "validationDefaults" $.passwordValidationDefaults "validationErrors" $.validationErrors "fieldName" "PasswordConfirm" }} -
-
-
@@ -102,7 +108,6 @@
- {{end}} {{define "js"}} diff --git a/cmd/web-app/templates/content/users-view.gohtml b/cmd/web-app/templates/content/users-view.gohtml index c01854a..b04a392 100644 --- a/cmd/web-app/templates/content/users-view.gohtml +++ b/cmd/web-app/templates/content/users-view.gohtml @@ -63,13 +63,13 @@ Role
{{ if .userAccount }} - {{ range $r := .userAccount.Roles }} - {{ if eq $r "admin" }} - {{ $r }} + {{ range $r := .userAccount.Roles }}{{ if $r.Selected }} + {{ if eq $r.Value "admin" }} + {{ $r.Title }} {{else}} - {{ $r }} + {{ $r.Title }} {{end}} - {{ end }} + {{ end }}{{ end }} {{ end }}

diff --git a/internal/account/account.go b/internal/account/account.go index 9e21ac0..5a8dc7a 100644 --- a/internal/account/account.go +++ b/internal/account/account.go @@ -153,8 +153,8 @@ func selectQuery() *sqlbuilder.SelectBuilder { func Find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req AccountFindRequest) (Accounts, error) { query := selectQuery() - if req.Where != nil { - query.Where(query.And(*req.Where)) + if req.Where != "" { + query.Where(query.And(req.Where)) } if len(req.Order) > 0 { query.OrderBy(req.Order...) diff --git a/internal/account/account_preference/account_preference.go b/internal/account/account_preference/account_preference.go index 7dd3c46..d995f75 100644 --- a/internal/account/account_preference/account_preference.go +++ b/internal/account/account_preference/account_preference.go @@ -65,8 +65,8 @@ func applyClaimsSelect(ctx context.Context, claims auth.Claims, query *sqlbuilde // TODO: Need to figure out why can't parse the args when appending the where to the query. func Find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req AccountPreferenceFindRequest) ([]*AccountPreference, error) { query := sqlbuilder.NewSelectBuilder() - if req.Where != nil { - query.Where(query.And(*req.Where)) + if req.Where != "" { + query.Where(query.And(req.Where)) } if len(req.Order) > 0 { query.OrderBy(req.Order...) diff --git a/internal/account/account_preference/account_preference_test.go b/internal/account/account_preference/account_preference_test.go index 88df781..e8886f9 100644 --- a/internal/account/account_preference/account_preference_test.go +++ b/internal/account/account_preference/account_preference_test.go @@ -397,7 +397,7 @@ func TestFind(t *testing.T) { // Test sort accounts. prefTests = append(prefTests, accountTest{"Find all order by created_at asc", AccountPreferenceFindRequest{ - Where: &createdFilter, + Where: createdFilter, Args: []interface{}{startTime, endTime}, Order: []string{"created_at"}, }, @@ -412,7 +412,7 @@ func TestFind(t *testing.T) { } prefTests = append(prefTests, accountTest{"Find all order by created_at desc", AccountPreferenceFindRequest{ - Where: &createdFilter, + Where: createdFilter, Args: []interface{}{startTime, endTime}, Order: []string{"created_at desc"}, }, @@ -424,7 +424,7 @@ func TestFind(t *testing.T) { var limit uint = 2 prefTests = append(prefTests, accountTest{"Find limit", AccountPreferenceFindRequest{ - Where: &createdFilter, + Where: createdFilter, Args: []interface{}{startTime, endTime}, Order: []string{"created_at"}, Limit: &limit, @@ -437,7 +437,7 @@ func TestFind(t *testing.T) { var offset uint = 1 prefTests = append(prefTests, accountTest{"Find limit, offset", AccountPreferenceFindRequest{ - Where: &createdFilter, + Where: createdFilter, Args: []interface{}{startTime, endTime}, Order: []string{"created_at"}, Limit: &limit, @@ -462,10 +462,9 @@ func TestFind(t *testing.T) { expected = append(expected, &u) } - where := createdFilter + " AND (" + strings.Join(whereParts, " OR ") + ")" prefTests = append(prefTests, accountTest{"Find where", AccountPreferenceFindRequest{ - Where: &where, + Where: createdFilter + " AND (" + strings.Join(whereParts, " OR ") + ")", Args: whereArgs, Order: []string{"created_at"}, }, diff --git a/internal/account/account_preference/models.go b/internal/account/account_preference/models.go index 8df3c80..869da27 100644 --- a/internal/account/account_preference/models.go +++ b/internal/account/account_preference/models.go @@ -84,7 +84,7 @@ type AccountPreferenceDeleteRequest struct { // AccountPreferenceFindRequest defines the possible options to search for accounts. By default // archived accounts will be excluded from response. type AccountPreferenceFindRequest struct { - Where *string `json:"where" example:"name = ?"` + Where string `json:"where" example:"name = ?"` Args []interface{} `json:"args" swaggertype:"array,string" example:"Company Name,active"` Order []string `json:"order" example:"created_at desc"` Limit *uint `json:"limit" example:"10"` diff --git a/internal/account/account_test.go b/internal/account/account_test.go index 457287d..628fc54 100644 --- a/internal/account/account_test.go +++ b/internal/account/account_test.go @@ -842,7 +842,7 @@ func TestFind(t *testing.T) { type accountTest struct { name string req AccountFindRequest - expected []*Account + expected Accounts error error } @@ -853,7 +853,7 @@ func TestFind(t *testing.T) { // Test sort accounts. accountTests = append(accountTests, accountTest{"Find all order by created_at asc", AccountFindRequest{ - Where: &createdFilter, + Where: createdFilter, Args: []interface{}{startTime, endTime}, Order: []string{"created_at"}, }, @@ -868,7 +868,7 @@ func TestFind(t *testing.T) { } accountTests = append(accountTests, accountTest{"Find all order by created_at desc", AccountFindRequest{ - Where: &createdFilter, + Where: createdFilter, Args: []interface{}{startTime, endTime}, Order: []string{"created_at desc"}, }, @@ -880,7 +880,7 @@ func TestFind(t *testing.T) { var limit uint = 2 accountTests = append(accountTests, accountTest{"Find limit", AccountFindRequest{ - Where: &createdFilter, + Where: createdFilter, Args: []interface{}{startTime, endTime}, Order: []string{"created_at"}, Limit: &limit, @@ -893,7 +893,7 @@ func TestFind(t *testing.T) { var offset uint = 3 accountTests = append(accountTests, accountTest{"Find limit, offset", AccountFindRequest{ - Where: &createdFilter, + Where: createdFilter, Args: []interface{}{startTime, endTime}, Order: []string{"created_at"}, Limit: &limit, @@ -918,10 +918,9 @@ func TestFind(t *testing.T) { expected = append(expected, &u) } - where := createdFilter + " AND (" + strings.Join(whereParts, " OR ") + ")" accountTests = append(accountTests, accountTest{"Find where", AccountFindRequest{ - Where: &where, + Where: createdFilter + " AND (" + strings.Join(whereParts, " OR ") + ")", Args: whereArgs, Order: []string{"created_at"}, }, diff --git a/internal/account/models.go b/internal/account/models.go index aa1374b..0907b36 100644 --- a/internal/account/models.go +++ b/internal/account/models.go @@ -170,7 +170,7 @@ type AccountDeleteRequest struct { // AccountFindRequest defines the possible options to search for accounts. By default // archived accounts will be excluded from response. type AccountFindRequest struct { - Where *string `json:"where" example:"name = ? and status = ?"` + Where string `json:"where" example:"name = ? and status = ?"` Args []interface{} `json:"args" swaggertype:"array,string" example:"Company Name,active"` Order []string `json:"order" example:"created_at desc"` Limit *uint `json:"limit" example:"10"` diff --git a/internal/platform/web/request.go b/internal/platform/web/request.go index 05b31e6..35ce5ad 100644 --- a/internal/platform/web/request.go +++ b/internal/platform/web/request.go @@ -51,7 +51,10 @@ func Decode(ctx context.Context, r *http.Request, val interface{}) error { } } - if err := webcontext.Validator().Struct(val); err != nil { + // Hack since we have no DB connection. + ctx = context.WithValue(ctx, webcontext.KeyTagUnique, true) + + if err := webcontext.Validator().StructCtx(ctx, val); err != nil { verr, _ := weberror.NewValidationError(ctx, err) return verr } diff --git a/internal/project/models.go b/internal/project/models.go index a84a974..ba52589 100644 --- a/internal/project/models.go +++ b/internal/project/models.go @@ -108,7 +108,7 @@ type ProjectDeleteRequest struct { // ProjectFindRequest defines the possible options to search for projects. By default // archived project will be excluded from response. type ProjectFindRequest struct { - Where *string `json:"where" example:"name = ? and status = ?"` + Where string `json:"where" example:"name = ? and status = ?"` Args []interface{} `json:"args" swaggertype:"array,string" example:"Moon Launch,active"` Order []string `json:"order" example:"created_at desc"` Limit *uint `json:"limit" example:"10"` diff --git a/internal/project/project.go b/internal/project/project.go index 7e1c53d..990fdd8 100644 --- a/internal/project/project.go +++ b/internal/project/project.go @@ -104,8 +104,8 @@ func selectQuery() *sqlbuilder.SelectBuilder { func findRequestQuery(req ProjectFindRequest) (*sqlbuilder.SelectBuilder, []interface{}) { query := selectQuery() - if req.Where != nil { - query.Where(query.And(*req.Where)) + if req.Where != "" { + query.Where(query.And(req.Where)) } if len(req.Order) > 0 { diff --git a/internal/project/project_test.go b/internal/project/project_test.go index 29bb9b0..f097c9e 100644 --- a/internal/project/project_test.go +++ b/internal/project/project_test.go @@ -24,14 +24,14 @@ func testMain(m *testing.M) int { // TestFindRequestQuery validates findRequestQuery func TestFindRequestQuery(t *testing.T) { - where := "field1 = ? or field2 = ?" + var ( limit uint = 12 offset uint = 34 ) req := ProjectFindRequest{ - Where: &where, + Where: "field1 = ? or field2 = ?", Args: []interface{}{ "lee brown", "103 East Main St.", diff --git a/internal/user/models.go b/internal/user/models.go index fe64b74..2aca963 100644 --- a/internal/user/models.go +++ b/internal/user/models.go @@ -5,6 +5,11 @@ import ( "database/sql" "encoding/json" "geeks-accelerator/oss/saas-starter-kit/internal/platform/web" + "geeks-accelerator/oss/saas-starter-kit/internal/platform/web/webcontext" + "github.com/pkg/errors" + "github.com/sudo-suhas/symcrypto" + "strconv" + "strings" "time" "github.com/lib/pq" @@ -141,7 +146,8 @@ type UserUpdatePasswordRequest struct { // UserArchiveRequest defines the information needed to archive an user. This will archive (soft-delete) the // existing database entry. type UserArchiveRequest struct { - ID string `json:"id" validate:"required,uuid" example:"d69bdef7-173f-4d29-b52c-3edc60baf6a2"` + ID string `json:"id" validate:"required,uuid" example:"d69bdef7-173f-4d29-b52c-3edc60baf6a2"` + force bool } // UserRestoreRequest defines the information needed to restore an user. @@ -151,13 +157,14 @@ type UserRestoreRequest struct { // UserDeleteRequest defines the information needed to delete a user. type UserDeleteRequest struct { - ID string `json:"id" validate:"required,uuid" example:"d69bdef7-173f-4d29-b52c-3edc60baf6a2"` + ID string `json:"id" validate:"required,uuid" example:"d69bdef7-173f-4d29-b52c-3edc60baf6a2"` + force bool } // UserFindRequest defines the possible options to search for users. By default // archived users will be excluded from response. type UserFindRequest struct { - Where *string `json:"where" example:"name = ? and email = ?"` + Where string `json:"where" example:"name = ? and email = ?"` Args []interface{} `json:"args" swaggertype:"array,string" example:"Company Name,gabi.may@geeksinthewoods.com"` Order []string `json:"order" example:"created_at desc"` Limit *uint `json:"limit" example:"10"` @@ -185,3 +192,63 @@ type UserResetConfirmRequest struct { Password string `json:"password" validate:"required" example:"SecretString"` PasswordConfirm string `json:"password_confirm" validate:"required,eqfield=Password" example:"SecretString"` } + +// NewResetHash generates a new encrypt reset hash that is web safe for use in URLs. +func NewResetHash(ctx context.Context, secretKey, resetId, requestIp string, ttl time.Duration, now time.Time) (string, error) { + + // Generate a string that embeds additional information. + hashPts := []string{ + resetId, + strconv.Itoa(int(now.UTC().Unix())), + strconv.Itoa(int(now.UTC().Add(ttl).Unix())), + requestIp, + } + hashStr := strings.Join(hashPts, "|") + + // This returns the nonce appended with the encrypted string. + crypto, err := symcrypto.New(secretKey) + if err != nil { + return "", errors.WithStack(err) + } + encrypted, err := crypto.Encrypt(hashStr) + if err != nil { + return "", errors.WithStack(err) + } + + return encrypted, nil +} + +// ParseResetHash extracts the details encrypted in the hash string. +func ParseResetHash(ctx context.Context, secretKey string, str string, now time.Time) (*ResetHash, error) { + + crypto, err := symcrypto.New(secretKey) + if err != nil { + return nil, errors.WithStack(err) + } + hashStr, err := crypto.Decrypt(str) + if err != nil { + return nil, errors.WithStack(err) + } + hashPts := strings.Split(hashStr, "|") + + var hash ResetHash + if len(hashPts) == 4 { + hash.ResetID = hashPts[0] + hash.CreatedAt, _ = strconv.Atoi(hashPts[1]) + hash.ExpiresAt, _ = strconv.Atoi(hashPts[2]) + hash.RequestIP = hashPts[3] + } + + // Validate the hash. + err = webcontext.Validator().StructCtx(ctx, hash) + if err != nil { + return nil, err + } + + if int64(hash.ExpiresAt) < now.UTC().Unix() { + err = errors.WithMessage(ErrResetExpired, "Password reset has expired.") + return nil, err + } + + return &hash, nil +} diff --git a/internal/user/user.go b/internal/user/user.go index 96fc430..6354c19 100644 --- a/internal/user/user.go +++ b/internal/user/user.go @@ -3,9 +3,6 @@ package user import ( "context" "database/sql" - "github.com/sudo-suhas/symcrypto" - "strconv" - "strings" "time" "geeks-accelerator/oss/saas-starter-kit/internal/platform/auth" @@ -185,8 +182,8 @@ func selectQuery() *sqlbuilder.SelectBuilder { // to the query. func findRequestQuery(req UserFindRequest) (*sqlbuilder.SelectBuilder, []interface{}) { query := selectQuery() - if req.Where != nil { - query.Where(query.And(*req.Where)) + if req.Where != "" { + query.Where(query.And(req.Where)) } if len(req.Order) > 0 { query.OrderBy(req.Order...) @@ -490,8 +487,6 @@ func Update(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserUp span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.Update") defer span.Finish() - v := webcontext.Validator() - // Validation email address is unique in the database. if req.Email != nil { // Validation email address is unique in the database. @@ -505,6 +500,7 @@ func Update(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserUp } // Validate the request. + v := webcontext.Validator() err := v.StructCtx(ctx, req) if err != nil { return err @@ -647,7 +643,7 @@ func Archive(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserA err = CanModifyUser(ctx, claims, dbConn, req.ID) if err != nil { return err - } else if claims.Subject != "" && claims.Subject == req.ID { + } else if claims.Subject != "" && claims.Subject == req.ID && !req.force { return errors.WithStack(ErrForbidden) } @@ -772,7 +768,7 @@ func Delete(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserDe err = CanModifyUser(ctx, claims, dbConn, req.ID) if err != nil { return err - } else if claims.Subject != "" && claims.Subject == req.ID { + } else if claims.Subject != "" && claims.Subject == req.ID && !req.force { return errors.WithStack(ErrForbidden) } @@ -899,23 +895,9 @@ func ResetPassword(ctx context.Context, dbConn *sqlx.DB, resetUrl func(string) s requestIp = vals.RequestIP } - // Generate a string that embeds additional information. - hashPts := []string{ - resetId, - strconv.Itoa(int(now.UTC().Unix())), - strconv.Itoa(int(now.UTC().Add(req.TTL).Unix())), - requestIp, - } - hashStr := strings.Join(hashPts, "|") - - // This returns the nonce appended with the encrypted string for "hello world". - crypto, err := symcrypto.New(secretKey) + encrypted, err := NewResetHash(ctx, secretKey, resetId, requestIp, req.TTL, now) if err != nil { - return "", errors.WithStack(err) - } - encrypted, err := crypto.Encrypt(hashStr) - if err != nil { - return "", errors.WithStack(err) + return "", err } data := map[string]interface{}{ @@ -946,32 +928,8 @@ func ResetConfirm(ctx context.Context, dbConn *sqlx.DB, req UserResetConfirmRequ return nil, err } - crypto, err := symcrypto.New(secretKey) + hash, err := ParseResetHash(ctx, secretKey, req.ResetHash, now) if err != nil { - return nil, errors.WithStack(err) - } - hashStr, err := crypto.Decrypt(req.ResetHash) - if err != nil { - return nil, errors.WithStack(err) - } - hashPts := strings.Split(hashStr, "|") - - var hash ResetHash - if len(hashPts) == 4 { - hash.ResetID = hashPts[0] - hash.CreatedAt, _ = strconv.Atoi(hashPts[1]) - hash.ExpiresAt, _ = strconv.Atoi(hashPts[2]) - hash.RequestIP = hashPts[3] - } - - // Validate the hash. - err = v.StructCtx(ctx, hash) - if err != nil { - return nil, err - } - - if int64(hash.ExpiresAt) < now.UTC().Unix() { - err = errors.WithMessage(ErrResetExpired, "Password reset has expired.") return nil, err } diff --git a/internal/user/user_test.go b/internal/user/user_test.go index 37dac93..2dc15fd 100644 --- a/internal/user/user_test.go +++ b/internal/user/user_test.go @@ -33,14 +33,13 @@ func testMain(m *testing.M) int { // TestFindRequestQuery validates findRequestQuery func TestFindRequestQuery(t *testing.T) { - where := "first_name = ? or email = ?" var ( limit uint = 12 offset uint = 34 ) req := UserFindRequest{ - Where: &where, + Where: "first_name = ? or email = ?", Args: []interface{}{ "lee", "lee@geeksinthewoods.com", @@ -195,7 +194,7 @@ func TestCreateValidation(t *testing.T) { FirstName: "Lee", LastName: "Brown", Email: req.Email, - Timezone: "America/Anchorage", + Timezone: nil, // Copy this fields from the result. ID: res.ID, @@ -847,7 +846,7 @@ func TestCrud(t *testing.T) { } // Archive (soft-delete) the user. - err = Archive(ctx, tt.claims(user, accountId), test.MasterDB, UserArchiveRequest{ID: user.ID}, now) + err = Archive(ctx, tt.claims(user, accountId), test.MasterDB, UserArchiveRequest{ID: user.ID, force: true}, now) if err != nil && errors.Cause(err) != tt.updateErr { t.Logf("\t\tGot : %+v", err) t.Logf("\t\tWant: %+v", tt.updateErr) @@ -888,7 +887,7 @@ func TestCrud(t *testing.T) { t.Logf("\t%s\tUnarchive ok.", tests.Success) // Delete (hard-delete) the user. - err = Delete(ctx, tt.claims(user, accountId), test.MasterDB, UserDeleteRequest{ID: user.ID}) + err = Delete(ctx, tt.claims(user, accountId), test.MasterDB, UserDeleteRequest{ID: user.ID, force: true}) if err != nil && errors.Cause(err) != tt.updateErr { t.Logf("\t\tGot : %+v", err) t.Logf("\t\tWant: %+v", tt.updateErr) @@ -936,7 +935,7 @@ func TestFind(t *testing.T) { type userTest struct { name string req UserFindRequest - expected []*User + expected Users error error } @@ -947,7 +946,7 @@ func TestFind(t *testing.T) { // Test sort users. userTests = append(userTests, userTest{"Find all order by created_at asc", UserFindRequest{ - Where: &createdFilter, + Where: createdFilter, Args: []interface{}{startTime, endTime}, Order: []string{"created_at"}, }, @@ -956,13 +955,13 @@ func TestFind(t *testing.T) { }) // Test reverse sorted users. - var expected []*User + var expected Users for i := len(users) - 1; i >= 0; i-- { expected = append(expected, users[i]) } userTests = append(userTests, userTest{"Find all order by created_at desc", UserFindRequest{ - Where: &createdFilter, + Where: createdFilter, Args: []interface{}{startTime, endTime}, Order: []string{"created_at desc"}, }, @@ -974,7 +973,7 @@ func TestFind(t *testing.T) { var limit uint = 2 userTests = append(userTests, userTest{"Find limit", UserFindRequest{ - Where: &createdFilter, + Where: createdFilter, Args: []interface{}{startTime, endTime}, Order: []string{"created_at"}, Limit: &limit, @@ -987,7 +986,7 @@ func TestFind(t *testing.T) { var offset uint = 3 userTests = append(userTests, userTest{"Find limit, offset", UserFindRequest{ - Where: &createdFilter, + Where: createdFilter, Args: []interface{}{startTime, endTime}, Order: []string{"created_at"}, Limit: &limit, @@ -1015,7 +1014,7 @@ func TestFind(t *testing.T) { where := createdFilter + " AND (" + strings.Join(whereParts, " OR ") + ")" userTests = append(userTests, userTest{"Find where", UserFindRequest{ - Where: &where, + Where: where, Args: whereArgs, Order: []string{"created_at"}, }, diff --git a/internal/user_account/invite/invite.go b/internal/user_account/invite/invite.go index e457721..b24c844 100644 --- a/internal/user_account/invite/invite.go +++ b/internal/user_account/invite/invite.go @@ -54,9 +54,9 @@ func SendUserInvites(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, r emailUserIDs := make(map[string]string) { // Find all users without passing in claims to search all users. - where := fmt.Sprintf("email in ('%s')", strings.Join(req.Emails, "','")) users, err := user.Find(ctx, auth.Claims{}, dbConn, user.UserFindRequest{ - Where: &where, + Where: fmt.Sprintf("email in ('%s')", + strings.Join(req.Emails, "','")), }) if err != nil { return nil, err @@ -75,9 +75,10 @@ func SendUserInvites(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, r args = append(args, userID) } - where := fmt.Sprintf("user_id in ('%s') and status = '%s'", strings.Join(args, "','"), user_account.UserAccountStatus_Active.String()) userAccs, err := user_account.Find(ctx, claims, dbConn, user_account.UserAccountFindRequest{ - Where: &where, + Where: fmt.Sprintf("user_id in ('%s') and status = '%s'", + strings.Join(args, "','"), + user_account.UserAccountStatus_Active.String()), }) if err != nil { return nil, err diff --git a/internal/user_account/invite/invite_test.go b/internal/user_account/invite/invite_test.go index fa4e451..ce72c2f 100644 --- a/internal/user_account/invite/invite_test.go +++ b/internal/user_account/invite/invite_test.go @@ -1,6 +1,7 @@ package invite import ( + "geeks-accelerator/oss/saas-starter-kit/internal/platform/web/webcontext" "os" "strings" "testing" @@ -78,7 +79,7 @@ func TestSendUserInvites(t *testing.T) { } claims := auth.Claims{ - AccountIds: []string{a.ID}, + AccountIDs: []string{a.ID}, StandardClaims: jwt.StandardClaims{ Subject: u.ID, Audience: a.ID, @@ -148,11 +149,12 @@ func TestSendUserInvites(t *testing.T) { // Ensure validation is working by trying ResetConfirm with an empty request. { expectedErr := errors.New("Key: 'AcceptInviteRequest.invite_hash' Error:Field validation for 'invite_hash' failed on the 'required' tag\n" + + "Key: 'AcceptInviteRequest.email' Error:Field validation for 'email' failed on the 'required' tag\n" + "Key: 'AcceptInviteRequest.first_name' Error:Field validation for 'first_name' failed on the 'required' tag\n" + "Key: 'AcceptInviteRequest.last_name' Error:Field validation for 'last_name' failed on the 'required' tag\n" + "Key: 'AcceptInviteRequest.password' Error:Field validation for 'password' failed on the 'required' tag\n" + "Key: 'AcceptInviteRequest.password_confirm' Error:Field validation for 'password_confirm' failed on the 'required' tag") - err = AcceptInvite(ctx, test.MasterDB, AcceptInviteRequest{}, secretKey, now) + _, err = AcceptInvite(ctx, test.MasterDB, AcceptInviteRequest{}, secretKey, now) if err == nil { t.Logf("\t\tWant: %+v", expectedErr) t.Fatalf("\t%s\tResetConfirm failed.", tests.Failed) @@ -172,8 +174,9 @@ func TestSendUserInvites(t *testing.T) { // Ensure the TTL is enforced. { newPass := uuid.NewRandom().String() - err = AcceptInvite(ctx, test.MasterDB, AcceptInviteRequest{ + _, err = AcceptInvite(ctx, test.MasterDB, AcceptInviteRequest{ InviteHash: inviteHashes[0], + Email: inviteEmails[0], FirstName: "Foo", LastName: "Bar", Password: newPass, @@ -188,10 +191,15 @@ func TestSendUserInvites(t *testing.T) { } // Assuming we have received the email and clicked the link, we now can ensure accept works. - for _, inviteHash := range inviteHashes { + for idx, inviteHash := range inviteHashes { + type expectRes struct { + UserID string `json:"user_id" validate:"required,uuid"` + } + var res expectRes newPass := uuid.NewRandom().String() - err = AcceptInvite(ctx, test.MasterDB, AcceptInviteRequest{ + res.UserID, err = AcceptInvite(ctx, test.MasterDB, AcceptInviteRequest{ InviteHash: inviteHash, + Email: inviteEmails[idx], FirstName: "Foo", LastName: "Bar", Password: newPass, @@ -201,20 +209,29 @@ func TestSendUserInvites(t *testing.T) { t.Log("\t\tGot :", err) t.Fatalf("\t%s\tInviteAccept failed.", tests.Failed) } + + // Validate the result. + err := webcontext.Validator().StructCtx(ctx, res) + if err != nil { + t.Log("\t\tGot :", err) + t.Fatalf("\t%s\tInviteAccept failed.", tests.Failed) + } + t.Logf("\t%s\tInviteAccept ok.", tests.Success) } // Ensure the reset hash does not work after its used. { newPass := uuid.NewRandom().String() - err = AcceptInvite(ctx, test.MasterDB, AcceptInviteRequest{ + _, err = AcceptInvite(ctx, test.MasterDB, AcceptInviteRequest{ InviteHash: inviteHashes[0], + Email: inviteEmails[0], FirstName: "Foo", LastName: "Bar", Password: newPass, PasswordConfirm: newPass, }, secretKey, now) - if errors.Cause(err) != ErrInviteUserPasswordSet { + if errors.Cause(err) != ErrUserAccountActive { t.Logf("\t\tGot : %+v", errors.Cause(err)) t.Logf("\t\tWant: %+v", ErrInviteUserPasswordSet) t.Fatalf("\t%s\tInviteAccept verify reuse failed.", tests.Failed) diff --git a/internal/user_account/invite/models.go b/internal/user_account/invite/models.go index 2536aa9..759fc57 100644 --- a/internal/user_account/invite/models.go +++ b/internal/user_account/invite/models.go @@ -2,7 +2,6 @@ package invite import ( "context" - "fmt" "strconv" "strings" "time" @@ -79,8 +78,6 @@ func ParseInviteHash(ctx context.Context, secretKey string, str string, now time } hashPts := strings.Split(hashStr, "|") - fmt.Println(hashPts) - var hash InviteHash if len(hashPts) == 5 { hash.UserID = hashPts[0] diff --git a/internal/user_account/models.go b/internal/user_account/models.go index 610c17a..9fe5bfa 100644 --- a/internal/user_account/models.go +++ b/internal/user_account/models.go @@ -32,13 +32,13 @@ type UserAccount struct { // UserAccountResponse defines the one to many relationship of an user to an account that is returned for display. type UserAccountResponse struct { //ID string `json:"id" example:"d69bdef7-173f-4d29-b52c-3edc60baf6a2"` - UserID string `json:"user_id" example:"d69bdef7-173f-4d29-b52c-3edc60baf6a2"` - AccountID string `json:"account_id" example:"c4653bf9-5978-48b7-89c5-95704aebb7e2"` - Roles UserAccountRoles `json:"roles" validate:"required,dive,oneof=admin user" enums:"admin,user" swaggertype:"array,string" example:"admin"` - Status web.EnumResponse `json:"status"` // Status is enum with values [active, invited, disabled]. - CreatedAt web.TimeResponse `json:"created_at"` // CreatedAt contains multiple format options for display. - UpdatedAt web.TimeResponse `json:"updated_at"` // UpdatedAt contains multiple format options for display. - ArchivedAt *web.TimeResponse `json:"archived_at,omitempty"` // ArchivedAt contains multiple format options for display. + UserID string `json:"user_id" example:"d69bdef7-173f-4d29-b52c-3edc60baf6a2"` + AccountID string `json:"account_id" example:"c4653bf9-5978-48b7-89c5-95704aebb7e2"` + Roles web.EnumMultiResponse `json:"roles" validate:"required,dive,oneof=admin user" enums:"admin,user" swaggertype:"array,string" example:"admin"` + Status web.EnumResponse `json:"status"` // Status is enum with values [active, invited, disabled]. + CreatedAt web.TimeResponse `json:"created_at"` // CreatedAt contains multiple format options for display. + UpdatedAt web.TimeResponse `json:"updated_at"` // UpdatedAt contains multiple format options for display. + ArchivedAt *web.TimeResponse `json:"archived_at,omitempty"` // ArchivedAt contains multiple format options for display. } // Response transforms UserAccount and UserAccountResponse that is used for display. @@ -52,12 +52,17 @@ func (m *UserAccount) Response(ctx context.Context) *UserAccountResponse { //ID: m.ID, UserID: m.UserID, AccountID: m.AccountID, - Roles: m.Roles, Status: web.NewEnumResponse(ctx, m.Status, UserAccountStatus_ValuesInterface()...), CreatedAt: web.NewTimeResponse(ctx, m.CreatedAt), UpdatedAt: web.NewTimeResponse(ctx, m.UpdatedAt), } + var selectedRoles []interface{} + for _, r := range m.Roles { + selectedRoles = append(selectedRoles, r.String()) + } + r.Roles = web.NewEnumMultiResponse(ctx, selectedRoles, UserAccountRole_ValuesInterface()...) + if m.ArchivedAt != nil && !m.ArchivedAt.Time.IsZero() { at := web.NewTimeResponse(ctx, m.ArchivedAt.Time) r.ArchivedAt = &at @@ -139,7 +144,7 @@ type UserAccountDeleteRequest struct { // UserAccountFindRequest defines the possible options to search for users accounts. // By default archived user accounts will be excluded from response. type UserAccountFindRequest struct { - Where *string `json:"where" example:"user_id = ? and account_id = ?"` + Where string `json:"where" example:"user_id = ? and account_id = ?"` Args []interface{} `json:"args" swaggertype:"array,string" example:"d69bdef7-173f-4d29-b52c-3edc60baf6a2,c4653bf9-5978-48b7-89c5-95704aebb7e2"` Order []string `json:"order" example:"created_at desc"` Limit *uint `json:"limit" example:"10"` @@ -291,19 +296,19 @@ type User struct { // UserResponse represents someone with access to our system that is returned for display. type UserResponse struct { - ID string `json:"id" example:"d69bdef7-173f-4d29-b52c-3edc60baf6a2"` - Name string `json:"name" example:"Gabi"` - FirstName string `json:"first_name" example:"Gabi"` - LastName string `json:"last_name" example:"May"` - Email string `json:"email" example:"gabi@geeksinthewoods.com"` - Timezone string `json:"timezone" example:"America/Anchorage"` - AccountID string `json:"account_id" example:"c4653bf9-5978-48b7-89c5-95704aebb7e2"` - Roles UserAccountRoles `json:"roles" validate:"required,dive,oneof=admin user" enums:"admin,user" swaggertype:"array,string" example:"admin"` - Status web.EnumResponse `json:"status"` // Status is enum with values [active, invited, disabled]. - CreatedAt web.TimeResponse `json:"created_at"` // CreatedAt contains multiple format options for display. - UpdatedAt web.TimeResponse `json:"updated_at"` // UpdatedAt contains multiple format options for display. - ArchivedAt *web.TimeResponse `json:"archived_at,omitempty"` // ArchivedAt contains multiple format options for display. - Gravatar web.GravatarResponse `json:"gravatar"` + ID string `json:"id" example:"d69bdef7-173f-4d29-b52c-3edc60baf6a2"` + Name string `json:"name" example:"Gabi"` + FirstName string `json:"first_name" example:"Gabi"` + LastName string `json:"last_name" example:"May"` + Email string `json:"email" example:"gabi@geeksinthewoods.com"` + Timezone string `json:"timezone" example:"America/Anchorage"` + AccountID string `json:"account_id" example:"c4653bf9-5978-48b7-89c5-95704aebb7e2"` + Roles web.EnumMultiResponse `json:"roles" validate:"required,dive,oneof=admin user" enums:"admin,user" swaggertype:"array,string" example:"admin"` + Status web.EnumResponse `json:"status"` // Status is enum with values [active, invited, disabled]. + CreatedAt web.TimeResponse `json:"created_at"` // CreatedAt contains multiple format options for display. + UpdatedAt web.TimeResponse `json:"updated_at"` // UpdatedAt contains multiple format options for display. + ArchivedAt *web.TimeResponse `json:"archived_at,omitempty"` // ArchivedAt contains multiple format options for display. + Gravatar web.GravatarResponse `json:"gravatar"` } // Response transforms User and UserResponse that is used for display. @@ -320,13 +325,18 @@ func (m *User) Response(ctx context.Context) *UserResponse { LastName: m.LastName, Email: m.Email, AccountID: m.AccountID, - Roles: m.Roles, Status: web.NewEnumResponse(ctx, m.Status, UserAccountStatus_Values), CreatedAt: web.NewTimeResponse(ctx, m.CreatedAt), UpdatedAt: web.NewTimeResponse(ctx, m.UpdatedAt), Gravatar: web.NewGravatarResponse(ctx, m.Email), } + var selectedRoles []interface{} + for _, r := range m.Roles { + selectedRoles = append(selectedRoles, r.String()) + } + r.Roles = web.NewEnumMultiResponse(ctx, selectedRoles, UserAccountRole_ValuesInterface()...) + if m.Timezone != nil { r.Timezone = *m.Timezone } diff --git a/internal/user_account/user_account.go b/internal/user_account/user_account.go index 01e0b28..98de009 100644 --- a/internal/user_account/user_account.go +++ b/internal/user_account/user_account.go @@ -114,8 +114,8 @@ func selectQuery() *sqlbuilder.SelectBuilder { // to the query. func findRequestQuery(req UserAccountFindRequest) (*sqlbuilder.SelectBuilder, []interface{}) { query := selectQuery() - if req.Where != nil { - query.Where(query.And(*req.Where)) + if req.Where != "" { + query.Where(query.And(req.Where)) } if len(req.Order) > 0 { query.OrderBy(req.Order...) diff --git a/internal/user_account/user_account_test.go b/internal/user_account/user_account_test.go index d6d3c73..7ff38c7 100644 --- a/internal/user_account/user_account_test.go +++ b/internal/user_account/user_account_test.go @@ -32,14 +32,14 @@ func testMain(m *testing.M) int { // TestFindRequestQuery validates findRequestQuery func TestFindRequestQuery(t *testing.T) { - where := "account_id = ? or user_id = ?" + var ( limit uint = 12 offset uint = 34 ) req := UserAccountFindRequest{ - Where: &where, + Where: "account_id = ? or user_id = ?", Args: []interface{}{ "xy7", "qwert", @@ -598,9 +598,8 @@ func TestCrud(t *testing.T) { // Find the account for the user to verify the updates where made. There should only // be one account associated with the user for this test. - ff := "user_id = ? or account_id = ?" findRes, err := Find(tests.Context(), tt.claims(userID, accountID), test.MasterDB, UserAccountFindRequest{ - Where: &ff, + Where: "user_id = ? or account_id = ?", Args: []interface{}{userID, accountID}, Order: []string{"created_at"}, }) @@ -609,7 +608,7 @@ func TestCrud(t *testing.T) { t.Logf("\t\tWant: %+v", tt.findErr) t.Fatalf("\t%s\tVerify update user account failed.", tests.Failed) } else if tt.findErr == nil { - expected := []*UserAccount{ + var expected UserAccounts = []*UserAccount{ &UserAccount{ //ID: ua.ID, UserID: ua.UserID, @@ -651,7 +650,7 @@ func TestCrud(t *testing.T) { t.Fatalf("\t%s\tVerify archive user account failed when including archived.", tests.Failed) } - expected := []*UserAccount{ + var expected UserAccounts = []*UserAccount{ &UserAccount{ //ID: ua.ID, UserID: ua.UserID, @@ -737,7 +736,7 @@ func TestFind(t *testing.T) { type accountTest struct { name string req UserAccountFindRequest - expected []*UserAccount + expected UserAccounts error error } @@ -748,7 +747,7 @@ func TestFind(t *testing.T) { // Test sort users. accountTests = append(accountTests, accountTest{"Find all order by created_at asx", UserAccountFindRequest{ - Where: &createdFilter, + Where: createdFilter, Args: []interface{}{startTime, endTime}, Order: []string{"created_at"}, }, @@ -763,7 +762,7 @@ func TestFind(t *testing.T) { } accountTests = append(accountTests, accountTest{"Find all order by created_at desc", UserAccountFindRequest{ - Where: &createdFilter, + Where: createdFilter, Args: []interface{}{startTime, endTime}, Order: []string{"created_at desc"}, }, @@ -775,7 +774,7 @@ func TestFind(t *testing.T) { var limit uint = 2 accountTests = append(accountTests, accountTest{"Find limit", UserAccountFindRequest{ - Where: &createdFilter, + Where: createdFilter, Args: []interface{}{startTime, endTime}, Order: []string{"created_at"}, Limit: &limit, @@ -788,7 +787,7 @@ func TestFind(t *testing.T) { var offset uint = 3 accountTests = append(accountTests, accountTest{"Find limit, offset", UserAccountFindRequest{ - Where: &createdFilter, + Where: createdFilter, Args: []interface{}{startTime, endTime}, Order: []string{"created_at"}, Limit: &limit, @@ -813,10 +812,10 @@ func TestFind(t *testing.T) { whereArgs = append(whereArgs, ua.AccountID) expected = append(expected, &ua) } - where := createdFilter + " AND (" + strings.Join(whereParts, " OR ") + ")" + accountTests = append(accountTests, accountTest{"Find where", UserAccountFindRequest{ - Where: &where, + Where: createdFilter + " AND (" + strings.Join(whereParts, " OR ") + ")", Args: whereArgs, Order: []string{"created_at"}, }, diff --git a/internal/user_auth/auth.go b/internal/user_auth/auth.go index 65f440d..34bf1ca 100644 --- a/internal/user_auth/auth.go +++ b/internal/user_auth/auth.go @@ -40,11 +40,18 @@ const ( // Authenticate finds a user by their email and verifies their password. On success // it returns a Token that can be used to authenticate access to the application in // the future. -func Authenticate(ctx context.Context, dbConn *sqlx.DB, tknGen TokenGenerator, email, password string, expires time.Duration, now time.Time, scopes ...string) (Token, error) { +func Authenticate(ctx context.Context, dbConn *sqlx.DB, tknGen TokenGenerator, req AuthenticateRequest, expires time.Duration, now time.Time, scopes ...string) (Token, error) { span, ctx := tracer.StartSpanFromContext(ctx, "internal.user_auth.Authenticate") defer span.Finish() - u, err := user.ReadByEmail(ctx, auth.Claims{}, dbConn, email, false) + // Validate the request. + v := webcontext.Validator() + err := v.Struct(req) + if err != nil { + return Token{}, err + } + + u, err := user.ReadByEmail(ctx, auth.Claims{}, dbConn, req.Email, false) if err != nil { if errors.Cause(err) == user.ErrNotFound { err = errors.WithStack(ErrAuthenticationFailure) @@ -55,7 +62,7 @@ func Authenticate(ctx context.Context, dbConn *sqlx.DB, tknGen TokenGenerator, e } // Append the salt from the user record to the supplied password. - saltedPassword := password + u.PasswordSalt + saltedPassword := req.Password + u.PasswordSalt // Compare the provided password with the saved hash. Use the bcrypt comparison // function so it is cryptographically secure. Return authentication error for @@ -66,7 +73,7 @@ func Authenticate(ctx context.Context, dbConn *sqlx.DB, tknGen TokenGenerator, e } // The user is successfully authenticated with the supplied email and password. - return generateToken(ctx, dbConn, tknGen, auth.Claims{}, u.ID, "", expires, now, scopes...) + return generateToken(ctx, dbConn, tknGen, auth.Claims{}, u.ID, req.AccountID, expires, now, scopes...) } // SwitchAccount allows users to switch between multiple accounts, this changes the claim audience. diff --git a/internal/user_auth/auth_test.go b/internal/user_auth/auth_test.go index 13e901b..db837b6 100644 --- a/internal/user_auth/auth_test.go +++ b/internal/user_auth/auth_test.go @@ -48,7 +48,11 @@ func TestAuthenticate(t *testing.T) { now := time.Now().Add(time.Hour * -1) // Try to authenticate an invalid user. - _, err := Authenticate(ctx, test.MasterDB, tknGen, "doesnotexist@gmail.com", "xy7", time.Hour, now) + _, err := Authenticate(ctx, test.MasterDB, tknGen, + AuthenticateRequest{ + Email: "doesnotexist@gmail.com", + Password: "xy7", + }, time.Hour, now) if errors.Cause(err) != ErrAuthenticationFailure { t.Logf("\t\tGot : %+v", err) t.Logf("\t\tWant: %+v", ErrAuthenticationFailure) @@ -88,7 +92,12 @@ func TestAuthenticate(t *testing.T) { now = now.Add(time.Minute * 5) // Try to authenticate valid user with invalid password. - _, err = Authenticate(ctx, test.MasterDB, tknGen, usrAcc.User.Email, "xy7", time.Hour, now) + _, err = Authenticate(ctx, test.MasterDB, tknGen, + AuthenticateRequest{ + Email: usrAcc.User.Email, + Password: "xy7", + }, + time.Hour, now) if errors.Cause(err) != ErrAuthenticationFailure { t.Logf("\t\tGot : %+v", err) t.Logf("\t\tWant: %+v", ErrAuthenticationFailure) @@ -97,7 +106,11 @@ func TestAuthenticate(t *testing.T) { t.Logf("\t%s\tAuthenticate user w/invalid password ok.", tests.Success) // Verify that the user can be authenticated with the created user. - tkn1, err := Authenticate(ctx, test.MasterDB, tknGen, usrAcc.User.Email, usrAcc.User.Password, time.Hour, now) + tkn1, err := Authenticate(ctx, test.MasterDB, tknGen, + AuthenticateRequest{ + Email: usrAcc.User.Email, + Password: usrAcc.User.Password, + }, time.Hour, now) if err != nil { t.Log("\t\tGot :", err) t.Fatalf("\t%s\tAuthenticate user failed.", tests.Failed) @@ -170,7 +183,11 @@ func TestUserUpdatePassword(t *testing.T) { t.Logf("\t%s\tCreate user account ok.", tests.Success) // Verify that the user can be authenticated with the created user. - _, err = Authenticate(ctx, test.MasterDB, tknGen, usrAcc.User.Email, usrAcc.User.Password, time.Hour, now) + _, err = Authenticate(ctx, test.MasterDB, tknGen, + AuthenticateRequest{ + Email: usrAcc.User.Email, + Password: usrAcc.User.Password, + }, time.Hour, now) if err != nil { t.Log("\t\tGot :", err) t.Fatalf("\t%s\tAuthenticate failed.", tests.Failed) @@ -190,7 +207,11 @@ func TestUserUpdatePassword(t *testing.T) { t.Logf("\t%s\tUpdatePassword ok.", tests.Success) // Verify that the user can be authenticated with the updated password. - _, err = Authenticate(ctx, test.MasterDB, tknGen, usrAcc.User.Email, newPass, time.Hour, now) + _, err = Authenticate(ctx, test.MasterDB, tknGen, + AuthenticateRequest{ + Email: usrAcc.User.Email, + Password: newPass, + }, time.Hour, now) if err != nil { t.Log("\t\tGot :", err) t.Fatalf("\t%s\tAuthenticate failed.", tests.Failed) @@ -257,7 +278,11 @@ func TestUserResetPassword(t *testing.T) { t.Logf("\t%s\tResetConfirm ok.", tests.Success) // Verify that the user can be authenticated with the updated password. - _, err = Authenticate(ctx, test.MasterDB, tknGen, usrAcc.User.Email, newPass, time.Hour, now) + _, err = Authenticate(ctx, test.MasterDB, tknGen, + AuthenticateRequest{ + Email: usrAcc.User.Email, + Password: newPass, + }, time.Hour, now) if err != nil { t.Log("\t\tGot :", err) t.Fatalf("\t%s\tAuthenticate failed.", tests.Failed) @@ -456,7 +481,11 @@ func TestSwitchAccount(t *testing.T) { { // Verify that the user can be authenticated with the created user. var claims1 auth.Claims - tkn1, err := Authenticate(ctx, test.MasterDB, tknGen, authTest.root.User.Email, authTest.root.User.Password, time.Hour, now) + tkn1, err := Authenticate(ctx, test.MasterDB, tknGen, + AuthenticateRequest{ + Email: authTest.root.User.Email, + Password: authTest.root.User.Password, + }, time.Hour, now) if err != nil { t.Log("\t\tGot :", err) t.Fatalf("\t%s\tAuthenticate user failed.", tests.Failed) @@ -856,7 +885,11 @@ func TestVirtualLogin(t *testing.T) { { // Verify that the user can be authenticated with the created user. var claims1 auth.Claims - tkn1, err := Authenticate(ctx, test.MasterDB, tknGen, authTest.root.User.Email, authTest.root.User.Password, time.Hour, now) + tkn1, err := Authenticate(ctx, test.MasterDB, tknGen, + AuthenticateRequest{ + Email: authTest.root.User.Email, + Password: authTest.root.User.Password, + }, time.Hour, now) if err != nil { t.Log("\t\tGot :", err) t.Fatalf("\t%s\tAuthenticate user failed.", tests.Failed) diff --git a/internal/user_auth/models.go b/internal/user_auth/models.go index 31bd1ef..9cf61a0 100644 --- a/internal/user_auth/models.go +++ b/internal/user_auth/models.go @@ -8,8 +8,9 @@ import ( // AuthenticateRequest defines what information is required to authenticate a user. type AuthenticateRequest struct { - Email string `json:"email" validate:"required,email" example:"gabi.may@geeksinthewoods.com"` - Password string `json:"password" validate:"required" example:"NeverTellSecret"` + Email string `json:"email" validate:"required,email" example:"gabi.may@geeksinthewoods.com"` + Password string `json:"password" validate:"required" example:"NeverTellSecret"` + AccountID string `json:"account_id" validate:"omitempty,uuid" example:"c4653bf9-5978-48b7-89c5-95704aebb7e2"` } // Token is the payload we deliver to users when they authenticate.