diff --git a/example-project/internal/schema/migrations.go b/example-project/internal/schema/migrations.go index 7c4269a..301e301 100644 --- a/example-project/internal/schema/migrations.go +++ b/example-project/internal/schema/migrations.go @@ -100,7 +100,7 @@ func migrationList(db *sqlx.DB, log *log.Logger) []*sqlxmigrate.Migration { return errors.WithMessagef(err, "Query failed %s", q1) } - q2 := `CREATE TYPE user_account_status_t as enum('active','disabled')` + q2 := `CREATE TYPE user_account_status_t as enum('active', 'invited','disabled')` if _, err := tx.Exec(q2); err != nil { return errors.WithMessagef(err, "Query failed %s", q2) } diff --git a/example-project/internal/user/auth.go b/example-project/internal/user/auth.go index 5ddff4e..5f635d9 100644 --- a/example-project/internal/user/auth.go +++ b/example-project/internal/user/auth.go @@ -2,6 +2,7 @@ package user import ( "context" + "gopkg.in/go-playground/validator.v9" "time" "geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/auth" @@ -12,23 +13,27 @@ import ( "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer" ) -// TokenGenerator is the behavior we need in our Authenticate to generate -// tokens for authenticated users. +// TokenGenerator is the behavior we need in our Authenticate to generate tokens for +// authenticated users. type TokenGenerator interface { GenerateToken(auth.Claims) (string, error) + ParseClaims(string) (auth.Claims, error) } -// Authenticate finds a user by their email and verifies their password. On -// success it returns a Token that can be used to authenticate in the future. -func Authenticate(ctx context.Context, dbConn *sqlx.DB, tknGen TokenGenerator, now time.Time, email, password string) (Token, error) { +// Authenticate finds a user by their email and verifies their password. On success +// it returns a Token that can be used to authenticate access to the application in +// the future. +func Authenticate(ctx context.Context, dbConn *sqlx.DB, tknGen TokenGenerator, email, password string, expires time.Duration, now time.Time) (Token, error) { span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.Authenticate") defer span.Finish() - // Generate sql query to select user by email address + // Generate sql query to select user by email address. query := sqlbuilder.NewSelectBuilder() query.Where(query.Equal("email", email)) - // Run the find, use empty claims to bypass ACLs + // Run the find, use empty claims to bypass ACLs since this in an internal request + // and the current user is not authenticated at this point. If the email is + // invalid, return the same error as when an invalid password is supplied. res, err := find(ctx, auth.Claims{}, dbConn, query, []interface{}{}, false) if err != nil { return Token{}, err @@ -39,43 +44,104 @@ func Authenticate(ctx context.Context, dbConn *sqlx.DB, tknGen TokenGenerator, n u := res[0] // Append the salt from the user record to the supplied password. - saltedPassword := password + string(u.PasswordSalt) + saltedPassword := password + u.PasswordSalt - // Compare the provided password with the saved hash. Use the bcrypt - // comparison function so it is cryptographically secure. + // Compare the provided password with the saved hash. Use the bcrypt comparison + // function so it is cryptographically secure. Return authentication error for + // invalid password. if err := bcrypt.CompareHashAndPassword(u.PasswordHash, []byte(saltedPassword)); err != nil { err = errors.WithStack(ErrAuthenticationFailure) return Token{}, err } - // Get a list of all the account ids associated with the user. - accounts, err := FindAccountsByUserID(ctx, auth.Claims{}, dbConn, u.ID, false) + // The user is successfully authenticated with the supplied email and password. + return generateToken(ctx, dbConn, tknGen, auth.Claims{}, u.ID, "", expires, now) +} + +// Authenticate finds a user by their email and verifies their password. On success +// it returns a Token that can be used to authenticate access to the application in +// the future. +func SwitchAccount(ctx context.Context, dbConn *sqlx.DB, tknGen TokenGenerator, claims auth.Claims, accountID string, expires time.Duration, now time.Time) (Token, error) { + span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.SwitchAccount") + defer span.Finish() + + // Defines struct to apply validation for the supplied claims and account ID. + req := struct { + UserID string `validate:"required,uuid"` + AccountID string `validate:"required,uuid"` + }{ + UserID: claims.Subject, + AccountID: accountID, + } + + // Validate the request. + err := validator.New().Struct(req) if err != nil { return Token{}, err } - // Claims needs an audience, select the first account associated with - // the user. - var ( - accountId string - roles []string - ) - if len(accounts) > 0 { - accountId = accounts[0].AccountID - for _, r := range accounts[0].Roles { + // Generate a token for the user ID in supplied in claims as the Subject. Pass + // in the supplied claims as well to enforce ACLs when finding the current + // list of accounts for the user. + return generateToken(ctx, dbConn, tknGen, claims, req.UserID, req.AccountID, expires, now) +} + +// generateToken generates claims for the supplied user ID and account ID and then +// returns the token for the generated claims used for authentication. +func generateToken(ctx context.Context, dbConn *sqlx.DB, tknGen TokenGenerator, claims auth.Claims, userID, accountID string, expires time.Duration, now time.Time) (Token, error) { + // Get a list of all the accounts associated with the user. + accounts, err := FindAccountsByUserID(ctx, auth.Claims{}, dbConn, userID, false) + if err != nil { + return Token{}, err + } + + // Load the user account entry for the specifed account ID. If none provided, + // choose the first. + var account *UserAccount + if accountID == "" { + // Select the first account associated with the user. For the login flow, + // users could be forced to select a specific account to override this. + if len(accounts) > 0 { + account = accounts[0] + accountID = account.AccountID + } + } else { + // Loop through all the accounts found for the user and select the specified + // account. + for _, a := range accounts { + if a.AccountID == accountID { + account = a + break + } + } + + // If no matching entry was found for the specified account ID throw an error. + if account == nil { + err = errors.WithStack(ErrAuthenticationFailure) + return Token{}, err + } + } + + // Generate list of user defined roles for accessing the account. + var roles []string + if account != nil { + for _, r := range account.Roles { roles = append(roles, r.String()) } } - // Generate a list of all the account IDs associated with the user so - // the use has the ability to switch between accounts. - accountIds := []string{} + // Generate a list of all the account IDs associated with the user so the use + // has the ability to switch between accounts. + var accountIds []string for _, a := range accounts { accountIds = append(accountIds, a.AccountID) } - // If we are this far the request is valid. Create some claims for the user. - claims := auth.NewClaims(u.ID, accountId, accountIds, roles, now, time.Hour) + // JWT claims requires both an audience and a subject. For this application: + // Subject: The ID of the user authenticated. + // Audience: The ID of the account the user is accessing. A list of account IDs + // will also be included to support the user switching between them. + claims = auth.NewClaims(userID, accountID, accountIds, roles, now, expires) // Generate a token for the user with the defined claims. tkn, err := tknGen.GenerateToken(claims) diff --git a/example-project/internal/user/auth_test.go b/example-project/internal/user/auth_test.go index 6e8019c..7ab5ddf 100644 --- a/example-project/internal/user/auth_test.go +++ b/example-project/internal/user/auth_test.go @@ -15,31 +15,33 @@ import ( // mockTokenGenerator is used for testing that Authenticate calls its provided // token generator in a specific way. -type mockTokenGenerator struct{} - -// Private key generated by GenerateToken that is need for ParseClaims -var mockTokenKey *rsa.PrivateKey +type mockTokenGenerator struct { + // Private key generated by GenerateToken that is need for ParseClaims + key *rsa.PrivateKey + // algorithm is the method used to generate the private key. + algorithm string +} // GenerateToken implements the TokenGenerator interface. It returns a "token" // that includes some information about the claims it was passed. -func (g mockTokenGenerator) GenerateToken(claims auth.Claims) (string, error) { +func (g *mockTokenGenerator) GenerateToken(claims auth.Claims) (string, error) { privateKey, err := auth.Keygen() if err != nil { return "", err } - mockTokenKey, err = jwt.ParseRSAPrivateKeyFromPEM(privateKey) + g.key, err = jwt.ParseRSAPrivateKeyFromPEM(privateKey) if err != nil { return "", err } - algorithm := "RS256" - method := jwt.GetSigningMethod(algorithm) + g.algorithm = "RS256" + method := jwt.GetSigningMethod(g.algorithm) tkn := jwt.NewWithClaims(method, claims) tkn.Header["kid"] = "1" - str, err := tkn.SignedString(mockTokenKey) + str, err := tkn.SignedString(g.key) if err != nil { return "", err } @@ -49,18 +51,17 @@ func (g mockTokenGenerator) GenerateToken(claims auth.Claims) (string, error) { // ParseClaims recreates the Claims that were used to generate a token. It // verifies that the token was signed using our key. -func (g mockTokenGenerator) ParseClaims(tknStr string) (auth.Claims, error) { - algorithm := "RS256" +func (g *mockTokenGenerator) ParseClaims(tknStr string) (auth.Claims, error) { parser := jwt.Parser{ - ValidMethods: []string{algorithm}, + ValidMethods: []string{g.algorithm}, } - if mockTokenKey == nil { - panic("key is nil") + if g.key == nil { + return auth.Claims{}, errors.New("Private key is empty.") } f := func(t *jwt.Token) (interface{}, error) { - return mockTokenKey.Public().(*rsa.PublicKey), nil + return g.key.Public().(*rsa.PublicKey), nil } var claims auth.Claims @@ -93,7 +94,7 @@ func TestAuthenticate(t *testing.T) { now := time.Now().Add(time.Hour * -1) // Try to authenticate an invalid user. - _, err := Authenticate(ctx, test.MasterDB, tknGen, now, "doesnotexist@gmail.com", "xy7") + _, err := Authenticate(ctx, test.MasterDB, tknGen, "doesnotexist@gmail.com", "xy7", time.Hour, now) if errors.Cause(err) != ErrAuthenticationFailure { t.Logf("\t\tGot : %+v", err) t.Logf("\t\tWant: %+v", ErrAuthenticationFailure) @@ -117,23 +118,36 @@ func TestAuthenticate(t *testing.T) { // Create a new random account and associate that with the user. // This defined role should be the claims. - accountId := uuid.NewRandom().String() - accountRole := UserAccountRole_Admin + account1Id := uuid.NewRandom().String() + account1Role := UserAccountRole_Admin _, err = AddAccount(tests.Context(), auth.Claims{}, test.MasterDB, AddAccountRequest{ UserID: user.ID, - AccountID: accountId, - Roles: []UserAccountRole{accountRole}, + AccountID: account1Id, + Roles: []UserAccountRole{account1Role}, }, now) if err != nil { t.Log("\t\tGot :", err) t.Fatalf("\t%s\tAddAccount failed.", tests.Failed) } + // Create a second new random account and associate that with the user. + account2Id := uuid.NewRandom().String() + account2Role := UserAccountRole_User + _, err = AddAccount(tests.Context(), auth.Claims{}, test.MasterDB, AddAccountRequest{ + UserID: user.ID, + AccountID: account2Id, + Roles: []UserAccountRole{account2Role}, + }, now.Add(time.Second)) + if err != nil { + t.Log("\t\tGot :", err) + t.Fatalf("\t%s\tAddAccount failed.", tests.Failed) + } + // Add 30 minutes to now to simulate time passing. now = now.Add(time.Minute * 30) // Try to authenticate valid user with invalid password. - _, err = Authenticate(ctx, test.MasterDB, tknGen, now, user.Email, "xy7") + _, err = Authenticate(ctx, test.MasterDB, tknGen, user.Email, "xy7", time.Hour, now) if errors.Cause(err) != ErrAuthenticationFailure { t.Logf("\t\tGot : %+v", err) t.Logf("\t\tWant: %+v", ErrAuthenticationFailure) @@ -142,7 +156,7 @@ func TestAuthenticate(t *testing.T) { t.Logf("\t%s\tAuthenticate user w/invalid password ok.", tests.Success) // Verify that the user can be authenticated with the created user. - tkn, err := Authenticate(ctx, test.MasterDB, tknGen, now, user.Email, initPass) + tkn1, err := Authenticate(ctx, test.MasterDB, tknGen, user.Email, initPass, time.Hour, now) if err != nil { t.Log("\t\tGot :", err) t.Fatalf("\t%s\tAuthenticate user failed.", tests.Failed) @@ -150,18 +164,40 @@ func TestAuthenticate(t *testing.T) { t.Logf("\t%s\tAuthenticate user ok.", tests.Success) // Ensure the token string was correctly generated. - claims, err := tknGen.ParseClaims(tkn.Token) + claims1, err := tknGen.ParseClaims(tkn1.Token) if err != nil { t.Log("\t\tGot :", err) t.Fatalf("\t%s\tParse claims from token failed.", tests.Failed) - } else if diff := cmp.Diff(claims, tkn.claims); diff != "" { + } else if diff := cmp.Diff(claims1, tkn1.claims); diff != "" { t.Fatalf("\t%s\tExpected parsed claims to match from token. Diff:\n%s", tests.Failed, diff) - } else if diff := cmp.Diff(claims.Roles, []string{accountRole.String()}); diff != "" { + } else if diff := cmp.Diff(claims1.Roles, []string{account1Role.String()}); diff != "" { t.Fatalf("\t%s\tExpected parsed claims roles to match user account. Diff:\n%s", tests.Failed, diff) - } else if diff := cmp.Diff(claims.AccountIds, []string{accountId}); diff != "" { + } else if diff := cmp.Diff(claims1.AccountIds, []string{account1Id, account2Id}); diff != "" { t.Fatalf("\t%s\tExpected parsed claims account IDs to match the single user account. Diff:\n%s", tests.Failed, diff) } - t.Logf("\t%s\tParse claims from token ok.", tests.Success) + t.Logf("\t%s\tAuthenticate parse claims from token ok.", tests.Success) + + // Try switching to a second account using the first set of claims. + tkn2, err := SwitchAccount(ctx, test.MasterDB, tknGen, claims1, account2Id, time.Hour, now) + if err != nil { + t.Log("\t\tGot :", err) + t.Fatalf("\t%s\tSwitchAccount user failed.", tests.Failed) + } + t.Logf("\t%s\tSwitchAccount user ok.", tests.Success) + + // Ensure the token string was correctly generated. + claims2, err := tknGen.ParseClaims(tkn2.Token) + if err != nil { + t.Log("\t\tGot :", err) + t.Fatalf("\t%s\tParse claims from token failed.", tests.Failed) + } else if diff := cmp.Diff(claims2, tkn2.claims); diff != "" { + t.Fatalf("\t%s\tExpected parsed claims to match from token. Diff:\n%s", tests.Failed, diff) + } else if diff := cmp.Diff(claims2.Roles, []string{account2Role.String()}); diff != "" { + t.Fatalf("\t%s\tExpected parsed claims roles to match user account. Diff:\n%s", tests.Failed, diff) + } else if diff := cmp.Diff(claims2.AccountIds, []string{account1Id, account2Id}); diff != "" { + t.Fatalf("\t%s\tExpected parsed claims account IDs to match the single user account. Diff:\n%s", tests.Failed, diff) + } + t.Logf("\t%s\tSwitchAccount parse claims from token ok.", tests.Success) } } } diff --git a/example-project/internal/user/models.go b/example-project/internal/user/models.go index 30e2bb9..b168a88 100644 --- a/example-project/internal/user/models.go +++ b/example-project/internal/user/models.go @@ -21,7 +21,7 @@ type User struct { PasswordHash []byte `db:"password_hash" json:"-"` PasswordReset sql.NullString `db:"password_reset" json:"-"` - Timezone string `db:"timezone" json:"timezone"` + Timezone string `db:"timezone" json:"timezone"` CreatedAt time.Time `db:"created_at" json:"created_at"` UpdatedAt time.Time `db:"updated_at" json:"updated_at"` @@ -30,11 +30,11 @@ type User struct { // CreateUserRequest contains information needed to create a new User. type CreateUserRequest struct { - Name string `json:"name" validate:"required"` - Email string `json:"email" validate:"required,email,unique"` - Password string `json:"password" validate:"required"` - PasswordConfirm string `json:"password_confirm" validate:"eqfield=Password"` - Timezone *string `json:"timezone" validate:"omitempty"` + Name string `json:"name" validate:"required"` + Email string `json:"email" validate:"required,email,unique"` + Password string `json:"password" validate:"required"` + PasswordConfirm string `json:"password_confirm" validate:"eqfield=Password"` + Timezone *string `json:"timezone" validate:"omitempty"` } // UpdateUserRequest defines what information may be provided to modify an existing @@ -44,20 +44,21 @@ type CreateUserRequest struct { // we do not want to use pointers to basic types but we make exceptions around // marshalling/unmarshalling. type UpdateUserRequest struct { - ID string `validate:"required,uuid"` - Name *string `json:"name" validate:"omitempty"` - Email *string `json:"email" validate:"omitempty,email,unique"` - Timezone *string `json:"timezone" validate:"omitempty"` + ID string `validate:"required,uuid"` + Name *string `json:"name" validate:"omitempty"` + Email *string `json:"email" validate:"omitempty,email,unique"` + Timezone *string `json:"timezone" validate:"omitempty"` } -// UpdatePassword defines what information may be provided to update user password. +// UpdatePassword defines what information is required to update a user password. type UpdatePasswordRequest struct { ID string `validate:"required,uuid"` Password string `json:"password" validate:"required"` PasswordConfirm string `json:"password_confirm" validate:"omitempty,eqfield=Password"` } -// UserFindRequest defines the possible options for search for users +// UserFindRequest defines the possible options to search for users. By default +// archived users will be excluded from response. type UserFindRequest struct { Where *string Args []interface{} @@ -67,53 +68,60 @@ type UserFindRequest struct { IncludedArchived bool } -// UserAccount defines the one to many relationship of an user to an account. -// Each association of an user to an account has a set of roles defined for the user -// that will be applied when accessing the account. +// UserAccount defines the one to many relationship of an user to an account. This +// will enable a single user access to multiple accounts without having duplicate +// users. Each association of a user to an account has a set of roles and a status +// defined for the user. The roles will be applied to enforce ACLs across the +// application. The status will allow users to be managed on by account with users +// being global to the application. type UserAccount struct { - ID string `db:"id" json:"id"` - UserID string `db:"user_id" json:"user_id"` - AccountID string `db:"account_id" json:"account_id"` - Roles UserAccountRoles `db:"roles" json:"roles"` - Status UserAccountStatus `db:"status" json:"status"` - CreatedAt time.Time `db:"created_at" json:"created_at"` - UpdatedAt time.Time `db:"updated_at" json:"updated_at"` - ArchivedAt pq.NullTime `db:"archived_at" json:"archived_at"` + ID string `db:"id" json:"id"` + UserID string `db:"user_id" json:"user_id"` + AccountID string `db:"account_id" json:"account_id"` + Roles UserAccountRoles `db:"roles" json:"roles"` + Status UserAccountStatus `db:"status" json:"status"` + CreatedAt time.Time `db:"created_at" json:"created_at"` + UpdatedAt time.Time `db:"updated_at" json:"updated_at"` + ArchivedAt pq.NullTime `db:"archived_at" json:"archived_at"` } -// AddAccountRequest defines the information needed to add a new account to a user. +// AddAccountRequest defines the information is needed to associate a user to an +// account. Users are global to the application and each users access can be managed +// on an account level. If a current entry exists in the database but is archived, +// it will be un-archived. type AddAccountRequest struct { - UserID string `validate:"required,uuid"` - AccountID string `validate:"required,uuid"` - Roles UserAccountRoles `json:"roles" validate:"required,dive,oneof=admin user"` - Status *UserAccountStatus `json:"status" validate:"omitempty,oneof=active disabled"` + UserID string `validate:"required,uuid"` + AccountID string `validate:"required,uuid"` + Roles UserAccountRoles `json:"roles" validate:"required,dive,oneof=admin user"` + Status *UserAccountStatus `json:"status" validate:"omitempty,oneof=active invited disabled"` } -// UpdateAccountRequest defines the information needed to update the roles for -// an existing user account. +// UpdateAccountRequest defines the information needed to update the roles or the +// status for an existing user account. type UpdateAccountRequest struct { - UserID string `validate:"required,uuid"` - AccountID string `validate:"required,uuid"` - Roles *UserAccountRoles `json:"roles" validate:"required,dive,oneof=admin user"` - Status *UserAccountStatus `json:"status" validate:"omitempty,oneof=active disabled"` - unArchive bool + UserID string `validate:"required,uuid"` + AccountID string `validate:"required,uuid"` + Roles *UserAccountRoles `json:"roles" validate:"required,dive,oneof=admin user"` + Status *UserAccountStatus `json:"status" validate:"omitempty,oneof=active invited disabled"` + unArchive bool `json:"-"` // Internal use only. } -// RemoveAccountRequest defines the information needed to remove an existing -// account for a user. This will archive (soft-delete) the existing database entry. +// RemoveAccountRequest defines the information needed to remove an existing account +// for a user. This will archive (soft-delete) the existing database entry. type RemoveAccountRequest struct { UserID string `validate:"required,uuid"` AccountID string `validate:"required,uuid"` } -// DeleteAccountRequest defines the information needed to delete an existing -// account for a user. This will hard delete the existing database entry. +// DeleteAccountRequest defines the information needed to delete an existing account +// for a user. This will hard delete the existing database entry. type DeleteAccountRequest struct { UserID string `validate:"required,uuid"` AccountID string `validate:"required,uuid"` } -// UserAccountFindRequest defines the possible options for search for users accounts +// UserAccountFindRequest defines the possible options to search for users accounts. +// By default archived user accounts will be excluded from response. type UserAccountFindRequest struct { Where *string Args []interface{} @@ -123,18 +131,25 @@ type UserAccountFindRequest struct { IncludedArchived bool } -// UserAccountStatus represents the status of a user. +// UserAccountStatus represents the status of a user for an account. type UserAccountStatus string -// UserAccountStatus values +// UserAccountStatus values define the status field of a user account. const ( - UserAccountStatus_Active UserAccountStatus = "active" + // UserAccountStatus_Active defines the state when a user can access an account. + UserAccountStatus_Active UserAccountStatus = "active" + // UserAccountStatus_Invited defined the state when a user has been invited to an + // account. + UserAccountStatus_Invited UserAccountStatus = "invited" + // UserAccountStatus_Disabled defines the state when a user has been disabled from + // accessing an account. UserAccountStatus_Disabled UserAccountStatus = "disabled" ) -// UserAccountStatus_Values provides list of valid UserAccountStatus values +// UserAccountStatus_Values provides list of valid UserAccountStatus values. var UserAccountStatus_Values = []UserAccountStatus{ UserAccountStatus_Active, + UserAccountStatus_Invited, UserAccountStatus_Disabled, } @@ -152,7 +167,7 @@ func (s *UserAccountStatus) Scan(value interface{}) error { func (s UserAccountStatus) Value() (driver.Value, error) { v := validator.New() - errs := v.Var(s, "required,oneof=active disabled") + errs := v.Var(s, "required,oneof=active invited disabled") if errs != nil { return nil, errs } @@ -168,13 +183,19 @@ func (s UserAccountStatus) String() string { // UserAccountRole represents the role of a user for an account. type UserAccountRole string -// UserAccountRole values +// UserAccountRole values define the role field of a user account. const ( + // UserAccountRole_Admin defines the state of a user when they have admin + // privileges for accessing an account. This role provides a user with full + // access to an account. UserAccountRole_Admin UserAccountRole = auth.RoleAdmin - UserAccountRole_User UserAccountRole = auth.RoleUser + // UserAccountRole_User defines the state of a user when they have basic + // privileges for accessing an account. This role provies a user with the most + // limited access to an account. + UserAccountRole_User UserAccountRole = auth.RoleUser ) -// UserAccountRole_Values provides list of valid UserAccountRole values +// UserAccountRole_Values provides list of valid UserAccountRole values. var UserAccountRole_Values = []UserAccountRole{ UserAccountRole_Admin, UserAccountRole_User, diff --git a/example-project/internal/user/user.go b/example-project/internal/user/user.go index 07fadba..f3cfcea 100644 --- a/example-project/internal/user/user.go +++ b/example-project/internal/user/user.go @@ -335,7 +335,7 @@ func Create(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req Create // Build the insert SQL statement. query := sqlbuilder.NewInsertBuilder() query.InsertInto(usersTableName) - query.Cols("id", "name", "email", "password_hash", "password_salt","timezone", "created_at", "updated_at") + query.Cols("id", "name", "email", "password_hash", "password_salt", "timezone", "created_at", "updated_at") query.Values(u.ID, u.Name, u.Email, u.PasswordHash, u.PasswordSalt, u.Timezone, u.CreatedAt, u.UpdatedAt) // Execute the query with the provided context. diff --git a/example-project/internal/user/user_account.go b/example-project/internal/user/user_account.go index 1012bba..bd02539 100644 --- a/example-project/internal/user/user_account.go +++ b/example-project/internal/user/user_account.go @@ -61,7 +61,6 @@ func CanModifyUserAccount(ctx context.Context, claims auth.Claims, dbConn *sqlx. return nil } - // applyClaimsUserAccountSelect applies a sub query to enforce ACL for // the supplied claims. If claims is empty then request must be internal and // no sub-query is applied. Else a list of user IDs is found all associated @@ -175,7 +174,7 @@ func FindAccountsByUserID(ctx context.Context, claims auth.Claims, dbConn *sqlx. // Filter base select query by ID query := sqlbuilder.NewSelectBuilder() query.Where(query.Equal("user_id", userID)) - query.OrderBy("id") + query.OrderBy("created_at") // Execute the find accounts method. res, err := findAccounts(ctx, claims, dbConn, query, []interface{}{}, includedArchived) @@ -251,13 +250,13 @@ func AddAccount(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req Ad } ua := UserAccount{ - ID: uuid.NewRandom().String(), - UserID: req.UserID, + ID: uuid.NewRandom().String(), + UserID: req.UserID, AccountID: req.AccountID, - Roles: req.Roles, - Status: UserAccountStatus_Active, - CreatedAt: now, - UpdatedAt: now, + Roles: req.Roles, + Status: UserAccountStatus_Active, + CreatedAt: now, + UpdatedAt: now, } if req.Status != nil { diff --git a/example-project/internal/user/user_account_test.go b/example-project/internal/user/user_account_test.go index 0a9fb6b..df0f7b6 100644 --- a/example-project/internal/user/user_account_test.go +++ b/example-project/internal/user/user_account_test.go @@ -8,11 +8,11 @@ import ( "time" "geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/auth" + "geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/tests" "github.com/dgrijalva/jwt-go" + "github.com/google/go-cmp/cmp" "github.com/huandu/go-sqlbuilder" "github.com/pborman/uuid" - "geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/tests" - "github.com/google/go-cmp/cmp" "github.com/pkg/errors" ) @@ -127,7 +127,6 @@ func TestAddAccountValidation(t *testing.T) { invalidRole := UserAccountRole("moon") invalidStatus := UserAccountStatus("moon") - var accountTests = []struct { name string req AddAccountRequest @@ -139,15 +138,15 @@ func TestAddAccountValidation(t *testing.T) { func(req AddAccountRequest, res *UserAccount) *UserAccount { return nil }, - errors.New("Key: 'AddAccountRequest.UserID' Error:Field validation for 'UserID' failed on the 'required' tag\n"+ - "Key: 'AddAccountRequest.AccountID' Error:Field validation for 'AccountID' failed on the 'required' tag\n"+ - "Key: 'AddAccountRequest.Roles' Error:Field validation for 'Roles' failed on the 'required' tag"), + errors.New("Key: 'AddAccountRequest.UserID' Error:Field validation for 'UserID' failed on the 'required' tag\n" + + "Key: 'AddAccountRequest.AccountID' Error:Field validation for 'AccountID' failed on the 'required' tag\n" + + "Key: 'AddAccountRequest.Roles' Error:Field validation for 'Roles' failed on the 'required' tag"), }, {"Valid Role", AddAccountRequest{ - UserID: uuid.NewRandom().String(), + UserID: uuid.NewRandom().String(), AccountID: uuid.NewRandom().String(), - Roles: []UserAccountRole{invalidRole}, + Roles: []UserAccountRole{invalidRole}, }, func(req AddAccountRequest, res *UserAccount) *UserAccount { return nil @@ -156,10 +155,10 @@ func TestAddAccountValidation(t *testing.T) { }, {"Valid Status", AddAccountRequest{ - UserID: uuid.NewRandom().String(), + UserID: uuid.NewRandom().String(), AccountID: uuid.NewRandom().String(), - Roles: []UserAccountRole{UserAccountRole_User}, - Status: &invalidStatus, + Roles: []UserAccountRole{UserAccountRole_User}, + Status: &invalidStatus, }, func(req AddAccountRequest, res *UserAccount) *UserAccount { return nil @@ -168,21 +167,21 @@ func TestAddAccountValidation(t *testing.T) { }, {"Default Status", AddAccountRequest{ - UserID: uuid.NewRandom().String(), + UserID: uuid.NewRandom().String(), AccountID: uuid.NewRandom().String(), - Roles: []UserAccountRole{UserAccountRole_User}, + Roles: []UserAccountRole{UserAccountRole_User}, }, func(req AddAccountRequest, res *UserAccount) *UserAccount { return &UserAccount{ - UserID: req.UserID, - AccountID: req.AccountID, - Roles: req.Roles, - Status: UserAccountStatus_Active, + UserID: req.UserID, + AccountID: req.AccountID, + Roles: req.Roles, + Status: UserAccountStatus_Active, // Copy this fields from the result. - ID: res.ID, - CreatedAt: res.CreatedAt, - UpdatedAt: res.UpdatedAt, + ID: res.ID, + CreatedAt: res.CreatedAt, + UpdatedAt: res.UpdatedAt, //ArchivedAt: nil, } }, @@ -245,9 +244,9 @@ func TestAddAccountExistingEntry(t *testing.T) { ctx := tests.Context() req1 := AddAccountRequest{ - UserID: uuid.NewRandom().String(), + UserID: uuid.NewRandom().String(), AccountID: uuid.NewRandom().String(), - Roles: []UserAccountRole{UserAccountRole_User}, + Roles: []UserAccountRole{UserAccountRole_User}, } ua1, err := AddAccount(ctx, auth.Claims{}, test.MasterDB, req1, now) if err != nil { @@ -260,9 +259,9 @@ func TestAddAccountExistingEntry(t *testing.T) { } req2 := AddAccountRequest{ - UserID: req1.UserID, + UserID: req1.UserID, AccountID: req1.AccountID, - Roles: []UserAccountRole{UserAccountRole_Admin}, + Roles: []UserAccountRole{UserAccountRole_Admin}, } ua2, err := AddAccount(ctx, auth.Claims{}, test.MasterDB, req2, now) if err != nil { @@ -285,33 +284,33 @@ func TestUpdateAccountValidation(t *testing.T) { invalidStatus := UserAccountStatus("xxxxxxxxx") var accountTests = []struct { - name string - req UpdateAccountRequest - error error + name string + req UpdateAccountRequest + error error }{ {"Required Fields", UpdateAccountRequest{}, errors.New("Key: 'UpdateAccountRequest.UserID' Error:Field validation for 'UserID' failed on the 'required' tag\n" + - "Key: 'UpdateAccountRequest.AccountID' Error:Field validation for 'AccountID' failed on the 'required' tag\n" + - "Key: 'UpdateAccountRequest.Roles' Error:Field validation for 'Roles' failed on the 'required' tag"), + "Key: 'UpdateAccountRequest.AccountID' Error:Field validation for 'AccountID' failed on the 'required' tag\n" + + "Key: 'UpdateAccountRequest.Roles' Error:Field validation for 'Roles' failed on the 'required' tag"), }, {"Valid Role", UpdateAccountRequest{ - UserID: uuid.NewRandom().String(), + UserID: uuid.NewRandom().String(), AccountID: uuid.NewRandom().String(), - Roles: &UserAccountRoles{invalidRole}, + Roles: &UserAccountRoles{invalidRole}, }, errors.New("Key: 'UpdateAccountRequest.Roles[0]' Error:Field validation for 'Roles[0]' failed on the 'oneof' tag"), }, {"Valid Status", UpdateAccountRequest{ - UserID: uuid.NewRandom().String(), + UserID: uuid.NewRandom().String(), AccountID: uuid.NewRandom().String(), - Roles: &UserAccountRoles{UserAccountRole_User}, - Status: &invalidStatus, + Roles: &UserAccountRoles{UserAccountRole_User}, + Status: &invalidStatus, }, - errors.New("Key: 'UpdateAccountRequest.Status' Error:Field validation for 'Status' failed on the 'oneof' tag"), + errors.New("Key: 'UpdateAccountRequest.Status' Error:Field validation for 'Status' failed on the 'oneof' tag"), }, } @@ -461,9 +460,9 @@ func TestAccountCrud(t *testing.T) { // Create a new random account and associate that with the user. accountID := uuid.NewRandom().String() createReq := AddAccountRequest{ - UserID: user.ID, + UserID: user.ID, AccountID: accountID, - Roles: []UserAccountRole{UserAccountRole_User}, + Roles: []UserAccountRole{UserAccountRole_User}, } ua, err := AddAccount(tests.Context(), tt.claims(user.ID, accountID), test.MasterDB, createReq, now) if err != nil && errors.Cause(err) != tt.updateErr { @@ -479,9 +478,9 @@ func TestAccountCrud(t *testing.T) { // Update the account. updateReq := UpdateAccountRequest{ - UserID: user.ID, + UserID: user.ID, AccountID: accountID, - Roles: &UserAccountRoles{UserAccountRole_Admin}, + Roles: &UserAccountRoles{UserAccountRole_Admin}, } err = UpdateAccount(tests.Context(), tt.claims(user.ID, accountID), test.MasterDB, updateReq, now) if err != nil && errors.Cause(err) != tt.updateErr { @@ -501,12 +500,12 @@ func TestAccountCrud(t *testing.T) { } else if tt.findErr == nil { expected := []*UserAccount{ &UserAccount{ - ID: ua.ID, - UserID: ua.UserID, + ID: ua.ID, + UserID: ua.UserID, AccountID: ua.AccountID, - Roles: *updateReq.Roles, - Status: ua.Status, - CreatedAt:ua.CreatedAt, + Roles: *updateReq.Roles, + Status: ua.Status, + CreatedAt: ua.CreatedAt, UpdatedAt: now, }, } @@ -518,7 +517,7 @@ func TestAccountCrud(t *testing.T) { // Archive (soft-delete) the user account. err = RemoveAccount(tests.Context(), tt.claims(user.ID, accountID), test.MasterDB, RemoveAccountRequest{ - UserID: user.ID, + UserID: user.ID, AccountID: accountID, }, now) if err != nil && errors.Cause(err) != tt.updateErr { @@ -543,14 +542,14 @@ func TestAccountCrud(t *testing.T) { expected := []*UserAccount{ &UserAccount{ - ID: ua.ID, - UserID: ua.UserID, - AccountID: ua.AccountID, - Roles: *updateReq.Roles, - Status: ua.Status, - CreatedAt:ua.CreatedAt, - UpdatedAt: now, - ArchivedAt: pq.NullTime{Time: now, Valid:true}, + ID: ua.ID, + UserID: ua.UserID, + AccountID: ua.AccountID, + Roles: *updateReq.Roles, + Status: ua.Status, + CreatedAt: ua.CreatedAt, + UpdatedAt: now, + ArchivedAt: pq.NullTime{Time: now, Valid: true}, }, } if diff := cmp.Diff(findRes, expected); diff != "" { @@ -561,7 +560,7 @@ func TestAccountCrud(t *testing.T) { // Delete (hard-delete) the user account. err = DeleteAccount(tests.Context(), tt.claims(user.ID, accountID), test.MasterDB, DeleteAccountRequest{ - UserID: user.ID, + UserID: user.ID, AccountID: accountID, }) if err != nil && errors.Cause(err) != tt.updateErr { @@ -586,23 +585,10 @@ func TestAccountCrud(t *testing.T) { // TestAccountFind validates all the request params are correctly parsed into a select query. func TestAccountFind(t *testing.T) { - // Ensure all the existing user accounts are deleted. - { - // Build the delete SQL statement. - query := sqlbuilder.NewDeleteBuilder() - query.DeleteFrom(usersAccountsTableName) + now := time.Now().Add(time.Hour * -2).UTC() - // Execute the query with the provided context. - sql, args := query.Build() - sql = test.MasterDB.Rebind(sql) - _, err := test.MasterDB.ExecContext(tests.Context(), sql, args...) - if err != nil { - t.Logf("\t\tGot : %+v", err) - t.Fatalf("\t%s\tDelete failed.", tests.Failed) - } - } - - now := time.Date(2018, time.October, 1, 0, 0, 0, 0, time.UTC) + startTime := now.Truncate(time.Millisecond) + var endTime time.Time var userAccounts []*UserAccount for i := 0; i <= 4; i++ { @@ -620,9 +606,9 @@ func TestAccountFind(t *testing.T) { // Create a new random account and associate that with the user. accountID := uuid.NewRandom().String() ua, err := AddAccount(tests.Context(), auth.Claims{}, test.MasterDB, AddAccountRequest{ - UserID: user.ID, + UserID: user.ID, AccountID: accountID, - Roles: []UserAccountRole{UserAccountRole_User}, + Roles: []UserAccountRole{UserAccountRole_User}, }, now.Add(time.Second*time.Duration(i))) if err != nil { t.Logf("\t\tGot : %+v", err) @@ -630,6 +616,7 @@ func TestAccountFind(t *testing.T) { } userAccounts = append(userAccounts, ua) + endTime = user.CreatedAt } type accountTest struct { @@ -641,9 +628,13 @@ func TestAccountFind(t *testing.T) { var accountTests []accountTest + createdFilter := "created_at BETWEEN ? AND ?" + // Test sort users. accountTests = append(accountTests, accountTest{"Find all order by created_at asx", UserAccountFindRequest{ + Where: &createdFilter, + Args: []interface{}{startTime, endTime}, Order: []string{"created_at"}, }, userAccounts, @@ -657,6 +648,8 @@ func TestAccountFind(t *testing.T) { } accountTests = append(accountTests, accountTest{"Find all order by created_at desc", UserAccountFindRequest{ + Where: &createdFilter, + Args: []interface{}{startTime, endTime}, Order: []string{"created_at desc"}, }, expected, @@ -667,6 +660,8 @@ func TestAccountFind(t *testing.T) { var limit uint = 2 accountTests = append(accountTests, accountTest{"Find limit", UserAccountFindRequest{ + Where: &createdFilter, + Args: []interface{}{startTime, endTime}, Order: []string{"created_at"}, Limit: &limit, }, @@ -678,6 +673,8 @@ func TestAccountFind(t *testing.T) { var offset uint = 3 accountTests = append(accountTests, accountTest{"Find limit, offset", UserAccountFindRequest{ + Where: &createdFilter, + Args: []interface{}{startTime, endTime}, Order: []string{"created_at"}, Limit: &limit, Offset: &offset, @@ -688,27 +685,24 @@ func TestAccountFind(t *testing.T) { // Test where filter. whereParts := []string{} - whereArgs := []interface{}{} + whereArgs := []interface{}{startTime, endTime} expected = []*UserAccount{} - selected := make(map[string]bool) - for i := 0; i <= 2; i++ { - ranIdx := rand.Intn(len(userAccounts)) - - id := userAccounts[ranIdx].ID - if selected[id] { + for i := 0; i <= len(userAccounts); i++ { + if rand.Intn(100) < 50 { continue } - selected[id] = true + ua := *userAccounts[i] whereParts = append(whereParts, "id = ?") - whereArgs = append(whereArgs, id) - expected = append(expected, userAccounts[ranIdx]) + whereArgs = append(whereArgs, ua.ID) + expected = append(expected, &ua) } - where := strings.Join(whereParts, " OR ") + where := createdFilter + " AND (" + strings.Join(whereParts, " OR ") + ")" accountTests = append(accountTests, accountTest{"Find where", UserAccountFindRequest{ Where: &where, Args: whereArgs, + Order: []string{"created_at"}, }, expected, nil, diff --git a/example-project/internal/user/user_test.go b/example-project/internal/user/user_test.go index e63c5ee..2732a1d 100644 --- a/example-project/internal/user/user_test.go +++ b/example-project/internal/user/user_test.go @@ -137,7 +137,6 @@ func TestApplyClaimsUserSelect(t *testing.T) { // TestCreateUser ensures all the validation tags work on Create func TestCreateUserValidation(t *testing.T) { - var userTests = []struct { name string req CreateUserRequest @@ -495,7 +494,7 @@ func TestUpdateUserPassword(t *testing.T) { now := time.Date(2018, time.October, 1, 0, 0, 0, 0, time.UTC) - var tknGen mockTokenGenerator + tknGen := &mockTokenGenerator{} // Create a new user for testing. initPass := uuid.NewRandom().String() @@ -523,7 +522,7 @@ func TestUpdateUserPassword(t *testing.T) { } // Verify that the user can be authenticated with the created user. - _, err = Authenticate(ctx, test.MasterDB, tknGen, now, user.Email, initPass) + _, err = Authenticate(ctx, test.MasterDB, tknGen, user.Email, initPass, time.Hour, now) if err != nil { t.Log("\t\tGot :", err) t.Fatalf("\t%s\tAuthenticate failed.", tests.Failed) @@ -557,7 +556,7 @@ func TestUpdateUserPassword(t *testing.T) { t.Logf("\t%s\tUpdatePassword ok.", tests.Success) // Verify that the user can be authenticated with the updated password. - _, err = Authenticate(ctx, test.MasterDB, tknGen, now, user.Email, newPass) + _, err = Authenticate(ctx, test.MasterDB, tknGen, user.Email, newPass, time.Hour, now) if err != nil { t.Log("\t\tGot :", err) t.Fatalf("\t%s\tAuthenticate failed.", tests.Failed) @@ -868,23 +867,10 @@ func TestUserCrud(t *testing.T) { // TestUserFind validates all the request params are correctly parsed into a select query. func TestUserFind(t *testing.T) { - // Ensure all the existing users are deleted. - { - // Build the delete SQL statement. - query := sqlbuilder.NewDeleteBuilder() - query.DeleteFrom(usersTableName) + now := time.Now().Add(time.Hour * -1).UTC() - // Execute the query with the provided context. - sql, args := query.Build() - sql = test.MasterDB.Rebind(sql) - _, err := test.MasterDB.ExecContext(tests.Context(), sql, args...) - if err != nil { - t.Logf("\t\tGot : %+v", err) - t.Fatalf("\t%s\tDelete failed.", tests.Failed) - } - } - - now := time.Date(2018, time.October, 1, 0, 0, 0, 0, time.UTC) + startTime := now.Truncate(time.Millisecond) + var endTime time.Time var users []*User for i := 0; i <= 4; i++ { @@ -899,6 +885,7 @@ func TestUserFind(t *testing.T) { t.Fatalf("\t%s\tCreate failed.", tests.Failed) } users = append(users, user) + endTime = user.CreatedAt } type userTest struct { @@ -910,9 +897,13 @@ func TestUserFind(t *testing.T) { var userTests []userTest + createdFilter := "created_at BETWEEN ? AND ?" + // Test sort users. - userTests = append(userTests, userTest{"Find all order by created_at asx", + userTests = append(userTests, userTest{"Find all order by created_at asc", UserFindRequest{ + Where: &createdFilter, + Args: []interface{}{startTime, endTime}, Order: []string{"created_at"}, }, users, @@ -926,6 +917,8 @@ func TestUserFind(t *testing.T) { } userTests = append(userTests, userTest{"Find all order by created_at desc", UserFindRequest{ + Where: &createdFilter, + Args: []interface{}{startTime, endTime}, Order: []string{"created_at desc"}, }, expected, @@ -936,6 +929,8 @@ func TestUserFind(t *testing.T) { var limit uint = 2 userTests = append(userTests, userTest{"Find limit", UserFindRequest{ + Where: &createdFilter, + Args: []interface{}{startTime, endTime}, Order: []string{"created_at"}, Limit: &limit, }, @@ -947,6 +942,8 @@ func TestUserFind(t *testing.T) { var offset uint = 3 userTests = append(userTests, userTest{"Find limit, offset", UserFindRequest{ + Where: &createdFilter, + Args: []interface{}{startTime, endTime}, Order: []string{"created_at"}, Limit: &limit, Offset: &offset, @@ -957,27 +954,25 @@ func TestUserFind(t *testing.T) { // Test where filter. whereParts := []string{} - whereArgs := []interface{}{} + whereArgs := []interface{}{startTime, endTime} expected = []*User{} - selected := make(map[string]bool) - for i := 0; i <= 2; i++ { - ranIdx := rand.Intn(len(users)) - - email := users[ranIdx].Email - if selected[email] { + for i := 0; i <= len(users); i++ { + if rand.Intn(100) < 50 { continue } - selected[email] = true + u := *users[i] whereParts = append(whereParts, "email = ?") - whereArgs = append(whereArgs, email) - expected = append(expected, users[ranIdx]) + whereArgs = append(whereArgs, u.Email) + expected = append(expected, &u) } - where := strings.Join(whereParts, " OR ") + + where := createdFilter + " AND (" + strings.Join(whereParts, " OR ") + ")" userTests = append(userTests, userTest{"Find where", UserFindRequest{ Where: &where, Args: whereArgs, + Order: []string{"created_at"}, }, expected, nil, @@ -998,6 +993,14 @@ func TestUserFind(t *testing.T) { } else if diff := cmp.Diff(res, tt.expected); diff != "" { t.Logf("\t\tGot: %d items", len(res)) t.Logf("\t\tWant: %d items", len(tt.expected)) + + for _, u := range res { + t.Logf("\t\tGot: %s ID", u.ID) + } + for _, u := range tt.expected { + t.Logf("\t\tExpected: %s ID", u.ID) + } + t.Fatalf("\t%s\tExpected find result to match expected. Diff:\n%s", tests.Failed, diff) } t.Logf("\t%s\tFind ok.", tests.Success)