From c121b7d2893e609bdef887d183c0a1dea44a730d Mon Sep 17 00:00:00 2001 From: Lee Brown Date: Tue, 28 May 2019 04:44:01 -0500 Subject: [PATCH] finish internal/user user and auth unittests --- example-project/go.mod | 1 + example-project/go.sum | 2 + .../internal/platform/auth/auth.go | 6 +- .../internal/platform/tests/main.go | 2 - example-project/internal/schema/migrations.go | 4 +- example-project/internal/schema/schema.go | 2 +- example-project/internal/user/auth.go | 14 +- example-project/internal/user/models.go | 102 +- example-project/internal/user/user.go | 122 +- example-project/internal/user/user_account.go | 25 +- example-project/internal/user/user_test.go | 1104 ++++++++++++----- 11 files changed, 913 insertions(+), 471 deletions(-) diff --git a/example-project/go.mod b/example-project/go.mod index 4d6b567..fb7245d 100644 --- a/example-project/go.mod +++ b/example-project/go.mod @@ -6,6 +6,7 @@ require ( github.com/aws/aws-sdk-go v1.19.33 github.com/dgrijalva/jwt-go v3.2.0+incompatible github.com/dimfeld/httptreemux v5.0.1+incompatible + github.com/geeks-accelerator/sqlxmigrate v0.0.0-20190527223850-4a863a2d30db github.com/gitwak/gondolier v0.0.0-20190521205431-504d297a6c42 // indirect github.com/gitwak/sqlxmigrate v0.0.0-20190527063335-e98d5d44fc0b github.com/go-playground/locales v0.12.1 diff --git a/example-project/go.sum b/example-project/go.sum index 93dd9b7..9c5f785 100644 --- a/example-project/go.sum +++ b/example-project/go.sum @@ -13,6 +13,8 @@ github.com/dgrijalva/jwt-go v3.2.0+incompatible h1:7qlOGliEKZXTDg6OTjfoBKDXWrumC github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ= github.com/dimfeld/httptreemux v5.0.1+incompatible h1:Qj3gVcDNoOthBAqftuD596rm4wg/adLLz5xh5CmpiCA= github.com/dimfeld/httptreemux v5.0.1+incompatible/go.mod h1:rbUlSV+CCpv/SuqUTP/8Bk2O3LyUV436/yaRGkhP6Z0= +github.com/geeks-accelerator/sqlxmigrate v0.0.0-20190527223850-4a863a2d30db h1:mjErP7mTFHQ3cw/ibAkW3CvQ8gM4k19EkfzRzRINDAE= +github.com/geeks-accelerator/sqlxmigrate v0.0.0-20190527223850-4a863a2d30db/go.mod h1:dzpCjo4q7chhMVuHDzs/odROkieZ5Wjp70rNDuX83jU= github.com/gitwak/gondolier v0.0.0-20190521205431-504d297a6c42 h1:+lo4HFeG6LlcgwvsvQC8H5FG8yr/kDn89E51BTw3loE= github.com/gitwak/gondolier v0.0.0-20190521205431-504d297a6c42/go.mod h1:ecEQ8e4eHeWKPf+g6ByatPM7l4QZgR3G5ZIZKvEAdCE= github.com/gitwak/sqlxmigrate v0.0.0-20190522211042-9625063dea5d h1:oaUPMY0F+lNUkyB5tzsQS3EC0m9Cxdglesp63i3UPso= diff --git a/example-project/internal/platform/auth/auth.go b/example-project/internal/platform/auth/auth.go index c7199ba..772fe54 100644 --- a/example-project/internal/platform/auth/auth.go +++ b/example-project/internal/platform/auth/auth.go @@ -185,7 +185,7 @@ func NewAuthenticator(awsSession *session.Session, awsSecretID string, now time. // refreshed on instance launch. Could store keys in a kv store and update that value // when new keys are generated if len(keyContents) == 0 || curKeyId == "" { - privateKey, err := keygen() + privateKey, err := Keygen() if err != nil { return nil, errors.Wrap(err, "failed to generate new private key") } @@ -307,8 +307,8 @@ func (a *Authenticator) ParseClaims(tknStr string) (Claims, error) { return claims, nil } -// keygen creates an x509 private key for signing auth tokens. -func keygen() ([]byte, error) { +// Keygen creates an x509 private key for signing auth tokens. +func Keygen() ([]byte, error) { key, err := rsa.GenerateKey(rand.Reader, 2048) if err != nil { return []byte{}, errors.Wrap(err, "generating keys") diff --git a/example-project/internal/platform/tests/main.go b/example-project/internal/platform/tests/main.go index dbbd516..fb59571 100644 --- a/example-project/internal/platform/tests/main.go +++ b/example-project/internal/platform/tests/main.go @@ -57,8 +57,6 @@ func New() *Test { dbHost := fmt.Sprintf("postgres://%s:%s@127.0.0.1:%s/%s?timezone=UTC&sslmode=disable", container.User, container.Pass, container.Port, container.Database) - fmt.Println(dbHost) - // ============================================================ // Start Postgres diff --git a/example-project/internal/schema/migrations.go b/example-project/internal/schema/migrations.go index 38b8299..bc12210 100644 --- a/example-project/internal/schema/migrations.go +++ b/example-project/internal/schema/migrations.go @@ -4,7 +4,7 @@ import ( "database/sql" "log" - "github.com/gitwak/sqlxmigrate" + "github.com/geeks-accelerator/sqlxmigrate" "github.com/jmoiron/sqlx" _ "github.com/lib/pq" "github.com/pkg/errors" @@ -106,7 +106,7 @@ func migrationList(db *sqlx.DB, log *log.Logger) []*sqlxmigrate.Migration { { ID: "20190522-01c", Migrate: func(tx *sql.Tx) error { - q1 := `CREATE TYPE user_account_role_t as enum('admin', 'user')` + q1 := `CREATE TYPE user_account_role_t as enum('ADMIN', 'USER')` if _, err := tx.Exec(q1); err != nil { return errors.WithMessagef(err, "Query failed %s", q1) } diff --git a/example-project/internal/schema/schema.go b/example-project/internal/schema/schema.go index e8373cf..84ef30f 100644 --- a/example-project/internal/schema/schema.go +++ b/example-project/internal/schema/schema.go @@ -3,7 +3,7 @@ package schema import ( "log" - "github.com/gitwak/sqlxmigrate" + "github.com/geeks-accelerator/sqlxmigrate" "github.com/jmoiron/sqlx" ) diff --git a/example-project/internal/user/auth.go b/example-project/internal/user/auth.go index d7fba20..5ddff4e 100644 --- a/example-project/internal/user/auth.go +++ b/example-project/internal/user/auth.go @@ -29,7 +29,7 @@ func Authenticate(ctx context.Context, dbConn *sqlx.DB, tknGen TokenGenerator, n query.Where(query.Equal("email", email)) // Run the find, use empty claims to bypass ACLs - res, err := find(ctx, auth.Claims{}, dbConn, query, false) + res, err := find(ctx, auth.Claims{}, dbConn, query, []interface{}{}, false) if err != nil { return Token{}, err } else if res == nil || len(res) == 0 { @@ -39,7 +39,7 @@ func Authenticate(ctx context.Context, dbConn *sqlx.DB, tknGen TokenGenerator, n u := res[0] // Append the salt from the user record to the supplied password. - saltedPassword := password + string(u.PasswordHash) + saltedPassword := password + string(u.PasswordSalt) // Compare the provided password with the saved hash. Use the bcrypt // comparison function so it is cryptographically secure. @@ -61,15 +61,17 @@ func Authenticate(ctx context.Context, dbConn *sqlx.DB, tknGen TokenGenerator, n roles []string ) if len(accounts) > 0 { - accountId = accounts[0].ID - roles = accounts[0].Roles + accountId = accounts[0].AccountID + for _, r := range accounts[0].Roles { + roles = append(roles, r.String()) + } } // Generate a list of all the account IDs associated with the user so // the use has the ability to switch between accounts. accountIds := []string{} for _, a := range accounts { - accountIds = append(accountIds, a.ID) + accountIds = append(accountIds, a.AccountID) } // If we are this far the request is valid. Create some claims for the user. @@ -81,5 +83,5 @@ func Authenticate(ctx context.Context, dbConn *sqlx.DB, tknGen TokenGenerator, n return Token{}, errors.Wrap(err, "generating token") } - return Token{Token: tkn}, nil + return Token{Token: tkn, claims: claims}, nil } diff --git a/example-project/internal/user/models.go b/example-project/internal/user/models.go index caf8568..a6270d1 100644 --- a/example-project/internal/user/models.go +++ b/example-project/internal/user/models.go @@ -5,6 +5,7 @@ import ( "database/sql/driver" "time" + "geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/auth" "github.com/lib/pq" "github.com/pkg/errors" "gopkg.in/go-playground/validator.v9" @@ -34,8 +35,8 @@ type CreateUserRequest struct { Email string `json:"email" validate:"required,email,unique"` Password string `json:"password" validate:"required"` PasswordConfirm string `json:"password_confirm" validate:"eqfield=Password"` - Status *UserStatus `json:"status" validate:"oneof=active disabled"` - Timezone *string `json:"timezone"` + Status *UserStatus `json:"status" validate:"omitempty,oneof=active disabled"` + Timezone *string `json:"timezone" validate:"omitempty"` } // UpdateUserRequest defines what information may be provided to modify an existing @@ -46,10 +47,10 @@ type CreateUserRequest struct { // marshalling/unmarshalling. type UpdateUserRequest struct { ID string `validate:"required,uuid"` - Name *string `json:"name"` - Email *string `json:"email" validate:"email,unique"` - Status *UserStatus `json:"status" validate:"oneof=active disabled"` - Timezone *string `json:"timezone"` + Name *string `json:"name" validate:"omitempty"` + Email *string `json:"email" validate:"omitempty,email,unique"` + Status *UserStatus `json:"status" validate:"omitempty,oneof=active disabled"` + Timezone *string `json:"timezone" validate:"omitempty"` } // UpdatePassword defines what information may be provided to update user password. @@ -73,28 +74,28 @@ type UserFindRequest struct { // Each association of an user to an account has a set of roles defined for the user // that will be applied when accessing the account. type UserAccount struct { - ID string `db:"id" json:"id"` - UserID string `db:"user_id" json:"user_id"` - AccountID string `db:"account_id" json:"account_id"` - Roles []string `db:"roles" json:"roles"` - CreatedAt time.Time `db:"created_at" json:"created_at"` - UpdatedAt time.Time `db:"updated_at" json:"updated_at"` - ArchivedAt pq.NullTime `db:"archived_at" json:"archived_at"` + ID string `db:"id" json:"id"` + UserID string `db:"user_id" json:"user_id"` + AccountID string `db:"account_id" json:"account_id"` + Roles UserAccountRoles `db:"roles" json:"roles"` + CreatedAt time.Time `db:"created_at" json:"created_at"` + UpdatedAt time.Time `db:"updated_at" json:"updated_at"` + ArchivedAt pq.NullTime `db:"archived_at" json:"archived_at"` } // AddAccountRequest defines the information needed to add a new account to a user. type AddAccountRequest struct { - UserID string `validate:"required,uuid"` - AccountID string `validate:"required,uuid"` - Roles []string `json:"roles" validate:"oneof=admin user"` + UserID string `validate:"required,uuid"` + AccountID string `validate:"required,uuid"` + Roles UserAccountRoles `json:"roles" validate:"required,dive,oneof=ADMIN USER"` } // UpdateAccountRequest defines the information needed to update the roles for // an existing user account. type UpdateAccountRequest struct { - UserID string `validate:"required,uuid"` - AccountID string `validate:"required,uuid"` - Roles []string `json:"roles" validate:"oneof=admin user"` + UserID string `validate:"required,uuid"` + AccountID string `validate:"required,uuid"` + Roles UserAccountRoles `json:"roles" validate:"oneof=ADMIN USER"` unArchive bool } @@ -122,6 +123,9 @@ type UserAccountFindRequest struct { IncludedArchived bool } +// UserStatus represents the status of a user. +type UserStatus string + // UserStatus values const ( UserStatus_Active UserStatus = "active" @@ -134,9 +138,6 @@ var UserStatus_Values = []UserStatus{ UserStatus_Disabled, } -// UserStatus represents the status of a user. -type UserStatus string - // Scan supports reading the UserStatus value from the database. func (s *UserStatus) Scan(value interface{}) error { asBytes, ok := value.([]byte) @@ -156,7 +157,6 @@ func (s UserStatus) Value() (driver.Value, error) { return nil, errs } - // validation would go here return string(s), nil } @@ -165,7 +165,61 @@ func (s UserStatus) String() string { return string(s) } +// UserAccountRole represents the role of a user for an account. +type UserAccountRole string + +// UserAccountRole values +const ( + UserAccountRole_Admin UserAccountRole = auth.RoleAdmin + UserAccountRole_User UserAccountRole = auth.RoleUser +) + +// UserAccountRole_Values provides list of valid UserAccountRole values +var UserAccountRole_Values = []UserAccountRole{ + UserAccountRole_Admin, + UserAccountRole_User, +} + +// String converts the UserAccountRole value to a string. +func (s UserAccountRole) String() string { + return string(s) +} + +// UserAccountRoles represents a set of roles for a user for an account. +type UserAccountRoles []UserAccountRole + +// Scan supports reading the UserAccountRole value from the database. +func (s *UserAccountRoles) Scan(value interface{}) error { + arr := &pq.StringArray{} + if err := arr.Scan(value); err != nil { + return err + } + + for _, v := range *arr { + *s = append(*s, UserAccountRole(v)) + } + + return nil +} + +// Value converts the UserAccountRole value to be stored in the database. +func (s UserAccountRoles) Value() (driver.Value, error) { + v := validator.New() + + var arr pq.StringArray + for _, r := range s { + errs := v.Var(r, "required,oneof=ADMIN USER") + if errs != nil { + return nil, errs + } + arr = append(arr, r.String()) + } + + return arr.Value() +} + // Token is the payload we deliver to users when they authenticate. type Token struct { - Token string `json:"token"` + Token string `json:"token"` + claims auth.Claims `json:"-"` } diff --git a/example-project/internal/user/user.go b/example-project/internal/user/user.go index 78bfeae..cdedd2f 100644 --- a/example-project/internal/user/user.go +++ b/example-project/internal/user/user.go @@ -3,6 +3,7 @@ package user import ( "context" "database/sql" + "fmt" "time" "geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/auth" @@ -43,7 +44,7 @@ func mapRowsToUser(rows *sql.Rows) (*User, error) { u User err error ) - err = rows.Scan(&u.ID, &u.Email, &u.PasswordSalt, &u.PasswordHash, &u.PasswordReset, &u.Status, &u.Timezone, &u.CreatedAt, &u.UpdatedAt, &u.ArchivedAt) + err = rows.Scan(&u.ID, &u.Name, &u.Email, &u.PasswordSalt, &u.PasswordHash, &u.PasswordReset, &u.Status, &u.Timezone, &u.CreatedAt, &u.UpdatedAt, &u.ArchivedAt) if err != nil { return nil, errors.WithStack(err) } @@ -59,18 +60,17 @@ func CanReadUserId(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, use // When the claims Subject - UserId - does not match the requested user, the // claims audience - AccountId - should have a record. if claims.Subject != userID { - query := sqlbuilder.NewSelectBuilder().Select("id").From(usersAccountsTableName) - query.Where(query.And( + query.Where(query.Or( query.Equal("account_id", claims.Audience), query.Equal("user_id", userID), )) - sql, args := query.Build() - sql = dbConn.Rebind(sql) + queryStr, args := query.Build() + queryStr = dbConn.Rebind(queryStr) var userAccountId string - err := dbConn.QueryRowContext(ctx, sql, args...).Scan(&userAccountId) - if err != nil { + err := dbConn.QueryRowContext(ctx, queryStr, args...).Scan(&userAccountId) + if err != nil && err != sql.ErrNoRows { err = errors.Wrapf(err, "query - %s", query.String()) return err } @@ -130,7 +130,7 @@ func applyClaimsUserSelect(ctx context.Context, claims auth.Claims, query *sqlbu if claims.Subject != "" { or = append(or, subQuery.Equal("user_id", claims.Subject)) } - subQuery.Where(or...) + subQuery.Where(subQuery.Or(or...)) // Append sub query query.Where(query.In("id", subQuery)) @@ -147,10 +147,12 @@ func selectQuery() *sqlbuilder.SelectBuilder { } // userFindRequestQuery generates the select query for the given find request. -func userFindRequestQuery(req UserFindRequest) *sqlbuilder.SelectBuilder { +// TODO: Need to figure out why can't parse the args when appending the where +// to the query. +func userFindRequestQuery(req UserFindRequest) (*sqlbuilder.SelectBuilder, []interface{}) { query := selectQuery() if req.Where != nil { - query.Where(*req.Where) + query.Where(query.And(*req.Where)) } if len(req.Order) > 0 { query.OrderBy(req.Order...) @@ -162,90 +164,17 @@ func userFindRequestQuery(req UserFindRequest) *sqlbuilder.SelectBuilder { query.Offset(int(*req.Offset)) } - b := sqlbuilder.Buildf(query.String(), req.Args...) - query.BuilderAs(b, usersTableName) - - return query + return query, req.Args } -// List enables streaming retrieval of Users from the database. The query results -// will be written to the interface{} resultReceiver channel enabling processing the results while -// they're still being fetched. After all pages have been processed the channel is closed -// Possible types sent to the channel are limited to: -// - error -// - User -// -// rr := make(chan interface{}) -// -// go List(ctx, claims, db, rr) -// -// for r := range rr { -// switch v := r.(type) { -// case User: -// // v is of type User -// // process the user here -// case error: -// // v is of type error -// // handle the error here -// } -// } -func List(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserFindRequest, results chan<- interface{}) { - query := userFindRequestQuery(req) - list(ctx, claims, dbConn, query, req.IncludedArchived, results) -} - -// List enables streaming retrieval of Users from the database for the supplied query. -func list(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, query *sqlbuilder.SelectBuilder, includedArchived bool, results chan<- interface{}) { - span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.List") - defer span.Finish() - - // Close the channel on complete - defer close(results) - - query.Select(usersMapColumns) - query.From(usersTableName) - - if !includedArchived { - query.Where(query.IsNull("archived_at")) - } - - // Check to see if a sub query needs to be applied for the claims - err := applyClaimsUserSelect(ctx, claims, query) - if err != nil { - results <- err - return - } - sql, args := query.Build() - sql = dbConn.Rebind(sql) - - // fetch all places from the db - rows, err := dbConn.QueryContext(ctx, sql, args...) - if err != nil { - err = errors.Wrapf(err, "query - %s", query.String()) - results <- errors.WithMessage(err, "list users failed") - return - } - - // iterate over each row - for rows.Next() { - u, err := mapRowsToUser(rows) - if err != nil { - err = errors.Wrapf(err, "query - %s", query.String()) - results <- err - return - } - results <- u - } -} - -// Find gets all the users from the database based on the request params +// Find gets all the users from the database based on the request params. func Find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserFindRequest) ([]*User, error) { - query := userFindRequestQuery(req) - return find(ctx, claims, dbConn, query, req.IncludedArchived) + query, args := userFindRequestQuery(req) + return find(ctx, claims, dbConn, query, args, req.IncludedArchived) } -// find gets all the users from the database based on the query -func find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, query *sqlbuilder.SelectBuilder, includedArchived bool) ([]*User, error) { +// find internal method for getting all the users from the database using a select query. +func find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, query *sqlbuilder.SelectBuilder, args []interface{}, includedArchived bool) ([]*User, error) { span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.Find") defer span.Finish() @@ -261,11 +190,15 @@ func find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, query *sqlbu if err != nil { return nil, err } - sql, args := query.Build() - sql = dbConn.Rebind(sql) + queryStr, queryArgs := query.Build() + queryStr = dbConn.Rebind(queryStr) + args = append(args, queryArgs...) + + fmt.Println(queryStr) + fmt.Println(args) // fetch all places from the db - rows, err := dbConn.QueryContext(ctx, sql, args...) + rows, err := dbConn.QueryContext(ctx, queryStr, args...) if err != nil { err = errors.Wrapf(err, "query - %s", query.String()) err = errors.WithMessage(err, "find users failed") @@ -283,6 +216,8 @@ func find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, query *sqlbu resp = append(resp, u) } + fmt.Println("len", len(resp)) + return resp, nil } @@ -295,7 +230,7 @@ func FindById(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, id strin query := selectQuery() query.Where(query.Equal("id", id)) - res, err := find(ctx, claims, dbConn, query, includedArchived) + res, err := find(ctx, claims, dbConn, query, []interface{}{}, includedArchived) if err != nil { return nil, err } else if res == nil || len(res) == 0 { @@ -309,7 +244,6 @@ func FindById(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, id strin // Validation an email address is unique excluding the current user ID. func uniqueEmail(ctx context.Context, dbConn *sqlx.DB, email, userId string) (bool, error) { - query := sqlbuilder.NewSelectBuilder().Select("id").From(usersTableName) query.Where(query.And( query.Equal("email", email), diff --git a/example-project/internal/user/user_account.go b/example-project/internal/user/user_account.go index 0b71800..23021ab 100644 --- a/example-project/internal/user/user_account.go +++ b/example-project/internal/user/user_account.go @@ -71,7 +71,9 @@ func accountSelectQuery() *sqlbuilder.SelectBuilder { } // userFindRequestQuery generates the select query for the given find request. -func accountFindRequestQuery(req UserAccountFindRequest) *sqlbuilder.SelectBuilder { +// TODO: Need to figure out why can't parse the args when appending the where +// to the query. +func accountFindRequestQuery(req UserAccountFindRequest) (*sqlbuilder.SelectBuilder, []interface{}) { query := accountSelectQuery() if req.Where != nil { query.Where(*req.Where) @@ -89,17 +91,17 @@ func accountFindRequestQuery(req UserAccountFindRequest) *sqlbuilder.SelectBuild b := sqlbuilder.Buildf(query.String(), req.Args...) query.BuilderAs(b, usersAccountsMapColumns) - return query + return query, req.Args } // Find gets all the users from the database based on the request params func FindAccounts(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserAccountFindRequest) ([]*UserAccount, error) { - query := accountFindRequestQuery(req) - return findAccounts(ctx, claims, dbConn, query, req.IncludedArchived) + query, args := accountFindRequestQuery(req) + return findAccounts(ctx, claims, dbConn, query, args, req.IncludedArchived) } // Find gets all the users from the database based on the select query -func findAccounts(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, query *sqlbuilder.SelectBuilder, includedArchived bool) ([]*UserAccount, error) { +func findAccounts(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, query *sqlbuilder.SelectBuilder, args []interface{}, includedArchived bool) ([]*UserAccount, error) { span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.FindAccounts") defer span.Finish() @@ -115,11 +117,12 @@ func findAccounts(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, quer if err != nil { return nil, err } - sql, args := query.Build() - sql = dbConn.Rebind(sql) + queryStr, queryArgs := query.Build() + queryStr = dbConn.Rebind(queryStr) + args = append(args, queryArgs...) // fetch all places from the db - rows, err := dbConn.QueryContext(ctx, sql, args...) + rows, err := dbConn.QueryContext(ctx, queryStr, args...) if err != nil { err = errors.Wrapf(err, "query - %s", query.String()) err = errors.WithMessage(err, "find accounts failed") @@ -151,7 +154,7 @@ func FindAccountsByUserID(ctx context.Context, claims auth.Claims, dbConn *sqlx. query.OrderBy("id") // Execute the find accounts method. - res, err := findAccounts(ctx, claims, dbConn, query, includedArchived) + res, err := findAccounts(ctx, claims, dbConn, query, []interface{}{}, includedArchived) if err != nil { return nil, err } @@ -194,7 +197,7 @@ func AddAccount(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req Ad existQuery.Equal("account_id", req.AccountID), existQuery.Equal("user_id", req.UserID), )) - existing, err := findAccounts(ctx, claims, dbConn, existQuery, true) + existing, err := findAccounts(ctx, claims, dbConn, existQuery, []interface{}{}, true) if err != nil { return err } @@ -217,7 +220,7 @@ func AddAccount(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req Ad query := sqlbuilder.NewInsertBuilder() query.InsertInto(usersAccountsTableName) query.Cols("id", "user_id", "account_id", "roles", "created_at", "updated_at") - query.Values(1, id, req.UserID, req.AccountID, req.Roles, now, now) + query.Values(id, req.UserID, req.AccountID, req.Roles, now, now) // Execute the query with the provided context. sql, args := query.Build() diff --git a/example-project/internal/user/user_test.go b/example-project/internal/user/user_test.go index 221684c..d0903c7 100644 --- a/example-project/internal/user/user_test.go +++ b/example-project/internal/user/user_test.go @@ -1,6 +1,12 @@ package user import ( + "math/rand" + "os" + "strings" + "testing" + "time" + "geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/auth" "geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/tests" "github.com/dgrijalva/jwt-go" @@ -8,9 +14,6 @@ import ( "github.com/huandu/go-sqlbuilder" "github.com/pborman/uuid" "github.com/pkg/errors" - "os" - "testing" - "time" ) var test *tests.Test @@ -47,13 +50,16 @@ func TestUserFindRequestQuery(t *testing.T) { Limit: &limit, Offset: &offset, } - expected := "SELECT " + usersMapColumns + " FROM " + usersTableName + " WHERE name = ? or email = ? ORDER BY id asc, created_at desc LIMIT 12 OFFSET 34" + expected := "SELECT " + usersMapColumns + " FROM " + usersTableName + " WHERE (name = ? or email = ?) ORDER BY id asc, created_at desc LIMIT 12 OFFSET 34" - res := userFindRequestQuery(req) + res, args := userFindRequestQuery(req) if diff := cmp.Diff(res.String(), expected); diff != "" { t.Fatalf("\t%s\tExpected result query to match. Diff:\n%s", tests.Failed, diff) } + if diff := cmp.Diff(args, req.Args); diff != "" { + t.Fatalf("\t%s\tExpected result query to match. Diff:\n%s", tests.Failed, diff) + } } // TestApplyClaimsUserSelect validates applyClaimsUserSelect @@ -77,7 +83,7 @@ func TestApplyClaimsUserSelect(t *testing.T) { Audience: "acc1", }, }, - "SELECT " + usersMapColumns + " FROM " + usersTableName + " WHERE id IN (SELECT user_id FROM " + usersAccountsTableName + " WHERE account_id = 'acc1' AND user_id = 'user1')", + "SELECT " + usersMapColumns + " FROM " + usersTableName + " WHERE id IN (SELECT user_id FROM " + usersAccountsTableName + " WHERE (account_id = 'acc1' OR user_id = 'user1'))", nil, }, {"RoleAdmin", @@ -88,7 +94,7 @@ func TestApplyClaimsUserSelect(t *testing.T) { Audience: "acc1", }, }, - "SELECT " + usersMapColumns + " FROM " + usersTableName + " WHERE id IN (SELECT user_id FROM " + usersAccountsTableName + " WHERE account_id = 'acc1' AND user_id = 'user1')", + "SELECT " + usersMapColumns + " FROM " + usersTableName + " WHERE id IN (SELECT user_id FROM " + usersAccountsTableName + " WHERE (account_id = 'acc1' OR user_id = 'user1'))", nil, }, } @@ -128,16 +134,183 @@ func TestApplyClaimsUserSelect(t *testing.T) { } } -// TestCreateUser validates CreateUser -func TestCreateUser(t *testing.T) { +// TestCreateUser ensures all the validation tags work on Create +func TestCreateUserValidation(t *testing.T) { + + invalidStatus := UserStatus("moon") + + var userTests = []struct { + name string + req CreateUserRequest + expected func(req CreateUserRequest, res *User) *User + error error + }{ + {"Required Fields", + CreateUserRequest{}, + func(req CreateUserRequest, res *User) *User { + return nil + }, + errors.New("Key: 'CreateUserRequest.Name' Error:Field validation for 'Name' failed on the 'required' tag\n" + + "Key: 'CreateUserRequest.Email' Error:Field validation for 'Email' failed on the 'required' tag\n" + + "Key: 'CreateUserRequest.Password' Error:Field validation for 'Password' failed on the 'required' tag"), + }, + {"Valid Email", + CreateUserRequest{ + Name: "Lee Brown", + Email: "xxxxxxxxxx", + Password: "akTechFr0n!ier", + PasswordConfirm: "akTechFr0n!ier", + }, + func(req CreateUserRequest, res *User) *User { + return nil + }, + errors.New("Key: 'CreateUserRequest.Email' Error:Field validation for 'Email' failed on the 'email' tag"), + }, + {"Valid Status", + CreateUserRequest{ + Name: "Lee Brown", + Email: uuid.NewRandom().String() + "@geeksinthewoods.com", + Password: "akTechFr0n!ier", + PasswordConfirm: "akTechFr0n!ier", + Status: &invalidStatus, + }, + func(req CreateUserRequest, res *User) *User { + return nil + }, + errors.New("Key: 'CreateUserRequest.Status' Error:Field validation for 'Status' failed on the 'oneof' tag"), + }, + {"Passwords Match", + CreateUserRequest{ + Name: "Lee Brown", + Email: uuid.NewRandom().String() + "@geeksinthewoods.com", + Password: "akTechFr0n!ier", + PasswordConfirm: "W0rkL1fe#", + }, + func(req CreateUserRequest, res *User) *User { + return nil + }, + errors.New("Key: 'CreateUserRequest.PasswordConfirm' Error:Field validation for 'PasswordConfirm' failed on the 'eqfield' tag"), + }, + {"Default Status & Timezone", + CreateUserRequest{ + Name: "Lee Brown", + Email: uuid.NewRandom().String() + "@geeksinthewoods.com", + Password: "akTechFr0n!ier", + PasswordConfirm: "akTechFr0n!ier", + }, + func(req CreateUserRequest, res *User) *User { + return &User{ + Name: req.Name, + Email: req.Email, + Status: UserStatus_Active, + Timezone: "America/Anchorage", + + // Copy this fields from the result. + ID: res.ID, + PasswordSalt: res.PasswordSalt, + PasswordHash: res.PasswordHash, + PasswordReset: res.PasswordReset, + CreatedAt: res.CreatedAt, + UpdatedAt: res.UpdatedAt, + //ArchivedAt: nil, + } + }, + nil, + }, + } now := time.Date(2018, time.October, 1, 0, 0, 0, 0, time.UTC) - // Use disabled status since default is active - us := UserStatus_Disabled - utz := "America/Santiago" + t.Log("Given the need ensure all validation tags are working for user create.") + { + for i, tt := range userTests { + t.Logf("\tTest: %d\tWhen running test: %s", i, tt.name) + { + ctx := tests.Context() - dupEmail := uuid.NewRandom().String() + "@geeksinthewoods.com" + res, err := Create(ctx, auth.Claims{}, test.MasterDB, tt.req, now) + if err != tt.error { + // TODO: need a better way to handle validation errors as they are + // of type interface validator.ValidationErrorsTranslations + var errStr string + if err != nil { + errStr = err.Error() + } + var expectStr string + if tt.error != nil { + expectStr = tt.error.Error() + } + if errStr != expectStr { + t.Logf("\t\tGot : %+v", err) + t.Logf("\t\tWant: %+v", tt.error) + t.Fatalf("\t%s\tCreate failed.", tests.Failed) + } + } + + // If there was an error that was expected, then don't go any further + if tt.error != nil { + t.Logf("\t%s\tCreate ok.", tests.Success) + continue + } + + expected := tt.expected(tt.req, res) + if diff := cmp.Diff(res, expected); diff != "" { + t.Fatalf("\t%s\tExpected result should match. Diff:\n%s", tests.Failed, diff) + } + + t.Logf("\t%s\tCreate ok.", tests.Success) + } + } + } +} + +// TestCreateUserValidationEmailUnique validates emails must be unique on Create. +func TestCreateUserValidationEmailUnique(t *testing.T) { + + now := time.Date(2018, time.October, 1, 0, 0, 0, 0, time.UTC) + + t.Log("Given the need ensure duplicate emails are not allowed for user create.") + { + ctx := tests.Context() + + req1 := CreateUserRequest{ + Name: "Lee Brown", + Email: uuid.NewRandom().String() + "@geeksinthewoods.com", + Password: "akTechFr0n!ier", + PasswordConfirm: "akTechFr0n!ier", + } + user1, err := Create(ctx, auth.Claims{}, test.MasterDB, req1, now) + if err != nil { + t.Log("\t\tGot :", err) + t.Fatalf("\t%s\tCreate failed.", tests.Failed) + } + + req2 := CreateUserRequest{ + Name: "Lucas Brown", + Email: user1.Email, + Password: "W0rkL1fe#", + PasswordConfirm: "W0rkL1fe#", + } + expectedErr := errors.New("Key: 'CreateUserRequest.Email' Error:Field validation for 'Email' failed on the 'unique' tag") + _, err = Create(ctx, auth.Claims{}, test.MasterDB, req2, now) + if err == nil { + t.Logf("\t\tWant: %+v", expectedErr) + t.Fatalf("\t%s\tCreate failed.", tests.Failed) + } + + if err.Error() != expectedErr.Error() { + t.Logf("\t\tGot : %+v", err) + t.Logf("\t\tWant: %+v", expectedErr) + t.Fatalf("\t%s\tCreate failed.", tests.Failed) + } + + t.Logf("\t%s\tCreate ok.", tests.Success) + } +} + +// TestCreateUserClaims validates ACLs are correctly applied to Create by claims. +func TestCreateUserClaims(t *testing.T) { + defer tests.Recover(t) var userTests = []struct { name string @@ -145,30 +318,18 @@ func TestCreateUser(t *testing.T) { req CreateUserRequest error error }{ + // Internal request, should bypass ACL. {"EmptyClaims", auth.Claims{}, CreateUserRequest{ Name: "Lee Brown", - Email: dupEmail, + Email: uuid.NewRandom().String() + "@geeksinthewoods.com", Password: "akTechFr0n!ier", PasswordConfirm: "akTechFr0n!ier", - Status: &us, - Timezone: &utz, }, nil, }, - {"DuplicateEmailValidation", - auth.Claims{}, - CreateUserRequest{ - Name: "Lee Brown", - Email: dupEmail, - Password: "akTechFr0n!ier", - PasswordConfirm: "akTechFr0n!ier", - Status: &us, - Timezone: &utz, - }, - errors.New("Key: 'CreateUserRequest.Email' Error:Field validation for 'Email' failed on the 'unique' tag"), - }, + // Role of user, only admins can create new users. {"RoleUser", auth.Claims{ Roles: []string{auth.RoleUser}, @@ -182,11 +343,10 @@ func TestCreateUser(t *testing.T) { Email: uuid.NewRandom().String() + "@geeksinthewoods.com", Password: "akTechFr0n!ier", PasswordConfirm: "akTechFr0n!ier", - Status: &us, - Timezone: &utz, }, ErrForbidden, }, + // Role of admin, can create users. {"RoleAdmin", auth.Claims{ Roles: []string{auth.RoleAdmin}, @@ -200,163 +360,77 @@ func TestCreateUser(t *testing.T) { Email: uuid.NewRandom().String() + "@geeksinthewoods.com", Password: "akTechFr0n!ier", PasswordConfirm: "akTechFr0n!ier", - Status: &us, - Timezone: &utz, }, nil, }, } - t.Log("Given the need to validate ACLs are enforced by claims for user create.") - { - for i, tt := range userTests { - t.Logf("\tTest: %d\tWhen running test: %s", i, tt.name) - { - ctx := tests.Context() - - dbConn := test.MasterDB - defer dbConn.Close() - - res, err := Create(ctx, tt.claims, dbConn, tt.req, now) - if err != tt.error { - // TODO: need a better way to handle validation errors as they are - // of type interface validator.ValidationErrorsTranslations - var errStr string - if err != nil { - errStr = err.Error() - } - var expectStr string - if tt.error != nil { - expectStr = tt.error.Error() - } - if errStr != expectStr { - t.Logf("\t\tGot : %+v", err) - t.Logf("\t\tWant: %+v", tt.error) - t.Fatalf("\t%s\tapplyClaimsUserSelect failed.", tests.Failed) - } - } - - // If there was an error that was expected, then don't go any further - if tt.error != nil { - continue - } - - expected := &User{ - Name: tt.req.Name, - Email: tt.req.Email, - Status: *tt.req.Status, - Timezone: *tt.req.Timezone, - - // Copy this fields from the result. - ID: res.ID, - PasswordSalt: res.PasswordSalt, - PasswordHash: res.PasswordHash, - PasswordReset: res.PasswordReset, - CreatedAt: res.CreatedAt, - UpdatedAt: res.UpdatedAt, - //ArchivedAt: nil, - } - - if diff := cmp.Diff(res, expected); diff != "" { - t.Fatalf("\t%s\tExpected result should match. Diff:\n%s", tests.Failed, diff) - } - - t.Logf("\t%s\tapplyClaimsUserSelect ok.", tests.Success) - } - } - } -} - -// TestUpdateUser validates Update -func TestUpdateUser(t *testing.T) { - now := time.Date(2018, time.October, 1, 0, 0, 0, 0, time.UTC) - // Use disabled status since default is active - us := UserStatus_Disabled - utz := "America/Santiago" - - create := CreateUserRequest{ - Name: "Lee Brown", - Password: "akTechFr0n!ier", - PasswordConfirm: "akTechFr0n!ier", - Status: &us, - Timezone: &utz, - } - - dupEmail := uuid.NewRandom().String() + "@geeksinthewoods.com" - - var userTests = []struct { - name string - claims auth.Claims - req UpdateUserRequest - error error - }{ - {"EmptyClaims", - auth.Claims{}, - UpdateUserRequest{ - Name: "Lee Brown", - Email: dupEmail, - Status: &us, - Timezone: &utz, - }, - nil, - }, - {"DuplicateEmailValidation", - auth.Claims{}, - UpdateUserRequest{ - Name: "Lee Brown", - Email: dupEmail, - Status: &us, - Timezone: &utz, - }, - errors.New("Key: 'CreateUserRequest.Email' Error:Field validation for 'Email' failed on the 'unique' tag"), - }, - {"RoleUser", - auth.Claims{ - Roles: []string{auth.RoleUser}, - StandardClaims: jwt.StandardClaims{ - Subject: "user1", - Audience: "acc1", - }, - }, - UpdateUserRequest{ - Name: "Lee Brown", - Email: &uuid.NewRandom().String(), - Status: &us, - Timezone: &utz, - }, - ErrForbidden, - }, - {"RoleAdmin", - auth.Claims{ - Roles: []string{auth.RoleAdmin}, - StandardClaims: jwt.StandardClaims{ - Subject: "user1", - Audience: "acc1", - }, - }, - UpdateUserRequest{ - Name: "Lee Brown", - Email: uuid.NewRandom().String() + "@geeksinthewoods.com", - Status: &us, - Timezone: &utz, - }, - nil, - }, - } - - t.Log("Given the need to validate ACLs are enforced by claims for user update.") + t.Log("Given the need to ensure claims are applied as ACL for create user.") { for i, tt := range userTests { t.Logf("\tTest: %d\tWhen running test: %s", i, tt.name) { ctx := tests.Context() - dbConn := test.MasterDB - defer dbConn.Close() + _, err := Create(ctx, auth.Claims{}, test.MasterDB, tt.req, now) + if err != nil && errors.Cause(err) != tt.error { + t.Logf("\t\tGot : %+v", err) + t.Logf("\t\tWant: %+v", tt.error) + t.Fatalf("\t%s\tCreate failed.", tests.Failed) + } - err := Update(ctx, tt.claims, dbConn, tt.req, now) + t.Logf("\t%s\tCreate ok.", tests.Success) + } + } + } +} + +// TestUpdateUser ensures all the validation tags work on Update +func TestUpdateUserValidation(t *testing.T) { + // TODO: actually create the user so can test the output of findbyId + type userTest struct { + name string + req UpdateUserRequest + error error + } + + var userTests = []userTest{ + {"Required Fields", + UpdateUserRequest{}, + errors.New("Key: 'UpdateUserRequest.ID' Error:Field validation for 'ID' failed on the 'required' tag"), + }, + } + + invalidEmail := "xxxxxxxxxx" + userTests = append(userTests, userTest{"Valid Email", + UpdateUserRequest{ + ID: uuid.NewRandom().String(), + Email: &invalidEmail, + }, + errors.New("Key: 'UpdateUserRequest.Email' Error:Field validation for 'Email' failed on the 'email' tag"), + }) + + invalidStatus := UserStatus("xxxxxxxxx") + userTests = append(userTests, userTest{"Valid Status", + UpdateUserRequest{ + ID: uuid.NewRandom().String(), + Status: &invalidStatus, + }, + errors.New("Key: 'UpdateUserRequest.Status' Error:Field validation for 'Status' failed on the 'oneof' tag"), + }) + + now := time.Date(2018, time.October, 1, 0, 0, 0, 0, time.UTC) + + t.Log("Given the need ensure all validation tags are working for user update.") + { + for i, tt := range userTests { + t.Logf("\tTest: %d\tWhen running test: %s", i, tt.name) + { + ctx := tests.Context() + + err := Update(ctx, auth.Claims{}, test.MasterDB, tt.req, now) if err != tt.error { // TODO: need a better way to handle validation errors as they are // of type interface validator.ValidationErrorsTranslations @@ -371,188 +445,562 @@ func TestUpdateUser(t *testing.T) { if errStr != expectStr { t.Logf("\t\tGot : %+v", err) t.Logf("\t\tWant: %+v", tt.error) - t.Fatalf("\t%s\tapplyClaimsUserSelect failed.", tests.Failed) + t.Fatalf("\t%s\tUpdate failed.", tests.Failed) } } - // If there was an error that was expected, then don't go any further - if tt.error != nil { - continue - } - - expected := &User{ - Name: tt.req.Name, - Email: tt.req.Email, - Status: *tt.req.Status, - Timezone: *tt.req.Timezone, - - // Copy this fields from the result. - ID: res.ID, - PasswordSalt: res.PasswordSalt, - PasswordHash: res.PasswordHash, - PasswordReset: res.PasswordReset, - CreatedAt: res.CreatedAt, - UpdatedAt: res.UpdatedAt, - //ArchivedAt: nil, - } - - if diff := cmp.Diff(res, expected); diff != "" { - t.Fatalf("\t%s\tExpected result should match. Diff:\n%s", tests.Failed, diff) - } - - t.Logf("\t%s\tapplyClaimsUserSelect ok.", tests.Success) + t.Logf("\t%s\tUpdate ok.", tests.Success) } } } } -/* -// TestUser validates the full set of CRUD operations on User values. -func TestUser(t *testing.T) { +// TestUpdateUserValidationEmailUnique validates emails must be unique on Update. +func TestUpdateUserValidationEmailUnique(t *testing.T) { + + now := time.Date(2018, time.October, 1, 0, 0, 0, 0, time.UTC) + + t.Log("Given the need ensure duplicate emails are not allowed for user update.") + { + ctx := tests.Context() + + req1 := CreateUserRequest{ + Name: "Lee Brown", + Email: uuid.NewRandom().String() + "@geeksinthewoods.com", + Password: "akTechFr0n!ier", + PasswordConfirm: "akTechFr0n!ier", + } + user1, err := Create(ctx, auth.Claims{}, test.MasterDB, req1, now) + if err != nil { + t.Log("\t\tGot :", err) + t.Fatalf("\t%s\tCreate failed.", tests.Failed) + } + + req2 := CreateUserRequest{ + Name: "Lucas Brown", + Email: uuid.NewRandom().String() + "@geeksinthewoods.com", + Password: "W0rkL1fe#", + PasswordConfirm: "W0rkL1fe#", + } + user2, err := Create(ctx, auth.Claims{}, test.MasterDB, req2, now) + if err != nil { + t.Log("\t\tGot :", err) + t.Fatalf("\t%s\tCreate failed.", tests.Failed) + } + + // Try to set the email for user 1 on user 2 + updateReq := UpdateUserRequest{ + ID: user2.ID, + Email: &user1.Email, + } + expectedErr := errors.New("Key: 'UpdateUserRequest.Email' Error:Field validation for 'Email' failed on the 'unique' tag") + err = Update(ctx, auth.Claims{}, test.MasterDB, updateReq, now) + if err == nil { + t.Logf("\t\tWant: %+v", expectedErr) + t.Fatalf("\t%s\tUpdate failed.", tests.Failed) + } + + if err.Error() != expectedErr.Error() { + t.Logf("\t\tGot : %+v", err) + t.Logf("\t\tWant: %+v", expectedErr) + t.Fatalf("\t%s\tUpdate failed.", tests.Failed) + } + + t.Logf("\t%s\tUpdate ok.", tests.Success) + } +} + +// TestUpdateUserPassword validates update user password works. +func TestUpdateUserPassword(t *testing.T) { + + t.Log("Given the need ensure a user password can be updated.") + { + ctx := tests.Context() + + now := time.Date(2018, time.October, 1, 0, 0, 0, 0, time.UTC) + + var tknGen mockTokenGenerator + + // Create a new user for testing. + initPass := uuid.NewRandom().String() + user, err := Create(ctx, auth.Claims{}, test.MasterDB, CreateUserRequest{ + Name: "Lee Brown", + Email: uuid.NewRandom().String() + "@geeksinthewoods.com", + Password: initPass, + PasswordConfirm: initPass, + }, now) + if err != nil { + t.Log("\t\tGot :", err) + t.Fatalf("\t%s\tCreate failed.", tests.Failed) + } + + // Verify that the user can be authenticated with the created user. + _, err = Authenticate(ctx, test.MasterDB, tknGen, now, user.Email, initPass) + if err != nil { + t.Log("\t\tGot :", err) + t.Fatalf("\t%s\tAuthenticate failed.", tests.Failed) + } + + // Ensure validation is working by trying UpdatePassword with an empty request. + expectedErr := errors.New("Key: 'UpdatePasswordRequest.ID' Error:Field validation for 'ID' failed on the 'required' tag\n" + + "Key: 'UpdatePasswordRequest.Password' Error:Field validation for 'Password' failed on the 'required' tag") + err = UpdatePassword(ctx, auth.Claims{}, test.MasterDB, UpdatePasswordRequest{}, now) + if err == nil { + t.Logf("\t\tWant: %+v", expectedErr) + t.Fatalf("\t%s\tUpdate failed.", tests.Failed) + } else if err.Error() != expectedErr.Error() { + t.Logf("\t\tGot : %+v", err) + t.Logf("\t\tWant: %+v", expectedErr) + t.Fatalf("\t%s\tValidation failed.", tests.Failed) + } + t.Logf("\t%s\tValidation ok.", tests.Success) + + // Update the users password. + newPass := uuid.NewRandom().String() + err = UpdatePassword(ctx, auth.Claims{}, test.MasterDB, UpdatePasswordRequest{ + ID: user.ID, + Password: newPass, + PasswordConfirm: newPass, + }, now) + if err != nil { + t.Log("\t\tGot :", err) + t.Fatalf("\t%s\tCreate failed.", tests.Failed) + } + t.Logf("\t%s\tUpdatePassword ok.", tests.Success) + + // Verify that the user can be authenticated with the updated password. + _, err = Authenticate(ctx, test.MasterDB, tknGen, now, user.Email, newPass) + if err != nil { + t.Log("\t\tGot :", err) + t.Fatalf("\t%s\tAuthenticate failed.", tests.Failed) + } + t.Logf("\t%s\tAuthenticate ok.", tests.Success) + } +} + +// TestUserCrud validates the full set of CRUD operations and ensures ACLs are correctly applied by claims. +func TestUserCrud(t *testing.T) { defer tests.Recover(t) - t.Log("Given the need to work with User records.") + type userTest struct { + name string + claims func(*User, string) auth.Claims + create CreateUserRequest + update func(*User) UpdateUserRequest + updateErr error + expected func(*User, UpdateUserRequest) *User + findErr error + } + + var userTests []userTest + + // Internal request, should bypass ACL. + userTests = append(userTests, userTest{"EmptyClaims", + func(user *User, accountId string) auth.Claims { + return auth.Claims{} + }, + CreateUserRequest{ + Name: "Lee Brown", + Email: uuid.NewRandom().String() + "@geeksinthewoods.com", + Password: "akTechFr0n!ier", + PasswordConfirm: "akTechFr0n!ier", + }, + func(user *User) UpdateUserRequest { + email := uuid.NewRandom().String() + "@geeksinthewoods.com" + return UpdateUserRequest{ + ID: user.ID, + Email: &email, + } + }, + nil, + func(user *User, req UpdateUserRequest) *User { + return &User{ + Email: *req.Email, + // Copy this fields from the created user. + ID: user.ID, + Name: user.Name, + PasswordSalt: user.PasswordSalt, + PasswordHash: user.PasswordHash, + PasswordReset: user.PasswordReset, + Status: user.Status, + Timezone: user.Timezone, + CreatedAt: user.CreatedAt, + UpdatedAt: user.UpdatedAt, + //ArchivedAt: nil, + } + }, + nil, + }) + + // Role of user but claim user does not match update user so forbidden. + userTests = append(userTests, userTest{"RoleUserDiffUser", + func(user *User, accountId string) auth.Claims { + return auth.Claims{ + Roles: []string{auth.RoleUser}, + StandardClaims: jwt.StandardClaims{ + Subject: uuid.NewRandom().String(), + Audience: accountId, + }, + } + }, + CreateUserRequest{ + Name: "Lee Brown", + Email: uuid.NewRandom().String() + "@geeksinthewoods.com", + Password: "akTechFr0n!ier", + PasswordConfirm: "akTechFr0n!ier", + }, + func(user *User) UpdateUserRequest { + email := uuid.NewRandom().String() + "@geeksinthewoods.com" + return UpdateUserRequest{ + ID: user.ID, + Email: &email, + } + }, + ErrForbidden, + func(user *User, req UpdateUserRequest) *User { + return user + }, + ErrNotFound, + }) + + // Role of user AND claim user matches update user so OK. + userTests = append(userTests, userTest{"RoleUserSameUser", + func(user *User, accountId string) auth.Claims { + return auth.Claims{ + Roles: []string{auth.RoleUser}, + StandardClaims: jwt.StandardClaims{ + Subject: user.ID, + Audience: accountId, + }, + } + }, + CreateUserRequest{ + Name: "Lee Brown", + Email: uuid.NewRandom().String() + "@geeksinthewoods.com", + Password: "akTechFr0n!ier", + PasswordConfirm: "akTechFr0n!ier", + }, + func(user *User) UpdateUserRequest { + email := uuid.NewRandom().String() + "@geeksinthewoods.com" + return UpdateUserRequest{ + ID: user.ID, + Email: &email, + } + }, + nil, + func(user *User, req UpdateUserRequest) *User { + return &User{ + Email: *req.Email, + // Copy this fields from the created user. + ID: user.ID, + Name: user.Name, + PasswordSalt: user.PasswordSalt, + PasswordHash: user.PasswordHash, + PasswordReset: user.PasswordReset, + Status: user.Status, + Timezone: user.Timezone, + CreatedAt: user.CreatedAt, + UpdatedAt: user.UpdatedAt, + //ArchivedAt: nil, + } + }, + nil, + }) + + // Role of admin but claim account does not match update user so forbidden. + userTests = append(userTests, userTest{"RoleAdminDiffUser", + func(user *User, accountId string) auth.Claims { + return auth.Claims{ + Roles: []string{auth.RoleAdmin}, + StandardClaims: jwt.StandardClaims{ + Subject: uuid.NewRandom().String(), + Audience: uuid.NewRandom().String(), + }, + } + }, + CreateUserRequest{ + Name: "Lee Brown", + Email: uuid.NewRandom().String() + "@geeksinthewoods.com", + Password: "akTechFr0n!ier", + PasswordConfirm: "akTechFr0n!ier", + }, + func(user *User) UpdateUserRequest { + email := uuid.NewRandom().String() + "@geeksinthewoods.com" + return UpdateUserRequest{ + ID: user.ID, + Email: &email, + } + }, + ErrForbidden, + func(user *User, req UpdateUserRequest) *User { + return nil + }, + ErrNotFound, + }) + + // Role of admin and claim account matches update user so ok. + userTests = append(userTests, userTest{"RoleAdminSameAccount", + func(user *User, accountId string) auth.Claims { + return auth.Claims{ + Roles: []string{auth.RoleAdmin}, + StandardClaims: jwt.StandardClaims{ + Subject: uuid.NewRandom().String(), + Audience: accountId, + }, + } + }, + CreateUserRequest{ + Name: "Lee Brown", + Email: uuid.NewRandom().String() + "@geeksinthewoods.com", + Password: "akTechFr0n!ier", + PasswordConfirm: "akTechFr0n!ier", + }, + func(user *User) UpdateUserRequest { + email := uuid.NewRandom().String() + "@geeksinthewoods.com" + return UpdateUserRequest{ + ID: user.ID, + Email: &email, + } + }, + nil, + func(user *User, req UpdateUserRequest) *User { + return &User{ + Email: *req.Email, + // Copy this fields from the created user. + ID: user.ID, + Name: user.Name, + PasswordSalt: user.PasswordSalt, + PasswordHash: user.PasswordHash, + PasswordReset: user.PasswordReset, + Status: user.Status, + Timezone: user.Timezone, + CreatedAt: user.CreatedAt, + UpdatedAt: user.UpdatedAt, + //ArchivedAt: nil, + } + }, + nil, + }) + + t.Log("Given the need to ensure claims are applied as ACL for update user.") { - t.Log("\tWhen handling a single User.") - { - ctx := tests.Context() + now := time.Date(2018, time.October, 1, 0, 0, 0, 0, time.UTC) - dbConn := test.MasterDB - defer dbConn.Close() + for i, tt := range userTests { + t.Logf("\tTest: %d\tWhen running test: %s", i, tt.name) + { + ctx := tests.Context() + // Always create the new user with empty claims, testing claims for create user + // will be handled separately. + user, err := Create(tests.Context(), auth.Claims{}, test.MasterDB, tt.create, now) + if err != nil { + t.Log("\t\tGot :", err) + t.Fatalf("\t%s\tCreate failed.", tests.Failed) + } + // Create a new random account and associate that with the user. + accountId := uuid.NewRandom().String() + err = AddAccount(tests.Context(), auth.Claims{}, test.MasterDB, AddAccountRequest{ + UserID: user.ID, + AccountID: accountId, + Roles: []UserAccountRole{UserAccountRole_User}, + }, now) + if err != nil { + t.Log("\t\tGot :", err) + t.Fatalf("\t%s\tAddAccount failed.", tests.Failed) + } - u, err := Create(ctx, dbConn, &nu, now) - if err != nil { - t.Fatalf("\t%s\tShould be able to create user : %s.", tests.Failed, err) + // Update the user. + updateReq := tt.update(user) + err = Update(ctx, tt.claims(user, accountId), test.MasterDB, updateReq, now) + if err != nil && errors.Cause(err) != tt.updateErr { + t.Logf("\t\tGot : %+v", err) + t.Logf("\t\tWant: %+v", tt.updateErr) + t.Fatalf("\t%s\tUpdate failed.", tests.Failed) + } + t.Logf("\t%s\tUpdate ok.", tests.Success) + + // Find the user and make sure the updates where made. + findRes, err := FindById(ctx, tt.claims(user, accountId), test.MasterDB, user.ID, false) + if err != nil && errors.Cause(err) != tt.findErr { + t.Logf("\t\tGot : %+v", err) + t.Logf("\t\tWant: %+v", tt.findErr) + t.Fatalf("\t%s\tFindById failed.", tests.Failed) + } else { + findExpected := tt.expected(findRes, updateReq) + if diff := cmp.Diff(findRes, findExpected); diff != "" { + t.Fatalf("\t%s\tExpected find result to match update. Diff:\n%s", tests.Failed, diff) + } + t.Logf("\t%s\tFindById ok.", tests.Success) + } + + // Archive (soft-delete) the user. + err = Archive(ctx, tt.claims(user, accountId), test.MasterDB, user.ID, now) + if err != nil && errors.Cause(err) != tt.updateErr { + t.Logf("\t\tGot : %+v", err) + t.Logf("\t\tWant: %+v", tt.updateErr) + t.Fatalf("\t%s\tUpdate failed.", tests.Failed) + } else if tt.updateErr == nil { + // Trying to find the archived user with the includeArchived false should result in not found. + _, err = FindById(ctx, tt.claims(user, accountId), test.MasterDB, user.ID, false) + if err != nil && errors.Cause(err) != ErrNotFound { + t.Logf("\t\tGot : %+v", err) + t.Logf("\t\tWant: %+v", ErrNotFound) + t.Fatalf("\t%s\tArchive FindById failed.", tests.Failed) + } + + // Trying to find the archived user with the includeArchived true should result no error. + _, err = FindById(ctx, tt.claims(user, accountId), test.MasterDB, user.ID, true) + if err != nil { + t.Log("\t\tGot :", err) + t.Fatalf("\t%s\tArchive FindById failed.", tests.Failed) + } + } + t.Logf("\t%s\tArchive ok.", tests.Success) + + // Delete (hard-delete) the user. + err = Delete(ctx, tt.claims(user, accountId), test.MasterDB, user.ID) + if err != nil && errors.Cause(err) != tt.updateErr { + t.Logf("\t\tGot : %+v", err) + t.Logf("\t\tWant: %+v", tt.updateErr) + t.Fatalf("\t%s\tUpdate failed.", tests.Failed) + } else if tt.updateErr == nil { + // Trying to find the deleted user with the includeArchived true should result in not found. + _, err = FindById(ctx, tt.claims(user, accountId), test.MasterDB, user.ID, true) + if errors.Cause(err) != ErrNotFound { + t.Logf("\t\tGot : %+v", err) + t.Logf("\t\tWant: %+v", ErrNotFound) + t.Fatalf("\t%s\tDelete FindById failed.", tests.Failed) + } + } + t.Logf("\t%s\tDelete ok.", tests.Success) } - t.Logf("\t%s\tShould be able to create user.", tests.Success) - - - - // claims is information about the person making the request. - claims := auth.NewClaims(bson.NewObjectId().Hex(), []string{auth.RoleAdmin}, now, time.Hour) - - - savedU, err := user.Retrieve(ctx, claims, dbConn, u.ID.Hex()) - if err != nil { - t.Fatalf("\t%s\tShould be able to retrieve user by ID: %s.", tests.Failed, err) - } - t.Logf("\t%s\tShould be able to retrieve user by ID.", tests.Success) - - if diff := cmp.Diff(u, savedU); diff != "" { - t.Fatalf("\t%s\tShould get back the same user. Diff:\n%s", tests.Failed, diff) - } - t.Logf("\t%s\tShould get back the same user.", tests.Success) - - upd := user.UpdateUser{ - Name: tests.StringPointer("Jacob Walker"), - Email: tests.StringPointer("jacob@ardanlabs.com"), - } - - if err := user.Update(ctx, dbConn, u.ID.Hex(), &upd, now); err != nil { - t.Fatalf("\t%s\tShould be able to update user : %s.", tests.Failed, err) - } - t.Logf("\t%s\tShould be able to update user.", tests.Success) - - savedU, err = user.Retrieve(ctx, claims, dbConn, u.ID.Hex()) - if err != nil { - t.Fatalf("\t%s\tShould be able to retrieve user : %s.", tests.Failed, err) - } - t.Logf("\t%s\tShould be able to retrieve user.", tests.Success) - - if savedU.Name != *upd.Name { - t.Errorf("\t%s\tShould be able to see updates to Name.", tests.Failed) - t.Log("\t\tGot:", savedU.Name) - t.Log("\t\tExp:", *upd.Name) - } else { - t.Logf("\t%s\tShould be able to see updates to Name.", tests.Success) - } - - if savedU.Email != *upd.Email { - t.Errorf("\t%s\tShould be able to see updates to Email.", tests.Failed) - t.Log("\t\tGot:", savedU.Email) - t.Log("\t\tExp:", *upd.Email) - } else { - t.Logf("\t%s\tShould be able to see updates to Email.", tests.Success) - } - - if err := user.Delete(ctx, dbConn, u.ID.Hex()); err != nil { - t.Fatalf("\t%s\tShould be able to delete user : %s.", tests.Failed, err) - } - t.Logf("\t%s\tShould be able to delete user.", tests.Success) - - savedU, err = user.Retrieve(ctx, claims, dbConn, u.ID.Hex()) - if errors.Cause(err) != user.ErrNotFound { - t.Fatalf("\t%s\tShould NOT be able to retrieve user : %s.", tests.Failed, err) - } - t.Logf("\t%s\tShould NOT be able to retrieve user.", tests.Success) - - } } } +// TestUserFind validates all the request params are correctly parsed into a select query. +func TestUserFind(t *testing.T) { -// mockTokenGenerator is used for testing that Authenticate calls its provided -// token generator in a specific way. -type mockTokenGenerator struct{} + now := time.Date(2018, time.October, 1, 0, 0, 0, 0, time.UTC) -// GenerateToken implements the TokenGenerator interface. It returns a "token" -// that includes some information about the claims it was passed. -func (mockTokenGenerator) GenerateToken(claims auth.Claims) (string, error) { - return fmt.Sprintf("sub:%q iss:%d", claims.Subject, claims.IssuedAt), nil -} + var users []*User + for i := 0; i <= 4; i++ { + user, err := Create(tests.Context(), auth.Claims{}, test.MasterDB, CreateUserRequest{ + Name: "Lee Brown", + Email: uuid.NewRandom().String() + "@geeksinthewoods.com", + Password: "akTechFr0n!ier", + PasswordConfirm: "akTechFr0n!ier", + }, now.Add(time.Second*time.Duration(i))) + if err != nil { + t.Logf("\t\tGot : %+v", err) + t.Fatalf("\t%s\tCreate failed.", tests.Failed) + } + users = append(users, user) + } -// TestAuthenticate validates the behavior around authenticating users. -func TestAuthenticate(t *testing.T) { - defer tests.Recover(t) + type userTest struct { + name string + req UserFindRequest + expected []*User + error error + } - t.Log("Given the need to authenticate users") + var userTests []userTest + + // Test sort users. + userTests = append(userTests, userTest{"Find all order by created_at asx", + UserFindRequest{ + Order: []string{"created_at"}, + }, + users, + nil, + }) + + // Test reverse sorted users. + var expected []*User + for i := len(users) - 1; i >= 0; i-- { + expected = append(expected, users[i]) + } + userTests = append(userTests, userTest{"Find all order by created_at desc", + UserFindRequest{ + Order: []string{"created_at desc"}, + }, + expected, + nil, + }) + + // Test limit. + var limit uint = 2 + userTests = append(userTests, userTest{"Find limit", + UserFindRequest{ + Order: []string{"created_at"}, + Limit: &limit, + }, + users[0:2], + nil, + }) + + // Test offset. + var offset uint = 3 + userTests = append(userTests, userTest{"Find limit, offset", + UserFindRequest{ + Order: []string{"created_at"}, + Limit: &limit, + Offset: &offset, + }, + users[3:5], + nil, + }) + + // Test where filter. + whereParts := []string{} + whereArgs := []interface{}{} + expected = []*User{} + selected := make(map[string]bool) + for i := 0; i <= 2; i++ { + ranIdx := rand.Intn(len(users)) + + email := users[ranIdx].Email + if selected[email] { + continue + } + selected[email] = true + + whereParts = append(whereParts, "email = ?") + whereArgs = append(whereArgs, email) + expected = append(expected, users[ranIdx]) + } + where := strings.Join(whereParts, " OR ") + userTests = append(userTests, userTest{"Find where", + UserFindRequest{ + Where: &where, + Args: whereArgs, + }, + expected, + nil, + }) + + t.Log("Given the need to ensure find users returns the expected results.") { - t.Log("\tWhen handling a single User.") - { - ctx := tests.Context() + for i, tt := range userTests { + t.Logf("\tTest: %d\tWhen running test: %s", i, tt.name) + { + ctx := tests.Context() - dbConn := test.MasterDB.Copy() - defer dbConn.Close() - - nu := user.NewUser{ - Name: "Anna Walker", - Email: "anna@ardanlabs.com", - Roles: []string{auth.RoleAdmin}, - Password: "goroutines", - PasswordConfirm: "goroutines", + res, err := Find(ctx, auth.Claims{}, test.MasterDB, tt.req) + if err != nil && errors.Cause(err) != tt.error { + t.Logf("\t\tGot : %+v", err) + t.Logf("\t\tWant: %+v", tt.error) + t.Fatalf("\t%s\tFind failed.", tests.Failed) + } else if diff := cmp.Diff(res, tt.expected); diff != "" { + t.Logf("\t\tGot: %d items", len(res)) + t.Logf("\t\tWant: %d items", len(tt.expected)) + t.Fatalf("\t%s\tExpected find result to match expected. Diff:\n%s", tests.Failed, diff) + } + t.Logf("\t%s\tFind ok.", tests.Success) } - - now := time.Date(2018, time.October, 1, 0, 0, 0, 0, time.UTC) - - u, err := user.Create(ctx, dbConn, &nu, now) - if err != nil { - t.Fatalf("\t%s\tShould be able to create user : %s.", tests.Failed, err) - } - t.Logf("\t%s\tShould be able to create user.", tests.Success) - - var tknGen mockTokenGenerator - tkn, err := user.Authenticate(ctx, dbConn, tknGen, now, "anna@ardanlabs.com", "goroutines") - if err != nil { - t.Fatalf("\t%s\tShould be able to generate a token : %s.", tests.Failed, err) - } - t.Logf("\t%s\tShould be able to generate a token.", tests.Success) - - want := fmt.Sprintf("sub:%q iss:1538352000", u.ID.Hex()) - if tkn.Token != want { - t.Log("\t\tGot :", tkn.Token) - t.Log("\t\tWant:", want) - t.Fatalf("\t%s\tToken should indicate the specified user and time were used.", tests.Failed) - } - t.Logf("\t%s\tToken should indicate the specified user and time were used.", tests.Success) - - if err := user.Delete(ctx, dbConn, u.ID.Hex()); err != nil { - t.Fatalf("\t%s\tShould be able to delete user : %s.", tests.Failed, err) - } - t.Logf("\t%s\tShould be able to delete user.", tests.Success) } } } -*/