1
0
mirror of https://github.com/raseels-repos/golang-saas-starter-kit.git synced 2025-06-15 00:15:15 +02:00

fix where, auth use request arg

This commit is contained in:
Lee Brown
2019-08-05 17:12:28 -08:00
parent 0471af921c
commit 4c25d50c76
39 changed files with 532 additions and 347 deletions

View File

@ -55,7 +55,7 @@ func (p *Project) Find(ctx context.Context, w http.ResponseWriter, r *http.Reque
if err != nil { if err != nil {
return web.RespondJsonError(ctx, w, weberror.NewError(ctx, err, http.StatusBadRequest)) return web.RespondJsonError(ctx, w, weberror.NewError(ctx, err, http.StatusBadRequest))
} }
req.Where = &where req.Where = where
req.Args = args req.Args = args
} }

View File

@ -42,6 +42,10 @@ func (c *Signup) Signup(ctx context.Context, w http.ResponseWriter, r *http.Requ
// Claims are optional as authentication is not required ATM for this method. // Claims are optional as authentication is not required ATM for this method.
claims, _ := auth.ClaimsFromContext(ctx) claims, _ := auth.ClaimsFromContext(ctx)
// Hack to allow custom validation to be handled by business logic package.
ctx = context.WithValue(ctx, signup.KeyTagUniqueEmail, true)
ctx = context.WithValue(ctx, signup.KeyTagUniqueName, true)
var req signup.SignupRequest var req signup.SignupRequest
if err := web.Decode(ctx, r, &req); err != nil { if err := web.Decode(ctx, r, &req); err != nil {
if _, ok := errors.Cause(err).(*weberror.Error); !ok { if _, ok := errors.Cause(err).(*weberror.Error); !ok {

View File

@ -60,7 +60,7 @@ func (u *User) Find(ctx context.Context, w http.ResponseWriter, r *http.Request,
if err != nil { if err != nil {
return web.RespondJsonError(ctx, w, weberror.NewError(ctx, err, http.StatusBadRequest)) return web.RespondJsonError(ctx, w, weberror.NewError(ctx, err, http.StatusBadRequest))
} }
req.Where = &where req.Where = where
req.Args = args req.Args = args
} }
@ -442,7 +442,9 @@ func (u *User) SwitchAccount(ctx context.Context, w http.ResponseWriter, r *http
return err return err
} }
tkn, err := user_auth.SwitchAccount(ctx, u.MasterDB, u.TokenGenerator, claims, params["account_id"], sessionTtl, v.Now) tkn, err := user_auth.SwitchAccount(ctx, u.MasterDB, u.TokenGenerator, claims, user_auth.SwitchAccountRequest{
AccountID: params["account_id"],
}, sessionTtl, v.Now)
if err != nil { if err != nil {
cause := errors.Cause(err) cause := errors.Cause(err)
switch cause { switch cause {
@ -486,10 +488,16 @@ func (u *User) Token(ctx context.Context, w http.ResponseWriter, r *http.Request
return web.RespondJsonError(ctx, w, weberror.NewError(ctx, err, http.StatusUnauthorized)) return web.RespondJsonError(ctx, w, weberror.NewError(ctx, err, http.StatusUnauthorized))
} }
accountID := r.URL.Query().Get("account_id")
// Optional to include scope. // Optional to include scope.
scope := r.URL.Query().Get("scope") scope := r.URL.Query().Get("scope")
tkn, err := user_auth.Authenticate(ctx, u.MasterDB, u.TokenGenerator, email, pass, sessionTtl, v.Now, scope) tkn, err := user_auth.Authenticate(ctx, u.MasterDB, u.TokenGenerator, user_auth.AuthenticateRequest{
Email: email,
Password: pass,
AccountID: accountID,
}, sessionTtl, v.Now, scope)
if err != nil { if err != nil {
cause := errors.Cause(err) cause := errors.Cause(err)
switch cause { switch cause {
@ -505,30 +513,5 @@ func (u *User) Token(ctx context.Context, w http.ResponseWriter, r *http.Request
} }
} }
accountID := r.URL.Query().Get("account_id")
if accountID != "" && accountID != tkn.AccountID {
claims, err := u.TokenGenerator.ParseClaims(tkn.AccessToken)
if err != nil {
return err
}
tkn, err = user_auth.SwitchAccount(ctx, u.MasterDB, u.TokenGenerator, claims, accountID, sessionTtl, v.Now)
if err != nil {
cause := errors.Cause(err)
switch cause {
case user_auth.ErrAuthenticationFailure:
return web.RespondJsonError(ctx, w, weberror.NewError(ctx, err, http.StatusUnauthorized))
default:
_, ok := cause.(validator.ValidationErrors)
if ok {
return web.RespondJsonError(ctx, w, weberror.NewError(ctx, err, http.StatusBadRequest))
}
return errors.Wrap(err, "switch account")
}
}
}
return web.RespondJson(ctx, w, tkn, http.StatusOK) return web.RespondJson(ctx, w, tkn, http.StatusOK)
} }

View File

@ -55,7 +55,7 @@ func (u *UserAccount) Find(ctx context.Context, w http.ResponseWriter, r *http.R
if err != nil { if err != nil {
return web.RespondJsonError(ctx, w, weberror.NewError(ctx, err, http.StatusBadRequest)) return web.RespondJsonError(ctx, w, weberror.NewError(ctx, err, http.StatusBadRequest))
} }
req.Where = &where req.Where = where
req.Args = args req.Args = args
} }

View File

@ -102,9 +102,13 @@ func TestAccountCRUDAdmin(t *testing.T) {
"address1": tr.Account.Address1, "address1": tr.Account.Address1,
"city": tr.Account.City, "city": tr.Account.City,
"status": map[string]interface{}{ "status": map[string]interface{}{
"value": "active", "value": "active",
"title": "Active", "title": "Active",
"options": []map[string]interface{}{{"selected": false, "title": "[Active Pending Disabled]", "value": "[active pending disabled]"}}, "options": []map[string]interface{}{
{"selected": true, "title": "Active", "value": "active"},
{"selected": false, "title": "Pending", "value": "pending"},
{"selected": false, "title": "Disabled", "value": "disabled"},
},
}, },
"signup_user_id": &tr.Account.SignupUserID.String, "signup_user_id": &tr.Account.SignupUserID.String,
} }
@ -322,9 +326,13 @@ func TestAccountCRUDUser(t *testing.T) {
"address1": tr.Account.Address1, "address1": tr.Account.Address1,
"city": tr.Account.City, "city": tr.Account.City,
"status": map[string]interface{}{ "status": map[string]interface{}{
"value": "active", "value": "active",
"title": "Active", "title": "Active",
"options": []map[string]interface{}{{"selected": false, "title": "[Active Pending Disabled]", "value": "[active pending disabled]"}}, "options": []map[string]interface{}{
{"selected": true, "title": "Active", "value": "active"},
{"selected": false, "title": "Pending", "value": "pending"},
{"selected": false, "title": "Disabled", "value": "disabled"},
},
}, },
"signup_user_id": &tr.Account.SignupUserID.String, "signup_user_id": &tr.Account.SignupUserID.String,
} }

View File

@ -79,7 +79,7 @@ func TestProjectCRUDAdmin(t *testing.T) {
"updated_at": web.NewTimeResponse(ctx, actual.UpdatedAt.Value), "updated_at": web.NewTimeResponse(ctx, actual.UpdatedAt.Value),
"id": actual.ID, "id": actual.ID,
"account_id": req.AccountID, "account_id": req.AccountID,
"status": web.NewEnumResponse(ctx, "active", project.ProjectStatus_Values), "status": web.NewEnumResponse(ctx, "active", project.ProjectStatus_ValuesInterface()...),
"created_at": web.NewTimeResponse(ctx, actual.CreatedAt.Value), "created_at": web.NewTimeResponse(ctx, actual.CreatedAt.Value),
"name": req.Name, "name": req.Name,
} }

View File

@ -56,7 +56,10 @@ func newMockSignup() mockSignup {
} }
expires := time.Now().UTC().Sub(s.User.CreatedAt) + time.Hour expires := time.Now().UTC().Sub(s.User.CreatedAt) + time.Hour
tkn, err := user_auth.Authenticate(tests.Context(), test.MasterDB, authenticator, req.User.Email, req.User.Password, expires, now) tkn, err := user_auth.Authenticate(tests.Context(), test.MasterDB, authenticator, user_auth.AuthenticateRequest{
Email: req.User.Email,
Password: req.User.Password,
}, expires, now)
if err != nil { if err != nil {
panic(err) panic(err)
} }
@ -139,9 +142,13 @@ func TestSignup(t *testing.T) {
"address1": req.Account.Address1, "address1": req.Account.Address1,
"city": req.Account.City, "city": req.Account.City,
"status": map[string]interface{}{ "status": map[string]interface{}{
"value": "active", "value": "active",
"title": "Active", "title": "Active",
"options": []map[string]interface{}{{"selected": false, "title": "[Active Pending Disabled]", "value": "[active pending disabled]"}}, "options": []map[string]interface{}{
{"selected": true, "title": "Active", "value": "active"},
{"selected": false, "title": "Pending", "value": "pending"},
{"selected": false, "title": "Disabled", "value": "disabled"},
},
}, },
"signup_user_id": &actual.Account.SignupUserID, "signup_user_id": &actual.Account.SignupUserID,
}, },

View File

@ -95,7 +95,10 @@ func testMain(m *testing.M) int {
} }
expires := time.Now().UTC().Sub(signup1.User.CreatedAt) + time.Hour expires := time.Now().UTC().Sub(signup1.User.CreatedAt) + time.Hour
adminTkn, err := user_auth.Authenticate(tests.Context(), test.MasterDB, authenticator, signupReq1.User.Email, signupReq1.User.Password, expires, now) adminTkn, err := user_auth.Authenticate(tests.Context(), test.MasterDB, authenticator, user_auth.AuthenticateRequest{
Email: signupReq1.User.Email,
Password: signupReq1.User.Password,
}, expires, now)
if err != nil { if err != nil {
panic(err) panic(err)
} }
@ -146,7 +149,10 @@ func testMain(m *testing.M) int {
panic(err) panic(err)
} }
userTkn, err := user_auth.Authenticate(tests.Context(), test.MasterDB, authenticator, usr.Email, userReq.Password, expires, now) userTkn, err := user_auth.Authenticate(tests.Context(), test.MasterDB, authenticator, user_auth.AuthenticateRequest{
Email: usr.Email,
Password: userReq.Password,
}, expires, now)
if err != nil { if err != nil {
panic(err) panic(err)
} }

View File

@ -89,13 +89,18 @@ func TestUserAccountCRUDAdmin(t *testing.T) {
} }
created = actual created = actual
var roles []interface{}
for _, r := range req.Roles {
roles = append(roles, r)
}
expectedMap := map[string]interface{}{ expectedMap := map[string]interface{}{
"updated_at": web.NewTimeResponse(ctx, actual.UpdatedAt.Value), "updated_at": web.NewTimeResponse(ctx, actual.UpdatedAt.Value),
//"id": actual.ID, //"id": actual.ID,
"account_id": req.AccountID, "account_id": req.AccountID,
"user_id": req.UserID, "user_id": req.UserID,
"status": web.NewEnumResponse(ctx, "active", user_account.UserAccountStatus_Values), "status": web.NewEnumResponse(ctx, "active", user_account.UserAccountStatus_ValuesInterface()...),
"roles": req.Roles, "roles": web.NewEnumMultiResponse(ctx, roles, user_account.UserAccountRole_ValuesInterface()...),
"created_at": web.NewTimeResponse(ctx, actual.CreatedAt.Value), "created_at": web.NewTimeResponse(ctx, actual.CreatedAt.Value),
} }

View File

@ -1419,7 +1419,7 @@ func TestUserToken(t *testing.T) {
// Test user token with invalid email. // Test user token with invalid email.
{ {
expectedStatus := http.StatusUnauthorized expectedStatus := http.StatusBadRequest
rt := requestTest{ rt := requestTest{
fmt.Sprintf("Token %d using invalid email", expectedStatus), fmt.Sprintf("Token %d using invalid email", expectedStatus),
@ -1434,7 +1434,9 @@ func TestUserToken(t *testing.T) {
t.Logf("\tTest: %s - %s %s", rt.name, rt.method, rt.url) t.Logf("\tTest: %s - %s %s", rt.name, rt.method, rt.url)
r := httptest.NewRequest(rt.method, rt.url, nil) r := httptest.NewRequest(rt.method, rt.url, nil)
r.SetBasicAuth("invalid email.com", "some random password")
invalidEmail := "invalid email.com"
r.SetBasicAuth(invalidEmail, "some random password")
w := httptest.NewRecorder() w := httptest.NewRecorder()
r.Header.Set("Content-Type", web.MIMEApplicationJSONCharsetUTF8) r.Header.Set("Content-Type", web.MIMEApplicationJSONCharsetUTF8)
@ -1456,8 +1458,17 @@ func TestUserToken(t *testing.T) {
expected := weberror.ErrorResponse{ expected := weberror.ErrorResponse{
StatusCode: expectedStatus, StatusCode: expectedStatus,
Error: http.StatusText(expectedStatus), Error: "Field validation error",
Details: user_auth.ErrAuthenticationFailure.Error(), Fields: []weberror.FieldError{
{
Field: "email",
Value: invalidEmail,
Tag: "email",
Error: "email must be a valid email address",
Display: "email must be a valid email address",
},
},
Details: actual.Details,
StackTrace: actual.StackTrace, StackTrace: actual.StackTrace,
} }

View File

@ -110,9 +110,8 @@ func (h *Projects) Index(ctx context.Context, w http.ResponseWriter, r *http.Req
} }
loadFunc := func(ctx context.Context, sorting string, fields []datatable.DisplayField) (resp [][]datatable.ColumnValue, err error) { loadFunc := func(ctx context.Context, sorting string, fields []datatable.DisplayField) (resp [][]datatable.ColumnValue, err error) {
whereFilter := "account_id = ?"
res, err := project.Find(ctx, claims, h.MasterDB, project.ProjectFindRequest{ res, err := project.Find(ctx, claims, h.MasterDB, project.ProjectFindRequest{
Where: &whereFilter, Where: "account_id = ?",
Args: []interface{}{claims.Audience}, Args: []interface{}{claims.Audience},
Order: strings.Split(sorting, ","), Order: strings.Split(sorting, ","),
}) })

View File

@ -209,6 +209,8 @@ func (h *User) ResetPassword(ctx context.Context, w http.ResponseWriter, r *http
// ResetConfirm handles changing a users password after they have clicked on the link emailed. // ResetConfirm handles changing a users password after they have clicked on the link emailed.
func (h *User) ResetConfirm(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { func (h *User) ResetConfirm(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error {
resetHash := params["hash"]
ctxValues, err := webcontext.ContextValues(ctx) ctxValues, err := webcontext.ContextValues(ctx)
if err != nil { if err != nil {
return err return err
@ -217,31 +219,36 @@ func (h *User) ResetConfirm(ctx context.Context, w http.ResponseWriter, r *http.
// //
req := new(user.UserResetConfirmRequest) req := new(user.UserResetConfirmRequest)
data := make(map[string]interface{}) data := make(map[string]interface{})
f := func() error { f := func() (bool, error) {
if r.Method == http.MethodPost { if r.Method == http.MethodPost {
err := r.ParseForm() err := r.ParseForm()
if err != nil { if err != nil {
return err return false, err
} }
decoder := schema.NewDecoder() decoder := schema.NewDecoder()
if err := decoder.Decode(req, r.PostForm); err != nil { if err := decoder.Decode(req, r.PostForm); err != nil {
return err return false, err
} }
// Append the query param value to the request. // Append the query param value to the request.
req.ResetHash = params["hash"] req.ResetHash = resetHash
u, err := user.ResetConfirm(ctx, h.MasterDB, *req, h.SecretKey, ctxValues.Now) u, err := user.ResetConfirm(ctx, h.MasterDB, *req, h.SecretKey, ctxValues.Now)
if err != nil { if err != nil {
switch errors.Cause(err) { switch errors.Cause(err) {
case user.ErrResetExpired:
webcontext.SessionFlashError(ctx,
"Reset Expired",
"The reset has expired.")
return false, nil
default: default:
if verr, ok := weberror.NewValidationError(ctx, err); ok { if verr, ok := weberror.NewValidationError(ctx, err); ok {
data["validationErrors"] = verr.(*weberror.Error) data["validationErrors"] = verr.(*weberror.Error)
return nil return false, nil
} else { } else {
return err return false, err
} }
} }
} }
@ -249,34 +256,51 @@ func (h *User) ResetConfirm(ctx context.Context, w http.ResponseWriter, r *http.
// Authenticated the user. Probably should use the default session TTL from UserLogin. // Authenticated the user. Probably should use the default session TTL from UserLogin.
token, err := user_auth.Authenticate(ctx, h.MasterDB, h.Authenticator, u.Email, req.Password, time.Hour, ctxValues.Now) token, err := user_auth.Authenticate(ctx, h.MasterDB, h.Authenticator, u.Email, req.Password, time.Hour, ctxValues.Now)
if err != nil { if err != nil {
switch errors.Cause(err) { if verr, ok := weberror.NewValidationError(ctx, err); ok {
case account.ErrForbidden: data["validationErrors"] = verr.(*weberror.Error)
return web.RespondError(ctx, w, weberror.NewError(ctx, err, http.StatusForbidden)) return false, nil
default: } else {
if verr, ok := weberror.NewValidationError(ctx, err); ok { return false, err
data["validationErrors"] = verr.(*weberror.Error)
return nil
} else {
return err
}
} }
} }
// Add the token to the users session. // Add the token to the users session.
err = handleSessionToken(ctx, h.MasterDB, w, r, token) err = handleSessionToken(ctx, h.MasterDB, w, r, token)
if err != nil { if err != nil {
return err return false, err
} }
// Redirect the user to the dashboard. // Redirect the user to the dashboard.
http.Redirect(w, r, "/", http.StatusFound) http.Redirect(w, r, "/", http.StatusFound)
return true, nil
} }
return nil _, err = user.ParseResetHash(ctx, h.SecretKey, resetHash, ctxValues.Now)
if err != nil {
switch errors.Cause(err) {
case user.ErrResetExpired:
webcontext.SessionFlashError(ctx,
"Reset Expired",
"The reset has expired.")
return false, nil
default:
if verr, ok := weberror.NewValidationError(ctx, err); ok {
data["validationErrors"] = verr.(*weberror.Error)
return false, nil
} else {
return false, err
}
}
}
return false, nil
} }
if err := f(); err != nil { end, err := f()
if err != nil {
return web.RenderError(ctx, w, r, err, h.Renderer, TmplLayoutBase, TmplContentErrorGeneric, web.MIMETextHTMLCharsetUTF8) return web.RenderError(ctx, w, r, err, h.Renderer, TmplLayoutBase, TmplContentErrorGeneric, web.MIMETextHTMLCharsetUTF8)
} else if end {
return nil
} }
data["form"] = req data["form"] = req
@ -572,9 +596,8 @@ func (h *User) VirtualLogin(ctx context.Context, w http.ResponseWriter, r *http.
return nil return nil
} }
usrAccFilter := "account_id = ?"
usrAccs, err := user_account.Find(ctx, claims, h.MasterDB, user_account.UserAccountFindRequest{ usrAccs, err := user_account.Find(ctx, claims, h.MasterDB, user_account.UserAccountFindRequest{
Where: &usrAccFilter, Where: "account_id = ?",
Args: []interface{}{claims.Audience}, Args: []interface{}{claims.Audience},
}) })
if err != nil { if err != nil {
@ -597,10 +620,10 @@ func (h *User) VirtualLogin(ctx context.Context, w http.ResponseWriter, r *http.
userPhs = append(userPhs, "?") userPhs = append(userPhs, "?")
} }
usrFilter := fmt.Sprintf("id IN (%s)", strings.Join(userPhs, ", "))
users, err := user.Find(ctx, claims, h.MasterDB, user.UserFindRequest{ users, err := user.Find(ctx, claims, h.MasterDB, user.UserFindRequest{
Where: &usrFilter, Where: fmt.Sprintf("id IN (%s)",
Args: userIDs, strings.Join(userPhs, ", ")),
Args: userIDs,
}) })
if err != nil { if err != nil {
return err return err

View File

@ -52,12 +52,18 @@ func urlUsersUpdate(userID string) string {
return fmt.Sprintf("/users/%s/update", userID) return fmt.Sprintf("/users/%s/update", userID)
} }
// UserLoginRequest extends the AuthenicateRequest with the RememberMe flag. // UserCreateRequest extends the UserCreateRequest with a list of roles.
type UserCreateRequest struct { type UserCreateRequest struct {
user.UserCreateRequest user.UserCreateRequest
Roles user_account.UserAccountRoles `json:"roles" validate:"required,dive,oneof=admin user" enums:"admin,user" swaggertype:"array,string" example:"admin"` Roles user_account.UserAccountRoles `json:"roles" validate:"required,dive,oneof=admin user" enums:"admin,user" swaggertype:"array,string" example:"admin"`
} }
// UserUpdateRequest extends the UserUpdateRequest with a list of roles.
type UserUpdateRequest struct {
user.UserUpdateRequest
Roles user_account.UserAccountRoles `json:"roles" validate:"required,dive,oneof=admin user" enums:"admin,user" swaggertype:"array,string" example:"admin"`
}
// Index handles listing all the users for the current account. // Index handles listing all the users for the current account.
func (h *Users) Index(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { func (h *Users) Index(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error {
@ -198,8 +204,6 @@ func (h *Users) Create(ctx context.Context, w http.ResponseWriter, r *http.Reque
} }
decoder := schema.NewDecoder() decoder := schema.NewDecoder()
decoder.IgnoreUnknownKeys(true)
if err := decoder.Decode(req, r.PostForm); err != nil { if err := decoder.Decode(req, r.PostForm); err != nil {
return false, err return false, err
} }
@ -279,11 +283,11 @@ func (h *Users) Create(ctx context.Context, w http.ResponseWriter, r *http.Reque
return err return err
} }
var roleValues []interface{} var selectedRoles []interface{}
for _, v := range user_account.UserAccountRole_Values { for _, r := range req.Roles {
roleValues = append(roleValues, string(v)) selectedRoles = append(selectedRoles, r.String())
} }
data["roles"] = web.NewEnumResponse(ctx, nil, roleValues...) data["roles"] = web.NewEnumMultiResponse(ctx, selectedRoles, user_account.UserAccountRole_ValuesInterface()...)
data["form"] = req data["form"] = req
@ -389,7 +393,7 @@ func (h *Users) Update(ctx context.Context, w http.ResponseWriter, r *http.Reque
} }
// //
req := new(user.UserUpdateRequest) req := new(UserUpdateRequest)
data := make(map[string]interface{}) data := make(map[string]interface{})
f := func() (bool, error) { f := func() (bool, error) {
if r.Method == http.MethodPost { if r.Method == http.MethodPost {
@ -400,13 +404,27 @@ func (h *Users) Update(ctx context.Context, w http.ResponseWriter, r *http.Reque
decoder := schema.NewDecoder() decoder := schema.NewDecoder()
decoder.IgnoreUnknownKeys(true) decoder.IgnoreUnknownKeys(true)
if err := decoder.Decode(req, r.PostForm); err != nil { if err := decoder.Decode(req, r.PostForm); err != nil {
return false, err return false, err
} }
req.ID = userID req.ID = userID
err = user.Update(ctx, claims, h.MasterDB, *req, ctxValues.Now) // Bypass the uniq check on email here for the moment, it will be caught before the user_account is
// created by user.Create.
ctx = context.WithValue(ctx, webcontext.KeyTagUnique, true)
// Validate the request.
err = webcontext.Validator().StructCtx(ctx, req)
if err != nil {
if verr, ok := weberror.NewValidationError(ctx, err); ok {
data["validationErrors"] = verr.(*weberror.Error)
return false, nil
} else {
return false, err
}
}
err = user.Update(ctx, claims, h.MasterDB, req.UserUpdateRequest, ctxValues.Now)
if err != nil { if err != nil {
switch errors.Cause(err) { switch errors.Cause(err) {
default: default:
@ -419,6 +437,25 @@ func (h *Users) Update(ctx context.Context, w http.ResponseWriter, r *http.Reque
} }
} }
if req.Roles != nil {
err = user_account.Update(ctx, claims, h.MasterDB, user_account.UserAccountUpdateRequest{
UserID: userID,
AccountID: claims.Audience,
Roles: &req.Roles,
}, ctxValues.Now)
if err != nil {
switch errors.Cause(err) {
default:
if verr, ok := weberror.NewValidationError(ctx, err); ok {
data["validationErrors"] = verr.(*weberror.Error)
return false, nil
} else {
return false, err
}
}
}
}
if r.PostForm.Get("Password") != "" { if r.PostForm.Get("Password") != "" {
pwdReq := new(user.UserUpdatePasswordRequest) pwdReq := new(user.UserUpdatePasswordRequest)
@ -469,11 +506,20 @@ func (h *Users) Update(ctx context.Context, w http.ResponseWriter, r *http.Reque
return err return err
} }
usrAcc, err := user_account.Read(ctx, claims, h.MasterDB, user_account.UserAccountReadRequest{
UserID: userID,
AccountID: claims.Audience,
})
if err != nil {
return err
}
if req.ID == "" { if req.ID == "" {
req.FirstName = &usr.FirstName req.FirstName = &usr.FirstName
req.LastName = &usr.LastName req.LastName = &usr.LastName
req.Email = &usr.Email req.Email = &usr.Email
req.Timezone = usr.Timezone req.Timezone = usr.Timezone
req.Roles = usrAcc.Roles
} }
data["user"] = usr.Response(ctx) data["user"] = usr.Response(ctx)
@ -483,9 +529,15 @@ func (h *Users) Update(ctx context.Context, w http.ResponseWriter, r *http.Reque
return err return err
} }
var selectedRoles []interface{}
for _, r := range req.Roles {
selectedRoles = append(selectedRoles, r.String())
}
data["roles"] = web.NewEnumMultiResponse(ctx, selectedRoles, user_account.UserAccountRole_ValuesInterface()...)
data["form"] = req data["form"] = req
if verr, ok := weberror.NewValidationError(ctx, webcontext.Validator().Struct(user.UserUpdateRequest{})); ok { if verr, ok := weberror.NewValidationError(ctx, webcontext.Validator().Struct(UserUpdateRequest{})); ok {
data["userValidationDefaults"] = verr.(*weberror.Error) data["userValidationDefaults"] = verr.(*weberror.Error)
} }

View File

@ -28,7 +28,6 @@
{{ template "validation-error" . }} {{ template "validation-error" . }}
<form class="user" method="post" novalidate> <form class="user" method="post" novalidate>
<input type="hidden" name="ResetHash" value="{{ $.form.ResetHash }}" />
<div class="form-group row"> <div class="form-group row">
<div class="col-sm-6 mb-3 mb-sm-0"> <div class="col-sm-6 mb-3 mb-sm-0">
<input type="password" class="form-control form-control-user {{ ValidationFieldClass $.validationErrors "Password" }}" name="Password" value="{{ $.form.Password }}" placeholder="Password" required> <input type="password" class="form-control form-control-user {{ ValidationFieldClass $.validationErrors "Password" }}" name="Password" value="{{ $.form.Password }}" placeholder="Password" required>

View File

@ -9,90 +9,83 @@
</div> </div>
<form class="user" method="post" novalidate> <form class="user" method="post" novalidate>
<div class="card shadow"> <div class="card shadow">
<div class="card-body"> <div class="card-body">
<div class="row"> <div class="row">
<div class="col-md-6"> <div class="col-md-6">
<div class="form-group"> <div class="form-group">
<label for="inputFirstName">First Name</label> <label for="inputFirstName">First Name</label>
<input type="text" class="form-control {{ ValidationFieldClass $.validationErrors "UserCreateRequest.FirstName" }}" placeholder="enter first name" name="FirstName" value="{{ .form.FirstName }}" required> <input type="text"
{{template "invalid-feedback" dict "validationDefaults" $.validationDefaults "validationErrors" $.validationErrors "fieldName" "UserCreateRequest.FirstName" }} class="form-control {{ ValidationFieldClass $.validationErrors "UserCreateRequest.FirstName" }}"
placeholder="enter first name" name="FirstName" value="{{ .form.FirstName }}" required>
{{template "invalid-feedback" dict "fieldName" "UserCreateRequest.FirstName" "validationDefaults" $.validationDefaults "validationErrors" $.validationErrors }}
</div> </div>
<div class="form-group"> <div class="form-group">
<label for="inputLastName">Last Name</label> <label for="inputLastName">Last Name</label>
<input type="text" class="form-control {{ ValidationFieldClass $.validationErrors "UserCreateRequest.LastName" }}" placeholder="enter last name" name="LastName" value="{{ .form.LastName }}" required> <input type="text"
{{template "invalid-feedback" dict "validationDefaults" $.validationDefaults "validationErrors" $.validationErrors "fieldName" "UserCreateRequest.LastName" }} class="form-control {{ ValidationFieldClass $.validationErrors "UserCreateRequest.LastName" }}"
placeholder="enter last name" name="LastName" value="{{ .form.LastName }}" required>
{{template "invalid-feedback" dict "fieldName" "UserCreateRequest.LastName" "validationDefaults" $.validationDefaults "validationErrors" $.validationErrors }}
</div> </div>
<div class="form-group"> <div class="form-group">
<label for="inputEmail">Email</label> <label for="inputEmail">Email</label>
<input type="text" class="form-control {{ ValidationFieldClass $.validationErrors "UserCreateRequest.Email" }}" placeholder="enter email" name="Email" value="{{ .form.Email }}" required> <input type="text" class="form-control {{ ValidationFieldClass $.validationErrors "UserCreateRequest.Email" }}"
{{template "invalid-feedback" dict "validationDefaults" $.validationDefaults "validationErrors" $.validationErrors "fieldName" "UserCreateRequest.Email" }} placeholder="enter email" name="Email" value="{{ .form.Email }}" required>
{{template "invalid-feedback" dict "fieldName" "UserCreateRequest.Email" "validationDefaults" $.validationDefaults "validationErrors" $.validationErrors }}
</div> </div>
<div class="form-group"> <div class="form-group">
<label for="inputTimezone">Timezone</label> <label for="inputTimezone">Timezone</label>
<select class="form-control {{ ValidationFieldClass $.validationErrors "UserCreateRequest.Timezone" }}" id="inputTimezone" name="Timezone"> <select class="form-control {{ ValidationFieldClass $.validationErrors "UserCreateRequest.Timezone" }}"
id="inputTimezone" name="Timezone">
<option value="">Not set</option> <option value="">Not set</option>
{{ range $idx, $t := .timezones }} {{ range $idx, $t := .timezones }}
<option value="{{ $t }}" {{ if CmpString $t $.form.Timezone }}selected="selected"{{ end }}>{{ $t }}</option> <option value="{{ $t }}" {{ if CmpString $t $.form.Timezone }}selected="selected"{{ end }}>{{ $t }}</option>
{{ end }} {{ end }}
</select> </select>
{{template "invalid-feedback" dict "validationDefaults" $.validationDefaults "validationErrors" $.validationErrors "fieldName" "UserCreateRequest.Timezone" }} {{template "invalid-feedback" dict "fieldName" "UserCreateRequest.Timezone" "validationDefaults" $.validationDefaults "validationErrors" $.validationErrors }}
</div> </div>
<div class="form-group"> <div class="form-group">
<label for="inputPassword">Password</label> <label for="inputPassword">Password</label>
<input type="password" class="form-control {{ ValidationFieldClass $.validationErrors "UserCreateRequest.Password" }}" id="inputPassword" placeholder="" name="Password" value="{{ .form.Password }}" required> <input type="password"
<span class="help-block "><small><a a href="javascript:void(0)" id="btnGeneratePassword"><i class="fas fa-random mr-1"></i>Generate random password </a></small></span> class="form-control {{ ValidationFieldClass $.validationErrors "UserCreateRequest.Password" }}"
{{template "invalid-feedback" dict "validationDefaults" $.validationDefaults "validationErrors" $.validationErrors "fieldName" "UserCreateRequest.Password" }} id="inputPassword" placeholder="" name="Password" value="{{ .form.Password }}" required>
<span class="help-block "><small>
<a a href="javascript:void(0)" id="btnGeneratePassword">
<i class="fas fa-random mr-1"></i>Generate random password </a>
</small></span>
{{template "invalid-feedback" dict "fieldName" "UserCreateRequest.Password" "validationDefaults" $.validationDefaults "validationErrors" $.validationErrors }}
</div> </div>
<div class="form-group"> <div class="form-group">
<label for="inputPasswordConfirm">Confirm Password</label> <label for="inputPasswordConfirm">Confirm Password</label>
<input type="password" class="form-control {{ ValidationFieldClass $.validationErrors "UserCreateRequest.PasswordConfirm" }}" id="inputPasswordConfirm" placeholder="" name="PasswordConfirm" value="{{ .form.PasswordConfirm }}" required> <input type="password"
{{template "invalid-feedback" dict "validationDefaults" $.validationDefaults "validationErrors" $.validationErrors "fieldName" "UserCreateRequest.PasswordConfirm" }} class="form-control {{ ValidationFieldClass $.validationErrors "UserCreateRequest.PasswordConfirm" }}"
id="inputPasswordConfirm" placeholder="" name="PasswordConfirm" value="{{ .form.PasswordConfirm }}" required>
{{template "invalid-feedback" dict "fieldName" "UserCreateRequest.PasswordConfirm" "validationDefaults" $.validationDefaults "validationErrors" $.validationErrors }}
</div> </div>
<div class="form-group"> <div class="form-group">
<label for="inputRoles">Roles</label> <label for="inputRoles">Roles</label>
<span class="help-block "><small>- Select at least one role.</small></span> <span class="help-block "><small>- Select at least one role.</small></span>
{{ range $r := .roles }}
{{ range $r := .roles.Options }}
{{ $selectRole := false }}
{{ range $fr := $.form.Roles }}
{{ if eq $r.Value $fr }}{{ $selectRole = true }}{{ end }}
{{ end }}
<div class="form-check"> <div class="form-check">
<input class="form-check-input" type="checkbox" value="{{ $r.Value }}" id="defaultCheck1" {{ if $selectRole }}checked="checked"{{ end }}> <input class="form-check-input {{ ValidationFieldClass $.validationErrors "Roles" }}"
<label class="form-check-label" for="defaultCheck1"> type="checkbox" name="Roles"
value="{{ $r.Value }}" id="inputRole{{ $r.Value }}"
{{ if $r.Selected }}checked="checked"{{ end }}>
<label class="form-check-label" for="inputRole{{ $r.Value }}">
{{ $r.Title }} {{ $r.Title }}
</label> </label>
</div> </div>
{{ end }} {{ end }}
{{template "invalid-feedback" dict "fieldName" "Roles" "validationDefaults" $.validationDefaults "validationErrors" $.validationErrors }}
<select class="form-control {{ ValidationFieldClass $.validationErrors "Roles" }}" id="inputRoles" name="Roles" multiple="multiple">
{{ range $r := .roles.Options }}
{{ $selectRole := false }}
{{ range $fr := $.form.Roles }}
{{ if eq $r.Value $fr }}{{ $selectRole = true }}{{ end }}
{{ end }}
<option value="{{ $r.Value }}" {{ if $selectRole }}selected="selected"{{ end }}>{{ $r.Title }}</option>
{{ end }}
</select>
{{template "invalid-feedback" dict "validationDefaults" $.validationDefaults "validationErrors" $.validationErrors "fieldName" "Roles" }}
</div> </div>
</div> </div>
</div> </div>
</div> </div>
</div> </div>
<div class="row mt-4"> <div class="row mt-4">
<div class="col"> <div class="col">
<input id="btnSubmit" type="submit" name="action" value="Save" class="btn btn-primary"/> <input type="submit" value="Save" class="btn btn-primary"/>
<a href="/users" class="ml-2 btn btn-secondary" >Cancel</a> <a href="/users" class="ml-2 btn btn-secondary" >Cancel</a>
</div> </div>
</div> </div>

View File

@ -8,93 +8,99 @@
<h1 class="h3 mb-0 text-gray-800">Update User</h1> <h1 class="h3 mb-0 text-gray-800">Update User</h1>
</div> </div>
<form class="user" method="post" novalidate> <form class="user" method="post" novalidate>
<div class="card shadow"> <div class="card shadow">
<div class="card-body"> <div class="card-body">
<div class="row mb-2"> <div class="row mb-2">
<div class="col-12"> <div class="col-12">
<h4 class="card-title">User Details</h4> <h4 class="card-title">User Details</h4>
</div> </div>
</div> </div>
<div class="row mb-2"> <div class="row mb-2">
<div class="col-md-6"> <div class="col-md-6">
<div class="form-group"> <div class="form-group">
<label for="inputFirstName">First Name</label> <label for="inputFirstName">First Name</label>
<input type="text" class="form-control {{ ValidationFieldClass $.validationErrors "FirstName" }}" placeholder="enter first name" name="FirstName" value="{{ .form.FirstName }}" required> <input type="text"
{{template "invalid-feedback" dict "validationDefaults" $.userValidationDefaults "validationErrors" $.validationErrors "fieldName" "FirstName" }} class="form-control {{ ValidationFieldClass $.validationErrors "UserUpdateRequest.FirstName" }}"
placeholder="enter first name" name="FirstName" value="{{ .form.FirstName }}" required>
{{template "invalid-feedback" dict "fieldName" "UserUpdateRequest.FirstName" "validationDefaults" $.userValidationDefaults "validationErrors" $.validationErrors }}
</div>
<div class="form-group">
<label for="inputLastName">Last Name</label>
<input type="text"
class="form-control {{ ValidationFieldClass $.validationErrors "UserUpdateRequest.LastName" }}"
placeholder="enter last name" name="LastName" value="{{ .form.LastName }}" required>
{{template "invalid-feedback" dict "fieldName" "UserUpdateRequest.LastName" "validationDefaults" $.userValidationDefaults "validationErrors" $.validationErrors }}
</div>
<div class="form-group">
<label for="inputEmail">Email</label>
<input type="text"
class="form-control {{ ValidationFieldClass $.validationErrors "UserUpdateRequest.Email" }}"
placeholder="enter email" name="Email" value="{{ .form.Email }}" required>
{{template "invalid-feedback" dict "fieldName" "UserUpdateRequest.Email" "validationDefaults" $.userValidationDefaults "validationErrors" $.validationErrors }}
</div>
<div class="form-group">
<label for="inputRoles">Roles</label>
<span class="help-block "><small>- Select at least one role.</small></span>
{{ range $r := .roles }}
<div class="form-check">
<input class="form-check-input {{ ValidationFieldClass $.validationErrors "Roles" }}"
type="checkbox" name="Roles"
value="{{ $r.Value }}" id="inputRole{{ $r.Value }}"
{{ if $r.Selected }}checked="checked"{{ end }}>
<label class="form-check-label" for="inputRole{{ $r.Value }}">
{{ $r.Title }}
</label>
</div>
{{ end }}
{{template "invalid-feedback" dict "fieldName" "Roles" "validationDefaults" $.validationDefaults "validationErrors" $.validationErrors }}
</div>
<div class="form-group">
<label for="inputTimezone">Timezone</label>
<select class="form-control {{ ValidationFieldClass $.validationErrors "UserUpdateRequest.Timezone" }}" name="Timezone">
<option value="">Not set</option>
{{ range $idx, $t := .timezones }}
<option value="{{ $t }}" {{ if CmpString $t $.form.Timezone }}selected="selected"{{ end }}>{{ $t }}</option>
{{ end }}
</select>
{{template "invalid-feedback" dict "fieldName" "UserUpdateRequest.Timezone" "validationDefaults" $.validationDefaults "validationErrors" $.validationErrors }}
</div>
</div>
</div> </div>
<div class="form-group">
<label for="inputLastName">Last Name</label>
<input type="text" class="form-control {{ ValidationFieldClass $.validationErrors "LastName" }}" placeholder="enter last name" name="LastName" value="{{ .form.LastName }}" required>
{{template "invalid-feedback" dict "validationDefaults" $.userValidationDefaults "validationErrors" $.validationErrors "fieldName" "LastName" }}
</div>
<div class="form-group">
<label for="inputEmail">Email</label>
<input type="text" class="form-control {{ ValidationFieldClass $.validationErrors "Email" }}" placeholder="enter email" name="Email" value="{{ .form.Email }}" required>
{{template "invalid-feedback" dict "validationDefaults" $.userValidationDefaults "validationErrors" $.validationErrors "fieldName" "Email" }}
</div>
<div class="form-group">
<label for="inputTimezone">Timezone</label>
<select class="form-control {{ ValidationFieldClass $.validationErrors "Timezone" }}" name="Timezone">
<option value="">Not set</option>
{{ range $idx, $t := .timezones }}
<option value="{{ $t }}" {{ if CmpString $t $.form.Timezone }}selected="selected"{{ end }}>{{ $t }}</option>
{{ end }}
</select>
{{template "invalid-feedback" dict "validationDefaults" $.validationDefaults "validationErrors" $.validationErrors "fieldName" "Timezone" }}
</div>
</div>
</div>
<div class="row"> <div class="row">
<div class="col"> <div class="col">
<input id="btnSubmit" type="submit" name="action" value="Save" class="btn btn-primary"/> <input id="btnSubmit" type="submit" value="Save" class="btn btn-primary"/>
<a href="/users/{{ .user.ID }}" class="ml-2 btn btn-secondary" >Cancel</a> <a href="/users/{{ .user.ID }}" class="ml-2 btn btn-secondary" >Cancel</a>
</div> </div>
</div> </div>
</div> </div>
</div> </div>
</form> </form>
<form class="user" method="post" novalidate> <form class="user" method="post" novalidate>
<div class="card mt-4"> <div class="card mt-4">
<div class="card-body"> <div class="card-body">
<div class="row mb-2"> <div class="row mb-2">
<div class="col-12"> <div class="col-12">
<h4 class="card-title">Change Password</h4> <h4 class="card-title">Change Password</h4>
<p><small><b>Optional</b>. You can change the users' password by specifying a new one below. Otherwise leave the fields empty.</small></p> <p><small><b>Optional</b>. You can change the users' password by specifying a new one below. Otherwise leave the fields empty.</small></p>
</div> </div>
</div> </div>
<div class="row mb-2"> <div class="row mb-2">
<div class="col-md-6"> <div class="col-md-6">
<div class="form-group"> <div class="form-group">
<label for="inputPassword">Password</label> <label for="inputPassword">Password</label>
<input type="password" class="form-control" id="inputPassword" placeholder="" name="Password" value=""> <input type="password" class="form-control" id="inputPassword" placeholder="" name="Password" value="">
<span class="help-block "><small><a a href="javascript:void(0)" id="btnGeneratePassword"><i class="fas fa-random mr-1"></i>Generate random password </a></small></span> <span class="help-block "><small><a a href="javascript:void(0)" id="btnGeneratePassword"><i class="fas fa-random mr-1"></i>Generate random password </a></small></span>
{{template "invalid-feedback" dict "validationDefaults" $.passwordValidationDefaults "validationErrors" $.validationErrors "fieldName" "Password" }} {{template "invalid-feedback" dict "validationDefaults" $.passwordValidationDefaults "validationErrors" $.validationErrors "fieldName" "Password" }}
</div>
<div class="form-group">
<label for="inputPasswordConfirm">Confirm Password</label>
<input type="password" class="form-control" id="inputPasswordConfirm" placeholder="" name="PasswordConfirm" value="">
{{template "invalid-feedback" dict "validationDefaults" $.passwordValidationDefaults "validationErrors" $.validationErrors "fieldName" "PasswordConfirm" }}
</div>
</div>
</div> </div>
<div class="form-group">
<label for="inputPasswordConfirm">Confirm Password</label>
<input type="password" class="form-control" id="inputPasswordConfirm" placeholder="" name="PasswordConfirm" value="">
{{template "invalid-feedback" dict "validationDefaults" $.passwordValidationDefaults "validationErrors" $.validationErrors "fieldName" "PasswordConfirm" }}
</div>
</div>
</div>
<div class="row"> <div class="row">
<div class="col"> <div class="col">
<input id="btnSubmit2" type="submit" name="action" value="Change Password" class="btn btn-primary btn-sm"/> <input id="btnSubmit2" type="submit" name="action" value="Change Password" class="btn btn-primary btn-sm"/>
@ -102,7 +108,6 @@
</div> </div>
</div> </div>
</div> </div>
</form> </form>
{{end}} {{end}}
{{define "js"}} {{define "js"}}

View File

@ -63,13 +63,13 @@
<small>Role</small><br/> <small>Role</small><br/>
{{ if .userAccount }} {{ if .userAccount }}
<b> <b>
{{ range $r := .userAccount.Roles }} {{ range $r := .userAccount.Roles }}{{ if $r.Selected }}
{{ if eq $r "admin" }} {{ if eq $r.Value "admin" }}
<span class="text-pink"><i class="far fa-kiss-wink-heart mr-1"></i>{{ $r }}</span> <span class="text-pink"><i class="far fa-kiss-wink-heart mr-1"></i>{{ $r.Title }}</span>
{{else}} {{else}}
<span class="text-purple"><i class="far fa-user-circle mr-1"></i>{{ $r }}</span> <span class="text-purple"><i class="far fa-user-circle mr-1"></i>{{ $r.Title }}</span>
{{end}} {{end}}
{{ end }} {{ end }}{{ end }}
</b> </b>
{{ end }} {{ end }}
</p> </p>

View File

@ -153,8 +153,8 @@ func selectQuery() *sqlbuilder.SelectBuilder {
func Find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req AccountFindRequest) (Accounts, error) { func Find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req AccountFindRequest) (Accounts, error) {
query := selectQuery() query := selectQuery()
if req.Where != nil { if req.Where != "" {
query.Where(query.And(*req.Where)) query.Where(query.And(req.Where))
} }
if len(req.Order) > 0 { if len(req.Order) > 0 {
query.OrderBy(req.Order...) query.OrderBy(req.Order...)

View File

@ -65,8 +65,8 @@ func applyClaimsSelect(ctx context.Context, claims auth.Claims, query *sqlbuilde
// TODO: Need to figure out why can't parse the args when appending the where to the query. // TODO: Need to figure out why can't parse the args when appending the where to the query.
func Find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req AccountPreferenceFindRequest) ([]*AccountPreference, error) { func Find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req AccountPreferenceFindRequest) ([]*AccountPreference, error) {
query := sqlbuilder.NewSelectBuilder() query := sqlbuilder.NewSelectBuilder()
if req.Where != nil { if req.Where != "" {
query.Where(query.And(*req.Where)) query.Where(query.And(req.Where))
} }
if len(req.Order) > 0 { if len(req.Order) > 0 {
query.OrderBy(req.Order...) query.OrderBy(req.Order...)

View File

@ -397,7 +397,7 @@ func TestFind(t *testing.T) {
// Test sort accounts. // Test sort accounts.
prefTests = append(prefTests, accountTest{"Find all order by created_at asc", prefTests = append(prefTests, accountTest{"Find all order by created_at asc",
AccountPreferenceFindRequest{ AccountPreferenceFindRequest{
Where: &createdFilter, Where: createdFilter,
Args: []interface{}{startTime, endTime}, Args: []interface{}{startTime, endTime},
Order: []string{"created_at"}, Order: []string{"created_at"},
}, },
@ -412,7 +412,7 @@ func TestFind(t *testing.T) {
} }
prefTests = append(prefTests, accountTest{"Find all order by created_at desc", prefTests = append(prefTests, accountTest{"Find all order by created_at desc",
AccountPreferenceFindRequest{ AccountPreferenceFindRequest{
Where: &createdFilter, Where: createdFilter,
Args: []interface{}{startTime, endTime}, Args: []interface{}{startTime, endTime},
Order: []string{"created_at desc"}, Order: []string{"created_at desc"},
}, },
@ -424,7 +424,7 @@ func TestFind(t *testing.T) {
var limit uint = 2 var limit uint = 2
prefTests = append(prefTests, accountTest{"Find limit", prefTests = append(prefTests, accountTest{"Find limit",
AccountPreferenceFindRequest{ AccountPreferenceFindRequest{
Where: &createdFilter, Where: createdFilter,
Args: []interface{}{startTime, endTime}, Args: []interface{}{startTime, endTime},
Order: []string{"created_at"}, Order: []string{"created_at"},
Limit: &limit, Limit: &limit,
@ -437,7 +437,7 @@ func TestFind(t *testing.T) {
var offset uint = 1 var offset uint = 1
prefTests = append(prefTests, accountTest{"Find limit, offset", prefTests = append(prefTests, accountTest{"Find limit, offset",
AccountPreferenceFindRequest{ AccountPreferenceFindRequest{
Where: &createdFilter, Where: createdFilter,
Args: []interface{}{startTime, endTime}, Args: []interface{}{startTime, endTime},
Order: []string{"created_at"}, Order: []string{"created_at"},
Limit: &limit, Limit: &limit,
@ -462,10 +462,9 @@ func TestFind(t *testing.T) {
expected = append(expected, &u) expected = append(expected, &u)
} }
where := createdFilter + " AND (" + strings.Join(whereParts, " OR ") + ")"
prefTests = append(prefTests, accountTest{"Find where", prefTests = append(prefTests, accountTest{"Find where",
AccountPreferenceFindRequest{ AccountPreferenceFindRequest{
Where: &where, Where: createdFilter + " AND (" + strings.Join(whereParts, " OR ") + ")",
Args: whereArgs, Args: whereArgs,
Order: []string{"created_at"}, Order: []string{"created_at"},
}, },

View File

@ -84,7 +84,7 @@ type AccountPreferenceDeleteRequest struct {
// AccountPreferenceFindRequest defines the possible options to search for accounts. By default // AccountPreferenceFindRequest defines the possible options to search for accounts. By default
// archived accounts will be excluded from response. // archived accounts will be excluded from response.
type AccountPreferenceFindRequest struct { type AccountPreferenceFindRequest struct {
Where *string `json:"where" example:"name = ?"` Where string `json:"where" example:"name = ?"`
Args []interface{} `json:"args" swaggertype:"array,string" example:"Company Name,active"` Args []interface{} `json:"args" swaggertype:"array,string" example:"Company Name,active"`
Order []string `json:"order" example:"created_at desc"` Order []string `json:"order" example:"created_at desc"`
Limit *uint `json:"limit" example:"10"` Limit *uint `json:"limit" example:"10"`

View File

@ -842,7 +842,7 @@ func TestFind(t *testing.T) {
type accountTest struct { type accountTest struct {
name string name string
req AccountFindRequest req AccountFindRequest
expected []*Account expected Accounts
error error error error
} }
@ -853,7 +853,7 @@ func TestFind(t *testing.T) {
// Test sort accounts. // Test sort accounts.
accountTests = append(accountTests, accountTest{"Find all order by created_at asc", accountTests = append(accountTests, accountTest{"Find all order by created_at asc",
AccountFindRequest{ AccountFindRequest{
Where: &createdFilter, Where: createdFilter,
Args: []interface{}{startTime, endTime}, Args: []interface{}{startTime, endTime},
Order: []string{"created_at"}, Order: []string{"created_at"},
}, },
@ -868,7 +868,7 @@ func TestFind(t *testing.T) {
} }
accountTests = append(accountTests, accountTest{"Find all order by created_at desc", accountTests = append(accountTests, accountTest{"Find all order by created_at desc",
AccountFindRequest{ AccountFindRequest{
Where: &createdFilter, Where: createdFilter,
Args: []interface{}{startTime, endTime}, Args: []interface{}{startTime, endTime},
Order: []string{"created_at desc"}, Order: []string{"created_at desc"},
}, },
@ -880,7 +880,7 @@ func TestFind(t *testing.T) {
var limit uint = 2 var limit uint = 2
accountTests = append(accountTests, accountTest{"Find limit", accountTests = append(accountTests, accountTest{"Find limit",
AccountFindRequest{ AccountFindRequest{
Where: &createdFilter, Where: createdFilter,
Args: []interface{}{startTime, endTime}, Args: []interface{}{startTime, endTime},
Order: []string{"created_at"}, Order: []string{"created_at"},
Limit: &limit, Limit: &limit,
@ -893,7 +893,7 @@ func TestFind(t *testing.T) {
var offset uint = 3 var offset uint = 3
accountTests = append(accountTests, accountTest{"Find limit, offset", accountTests = append(accountTests, accountTest{"Find limit, offset",
AccountFindRequest{ AccountFindRequest{
Where: &createdFilter, Where: createdFilter,
Args: []interface{}{startTime, endTime}, Args: []interface{}{startTime, endTime},
Order: []string{"created_at"}, Order: []string{"created_at"},
Limit: &limit, Limit: &limit,
@ -918,10 +918,9 @@ func TestFind(t *testing.T) {
expected = append(expected, &u) expected = append(expected, &u)
} }
where := createdFilter + " AND (" + strings.Join(whereParts, " OR ") + ")"
accountTests = append(accountTests, accountTest{"Find where", accountTests = append(accountTests, accountTest{"Find where",
AccountFindRequest{ AccountFindRequest{
Where: &where, Where: createdFilter + " AND (" + strings.Join(whereParts, " OR ") + ")",
Args: whereArgs, Args: whereArgs,
Order: []string{"created_at"}, Order: []string{"created_at"},
}, },

View File

@ -170,7 +170,7 @@ type AccountDeleteRequest struct {
// AccountFindRequest defines the possible options to search for accounts. By default // AccountFindRequest defines the possible options to search for accounts. By default
// archived accounts will be excluded from response. // archived accounts will be excluded from response.
type AccountFindRequest struct { type AccountFindRequest struct {
Where *string `json:"where" example:"name = ? and status = ?"` Where string `json:"where" example:"name = ? and status = ?"`
Args []interface{} `json:"args" swaggertype:"array,string" example:"Company Name,active"` Args []interface{} `json:"args" swaggertype:"array,string" example:"Company Name,active"`
Order []string `json:"order" example:"created_at desc"` Order []string `json:"order" example:"created_at desc"`
Limit *uint `json:"limit" example:"10"` Limit *uint `json:"limit" example:"10"`

View File

@ -51,7 +51,10 @@ func Decode(ctx context.Context, r *http.Request, val interface{}) error {
} }
} }
if err := webcontext.Validator().Struct(val); err != nil { // Hack since we have no DB connection.
ctx = context.WithValue(ctx, webcontext.KeyTagUnique, true)
if err := webcontext.Validator().StructCtx(ctx, val); err != nil {
verr, _ := weberror.NewValidationError(ctx, err) verr, _ := weberror.NewValidationError(ctx, err)
return verr return verr
} }

View File

@ -108,7 +108,7 @@ type ProjectDeleteRequest struct {
// ProjectFindRequest defines the possible options to search for projects. By default // ProjectFindRequest defines the possible options to search for projects. By default
// archived project will be excluded from response. // archived project will be excluded from response.
type ProjectFindRequest struct { type ProjectFindRequest struct {
Where *string `json:"where" example:"name = ? and status = ?"` Where string `json:"where" example:"name = ? and status = ?"`
Args []interface{} `json:"args" swaggertype:"array,string" example:"Moon Launch,active"` Args []interface{} `json:"args" swaggertype:"array,string" example:"Moon Launch,active"`
Order []string `json:"order" example:"created_at desc"` Order []string `json:"order" example:"created_at desc"`
Limit *uint `json:"limit" example:"10"` Limit *uint `json:"limit" example:"10"`

View File

@ -104,8 +104,8 @@ func selectQuery() *sqlbuilder.SelectBuilder {
func findRequestQuery(req ProjectFindRequest) (*sqlbuilder.SelectBuilder, []interface{}) { func findRequestQuery(req ProjectFindRequest) (*sqlbuilder.SelectBuilder, []interface{}) {
query := selectQuery() query := selectQuery()
if req.Where != nil { if req.Where != "" {
query.Where(query.And(*req.Where)) query.Where(query.And(req.Where))
} }
if len(req.Order) > 0 { if len(req.Order) > 0 {

View File

@ -24,14 +24,14 @@ func testMain(m *testing.M) int {
// TestFindRequestQuery validates findRequestQuery // TestFindRequestQuery validates findRequestQuery
func TestFindRequestQuery(t *testing.T) { func TestFindRequestQuery(t *testing.T) {
where := "field1 = ? or field2 = ?"
var ( var (
limit uint = 12 limit uint = 12
offset uint = 34 offset uint = 34
) )
req := ProjectFindRequest{ req := ProjectFindRequest{
Where: &where, Where: "field1 = ? or field2 = ?",
Args: []interface{}{ Args: []interface{}{
"lee brown", "lee brown",
"103 East Main St.", "103 East Main St.",

View File

@ -5,6 +5,11 @@ import (
"database/sql" "database/sql"
"encoding/json" "encoding/json"
"geeks-accelerator/oss/saas-starter-kit/internal/platform/web" "geeks-accelerator/oss/saas-starter-kit/internal/platform/web"
"geeks-accelerator/oss/saas-starter-kit/internal/platform/web/webcontext"
"github.com/pkg/errors"
"github.com/sudo-suhas/symcrypto"
"strconv"
"strings"
"time" "time"
"github.com/lib/pq" "github.com/lib/pq"
@ -141,7 +146,8 @@ type UserUpdatePasswordRequest struct {
// UserArchiveRequest defines the information needed to archive an user. This will archive (soft-delete) the // UserArchiveRequest defines the information needed to archive an user. This will archive (soft-delete) the
// existing database entry. // existing database entry.
type UserArchiveRequest struct { type UserArchiveRequest struct {
ID string `json:"id" validate:"required,uuid" example:"d69bdef7-173f-4d29-b52c-3edc60baf6a2"` ID string `json:"id" validate:"required,uuid" example:"d69bdef7-173f-4d29-b52c-3edc60baf6a2"`
force bool
} }
// UserRestoreRequest defines the information needed to restore an user. // UserRestoreRequest defines the information needed to restore an user.
@ -151,13 +157,14 @@ type UserRestoreRequest struct {
// UserDeleteRequest defines the information needed to delete a user. // UserDeleteRequest defines the information needed to delete a user.
type UserDeleteRequest struct { type UserDeleteRequest struct {
ID string `json:"id" validate:"required,uuid" example:"d69bdef7-173f-4d29-b52c-3edc60baf6a2"` ID string `json:"id" validate:"required,uuid" example:"d69bdef7-173f-4d29-b52c-3edc60baf6a2"`
force bool
} }
// UserFindRequest defines the possible options to search for users. By default // UserFindRequest defines the possible options to search for users. By default
// archived users will be excluded from response. // archived users will be excluded from response.
type UserFindRequest struct { type UserFindRequest struct {
Where *string `json:"where" example:"name = ? and email = ?"` Where string `json:"where" example:"name = ? and email = ?"`
Args []interface{} `json:"args" swaggertype:"array,string" example:"Company Name,gabi.may@geeksinthewoods.com"` Args []interface{} `json:"args" swaggertype:"array,string" example:"Company Name,gabi.may@geeksinthewoods.com"`
Order []string `json:"order" example:"created_at desc"` Order []string `json:"order" example:"created_at desc"`
Limit *uint `json:"limit" example:"10"` Limit *uint `json:"limit" example:"10"`
@ -185,3 +192,63 @@ type UserResetConfirmRequest struct {
Password string `json:"password" validate:"required" example:"SecretString"` Password string `json:"password" validate:"required" example:"SecretString"`
PasswordConfirm string `json:"password_confirm" validate:"required,eqfield=Password" example:"SecretString"` PasswordConfirm string `json:"password_confirm" validate:"required,eqfield=Password" example:"SecretString"`
} }
// NewResetHash generates a new encrypt reset hash that is web safe for use in URLs.
func NewResetHash(ctx context.Context, secretKey, resetId, requestIp string, ttl time.Duration, now time.Time) (string, error) {
// Generate a string that embeds additional information.
hashPts := []string{
resetId,
strconv.Itoa(int(now.UTC().Unix())),
strconv.Itoa(int(now.UTC().Add(ttl).Unix())),
requestIp,
}
hashStr := strings.Join(hashPts, "|")
// This returns the nonce appended with the encrypted string.
crypto, err := symcrypto.New(secretKey)
if err != nil {
return "", errors.WithStack(err)
}
encrypted, err := crypto.Encrypt(hashStr)
if err != nil {
return "", errors.WithStack(err)
}
return encrypted, nil
}
// ParseResetHash extracts the details encrypted in the hash string.
func ParseResetHash(ctx context.Context, secretKey string, str string, now time.Time) (*ResetHash, error) {
crypto, err := symcrypto.New(secretKey)
if err != nil {
return nil, errors.WithStack(err)
}
hashStr, err := crypto.Decrypt(str)
if err != nil {
return nil, errors.WithStack(err)
}
hashPts := strings.Split(hashStr, "|")
var hash ResetHash
if len(hashPts) == 4 {
hash.ResetID = hashPts[0]
hash.CreatedAt, _ = strconv.Atoi(hashPts[1])
hash.ExpiresAt, _ = strconv.Atoi(hashPts[2])
hash.RequestIP = hashPts[3]
}
// Validate the hash.
err = webcontext.Validator().StructCtx(ctx, hash)
if err != nil {
return nil, err
}
if int64(hash.ExpiresAt) < now.UTC().Unix() {
err = errors.WithMessage(ErrResetExpired, "Password reset has expired.")
return nil, err
}
return &hash, nil
}

View File

@ -3,9 +3,6 @@ package user
import ( import (
"context" "context"
"database/sql" "database/sql"
"github.com/sudo-suhas/symcrypto"
"strconv"
"strings"
"time" "time"
"geeks-accelerator/oss/saas-starter-kit/internal/platform/auth" "geeks-accelerator/oss/saas-starter-kit/internal/platform/auth"
@ -185,8 +182,8 @@ func selectQuery() *sqlbuilder.SelectBuilder {
// to the query. // to the query.
func findRequestQuery(req UserFindRequest) (*sqlbuilder.SelectBuilder, []interface{}) { func findRequestQuery(req UserFindRequest) (*sqlbuilder.SelectBuilder, []interface{}) {
query := selectQuery() query := selectQuery()
if req.Where != nil { if req.Where != "" {
query.Where(query.And(*req.Where)) query.Where(query.And(req.Where))
} }
if len(req.Order) > 0 { if len(req.Order) > 0 {
query.OrderBy(req.Order...) query.OrderBy(req.Order...)
@ -490,8 +487,6 @@ func Update(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserUp
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.Update") span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.Update")
defer span.Finish() defer span.Finish()
v := webcontext.Validator()
// Validation email address is unique in the database. // Validation email address is unique in the database.
if req.Email != nil { if req.Email != nil {
// Validation email address is unique in the database. // Validation email address is unique in the database.
@ -505,6 +500,7 @@ func Update(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserUp
} }
// Validate the request. // Validate the request.
v := webcontext.Validator()
err := v.StructCtx(ctx, req) err := v.StructCtx(ctx, req)
if err != nil { if err != nil {
return err return err
@ -647,7 +643,7 @@ func Archive(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserA
err = CanModifyUser(ctx, claims, dbConn, req.ID) err = CanModifyUser(ctx, claims, dbConn, req.ID)
if err != nil { if err != nil {
return err return err
} else if claims.Subject != "" && claims.Subject == req.ID { } else if claims.Subject != "" && claims.Subject == req.ID && !req.force {
return errors.WithStack(ErrForbidden) return errors.WithStack(ErrForbidden)
} }
@ -772,7 +768,7 @@ func Delete(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserDe
err = CanModifyUser(ctx, claims, dbConn, req.ID) err = CanModifyUser(ctx, claims, dbConn, req.ID)
if err != nil { if err != nil {
return err return err
} else if claims.Subject != "" && claims.Subject == req.ID { } else if claims.Subject != "" && claims.Subject == req.ID && !req.force {
return errors.WithStack(ErrForbidden) return errors.WithStack(ErrForbidden)
} }
@ -899,23 +895,9 @@ func ResetPassword(ctx context.Context, dbConn *sqlx.DB, resetUrl func(string) s
requestIp = vals.RequestIP requestIp = vals.RequestIP
} }
// Generate a string that embeds additional information. encrypted, err := NewResetHash(ctx, secretKey, resetId, requestIp, req.TTL, now)
hashPts := []string{
resetId,
strconv.Itoa(int(now.UTC().Unix())),
strconv.Itoa(int(now.UTC().Add(req.TTL).Unix())),
requestIp,
}
hashStr := strings.Join(hashPts, "|")
// This returns the nonce appended with the encrypted string for "hello world".
crypto, err := symcrypto.New(secretKey)
if err != nil { if err != nil {
return "", errors.WithStack(err) return "", err
}
encrypted, err := crypto.Encrypt(hashStr)
if err != nil {
return "", errors.WithStack(err)
} }
data := map[string]interface{}{ data := map[string]interface{}{
@ -946,32 +928,8 @@ func ResetConfirm(ctx context.Context, dbConn *sqlx.DB, req UserResetConfirmRequ
return nil, err return nil, err
} }
crypto, err := symcrypto.New(secretKey) hash, err := ParseResetHash(ctx, secretKey, req.ResetHash, now)
if err != nil { if err != nil {
return nil, errors.WithStack(err)
}
hashStr, err := crypto.Decrypt(req.ResetHash)
if err != nil {
return nil, errors.WithStack(err)
}
hashPts := strings.Split(hashStr, "|")
var hash ResetHash
if len(hashPts) == 4 {
hash.ResetID = hashPts[0]
hash.CreatedAt, _ = strconv.Atoi(hashPts[1])
hash.ExpiresAt, _ = strconv.Atoi(hashPts[2])
hash.RequestIP = hashPts[3]
}
// Validate the hash.
err = v.StructCtx(ctx, hash)
if err != nil {
return nil, err
}
if int64(hash.ExpiresAt) < now.UTC().Unix() {
err = errors.WithMessage(ErrResetExpired, "Password reset has expired.")
return nil, err return nil, err
} }

View File

@ -33,14 +33,13 @@ func testMain(m *testing.M) int {
// TestFindRequestQuery validates findRequestQuery // TestFindRequestQuery validates findRequestQuery
func TestFindRequestQuery(t *testing.T) { func TestFindRequestQuery(t *testing.T) {
where := "first_name = ? or email = ?"
var ( var (
limit uint = 12 limit uint = 12
offset uint = 34 offset uint = 34
) )
req := UserFindRequest{ req := UserFindRequest{
Where: &where, Where: "first_name = ? or email = ?",
Args: []interface{}{ Args: []interface{}{
"lee", "lee",
"lee@geeksinthewoods.com", "lee@geeksinthewoods.com",
@ -195,7 +194,7 @@ func TestCreateValidation(t *testing.T) {
FirstName: "Lee", FirstName: "Lee",
LastName: "Brown", LastName: "Brown",
Email: req.Email, Email: req.Email,
Timezone: "America/Anchorage", Timezone: nil,
// Copy this fields from the result. // Copy this fields from the result.
ID: res.ID, ID: res.ID,
@ -847,7 +846,7 @@ func TestCrud(t *testing.T) {
} }
// Archive (soft-delete) the user. // Archive (soft-delete) the user.
err = Archive(ctx, tt.claims(user, accountId), test.MasterDB, UserArchiveRequest{ID: user.ID}, now) err = Archive(ctx, tt.claims(user, accountId), test.MasterDB, UserArchiveRequest{ID: user.ID, force: true}, now)
if err != nil && errors.Cause(err) != tt.updateErr { if err != nil && errors.Cause(err) != tt.updateErr {
t.Logf("\t\tGot : %+v", err) t.Logf("\t\tGot : %+v", err)
t.Logf("\t\tWant: %+v", tt.updateErr) t.Logf("\t\tWant: %+v", tt.updateErr)
@ -888,7 +887,7 @@ func TestCrud(t *testing.T) {
t.Logf("\t%s\tUnarchive ok.", tests.Success) t.Logf("\t%s\tUnarchive ok.", tests.Success)
// Delete (hard-delete) the user. // Delete (hard-delete) the user.
err = Delete(ctx, tt.claims(user, accountId), test.MasterDB, UserDeleteRequest{ID: user.ID}) err = Delete(ctx, tt.claims(user, accountId), test.MasterDB, UserDeleteRequest{ID: user.ID, force: true})
if err != nil && errors.Cause(err) != tt.updateErr { if err != nil && errors.Cause(err) != tt.updateErr {
t.Logf("\t\tGot : %+v", err) t.Logf("\t\tGot : %+v", err)
t.Logf("\t\tWant: %+v", tt.updateErr) t.Logf("\t\tWant: %+v", tt.updateErr)
@ -936,7 +935,7 @@ func TestFind(t *testing.T) {
type userTest struct { type userTest struct {
name string name string
req UserFindRequest req UserFindRequest
expected []*User expected Users
error error error error
} }
@ -947,7 +946,7 @@ func TestFind(t *testing.T) {
// Test sort users. // Test sort users.
userTests = append(userTests, userTest{"Find all order by created_at asc", userTests = append(userTests, userTest{"Find all order by created_at asc",
UserFindRequest{ UserFindRequest{
Where: &createdFilter, Where: createdFilter,
Args: []interface{}{startTime, endTime}, Args: []interface{}{startTime, endTime},
Order: []string{"created_at"}, Order: []string{"created_at"},
}, },
@ -956,13 +955,13 @@ func TestFind(t *testing.T) {
}) })
// Test reverse sorted users. // Test reverse sorted users.
var expected []*User var expected Users
for i := len(users) - 1; i >= 0; i-- { for i := len(users) - 1; i >= 0; i-- {
expected = append(expected, users[i]) expected = append(expected, users[i])
} }
userTests = append(userTests, userTest{"Find all order by created_at desc", userTests = append(userTests, userTest{"Find all order by created_at desc",
UserFindRequest{ UserFindRequest{
Where: &createdFilter, Where: createdFilter,
Args: []interface{}{startTime, endTime}, Args: []interface{}{startTime, endTime},
Order: []string{"created_at desc"}, Order: []string{"created_at desc"},
}, },
@ -974,7 +973,7 @@ func TestFind(t *testing.T) {
var limit uint = 2 var limit uint = 2
userTests = append(userTests, userTest{"Find limit", userTests = append(userTests, userTest{"Find limit",
UserFindRequest{ UserFindRequest{
Where: &createdFilter, Where: createdFilter,
Args: []interface{}{startTime, endTime}, Args: []interface{}{startTime, endTime},
Order: []string{"created_at"}, Order: []string{"created_at"},
Limit: &limit, Limit: &limit,
@ -987,7 +986,7 @@ func TestFind(t *testing.T) {
var offset uint = 3 var offset uint = 3
userTests = append(userTests, userTest{"Find limit, offset", userTests = append(userTests, userTest{"Find limit, offset",
UserFindRequest{ UserFindRequest{
Where: &createdFilter, Where: createdFilter,
Args: []interface{}{startTime, endTime}, Args: []interface{}{startTime, endTime},
Order: []string{"created_at"}, Order: []string{"created_at"},
Limit: &limit, Limit: &limit,
@ -1015,7 +1014,7 @@ func TestFind(t *testing.T) {
where := createdFilter + " AND (" + strings.Join(whereParts, " OR ") + ")" where := createdFilter + " AND (" + strings.Join(whereParts, " OR ") + ")"
userTests = append(userTests, userTest{"Find where", userTests = append(userTests, userTest{"Find where",
UserFindRequest{ UserFindRequest{
Where: &where, Where: where,
Args: whereArgs, Args: whereArgs,
Order: []string{"created_at"}, Order: []string{"created_at"},
}, },

View File

@ -54,9 +54,9 @@ func SendUserInvites(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, r
emailUserIDs := make(map[string]string) emailUserIDs := make(map[string]string)
{ {
// Find all users without passing in claims to search all users. // Find all users without passing in claims to search all users.
where := fmt.Sprintf("email in ('%s')", strings.Join(req.Emails, "','"))
users, err := user.Find(ctx, auth.Claims{}, dbConn, user.UserFindRequest{ users, err := user.Find(ctx, auth.Claims{}, dbConn, user.UserFindRequest{
Where: &where, Where: fmt.Sprintf("email in ('%s')",
strings.Join(req.Emails, "','")),
}) })
if err != nil { if err != nil {
return nil, err return nil, err
@ -75,9 +75,10 @@ func SendUserInvites(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, r
args = append(args, userID) args = append(args, userID)
} }
where := fmt.Sprintf("user_id in ('%s') and status = '%s'", strings.Join(args, "','"), user_account.UserAccountStatus_Active.String())
userAccs, err := user_account.Find(ctx, claims, dbConn, user_account.UserAccountFindRequest{ userAccs, err := user_account.Find(ctx, claims, dbConn, user_account.UserAccountFindRequest{
Where: &where, Where: fmt.Sprintf("user_id in ('%s') and status = '%s'",
strings.Join(args, "','"),
user_account.UserAccountStatus_Active.String()),
}) })
if err != nil { if err != nil {
return nil, err return nil, err

View File

@ -1,6 +1,7 @@
package invite package invite
import ( import (
"geeks-accelerator/oss/saas-starter-kit/internal/platform/web/webcontext"
"os" "os"
"strings" "strings"
"testing" "testing"
@ -78,7 +79,7 @@ func TestSendUserInvites(t *testing.T) {
} }
claims := auth.Claims{ claims := auth.Claims{
AccountIds: []string{a.ID}, AccountIDs: []string{a.ID},
StandardClaims: jwt.StandardClaims{ StandardClaims: jwt.StandardClaims{
Subject: u.ID, Subject: u.ID,
Audience: a.ID, Audience: a.ID,
@ -148,11 +149,12 @@ func TestSendUserInvites(t *testing.T) {
// Ensure validation is working by trying ResetConfirm with an empty request. // Ensure validation is working by trying ResetConfirm with an empty request.
{ {
expectedErr := errors.New("Key: 'AcceptInviteRequest.invite_hash' Error:Field validation for 'invite_hash' failed on the 'required' tag\n" + expectedErr := errors.New("Key: 'AcceptInviteRequest.invite_hash' Error:Field validation for 'invite_hash' failed on the 'required' tag\n" +
"Key: 'AcceptInviteRequest.email' Error:Field validation for 'email' failed on the 'required' tag\n" +
"Key: 'AcceptInviteRequest.first_name' Error:Field validation for 'first_name' failed on the 'required' tag\n" + "Key: 'AcceptInviteRequest.first_name' Error:Field validation for 'first_name' failed on the 'required' tag\n" +
"Key: 'AcceptInviteRequest.last_name' Error:Field validation for 'last_name' failed on the 'required' tag\n" + "Key: 'AcceptInviteRequest.last_name' Error:Field validation for 'last_name' failed on the 'required' tag\n" +
"Key: 'AcceptInviteRequest.password' Error:Field validation for 'password' failed on the 'required' tag\n" + "Key: 'AcceptInviteRequest.password' Error:Field validation for 'password' failed on the 'required' tag\n" +
"Key: 'AcceptInviteRequest.password_confirm' Error:Field validation for 'password_confirm' failed on the 'required' tag") "Key: 'AcceptInviteRequest.password_confirm' Error:Field validation for 'password_confirm' failed on the 'required' tag")
err = AcceptInvite(ctx, test.MasterDB, AcceptInviteRequest{}, secretKey, now) _, err = AcceptInvite(ctx, test.MasterDB, AcceptInviteRequest{}, secretKey, now)
if err == nil { if err == nil {
t.Logf("\t\tWant: %+v", expectedErr) t.Logf("\t\tWant: %+v", expectedErr)
t.Fatalf("\t%s\tResetConfirm failed.", tests.Failed) t.Fatalf("\t%s\tResetConfirm failed.", tests.Failed)
@ -172,8 +174,9 @@ func TestSendUserInvites(t *testing.T) {
// Ensure the TTL is enforced. // Ensure the TTL is enforced.
{ {
newPass := uuid.NewRandom().String() newPass := uuid.NewRandom().String()
err = AcceptInvite(ctx, test.MasterDB, AcceptInviteRequest{ _, err = AcceptInvite(ctx, test.MasterDB, AcceptInviteRequest{
InviteHash: inviteHashes[0], InviteHash: inviteHashes[0],
Email: inviteEmails[0],
FirstName: "Foo", FirstName: "Foo",
LastName: "Bar", LastName: "Bar",
Password: newPass, Password: newPass,
@ -188,10 +191,15 @@ func TestSendUserInvites(t *testing.T) {
} }
// Assuming we have received the email and clicked the link, we now can ensure accept works. // Assuming we have received the email and clicked the link, we now can ensure accept works.
for _, inviteHash := range inviteHashes { for idx, inviteHash := range inviteHashes {
type expectRes struct {
UserID string `json:"user_id" validate:"required,uuid"`
}
var res expectRes
newPass := uuid.NewRandom().String() newPass := uuid.NewRandom().String()
err = AcceptInvite(ctx, test.MasterDB, AcceptInviteRequest{ res.UserID, err = AcceptInvite(ctx, test.MasterDB, AcceptInviteRequest{
InviteHash: inviteHash, InviteHash: inviteHash,
Email: inviteEmails[idx],
FirstName: "Foo", FirstName: "Foo",
LastName: "Bar", LastName: "Bar",
Password: newPass, Password: newPass,
@ -201,20 +209,29 @@ func TestSendUserInvites(t *testing.T) {
t.Log("\t\tGot :", err) t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tInviteAccept failed.", tests.Failed) t.Fatalf("\t%s\tInviteAccept failed.", tests.Failed)
} }
// Validate the result.
err := webcontext.Validator().StructCtx(ctx, res)
if err != nil {
t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tInviteAccept failed.", tests.Failed)
}
t.Logf("\t%s\tInviteAccept ok.", tests.Success) t.Logf("\t%s\tInviteAccept ok.", tests.Success)
} }
// Ensure the reset hash does not work after its used. // Ensure the reset hash does not work after its used.
{ {
newPass := uuid.NewRandom().String() newPass := uuid.NewRandom().String()
err = AcceptInvite(ctx, test.MasterDB, AcceptInviteRequest{ _, err = AcceptInvite(ctx, test.MasterDB, AcceptInviteRequest{
InviteHash: inviteHashes[0], InviteHash: inviteHashes[0],
Email: inviteEmails[0],
FirstName: "Foo", FirstName: "Foo",
LastName: "Bar", LastName: "Bar",
Password: newPass, Password: newPass,
PasswordConfirm: newPass, PasswordConfirm: newPass,
}, secretKey, now) }, secretKey, now)
if errors.Cause(err) != ErrInviteUserPasswordSet { if errors.Cause(err) != ErrUserAccountActive {
t.Logf("\t\tGot : %+v", errors.Cause(err)) t.Logf("\t\tGot : %+v", errors.Cause(err))
t.Logf("\t\tWant: %+v", ErrInviteUserPasswordSet) t.Logf("\t\tWant: %+v", ErrInviteUserPasswordSet)
t.Fatalf("\t%s\tInviteAccept verify reuse failed.", tests.Failed) t.Fatalf("\t%s\tInviteAccept verify reuse failed.", tests.Failed)

View File

@ -2,7 +2,6 @@ package invite
import ( import (
"context" "context"
"fmt"
"strconv" "strconv"
"strings" "strings"
"time" "time"
@ -79,8 +78,6 @@ func ParseInviteHash(ctx context.Context, secretKey string, str string, now time
} }
hashPts := strings.Split(hashStr, "|") hashPts := strings.Split(hashStr, "|")
fmt.Println(hashPts)
var hash InviteHash var hash InviteHash
if len(hashPts) == 5 { if len(hashPts) == 5 {
hash.UserID = hashPts[0] hash.UserID = hashPts[0]

View File

@ -32,13 +32,13 @@ type UserAccount struct {
// UserAccountResponse defines the one to many relationship of an user to an account that is returned for display. // UserAccountResponse defines the one to many relationship of an user to an account that is returned for display.
type UserAccountResponse struct { type UserAccountResponse struct {
//ID string `json:"id" example:"d69bdef7-173f-4d29-b52c-3edc60baf6a2"` //ID string `json:"id" example:"d69bdef7-173f-4d29-b52c-3edc60baf6a2"`
UserID string `json:"user_id" example:"d69bdef7-173f-4d29-b52c-3edc60baf6a2"` UserID string `json:"user_id" example:"d69bdef7-173f-4d29-b52c-3edc60baf6a2"`
AccountID string `json:"account_id" example:"c4653bf9-5978-48b7-89c5-95704aebb7e2"` AccountID string `json:"account_id" example:"c4653bf9-5978-48b7-89c5-95704aebb7e2"`
Roles UserAccountRoles `json:"roles" validate:"required,dive,oneof=admin user" enums:"admin,user" swaggertype:"array,string" example:"admin"` Roles web.EnumMultiResponse `json:"roles" validate:"required,dive,oneof=admin user" enums:"admin,user" swaggertype:"array,string" example:"admin"`
Status web.EnumResponse `json:"status"` // Status is enum with values [active, invited, disabled]. Status web.EnumResponse `json:"status"` // Status is enum with values [active, invited, disabled].
CreatedAt web.TimeResponse `json:"created_at"` // CreatedAt contains multiple format options for display. CreatedAt web.TimeResponse `json:"created_at"` // CreatedAt contains multiple format options for display.
UpdatedAt web.TimeResponse `json:"updated_at"` // UpdatedAt contains multiple format options for display. UpdatedAt web.TimeResponse `json:"updated_at"` // UpdatedAt contains multiple format options for display.
ArchivedAt *web.TimeResponse `json:"archived_at,omitempty"` // ArchivedAt contains multiple format options for display. ArchivedAt *web.TimeResponse `json:"archived_at,omitempty"` // ArchivedAt contains multiple format options for display.
} }
// Response transforms UserAccount and UserAccountResponse that is used for display. // Response transforms UserAccount and UserAccountResponse that is used for display.
@ -52,12 +52,17 @@ func (m *UserAccount) Response(ctx context.Context) *UserAccountResponse {
//ID: m.ID, //ID: m.ID,
UserID: m.UserID, UserID: m.UserID,
AccountID: m.AccountID, AccountID: m.AccountID,
Roles: m.Roles,
Status: web.NewEnumResponse(ctx, m.Status, UserAccountStatus_ValuesInterface()...), Status: web.NewEnumResponse(ctx, m.Status, UserAccountStatus_ValuesInterface()...),
CreatedAt: web.NewTimeResponse(ctx, m.CreatedAt), CreatedAt: web.NewTimeResponse(ctx, m.CreatedAt),
UpdatedAt: web.NewTimeResponse(ctx, m.UpdatedAt), UpdatedAt: web.NewTimeResponse(ctx, m.UpdatedAt),
} }
var selectedRoles []interface{}
for _, r := range m.Roles {
selectedRoles = append(selectedRoles, r.String())
}
r.Roles = web.NewEnumMultiResponse(ctx, selectedRoles, UserAccountRole_ValuesInterface()...)
if m.ArchivedAt != nil && !m.ArchivedAt.Time.IsZero() { if m.ArchivedAt != nil && !m.ArchivedAt.Time.IsZero() {
at := web.NewTimeResponse(ctx, m.ArchivedAt.Time) at := web.NewTimeResponse(ctx, m.ArchivedAt.Time)
r.ArchivedAt = &at r.ArchivedAt = &at
@ -139,7 +144,7 @@ type UserAccountDeleteRequest struct {
// UserAccountFindRequest defines the possible options to search for users accounts. // UserAccountFindRequest defines the possible options to search for users accounts.
// By default archived user accounts will be excluded from response. // By default archived user accounts will be excluded from response.
type UserAccountFindRequest struct { type UserAccountFindRequest struct {
Where *string `json:"where" example:"user_id = ? and account_id = ?"` Where string `json:"where" example:"user_id = ? and account_id = ?"`
Args []interface{} `json:"args" swaggertype:"array,string" example:"d69bdef7-173f-4d29-b52c-3edc60baf6a2,c4653bf9-5978-48b7-89c5-95704aebb7e2"` Args []interface{} `json:"args" swaggertype:"array,string" example:"d69bdef7-173f-4d29-b52c-3edc60baf6a2,c4653bf9-5978-48b7-89c5-95704aebb7e2"`
Order []string `json:"order" example:"created_at desc"` Order []string `json:"order" example:"created_at desc"`
Limit *uint `json:"limit" example:"10"` Limit *uint `json:"limit" example:"10"`
@ -291,19 +296,19 @@ type User struct {
// UserResponse represents someone with access to our system that is returned for display. // UserResponse represents someone with access to our system that is returned for display.
type UserResponse struct { type UserResponse struct {
ID string `json:"id" example:"d69bdef7-173f-4d29-b52c-3edc60baf6a2"` ID string `json:"id" example:"d69bdef7-173f-4d29-b52c-3edc60baf6a2"`
Name string `json:"name" example:"Gabi"` Name string `json:"name" example:"Gabi"`
FirstName string `json:"first_name" example:"Gabi"` FirstName string `json:"first_name" example:"Gabi"`
LastName string `json:"last_name" example:"May"` LastName string `json:"last_name" example:"May"`
Email string `json:"email" example:"gabi@geeksinthewoods.com"` Email string `json:"email" example:"gabi@geeksinthewoods.com"`
Timezone string `json:"timezone" example:"America/Anchorage"` Timezone string `json:"timezone" example:"America/Anchorage"`
AccountID string `json:"account_id" example:"c4653bf9-5978-48b7-89c5-95704aebb7e2"` AccountID string `json:"account_id" example:"c4653bf9-5978-48b7-89c5-95704aebb7e2"`
Roles UserAccountRoles `json:"roles" validate:"required,dive,oneof=admin user" enums:"admin,user" swaggertype:"array,string" example:"admin"` Roles web.EnumMultiResponse `json:"roles" validate:"required,dive,oneof=admin user" enums:"admin,user" swaggertype:"array,string" example:"admin"`
Status web.EnumResponse `json:"status"` // Status is enum with values [active, invited, disabled]. Status web.EnumResponse `json:"status"` // Status is enum with values [active, invited, disabled].
CreatedAt web.TimeResponse `json:"created_at"` // CreatedAt contains multiple format options for display. CreatedAt web.TimeResponse `json:"created_at"` // CreatedAt contains multiple format options for display.
UpdatedAt web.TimeResponse `json:"updated_at"` // UpdatedAt contains multiple format options for display. UpdatedAt web.TimeResponse `json:"updated_at"` // UpdatedAt contains multiple format options for display.
ArchivedAt *web.TimeResponse `json:"archived_at,omitempty"` // ArchivedAt contains multiple format options for display. ArchivedAt *web.TimeResponse `json:"archived_at,omitempty"` // ArchivedAt contains multiple format options for display.
Gravatar web.GravatarResponse `json:"gravatar"` Gravatar web.GravatarResponse `json:"gravatar"`
} }
// Response transforms User and UserResponse that is used for display. // Response transforms User and UserResponse that is used for display.
@ -320,13 +325,18 @@ func (m *User) Response(ctx context.Context) *UserResponse {
LastName: m.LastName, LastName: m.LastName,
Email: m.Email, Email: m.Email,
AccountID: m.AccountID, AccountID: m.AccountID,
Roles: m.Roles,
Status: web.NewEnumResponse(ctx, m.Status, UserAccountStatus_Values), Status: web.NewEnumResponse(ctx, m.Status, UserAccountStatus_Values),
CreatedAt: web.NewTimeResponse(ctx, m.CreatedAt), CreatedAt: web.NewTimeResponse(ctx, m.CreatedAt),
UpdatedAt: web.NewTimeResponse(ctx, m.UpdatedAt), UpdatedAt: web.NewTimeResponse(ctx, m.UpdatedAt),
Gravatar: web.NewGravatarResponse(ctx, m.Email), Gravatar: web.NewGravatarResponse(ctx, m.Email),
} }
var selectedRoles []interface{}
for _, r := range m.Roles {
selectedRoles = append(selectedRoles, r.String())
}
r.Roles = web.NewEnumMultiResponse(ctx, selectedRoles, UserAccountRole_ValuesInterface()...)
if m.Timezone != nil { if m.Timezone != nil {
r.Timezone = *m.Timezone r.Timezone = *m.Timezone
} }

View File

@ -114,8 +114,8 @@ func selectQuery() *sqlbuilder.SelectBuilder {
// to the query. // to the query.
func findRequestQuery(req UserAccountFindRequest) (*sqlbuilder.SelectBuilder, []interface{}) { func findRequestQuery(req UserAccountFindRequest) (*sqlbuilder.SelectBuilder, []interface{}) {
query := selectQuery() query := selectQuery()
if req.Where != nil { if req.Where != "" {
query.Where(query.And(*req.Where)) query.Where(query.And(req.Where))
} }
if len(req.Order) > 0 { if len(req.Order) > 0 {
query.OrderBy(req.Order...) query.OrderBy(req.Order...)

View File

@ -32,14 +32,14 @@ func testMain(m *testing.M) int {
// TestFindRequestQuery validates findRequestQuery // TestFindRequestQuery validates findRequestQuery
func TestFindRequestQuery(t *testing.T) { func TestFindRequestQuery(t *testing.T) {
where := "account_id = ? or user_id = ?"
var ( var (
limit uint = 12 limit uint = 12
offset uint = 34 offset uint = 34
) )
req := UserAccountFindRequest{ req := UserAccountFindRequest{
Where: &where, Where: "account_id = ? or user_id = ?",
Args: []interface{}{ Args: []interface{}{
"xy7", "xy7",
"qwert", "qwert",
@ -598,9 +598,8 @@ func TestCrud(t *testing.T) {
// Find the account for the user to verify the updates where made. There should only // Find the account for the user to verify the updates where made. There should only
// be one account associated with the user for this test. // be one account associated with the user for this test.
ff := "user_id = ? or account_id = ?"
findRes, err := Find(tests.Context(), tt.claims(userID, accountID), test.MasterDB, UserAccountFindRequest{ findRes, err := Find(tests.Context(), tt.claims(userID, accountID), test.MasterDB, UserAccountFindRequest{
Where: &ff, Where: "user_id = ? or account_id = ?",
Args: []interface{}{userID, accountID}, Args: []interface{}{userID, accountID},
Order: []string{"created_at"}, Order: []string{"created_at"},
}) })
@ -609,7 +608,7 @@ func TestCrud(t *testing.T) {
t.Logf("\t\tWant: %+v", tt.findErr) t.Logf("\t\tWant: %+v", tt.findErr)
t.Fatalf("\t%s\tVerify update user account failed.", tests.Failed) t.Fatalf("\t%s\tVerify update user account failed.", tests.Failed)
} else if tt.findErr == nil { } else if tt.findErr == nil {
expected := []*UserAccount{ var expected UserAccounts = []*UserAccount{
&UserAccount{ &UserAccount{
//ID: ua.ID, //ID: ua.ID,
UserID: ua.UserID, UserID: ua.UserID,
@ -651,7 +650,7 @@ func TestCrud(t *testing.T) {
t.Fatalf("\t%s\tVerify archive user account failed when including archived.", tests.Failed) t.Fatalf("\t%s\tVerify archive user account failed when including archived.", tests.Failed)
} }
expected := []*UserAccount{ var expected UserAccounts = []*UserAccount{
&UserAccount{ &UserAccount{
//ID: ua.ID, //ID: ua.ID,
UserID: ua.UserID, UserID: ua.UserID,
@ -737,7 +736,7 @@ func TestFind(t *testing.T) {
type accountTest struct { type accountTest struct {
name string name string
req UserAccountFindRequest req UserAccountFindRequest
expected []*UserAccount expected UserAccounts
error error error error
} }
@ -748,7 +747,7 @@ func TestFind(t *testing.T) {
// Test sort users. // Test sort users.
accountTests = append(accountTests, accountTest{"Find all order by created_at asx", accountTests = append(accountTests, accountTest{"Find all order by created_at asx",
UserAccountFindRequest{ UserAccountFindRequest{
Where: &createdFilter, Where: createdFilter,
Args: []interface{}{startTime, endTime}, Args: []interface{}{startTime, endTime},
Order: []string{"created_at"}, Order: []string{"created_at"},
}, },
@ -763,7 +762,7 @@ func TestFind(t *testing.T) {
} }
accountTests = append(accountTests, accountTest{"Find all order by created_at desc", accountTests = append(accountTests, accountTest{"Find all order by created_at desc",
UserAccountFindRequest{ UserAccountFindRequest{
Where: &createdFilter, Where: createdFilter,
Args: []interface{}{startTime, endTime}, Args: []interface{}{startTime, endTime},
Order: []string{"created_at desc"}, Order: []string{"created_at desc"},
}, },
@ -775,7 +774,7 @@ func TestFind(t *testing.T) {
var limit uint = 2 var limit uint = 2
accountTests = append(accountTests, accountTest{"Find limit", accountTests = append(accountTests, accountTest{"Find limit",
UserAccountFindRequest{ UserAccountFindRequest{
Where: &createdFilter, Where: createdFilter,
Args: []interface{}{startTime, endTime}, Args: []interface{}{startTime, endTime},
Order: []string{"created_at"}, Order: []string{"created_at"},
Limit: &limit, Limit: &limit,
@ -788,7 +787,7 @@ func TestFind(t *testing.T) {
var offset uint = 3 var offset uint = 3
accountTests = append(accountTests, accountTest{"Find limit, offset", accountTests = append(accountTests, accountTest{"Find limit, offset",
UserAccountFindRequest{ UserAccountFindRequest{
Where: &createdFilter, Where: createdFilter,
Args: []interface{}{startTime, endTime}, Args: []interface{}{startTime, endTime},
Order: []string{"created_at"}, Order: []string{"created_at"},
Limit: &limit, Limit: &limit,
@ -813,10 +812,10 @@ func TestFind(t *testing.T) {
whereArgs = append(whereArgs, ua.AccountID) whereArgs = append(whereArgs, ua.AccountID)
expected = append(expected, &ua) expected = append(expected, &ua)
} }
where := createdFilter + " AND (" + strings.Join(whereParts, " OR ") + ")"
accountTests = append(accountTests, accountTest{"Find where", accountTests = append(accountTests, accountTest{"Find where",
UserAccountFindRequest{ UserAccountFindRequest{
Where: &where, Where: createdFilter + " AND (" + strings.Join(whereParts, " OR ") + ")",
Args: whereArgs, Args: whereArgs,
Order: []string{"created_at"}, Order: []string{"created_at"},
}, },

View File

@ -40,11 +40,18 @@ const (
// Authenticate finds a user by their email and verifies their password. On success // Authenticate finds a user by their email and verifies their password. On success
// it returns a Token that can be used to authenticate access to the application in // it returns a Token that can be used to authenticate access to the application in
// the future. // the future.
func Authenticate(ctx context.Context, dbConn *sqlx.DB, tknGen TokenGenerator, email, password string, expires time.Duration, now time.Time, scopes ...string) (Token, error) { func Authenticate(ctx context.Context, dbConn *sqlx.DB, tknGen TokenGenerator, req AuthenticateRequest, expires time.Duration, now time.Time, scopes ...string) (Token, error) {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user_auth.Authenticate") span, ctx := tracer.StartSpanFromContext(ctx, "internal.user_auth.Authenticate")
defer span.Finish() defer span.Finish()
u, err := user.ReadByEmail(ctx, auth.Claims{}, dbConn, email, false) // Validate the request.
v := webcontext.Validator()
err := v.Struct(req)
if err != nil {
return Token{}, err
}
u, err := user.ReadByEmail(ctx, auth.Claims{}, dbConn, req.Email, false)
if err != nil { if err != nil {
if errors.Cause(err) == user.ErrNotFound { if errors.Cause(err) == user.ErrNotFound {
err = errors.WithStack(ErrAuthenticationFailure) err = errors.WithStack(ErrAuthenticationFailure)
@ -55,7 +62,7 @@ func Authenticate(ctx context.Context, dbConn *sqlx.DB, tknGen TokenGenerator, e
} }
// Append the salt from the user record to the supplied password. // Append the salt from the user record to the supplied password.
saltedPassword := password + u.PasswordSalt saltedPassword := req.Password + u.PasswordSalt
// Compare the provided password with the saved hash. Use the bcrypt comparison // Compare the provided password with the saved hash. Use the bcrypt comparison
// function so it is cryptographically secure. Return authentication error for // function so it is cryptographically secure. Return authentication error for
@ -66,7 +73,7 @@ func Authenticate(ctx context.Context, dbConn *sqlx.DB, tknGen TokenGenerator, e
} }
// The user is successfully authenticated with the supplied email and password. // The user is successfully authenticated with the supplied email and password.
return generateToken(ctx, dbConn, tknGen, auth.Claims{}, u.ID, "", expires, now, scopes...) return generateToken(ctx, dbConn, tknGen, auth.Claims{}, u.ID, req.AccountID, expires, now, scopes...)
} }
// SwitchAccount allows users to switch between multiple accounts, this changes the claim audience. // SwitchAccount allows users to switch between multiple accounts, this changes the claim audience.

View File

@ -48,7 +48,11 @@ func TestAuthenticate(t *testing.T) {
now := time.Now().Add(time.Hour * -1) now := time.Now().Add(time.Hour * -1)
// Try to authenticate an invalid user. // Try to authenticate an invalid user.
_, err := Authenticate(ctx, test.MasterDB, tknGen, "doesnotexist@gmail.com", "xy7", time.Hour, now) _, err := Authenticate(ctx, test.MasterDB, tknGen,
AuthenticateRequest{
Email: "doesnotexist@gmail.com",
Password: "xy7",
}, time.Hour, now)
if errors.Cause(err) != ErrAuthenticationFailure { if errors.Cause(err) != ErrAuthenticationFailure {
t.Logf("\t\tGot : %+v", err) t.Logf("\t\tGot : %+v", err)
t.Logf("\t\tWant: %+v", ErrAuthenticationFailure) t.Logf("\t\tWant: %+v", ErrAuthenticationFailure)
@ -88,7 +92,12 @@ func TestAuthenticate(t *testing.T) {
now = now.Add(time.Minute * 5) now = now.Add(time.Minute * 5)
// Try to authenticate valid user with invalid password. // Try to authenticate valid user with invalid password.
_, err = Authenticate(ctx, test.MasterDB, tknGen, usrAcc.User.Email, "xy7", time.Hour, now) _, err = Authenticate(ctx, test.MasterDB, tknGen,
AuthenticateRequest{
Email: usrAcc.User.Email,
Password: "xy7",
},
time.Hour, now)
if errors.Cause(err) != ErrAuthenticationFailure { if errors.Cause(err) != ErrAuthenticationFailure {
t.Logf("\t\tGot : %+v", err) t.Logf("\t\tGot : %+v", err)
t.Logf("\t\tWant: %+v", ErrAuthenticationFailure) t.Logf("\t\tWant: %+v", ErrAuthenticationFailure)
@ -97,7 +106,11 @@ func TestAuthenticate(t *testing.T) {
t.Logf("\t%s\tAuthenticate user w/invalid password ok.", tests.Success) t.Logf("\t%s\tAuthenticate user w/invalid password ok.", tests.Success)
// Verify that the user can be authenticated with the created user. // Verify that the user can be authenticated with the created user.
tkn1, err := Authenticate(ctx, test.MasterDB, tknGen, usrAcc.User.Email, usrAcc.User.Password, time.Hour, now) tkn1, err := Authenticate(ctx, test.MasterDB, tknGen,
AuthenticateRequest{
Email: usrAcc.User.Email,
Password: usrAcc.User.Password,
}, time.Hour, now)
if err != nil { if err != nil {
t.Log("\t\tGot :", err) t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tAuthenticate user failed.", tests.Failed) t.Fatalf("\t%s\tAuthenticate user failed.", tests.Failed)
@ -170,7 +183,11 @@ func TestUserUpdatePassword(t *testing.T) {
t.Logf("\t%s\tCreate user account ok.", tests.Success) t.Logf("\t%s\tCreate user account ok.", tests.Success)
// Verify that the user can be authenticated with the created user. // Verify that the user can be authenticated with the created user.
_, err = Authenticate(ctx, test.MasterDB, tknGen, usrAcc.User.Email, usrAcc.User.Password, time.Hour, now) _, err = Authenticate(ctx, test.MasterDB, tknGen,
AuthenticateRequest{
Email: usrAcc.User.Email,
Password: usrAcc.User.Password,
}, time.Hour, now)
if err != nil { if err != nil {
t.Log("\t\tGot :", err) t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tAuthenticate failed.", tests.Failed) t.Fatalf("\t%s\tAuthenticate failed.", tests.Failed)
@ -190,7 +207,11 @@ func TestUserUpdatePassword(t *testing.T) {
t.Logf("\t%s\tUpdatePassword ok.", tests.Success) t.Logf("\t%s\tUpdatePassword ok.", tests.Success)
// Verify that the user can be authenticated with the updated password. // Verify that the user can be authenticated with the updated password.
_, err = Authenticate(ctx, test.MasterDB, tknGen, usrAcc.User.Email, newPass, time.Hour, now) _, err = Authenticate(ctx, test.MasterDB, tknGen,
AuthenticateRequest{
Email: usrAcc.User.Email,
Password: newPass,
}, time.Hour, now)
if err != nil { if err != nil {
t.Log("\t\tGot :", err) t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tAuthenticate failed.", tests.Failed) t.Fatalf("\t%s\tAuthenticate failed.", tests.Failed)
@ -257,7 +278,11 @@ func TestUserResetPassword(t *testing.T) {
t.Logf("\t%s\tResetConfirm ok.", tests.Success) t.Logf("\t%s\tResetConfirm ok.", tests.Success)
// Verify that the user can be authenticated with the updated password. // Verify that the user can be authenticated with the updated password.
_, err = Authenticate(ctx, test.MasterDB, tknGen, usrAcc.User.Email, newPass, time.Hour, now) _, err = Authenticate(ctx, test.MasterDB, tknGen,
AuthenticateRequest{
Email: usrAcc.User.Email,
Password: newPass,
}, time.Hour, now)
if err != nil { if err != nil {
t.Log("\t\tGot :", err) t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tAuthenticate failed.", tests.Failed) t.Fatalf("\t%s\tAuthenticate failed.", tests.Failed)
@ -456,7 +481,11 @@ func TestSwitchAccount(t *testing.T) {
{ {
// Verify that the user can be authenticated with the created user. // Verify that the user can be authenticated with the created user.
var claims1 auth.Claims var claims1 auth.Claims
tkn1, err := Authenticate(ctx, test.MasterDB, tknGen, authTest.root.User.Email, authTest.root.User.Password, time.Hour, now) tkn1, err := Authenticate(ctx, test.MasterDB, tknGen,
AuthenticateRequest{
Email: authTest.root.User.Email,
Password: authTest.root.User.Password,
}, time.Hour, now)
if err != nil { if err != nil {
t.Log("\t\tGot :", err) t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tAuthenticate user failed.", tests.Failed) t.Fatalf("\t%s\tAuthenticate user failed.", tests.Failed)
@ -856,7 +885,11 @@ func TestVirtualLogin(t *testing.T) {
{ {
// Verify that the user can be authenticated with the created user. // Verify that the user can be authenticated with the created user.
var claims1 auth.Claims var claims1 auth.Claims
tkn1, err := Authenticate(ctx, test.MasterDB, tknGen, authTest.root.User.Email, authTest.root.User.Password, time.Hour, now) tkn1, err := Authenticate(ctx, test.MasterDB, tknGen,
AuthenticateRequest{
Email: authTest.root.User.Email,
Password: authTest.root.User.Password,
}, time.Hour, now)
if err != nil { if err != nil {
t.Log("\t\tGot :", err) t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tAuthenticate user failed.", tests.Failed) t.Fatalf("\t%s\tAuthenticate user failed.", tests.Failed)

View File

@ -8,8 +8,9 @@ import (
// AuthenticateRequest defines what information is required to authenticate a user. // AuthenticateRequest defines what information is required to authenticate a user.
type AuthenticateRequest struct { type AuthenticateRequest struct {
Email string `json:"email" validate:"required,email" example:"gabi.may@geeksinthewoods.com"` Email string `json:"email" validate:"required,email" example:"gabi.may@geeksinthewoods.com"`
Password string `json:"password" validate:"required" example:"NeverTellSecret"` Password string `json:"password" validate:"required" example:"NeverTellSecret"`
AccountID string `json:"account_id" validate:"omitempty,uuid" example:"c4653bf9-5978-48b7-89c5-95704aebb7e2"`
} }
// Token is the payload we deliver to users when they authenticate. // Token is the payload we deliver to users when they authenticate.