diff --git a/internal/account/account.go b/internal/account/account.go index 5a8dc7a..19575ee 100644 --- a/internal/account/account.go +++ b/internal/account/account.go @@ -64,6 +64,11 @@ func CanReadAccount(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, ac return nil } +// CanReadAccount determines if claims has the authority to access the specified account ID. +func (repo *Repository) CanReadAccount(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, accountID string) error { + return repo.CanReadAccount(ctx, claims, repo.DbConn, accountID) +} + // CanModifyAccount determines if claims has the authority to modify the specified account ID. func CanModifyAccount(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, accountID string) error { // If the request has claims from a specific account, ensure that the claims @@ -105,6 +110,11 @@ func CanModifyAccount(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, return nil } +// CanModifyAccount determines if claims has the authority to modify the specified account ID. +func (repo *Repository) CanModifyAccount(ctx context.Context, claims auth.Claims, accountID string) error { + return CanModifyAccount(ctx, claims, repo.DbConn, accountID) +} + // applyClaimsSelect applies a sub-query to the provided query to enforce ACL based on // the claims provided. // 1. All role types can access their user ID @@ -150,7 +160,7 @@ func selectQuery() *sqlbuilder.SelectBuilder { // Find gets all the accounts from the database based on the request params. // TODO: Need to figure out why can't parse the args when appending the where // to the query. -func Find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req AccountFindRequest) (Accounts, error) { +func (repo *Repository) Find(ctx context.Context, claims auth.Claims, req AccountFindRequest) (Accounts, error) { query := selectQuery() if req.Where != "" { @@ -166,7 +176,7 @@ func Find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req AccountF query.Offset(int(*req.Offset)) } - return find(ctx, claims, dbConn, query, req.Args, req.IncludeArchived) + return find(ctx, claims, repo.DbConn, query, req.Args, req.IncludeArchived) } // find internal method for getting all the accounts from the database using a select query. @@ -242,14 +252,14 @@ func UniqueName(ctx context.Context, dbConn *sqlx.DB, name, accountId string) (b } // Create inserts a new account into the database. -func Create(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req AccountCreateRequest, now time.Time) (*Account, error) { +func (repo *Repository) Create(ctx context.Context, claims auth.Claims, req AccountCreateRequest, now time.Time) (*Account, error) { span, ctx := tracer.StartSpanFromContext(ctx, "internal.account.Create") defer span.Finish() v := webcontext.Validator() // Validation account name is unique in the database. - uniq, err := UniqueName(ctx, dbConn, req.Name, "") + uniq, err := UniqueName(ctx, repo.DbConn, req.Name, "") if err != nil { return nil, err } @@ -310,8 +320,8 @@ func Create(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req Accoun // Execute the query with the provided context. sql, args := query.Build() - sql = dbConn.Rebind(sql) - _, err = dbConn.ExecContext(ctx, sql, args...) + sql = repo.DbConn.Rebind(sql) + _, err = repo.DbConn.ExecContext(ctx, sql, args...) if err != nil { err = errors.Wrapf(err, "query - %s", query.String()) err = errors.WithMessage(err, "create account failed") @@ -322,15 +332,15 @@ func Create(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req Accoun } // ReadByID gets the specified user by ID from the database. -func ReadByID(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, id string) (*Account, error) { - return Read(ctx, claims, dbConn, AccountReadRequest{ +func (repo *Repository) ReadByID(ctx context.Context, claims auth.Claims, id string) (*Account, error) { + return repo.Read(ctx, claims, AccountReadRequest{ ID: id, IncludeArchived: false, }) } // Read gets the specified account from the database. -func Read(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req AccountReadRequest) (*Account, error) { +func (repo *Repository) Read(ctx context.Context, claims auth.Claims, req AccountReadRequest) (*Account, error) { span, ctx := tracer.StartSpanFromContext(ctx, "internal.account.Read") defer span.Finish() @@ -345,7 +355,7 @@ func Read(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req AccountR query := sqlbuilder.NewSelectBuilder() query.Where(query.Equal("id", req.ID)) - res, err := find(ctx, claims, dbConn, query, []interface{}{}, req.IncludeArchived) + res, err := find(ctx, claims, repo.DbConn, query, []interface{}{}, req.IncludeArchived) if err != nil { return nil, err } else if res == nil || len(res) == 0 { @@ -358,7 +368,7 @@ func Read(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req AccountR } // Update replaces an account in the database. -func Update(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req AccountUpdateRequest, now time.Time) error { +func (repo *Repository) Update(ctx context.Context, claims auth.Claims, req AccountUpdateRequest, now time.Time) error { span, ctx := tracer.StartSpanFromContext(ctx, "internal.account.Update") defer span.Finish() @@ -366,7 +376,7 @@ func Update(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req Accoun if req.Name != nil { // Validation account name is unique in the database. - uniq, err := UniqueName(ctx, dbConn, *req.Name, req.ID) + uniq, err := UniqueName(ctx, repo.DbConn, *req.Name, req.ID) if err != nil { return err } @@ -382,7 +392,7 @@ func Update(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req Accoun } // Ensure the claims can modify the account specified in the request. - err = CanModifyAccount(ctx, claims, dbConn, req.ID) + err = CanModifyAccount(ctx, claims, repo.DbConn, req.ID) if err != nil { return err } @@ -460,8 +470,8 @@ func Update(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req Accoun // Execute the query with the provided context. sql, args := query.Build() - sql = dbConn.Rebind(sql) - _, err = dbConn.ExecContext(ctx, sql, args...) + sql = repo.DbConn.Rebind(sql) + _, err = repo.DbConn.ExecContext(ctx, sql, args...) if err != nil { err = errors.Wrapf(err, "query - %s", query.String()) err = errors.WithMessagef(err, "update account %s failed", req.ID) @@ -472,7 +482,7 @@ func Update(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req Accoun } // Archive soft deleted the account from the database. -func Archive(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req AccountArchiveRequest, now time.Time) error { +func (repo *Repository) Archive(ctx context.Context, claims auth.Claims, req AccountArchiveRequest, now time.Time) error { span, ctx := tracer.StartSpanFromContext(ctx, "internal.account.Archive") defer span.Finish() @@ -484,7 +494,7 @@ func Archive(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req Accou } // Ensure the claims can modify the account specified in the request. - err = CanModifyAccount(ctx, claims, dbConn, req.ID) + err = CanModifyAccount(ctx, claims, repo.DbConn, req.ID) if err != nil { return err } @@ -511,8 +521,8 @@ func Archive(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req Accou // Execute the query with the provided context. sql, args := query.Build() - sql = dbConn.Rebind(sql) - _, err = dbConn.ExecContext(ctx, sql, args...) + sql = repo.DbConn.Rebind(sql) + _, err = repo.DbConn.ExecContext(ctx, sql, args...) if err != nil { err = errors.Wrapf(err, "query - %s", query.String()) err = errors.WithMessagef(err, "archive account %s failed", req.ID) @@ -531,8 +541,8 @@ func Archive(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req Accou // Execute the query with the provided context. sql, args := query.Build() - sql = dbConn.Rebind(sql) - _, err = dbConn.ExecContext(ctx, sql, args...) + sql = repo.DbConn.Rebind(sql) + _, err = repo.DbConn.ExecContext(ctx, sql, args...) if err != nil { err = errors.Wrapf(err, "query - %s", query.String()) err = errors.WithMessagef(err, "archive users for account %s failed", req.ID) @@ -544,7 +554,7 @@ func Archive(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req Accou } // Delete removes an account from the database. -func Delete(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req AccountDeleteRequest) error { +func (repo *Repository) Delete(ctx context.Context, claims auth.Claims, req AccountDeleteRequest) error { span, ctx := tracer.StartSpanFromContext(ctx, "internal.account.Delete") defer span.Finish() @@ -556,13 +566,13 @@ func Delete(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req Accoun } // Ensure the claims can modify the account specified in the request. - err = CanModifyAccount(ctx, claims, dbConn, req.ID) + err = CanModifyAccount(ctx, claims, repo.DbConn, req.ID) if err != nil { return err } // Start a new transaction to handle rollbacks on error. - tx, err := dbConn.Begin() + tx, err := repo.DbConn.Begin() if err != nil { return errors.WithStack(err) } @@ -579,7 +589,7 @@ func Delete(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req Accoun // Execute the query with the provided context. sql, args := query.Build() - sql = dbConn.Rebind(sql) + sql = repo.DbConn.Rebind(sql) _, err = tx.ExecContext(ctx, sql, args...) if err != nil { tx.Rollback() @@ -602,7 +612,7 @@ func Delete(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req Accoun // Execute the query with the provided context. sql, args := query.Build() - sql = dbConn.Rebind(sql) + sql = repo.DbConn.Rebind(sql) _, err = tx.ExecContext(ctx, sql, args...) if err != nil { tx.Rollback() @@ -620,7 +630,7 @@ func Delete(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req Accoun // Execute the query with the provided context. sql, args := query.Build() - sql = dbConn.Rebind(sql) + sql = repo.DbConn.Rebind(sql) _, err = tx.ExecContext(ctx, sql, args...) if err != nil { tx.Rollback() @@ -642,6 +652,10 @@ func Delete(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req Accoun func MockAccount(ctx context.Context, dbConn *sqlx.DB, now time.Time) (*Account, error) { s := AccountStatus_Active + repo := &Repository{ + DbConn: dbConn, + } + req := AccountCreateRequest{ Name: uuid.NewRandom().String(), Address1: "103 East Main St", @@ -652,5 +666,5 @@ func MockAccount(ctx context.Context, dbConn *sqlx.DB, now time.Time) (*Account, Zipcode: "99686", Status: &s, } - return Create(ctx, auth.Claims{}, dbConn, req, now) + return repo.Create(ctx, auth.Claims{}, req, now) } diff --git a/internal/account/account_preference/account_preference.go b/internal/account/account_preference/account_preference.go index d995f75..3dd0e41 100644 --- a/internal/account/account_preference/account_preference.go +++ b/internal/account/account_preference/account_preference.go @@ -63,7 +63,7 @@ func applyClaimsSelect(ctx context.Context, claims auth.Claims, query *sqlbuilde // Find gets all the account preferences from the database based on the request params. // TODO: Need to figure out why can't parse the args when appending the where to the query. -func Find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req AccountPreferenceFindRequest) ([]*AccountPreference, error) { +func (repo *Repository) Find(ctx context.Context, claims auth.Claims, req AccountPreferenceFindRequest) ([]*AccountPreference, error) { query := sqlbuilder.NewSelectBuilder() if req.Where != "" { query.Where(query.And(req.Where)) @@ -78,11 +78,11 @@ func Find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req AccountP query.Offset(int(*req.Offset)) } - return find(ctx, claims, dbConn, query, req.Args, req.IncludeArchived) + return find(ctx, claims, repo.DbConn, query, req.Args, req.IncludeArchived) } // FindByAccountID gets the specified account preferences for an account from the database. -func FindByAccountID(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req AccountPreferenceFindByAccountIDRequest) ([]*AccountPreference, error) { +func (repo *Repository) FindByAccountID(ctx context.Context, claims auth.Claims, req AccountPreferenceFindByAccountIDRequest) ([]*AccountPreference, error) { span, ctx := tracer.StartSpanFromContext(ctx, "internal.account_preference.FindByAccountID") defer span.Finish() @@ -106,7 +106,7 @@ func FindByAccountID(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, r query.Offset(int(*req.Offset)) } - return find(ctx, claims, dbConn, query, []interface{}{}, req.IncludeArchived) + return find(ctx, claims, repo.DbConn, query, []interface{}{}, req.IncludeArchived) } // find internal method for getting all the account preferences from the database using a select query. @@ -157,7 +157,7 @@ func find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, query *sqlbu } // Read gets the specified account preference from the database. -func Read(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req AccountPreferenceReadRequest) (*AccountPreference, error) { +func (repo *Repository) Read(ctx context.Context, claims auth.Claims, req AccountPreferenceReadRequest) (*AccountPreference, error) { span, ctx := tracer.StartSpanFromContext(ctx, "internal.account_preference.Read") defer span.Finish() @@ -173,7 +173,7 @@ func Read(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req AccountP query.Equal("account_id", req.AccountID)), query.Equal("name", req.Name)) - res, err := find(ctx, claims, dbConn, query, []interface{}{}, req.IncludeArchived) + res, err := find(ctx, claims, repo.DbConn, query, []interface{}{}, req.IncludeArchived) if err != nil { return nil, err } else if res == nil || len(res) == 0 { @@ -263,7 +263,7 @@ func Validator() *validator.Validate { } // Set inserts a new account preference or updates an existing on. -func Set(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req AccountPreferenceSetRequest, now time.Time) error { +func (repo *Repository) Set(ctx context.Context, claims auth.Claims, req AccountPreferenceSetRequest, now time.Time) error { span, ctx := tracer.StartSpanFromContext(ctx, "internal.account_preference.Set") defer span.Finish() @@ -276,7 +276,7 @@ func Set(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req AccountPr } // Ensure the claims can modify the account specified in the request. - err = account.CanModifyAccount(ctx, claims, dbConn, req.AccountID) + err = account.CanModifyAccount(ctx, claims, repo.DbConn, req.AccountID) if err != nil { return err } @@ -301,11 +301,11 @@ func Set(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req AccountPr // Execute the query with the provided context. sql, args := query.Build() - sql = dbConn.Rebind(sql) + sql = repo.DbConn.Rebind(sql) sql = sql + " ON CONFLICT ON CONSTRAINT account_preferences_pkey DO UPDATE set value = EXCLUDED.value " - _, err = dbConn.ExecContext(ctx, sql, args...) + _, err = repo.DbConn.ExecContext(ctx, sql, args...) if err != nil { err = errors.Wrapf(err, "query - %s", query.String()) err = errors.WithMessage(err, "set account preference failed") @@ -316,7 +316,7 @@ func Set(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req AccountPr } // Archive soft deleted the account preference from the database. -func Archive(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req AccountPreferenceArchiveRequest, now time.Time) error { +func (repo *Repository) Archive(ctx context.Context, claims auth.Claims, req AccountPreferenceArchiveRequest, now time.Time) error { span, ctx := tracer.StartSpanFromContext(ctx, "internal.account_preference.Archive") defer span.Finish() @@ -328,7 +328,7 @@ func Archive(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req Accou } // Ensure the claims can modify the account specified in the request. - err = account.CanModifyAccount(ctx, claims, dbConn, req.AccountID) + err = account.CanModifyAccount(ctx, claims, repo.DbConn, req.AccountID) if err != nil { return err } @@ -355,8 +355,8 @@ func Archive(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req Accou // Execute the query with the provided context. sql, args := query.Build() - sql = dbConn.Rebind(sql) - _, err = dbConn.ExecContext(ctx, sql, args...) + sql = repo.DbConn.Rebind(sql) + _, err = repo.DbConn.ExecContext(ctx, sql, args...) if err != nil { err = errors.Wrapf(err, "query - %s", query.String()) err = errors.WithMessagef(err, "archive account preference %s for account %s failed", req.Name, req.AccountID) @@ -367,7 +367,7 @@ func Archive(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req Accou } // Delete removes an account preference from the database. -func Delete(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req AccountPreferenceDeleteRequest) error { +func (repo *Repository) Delete(ctx context.Context, claims auth.Claims, req AccountPreferenceDeleteRequest) error { span, ctx := tracer.StartSpanFromContext(ctx, "internal.account_preference.Delete") defer span.Finish() @@ -379,13 +379,13 @@ func Delete(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req Accoun } // Ensure the claims can modify the account specified in the request. - err = account.CanModifyAccount(ctx, claims, dbConn, req.AccountID) + err = account.CanModifyAccount(ctx, claims, repo.DbConn, req.AccountID) if err != nil { return err } // Start a new transaction to handle rollbacks on error. - tx, err := dbConn.Begin() + tx, err := repo.DbConn.Begin() if err != nil { return errors.WithStack(err) } @@ -397,7 +397,7 @@ func Delete(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req Accoun // Execute the query with the provided context. sql, args := query.Build() - sql = dbConn.Rebind(sql) + sql = repo.DbConn.Rebind(sql) _, err = tx.ExecContext(ctx, sql, args...) if err != nil { tx.Rollback() @@ -417,10 +417,15 @@ func Delete(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req Accoun // MockAccountPreference returns a fake AccountPreference for testing. func MockAccountPreference(ctx context.Context, dbConn *sqlx.DB, now time.Time) error { + + repo := &Repository{ + DbConn: dbConn, + } + req := AccountPreferenceSetRequest{ AccountID: uuid.NewRandom().String(), Name: AccountPreference_Datetime_Format, Value: AccountPreference_Datetime_Format_Default, } - return Set(ctx, auth.Claims{}, dbConn, req, now) + return repo.Set(ctx, auth.Claims{}, req, now) } diff --git a/internal/account/account_preference/account_preference_test.go b/internal/account/account_preference/account_preference_test.go index e8886f9..2a9ebdf 100644 --- a/internal/account/account_preference/account_preference_test.go +++ b/internal/account/account_preference/account_preference_test.go @@ -1,13 +1,13 @@ package account_preference import ( - "geeks-accelerator/oss/saas-starter-kit/internal/account" "math/rand" "os" "strings" "testing" "time" + "geeks-accelerator/oss/saas-starter-kit/internal/account" "geeks-accelerator/oss/saas-starter-kit/internal/platform/auth" "geeks-accelerator/oss/saas-starter-kit/internal/platform/tests" "geeks-accelerator/oss/saas-starter-kit/internal/user_account" @@ -17,7 +17,10 @@ import ( "github.com/pkg/errors" ) -var test *tests.Test +var ( + test *tests.Test + repo *Repository +) // TestMain is the entry point for testing. func TestMain(m *testing.M) { @@ -27,6 +30,9 @@ func TestMain(m *testing.M) { func testMain(m *testing.M) int { test = tests.New() defer test.TearDown() + + repo = NewRepository(test.MasterDB) + return m.Run() } @@ -66,7 +72,7 @@ func TestSetValidation(t *testing.T) { { ctx := tests.Context() - err := Set(ctx, auth.Claims{}, test.MasterDB, tt.req, now) + err := repo.Set(ctx, auth.Claims{}, tt.req, now) if err != tt.error { // TODO: need a better way to handle validation errors as they are // of type interface validator.ValidationErrorsTranslations @@ -225,7 +231,7 @@ func TestCrud(t *testing.T) { { ctx := tests.Context() - err := Set(ctx, tt.claims(usrAcc.AccountID, usrAcc.UserID), test.MasterDB, tt.set, now) + err := repo.Set(ctx, tt.claims(usrAcc.AccountID, usrAcc.UserID), tt.set, now) if err != nil && errors.Cause(err) != tt.writeErr { t.Logf("\t\tGot : %+v", err) t.Logf("\t\tWant: %+v", tt.writeErr) @@ -234,7 +240,7 @@ func TestCrud(t *testing.T) { // If user doesn't have access to set, create one anyways to test the other endpoints. if tt.writeErr != nil { - err := Set(ctx, auth.Claims{}, test.MasterDB, tt.set, now) + err := repo.Set(ctx, auth.Claims{}, tt.set, now) if err != nil { t.Log("\t\tGot :", err) t.Fatalf("\t%s\tCreate failed.", tests.Failed) @@ -242,7 +248,7 @@ func TestCrud(t *testing.T) { } // Find the account and make sure the set where made. - readRes, err := Read(ctx, tt.claims(usrAcc.AccountID, usrAcc.UserID), test.MasterDB, AccountPreferenceReadRequest{ + readRes, err := repo.Read(ctx, tt.claims(usrAcc.AccountID, usrAcc.UserID), AccountPreferenceReadRequest{ AccountID: tt.set.AccountID, Name: tt.set.Name, }) @@ -266,7 +272,7 @@ func TestCrud(t *testing.T) { } // Archive (soft-delete) the account. - err = Archive(ctx, tt.claims(usrAcc.AccountID, usrAcc.UserID), test.MasterDB, AccountPreferenceArchiveRequest{ + err = repo.Archive(ctx, tt.claims(usrAcc.AccountID, usrAcc.UserID), AccountPreferenceArchiveRequest{ AccountID: tt.set.AccountID, Name: tt.set.Name, }, now) @@ -276,7 +282,7 @@ func TestCrud(t *testing.T) { t.Fatalf("\t%s\tArchive failed.", tests.Failed) } else if tt.findErr == nil { // Trying to find the archived account with the includeArchived false should result in not found. - _, err = Read(ctx, tt.claims(usrAcc.AccountID, usrAcc.UserID), test.MasterDB, AccountPreferenceReadRequest{ + _, err = repo.Read(ctx, tt.claims(usrAcc.AccountID, usrAcc.UserID), AccountPreferenceReadRequest{ AccountID: tt.set.AccountID, Name: tt.set.Name, }) @@ -287,7 +293,7 @@ func TestCrud(t *testing.T) { } // Trying to find the archived account with the includeArchived true should result no error. - _, err = Read(ctx, tt.claims(usrAcc.AccountID, usrAcc.UserID), test.MasterDB, AccountPreferenceReadRequest{ + _, err = repo.Read(ctx, tt.claims(usrAcc.AccountID, usrAcc.UserID), AccountPreferenceReadRequest{ AccountID: tt.set.AccountID, Name: tt.set.Name, IncludeArchived: true, @@ -300,7 +306,7 @@ func TestCrud(t *testing.T) { t.Logf("\t%s\tArchive ok.", tests.Success) // Delete (hard-delete) the account. - err = Delete(ctx, tt.claims(usrAcc.AccountID, usrAcc.UserID), test.MasterDB, AccountPreferenceDeleteRequest{ + err = repo.Delete(ctx, tt.claims(usrAcc.AccountID, usrAcc.UserID), AccountPreferenceDeleteRequest{ AccountID: tt.set.AccountID, Name: tt.set.Name, }) @@ -310,7 +316,7 @@ func TestCrud(t *testing.T) { t.Fatalf("\t%s\tDelete failed.", tests.Failed) } else if tt.writeErr == nil { // Trying to find the deleted account with the includeArchived true should result in not found. - _, err = Read(ctx, tt.claims(usrAcc.AccountID, usrAcc.UserID), test.MasterDB, AccountPreferenceReadRequest{ + _, err = repo.Read(ctx, tt.claims(usrAcc.AccountID, usrAcc.UserID), AccountPreferenceReadRequest{ AccountID: tt.set.AccountID, Name: tt.set.Name, IncludeArchived: true, @@ -362,14 +368,14 @@ func TestFind(t *testing.T) { var prefs []*AccountPreference for idx, req := range reqs { - err = Set(tests.Context(), auth.Claims{}, test.MasterDB, req, now.Add(time.Second*time.Duration(idx))) + err = repo.Set(tests.Context(), auth.Claims{}, req, now.Add(time.Second*time.Duration(idx))) if err != nil { t.Logf("\t\tGot : %+v", err) t.Logf("\t\tRequest : %+v", req) t.Fatalf("\t%s\tSet failed.", tests.Failed) } - pref, err := Read(tests.Context(), auth.Claims{}, test.MasterDB, AccountPreferenceReadRequest{ + pref, err := repo.Read(tests.Context(), auth.Claims{}, AccountPreferenceReadRequest{ AccountID: req.AccountID, Name: req.Name, }) @@ -479,7 +485,7 @@ func TestFind(t *testing.T) { { ctx := tests.Context() - res, err := Find(ctx, auth.Claims{}, test.MasterDB, tt.req) + res, err := repo.Find(ctx, auth.Claims{}, tt.req) if errors.Cause(err) != tt.error { t.Logf("\t\tGot : %+v", err) t.Logf("\t\tWant: %+v", tt.error) diff --git a/internal/account/account_preference/models.go b/internal/account/account_preference/models.go index 869da27..3d383b3 100644 --- a/internal/account/account_preference/models.go +++ b/internal/account/account_preference/models.go @@ -2,15 +2,28 @@ package account_preference import ( "context" - "github.com/pkg/errors" "time" "database/sql/driver" "geeks-accelerator/oss/saas-starter-kit/internal/platform/web" + "github.com/jmoiron/sqlx" "github.com/lib/pq" + "github.com/pkg/errors" "gopkg.in/go-playground/validator.v9" ) +// Repository defines the required dependencies for AccountPreference. +type Repository struct { + DbConn *sqlx.DB +} + +// NewRepository creates a new Repository that defines dependencies for AccountPreference. +func NewRepository(db *sqlx.DB) *Repository { + return &Repository{ + DbConn: db, + } +} + // AccountPreference represents an account setting. type AccountPreference struct { AccountID string `json:"account_id" validate:"required,uuid" example:"c4653bf9-5978-48b7-89c5-95704aebb7e2"` diff --git a/internal/account/account_test.go b/internal/account/account_test.go index 628fc54..8f14b4f 100644 --- a/internal/account/account_test.go +++ b/internal/account/account_test.go @@ -17,7 +17,10 @@ import ( "github.com/pkg/errors" ) -var test *tests.Test +var ( + test *tests.Test + repo *Repository +) // TestMain is the entry point for testing. func TestMain(m *testing.M) { @@ -27,6 +30,9 @@ func TestMain(m *testing.M) { func testMain(m *testing.M) int { test = tests.New() defer test.TearDown() + + repo = NewRepository(test.MasterDB) + return m.Run() } @@ -184,7 +190,7 @@ func TestCreateValidation(t *testing.T) { { ctx := tests.Context() - res, err := Create(ctx, auth.Claims{}, test.MasterDB, tt.req, now) + res, err := repo.Create(ctx, auth.Claims{}, tt.req, now) if err != tt.error { // TODO: need a better way to handle validation errors as they are // of type interface validator.ValidationErrorsTranslations @@ -239,7 +245,7 @@ func TestCreateValidationNameUnique(t *testing.T) { Country: "USA", Zipcode: "99686", } - account1, err := Create(ctx, auth.Claims{}, test.MasterDB, req1, now) + account1, err := repo.Create(ctx, auth.Claims{}, req1, now) if err != nil { t.Log("\t\tGot :", err) t.Fatalf("\t%s\tCreate failed.", tests.Failed) @@ -255,7 +261,7 @@ func TestCreateValidationNameUnique(t *testing.T) { Zipcode: "99686", } expectedErr := errors.New("Key: 'AccountCreateRequest.name' Error:Field validation for 'name' failed on the 'unique' tag") - _, err = Create(ctx, auth.Claims{}, test.MasterDB, req2, now) + _, err = repo.Create(ctx, auth.Claims{}, req2, now) if err == nil { t.Logf("\t\tWant: %+v", expectedErr) t.Fatalf("\t%s\tCreate failed.", tests.Failed) @@ -349,7 +355,7 @@ func TestCreateClaims(t *testing.T) { { ctx := tests.Context() - _, err := Create(ctx, auth.Claims{}, test.MasterDB, tt.req, now) + _, err := repo.Create(ctx, auth.Claims{}, tt.req, now) if errors.Cause(err) != tt.error { t.Logf("\t\tGot : %+v", err) t.Logf("\t\tWant: %+v", tt.error) @@ -396,7 +402,7 @@ func TestUpdateValidation(t *testing.T) { { ctx := tests.Context() - err := Update(ctx, auth.Claims{}, test.MasterDB, tt.req, now) + err := repo.Update(ctx, auth.Claims{}, tt.req, now) if err != tt.error { // TODO: need a better way to handle validation errors as they are // of type interface validator.ValidationErrorsTranslations @@ -440,7 +446,7 @@ func TestUpdateValidationNameUnique(t *testing.T) { Country: "USA", Zipcode: "99686", } - account1, err := Create(ctx, auth.Claims{}, test.MasterDB, req1, now) + account1, err := repo.Create(ctx, auth.Claims{}, req1, now) if err != nil { t.Log("\t\tGot :", err) t.Fatalf("\t%s\tCreate failed.", tests.Failed) @@ -455,7 +461,7 @@ func TestUpdateValidationNameUnique(t *testing.T) { Country: "USA", Zipcode: "99686", } - account2, err := Create(ctx, auth.Claims{}, test.MasterDB, req2, now) + account2, err := repo.Create(ctx, auth.Claims{}, req2, now) if err != nil { t.Log("\t\tGot :", err) t.Fatalf("\t%s\tCreate failed.", tests.Failed) @@ -467,7 +473,7 @@ func TestUpdateValidationNameUnique(t *testing.T) { Name: &account1.Name, } expectedErr := errors.New("Key: 'AccountUpdateRequest.name' Error:Field validation for 'name' failed on the 'unique' tag") - err = Update(ctx, auth.Claims{}, test.MasterDB, updateReq, now) + err = repo.Update(ctx, auth.Claims{}, updateReq, now) if err == nil { t.Logf("\t\tWant: %+v", expectedErr) t.Fatalf("\t%s\tUpdate failed.", tests.Failed) @@ -728,7 +734,7 @@ func TestCrud(t *testing.T) { // Always create the new account with empty claims, testing claims for create account // will be handled separately. - account, err := Create(ctx, auth.Claims{}, test.MasterDB, tt.create, now) + account, err := repo.Create(ctx, auth.Claims{}, tt.create, now) if err != nil { t.Log("\t\tGot :", err) t.Fatalf("\t%s\tCreate failed.", tests.Failed) @@ -744,7 +750,7 @@ func TestCrud(t *testing.T) { // Update the account. updateReq := tt.update(account) - err = Update(ctx, tt.claims(account, userId), test.MasterDB, updateReq, now) + err = repo.Update(ctx, tt.claims(account, userId), updateReq, now) if err != nil && errors.Cause(err) != tt.updateErr { t.Logf("\t\tGot : %+v", err) t.Logf("\t\tWant: %+v", tt.updateErr) @@ -753,7 +759,7 @@ func TestCrud(t *testing.T) { t.Logf("\t%s\tUpdate ok.", tests.Success) // Find the account and make sure the updates where made. - findRes, err := ReadByID(ctx, tt.claims(account, userId), test.MasterDB, account.ID) + findRes, err := repo.ReadByID(ctx, tt.claims(account, userId), account.ID) if err != nil && errors.Cause(err) != tt.findErr { t.Logf("\t\tGot : %+v", err) t.Logf("\t\tWant: %+v", tt.findErr) @@ -767,14 +773,14 @@ func TestCrud(t *testing.T) { } // Archive (soft-delete) the account. - err = Archive(ctx, tt.claims(account, userId), test.MasterDB, AccountArchiveRequest{ID: account.ID}, now) + err = repo.Archive(ctx, tt.claims(account, userId), AccountArchiveRequest{ID: account.ID}, now) if err != nil && errors.Cause(err) != tt.updateErr { t.Logf("\t\tGot : %+v", err) t.Logf("\t\tWant: %+v", tt.updateErr) t.Fatalf("\t%s\tArchive failed.", tests.Failed) } else if tt.updateErr == nil { // Trying to find the archived account with the includeArchived false should result in not found. - _, err = ReadByID(ctx, tt.claims(account, userId), test.MasterDB, account.ID) + _, err = repo.ReadByID(ctx, tt.claims(account, userId), account.ID) if err != nil && errors.Cause(err) != ErrNotFound { t.Logf("\t\tGot : %+v", err) t.Logf("\t\tWant: %+v", ErrNotFound) @@ -782,7 +788,7 @@ func TestCrud(t *testing.T) { } // Trying to find the archived account with the includeArchived true should result no error. - _, err = Read(ctx, tt.claims(account, userId), test.MasterDB, + _, err = repo.Read(ctx, tt.claims(account, userId), AccountReadRequest{ID: account.ID, IncludeArchived: true}) if err != nil { t.Log("\t\tGot :", err) @@ -792,14 +798,14 @@ func TestCrud(t *testing.T) { t.Logf("\t%s\tArchive ok.", tests.Success) // Delete (hard-delete) the account. - err = Delete(ctx, tt.claims(account, userId), test.MasterDB, AccountDeleteRequest{ID: account.ID}) + err = repo.Delete(ctx, tt.claims(account, userId), AccountDeleteRequest{ID: account.ID}) if err != nil && errors.Cause(err) != tt.updateErr { t.Logf("\t\tGot : %+v", err) t.Logf("\t\tWant: %+v", tt.updateErr) t.Fatalf("\t%s\tUpdate failed.", tests.Failed) } else if tt.updateErr == nil { // Trying to find the deleted account with the includeArchived true should result in not found. - _, err = ReadByID(ctx, tt.claims(account, userId), test.MasterDB, account.ID) + _, err = repo.ReadByID(ctx, tt.claims(account, userId), account.ID) if errors.Cause(err) != ErrNotFound { t.Logf("\t\tGot : %+v", err) t.Logf("\t\tWant: %+v", ErrNotFound) @@ -822,7 +828,7 @@ func TestFind(t *testing.T) { var accounts []*Account for i := 0; i <= 4; i++ { - account, err := Create(tests.Context(), auth.Claims{}, test.MasterDB, AccountCreateRequest{ + account, err := repo.Create(tests.Context(), auth.Claims{}, AccountCreateRequest{ Name: uuid.NewRandom().String(), Address1: "103 East Main St", Address2: "Unit 546", @@ -935,7 +941,7 @@ func TestFind(t *testing.T) { { ctx := tests.Context() - res, err := Find(ctx, auth.Claims{}, test.MasterDB, tt.req) + res, err := repo.Find(ctx, auth.Claims{}, tt.req) if errors.Cause(err) != tt.error { t.Logf("\t\tGot : %+v", err) t.Logf("\t\tWant: %+v", tt.error) diff --git a/internal/account/models.go b/internal/account/models.go index 0907b36..843ce7c 100644 --- a/internal/account/models.go +++ b/internal/account/models.go @@ -5,14 +5,27 @@ import ( "database/sql" "database/sql/driver" "encoding/json" - "geeks-accelerator/oss/saas-starter-kit/internal/platform/web" "time" + "geeks-accelerator/oss/saas-starter-kit/internal/platform/web" + "github.com/jmoiron/sqlx" "github.com/lib/pq" "github.com/pkg/errors" "gopkg.in/go-playground/validator.v9" ) +// Repository defines the required dependencies for Account. +type Repository struct { + DbConn *sqlx.DB +} + +// NewRepository creates a new Repository that defines dependencies for Account. +func NewRepository(db *sqlx.DB) *Repository { + return &Repository{ + DbConn: db, + } +} + // Account represents someone with access to our system. type Account struct { ID string `json:"id" validate:"required,uuid" example:"c4653bf9-5978-48b7-89c5-95704aebb7e2"` diff --git a/internal/project/models.go b/internal/project/models.go index ba52589..eff7cfb 100644 --- a/internal/project/models.go +++ b/internal/project/models.go @@ -2,14 +2,28 @@ package project import ( "context" + "time" + "database/sql/driver" "geeks-accelerator/oss/saas-starter-kit/internal/platform/web" + "github.com/jmoiron/sqlx" "github.com/lib/pq" "github.com/pkg/errors" "gopkg.in/go-playground/validator.v9" - "time" ) +// Repository defines the required dependencies for Project. +type Repository struct { + DbConn *sqlx.DB +} + +// NewRepository creates a new Repository that defines dependencies for Project. +func NewRepository(db *sqlx.DB) *Repository { + return &Repository{ + DbConn: db, + } +} + // Project represents a workflow. type Project struct { ID string `json:"id" validate:"required,uuid" example:"985f1746-1d9f-459f-a2d9-fc53ece5ae86"` diff --git a/internal/project/project.go b/internal/project/project.go index 990fdd8..b4fe3e2 100644 --- a/internal/project/project.go +++ b/internal/project/project.go @@ -3,6 +3,8 @@ package project import ( "context" "database/sql" + "time" + "geeks-accelerator/oss/saas-starter-kit/internal/platform/auth" "geeks-accelerator/oss/saas-starter-kit/internal/platform/web/webcontext" "github.com/huandu/go-sqlbuilder" @@ -10,7 +12,6 @@ import ( "github.com/pborman/uuid" "github.com/pkg/errors" "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer" - "time" ) const ( @@ -27,7 +28,7 @@ var ( ) // CanReadProject determines if claims has the authority to access the specified project by id. -func CanReadProject(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, id string) error { +func (repo *Repository) CanReadProject(ctx context.Context, claims auth.Claims, id string) error { // If the request has claims from a specific project, ensure that the claims // has the correct access to the project. @@ -40,9 +41,9 @@ func CanReadProject(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, id )) queryStr, args := query.Build() - queryStr = dbConn.Rebind(queryStr) + queryStr = repo.DbConn.Rebind(queryStr) var id string - err := dbConn.QueryRowContext(ctx, queryStr, args...).Scan(&id) + err := repo.DbConn.QueryRowContext(ctx, queryStr, args...).Scan(&id) if err != nil && err != sql.ErrNoRows { err = errors.Wrapf(err, "query - %s", query.String()) return err @@ -60,8 +61,8 @@ func CanReadProject(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, id } // CanModifyProject determines if claims has the authority to modify the specified project by id. -func CanModifyProject(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, id string) error { - err := CanReadProject(ctx, claims, dbConn, id) +func (repo *Repository) CanModifyProject(ctx context.Context, claims auth.Claims, id string) error { + err := repo.CanReadProject(ctx, claims, id) if err != nil { return err } @@ -124,9 +125,9 @@ func findRequestQuery(req ProjectFindRequest) (*sqlbuilder.SelectBuilder, []inte } // Find gets all the projects from the database based on the request params. -func Find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req ProjectFindRequest) (Projects, error) { +func (repo *Repository) Find(ctx context.Context, claims auth.Claims, req ProjectFindRequest) (Projects, error) { query, args := findRequestQuery(req) - return find(ctx, claims, dbConn, query, args, req.IncludeArchived) + return find(ctx, claims, repo.DbConn, query, args, req.IncludeArchived) } // find internal method for getting all the projects from the database using a select query. @@ -177,15 +178,15 @@ func find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, query *sqlbu } // ReadByID gets the specified project by ID from the database. -func ReadByID(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, id string) (*Project, error) { - return Read(ctx, claims, dbConn, ProjectReadRequest{ +func (repo *Repository) ReadByID(ctx context.Context, claims auth.Claims, id string) (*Project, error) { + return repo.Read(ctx, claims, ProjectReadRequest{ ID: id, IncludeArchived: false, }) } // Read gets the specified project from the database. -func Read(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req ProjectReadRequest) (*Project, error) { +func (repo *Repository) Read(ctx context.Context, claims auth.Claims, req ProjectReadRequest) (*Project, error) { span, ctx := tracer.StartSpanFromContext(ctx, "internal.project.Read") defer span.Finish() @@ -200,7 +201,7 @@ func Read(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req ProjectR query := sqlbuilder.NewSelectBuilder() query.Where(query.Equal("id", req.ID)) - res, err := find(ctx, claims, dbConn, query, []interface{}{}, req.IncludeArchived) + res, err := find(ctx, claims, repo.DbConn, query, []interface{}{}, req.IncludeArchived) if err != nil { return nil, err } else if res == nil || len(res) == 0 { @@ -213,7 +214,7 @@ func Read(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req ProjectR } // Create inserts a new project into the database. -func Create(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req ProjectCreateRequest, now time.Time) (*Project, error) { +func (repo *Repository) Create(ctx context.Context, claims auth.Claims, req ProjectCreateRequest, now time.Time) (*Project, error) { span, ctx := tracer.StartSpanFromContext(ctx, "internal.project.Create") defer span.Finish() if claims.Audience != "" { @@ -290,8 +291,8 @@ func Create(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req Projec // Execute the query with the provided context. sql, args := query.Build() - sql = dbConn.Rebind(sql) - _, err = dbConn.ExecContext(ctx, sql, args...) + sql = repo.DbConn.Rebind(sql) + _, err = repo.DbConn.ExecContext(ctx, sql, args...) if err != nil { err = errors.Wrapf(err, "query - %s", query.String()) err = errors.WithMessage(err, "create project failed") @@ -302,7 +303,7 @@ func Create(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req Projec } // Update replaces an project in the database. -func Update(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req ProjectUpdateRequest, now time.Time) error { +func (repo *Repository) Update(ctx context.Context, claims auth.Claims, req ProjectUpdateRequest, now time.Time) error { span, ctx := tracer.StartSpanFromContext(ctx, "internal.project.Update") defer span.Finish() @@ -314,7 +315,7 @@ func Update(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req Projec } // Ensure the claims can modify the project specified in the request. - err = CanModifyProject(ctx, claims, dbConn, req.ID) + err = repo.CanModifyProject(ctx, claims, req.ID) if err != nil { return err } @@ -352,8 +353,8 @@ func Update(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req Projec query.Where(query.Equal("id", req.ID)) // Execute the query with the provided context. sql, args := query.Build() - sql = dbConn.Rebind(sql) - _, err = dbConn.ExecContext(ctx, sql, args...) + sql = repo.DbConn.Rebind(sql) + _, err = repo.DbConn.ExecContext(ctx, sql, args...) if err != nil { err = errors.Wrapf(err, "query - %s", query.String()) err = errors.WithMessagef(err, "update project %s failed", req.ID) @@ -364,7 +365,7 @@ func Update(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req Projec } // Archive soft deleted the project from the database. -func Archive(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req ProjectArchiveRequest, now time.Time) error { +func (repo *Repository) Archive(ctx context.Context, claims auth.Claims, req ProjectArchiveRequest, now time.Time) error { span, ctx := tracer.StartSpanFromContext(ctx, "internal.project.Archive") defer span.Finish() @@ -376,7 +377,7 @@ func Archive(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req Proje } // Ensure the claims can modify the project specified in the request. - err = CanModifyProject(ctx, claims, dbConn, req.ID) + err = repo.CanModifyProject(ctx, claims, req.ID) if err != nil { return err } @@ -401,8 +402,8 @@ func Archive(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req Proje query.Where(query.Equal("id", req.ID)) // Execute the query with the provided context. sql, args := query.Build() - sql = dbConn.Rebind(sql) - _, err = dbConn.ExecContext(ctx, sql, args...) + sql = repo.DbConn.Rebind(sql) + _, err = repo.DbConn.ExecContext(ctx, sql, args...) if err != nil { err = errors.Wrapf(err, "query - %s", query.String()) err = errors.WithMessagef(err, "archive project %s failed", req.ID) @@ -413,7 +414,7 @@ func Archive(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req Proje } // Delete removes an project from the database. -func Delete(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req ProjectDeleteRequest) error { +func (repo *Repository) Delete(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req ProjectDeleteRequest) error { span, ctx := tracer.StartSpanFromContext(ctx, "internal.project.Delete") defer span.Finish() @@ -425,7 +426,7 @@ func Delete(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req Projec } // Ensure the claims can modify the project specified in the request. - err = CanModifyProject(ctx, claims, dbConn, req.ID) + err = repo.CanModifyProject(ctx, claims, req.ID) if err != nil { return err } diff --git a/internal/project/project_test.go b/internal/project/project_test.go index f097c9e..8e702aa 100644 --- a/internal/project/project_test.go +++ b/internal/project/project_test.go @@ -1,15 +1,19 @@ package project import ( + "os" + "testing" + "geeks-accelerator/oss/saas-starter-kit/internal/platform/auth" "geeks-accelerator/oss/saas-starter-kit/internal/platform/tests" "github.com/google/go-cmp/cmp" "github.com/huandu/go-sqlbuilder" - "os" - "testing" ) -var test *tests.Test +var ( + test *tests.Test + repo *Repository +) // TestMain is the entry point for testing. func TestMain(m *testing.M) { @@ -19,6 +23,9 @@ func TestMain(m *testing.M) { func testMain(m *testing.M) int { test = tests.New() defer test.TearDown() + + repo = NewRepository(test.MasterDB) + return m.Run() } diff --git a/internal/project-routes/project_routes.go b/internal/project_route/project_routes.go similarity index 64% rename from internal/project-routes/project_routes.go rename to internal/project_route/project_routes.go index 14ab093..2c98003 100644 --- a/internal/project-routes/project_routes.go +++ b/internal/project_route/project_routes.go @@ -1,17 +1,17 @@ -package project_routes +package project_route import ( "github.com/pkg/errors" "net/url" ) -type ProjectRoutes struct { +type ProjectRoute struct { webAppUrl url.URL webApiUrl url.URL } -func New(apiBaseUrl, appBaseUrl string) (ProjectRoutes, error) { - var r ProjectRoutes +func New(apiBaseUrl, appBaseUrl string) (ProjectRoute, error) { + var r ProjectRoute apiUrl, err := url.Parse(apiBaseUrl) if err != nil { @@ -28,37 +28,37 @@ func New(apiBaseUrl, appBaseUrl string) (ProjectRoutes, error) { return r, nil } -func (r ProjectRoutes) WebAppUrl(urlPath string) string { +func (r ProjectRoute) WebAppUrl(urlPath string) string { u := r.webAppUrl u.Path = urlPath return u.String() } -func (r ProjectRoutes) WebApiUrl(urlPath string) string { +func (r ProjectRoute) WebApiUrl(urlPath string) string { u := r.webApiUrl u.Path = urlPath return u.String() } -func (r ProjectRoutes) UserResetPassword(resetHash string) string { +func (r ProjectRoute) UserResetPassword(resetHash string) string { u := r.webAppUrl u.Path = "/user/reset-password/" + resetHash return u.String() } -func (r ProjectRoutes) UserInviteAccept(inviteHash string) string { +func (r ProjectRoute) UserInviteAccept(inviteHash string) string { u := r.webAppUrl u.Path = "/users/invite/" + inviteHash return u.String() } -func (r ProjectRoutes) ApiDocs() string { +func (r ProjectRoute) ApiDocs() string { u := r.webApiUrl u.Path = "/docs" return u.String() } -func (r ProjectRoutes) ApiDocsJson() string { +func (r ProjectRoute) ApiDocsJson() string { u := r.webApiUrl u.Path = "/docs/doc.json" return u.String() diff --git a/internal/signup/models.go b/internal/signup/models.go index 8db28ca..a84b272 100644 --- a/internal/signup/models.go +++ b/internal/signup/models.go @@ -2,10 +2,31 @@ package signup import ( "context" + "geeks-accelerator/oss/saas-starter-kit/internal/account" "geeks-accelerator/oss/saas-starter-kit/internal/user" + "geeks-accelerator/oss/saas-starter-kit/internal/user_account" + "github.com/jmoiron/sqlx" ) +// Repository defines the required dependencies for Signup. +type Repository struct { + DbConn *sqlx.DB + User *user.Repository + UserAccount *user_account.Repository + Account *account.Repository +} + +// NewRepository creates a new Repository that defines dependencies for Signup. +func NewRepository(db *sqlx.DB, user *user.Repository, userAccount *user_account.Repository, account *account.Repository) *Repository { + return &Repository{ + DbConn: db, + User: user, + UserAccount: userAccount, + Account: account, + } +} + // SignupRequest contains information needed perform signup. type SignupRequest struct { Account SignupAccount `json:"account" validate:"required"` // Account details. diff --git a/internal/signup/signup.go b/internal/signup/signup.go index ce5e51c..4c049eb 100644 --- a/internal/signup/signup.go +++ b/internal/signup/signup.go @@ -9,25 +9,24 @@ import ( "geeks-accelerator/oss/saas-starter-kit/internal/platform/web/webcontext" "geeks-accelerator/oss/saas-starter-kit/internal/user" "geeks-accelerator/oss/saas-starter-kit/internal/user_account" - "github.com/jmoiron/sqlx" "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer" ) // Signup performs the steps needed to create a new account, new user and then associate // both records with a new user_account entry. -func Signup(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req SignupRequest, now time.Time) (*SignupResult, error) { +func (repo *Repository) Signup(ctx context.Context, claims auth.Claims, req SignupRequest, now time.Time) (*SignupResult, error) { span, ctx := tracer.StartSpanFromContext(ctx, "internal.signup.Signup") defer span.Finish() // Validate the user email address is unique in the database. - uniqEmail, err := user.UniqueEmail(ctx, dbConn, req.User.Email, "") + uniqEmail, err := user.UniqueEmail(ctx, repo.DbConn, req.User.Email, "") if err != nil { return nil, err } ctx = webcontext.ContextAddUniqueValue(ctx, req.User, "Email", uniqEmail) // Validate the account name is unique in the database. - uniqName, err := account.UniqueName(ctx, dbConn, req.Account.Name, "") + uniqName, err := account.UniqueName(ctx, repo.DbConn, req.Account.Name, "") if err != nil { return nil, err } @@ -52,7 +51,7 @@ func Signup(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req Signup } // Execute user creation. - resp.User, err = user.Create(ctx, claims, dbConn, userReq, now) + resp.User, err = repo.User.Create(ctx, claims, userReq, now) if err != nil { return nil, err } @@ -73,7 +72,7 @@ func Signup(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req Signup } // Execute account creation. - resp.Account, err = account.Create(ctx, claims, dbConn, accountReq, now) + resp.Account, err = repo.Account.Create(ctx, claims, accountReq, now) if err != nil { return nil, err } @@ -87,7 +86,7 @@ func Signup(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req Signup //Status: Use default value } - _, err = user_account.Create(ctx, claims, dbConn, ua, now) + _, err = repo.UserAccount.Create(ctx, claims, ua, now) if err != nil { return nil, err } diff --git a/internal/signup/signup_test.go b/internal/signup/signup_test.go index a8d7e95..369f114 100644 --- a/internal/signup/signup_test.go +++ b/internal/signup/signup_test.go @@ -1,19 +1,26 @@ package signup import ( - "geeks-accelerator/oss/saas-starter-kit/internal/user_auth" "os" "testing" "time" + "geeks-accelerator/oss/saas-starter-kit/internal/account" + "geeks-accelerator/oss/saas-starter-kit/internal/account/account_preference" "geeks-accelerator/oss/saas-starter-kit/internal/platform/auth" "geeks-accelerator/oss/saas-starter-kit/internal/platform/tests" + "geeks-accelerator/oss/saas-starter-kit/internal/user" + "geeks-accelerator/oss/saas-starter-kit/internal/user_account" + "geeks-accelerator/oss/saas-starter-kit/internal/user_auth" "github.com/google/go-cmp/cmp" "github.com/pborman/uuid" "github.com/pkg/errors" ) -var test *tests.Test +var ( + test *tests.Test + repo *Repository +) // TestMain is the entry point for testing. func TestMain(m *testing.M) { @@ -23,6 +30,13 @@ func TestMain(m *testing.M) { func testMain(m *testing.M) int { test = tests.New() defer test.TearDown() + + userRepo := user.MockRepository(test.MasterDB) + userAccRepo := user_account.NewRepository(test.MasterDB) + accRepo := account.NewRepository(test.MasterDB) + + repo = NewRepository(test.MasterDB, userRepo, userAccRepo, accRepo) + return m.Run() } @@ -63,7 +77,7 @@ func TestSignupValidation(t *testing.T) { { ctx := tests.Context() - res, err := Signup(ctx, auth.Claims{}, test.MasterDB, tt.req, now) + res, err := repo.Signup(ctx, auth.Claims{}, tt.req, now) if err != tt.error { // TODO: need a better way to handle validation errors as they are // of type interface validator.ValidationErrorsTranslations @@ -127,9 +141,12 @@ func TestSignupFull(t *testing.T) { tknGen := &auth.MockTokenGenerator{} + accPrefRepo := account_preference.NewRepository(test.MasterDB) + authRepo := user_auth.NewRepository(test.MasterDB, tknGen, repo.User, repo.UserAccount, accPrefRepo) + t.Log("Given the need to ensure signup works.") { - res, err := Signup(ctx, auth.Claims{}, test.MasterDB, req, now) + res, err := repo.Signup(ctx, auth.Claims{}, req, now) if err != nil { t.Logf("\t\tGot error : %+v", err) t.Fatalf("\t%s\tSignup failed.", tests.Failed) @@ -162,7 +179,7 @@ func TestSignupFull(t *testing.T) { t.Logf("\t%s\tSignup ok.", tests.Success) // Verify that the user can be authenticated with the updated password. - _, err = user_auth.Authenticate(ctx, test.MasterDB, tknGen, user_auth.AuthenticateRequest{ + _, err = authRepo.Authenticate(ctx, user_auth.AuthenticateRequest{ Email: res.User.Email, Password: req.User.Password, }, time.Hour, now) diff --git a/internal/user/models.go b/internal/user/models.go index b4c65c5..7860dd3 100644 --- a/internal/user/models.go +++ b/internal/user/models.go @@ -4,34 +4,34 @@ import ( "context" "database/sql" "encoding/json" - "geeks-accelerator/oss/saas-starter-kit/internal/platform/notify" - "geeks-accelerator/oss/saas-starter-kit/internal/platform/web" - "geeks-accelerator/oss/saas-starter-kit/internal/platform/web/webcontext" - "github.com/jmoiron/sqlx" - "github.com/pkg/errors" - "github.com/sudo-suhas/symcrypto" "strconv" "strings" "time" + "geeks-accelerator/oss/saas-starter-kit/internal/platform/notify" + "geeks-accelerator/oss/saas-starter-kit/internal/platform/web" + "geeks-accelerator/oss/saas-starter-kit/internal/platform/web/webcontext" + "github.com/jmoiron/sqlx" "github.com/lib/pq" + "github.com/pkg/errors" + "github.com/sudo-suhas/symcrypto" ) // Repository defines the required dependencies for User. type Repository struct { - DbConn *sqlx.DB - ResetUrl func(string) string - Notify notify.Email - SecretKey string + DbConn *sqlx.DB + ResetUrl func(string) string + Notify notify.Email + secretKey string } // NewRepository creates a new Repository that defines dependencies for User. func NewRepository(db *sqlx.DB, resetUrl func(string) string, notify notify.Email, secretKey string) *Repository { return &Repository{ - DbConn: db, - ResetUrl: resetUrl, - Notify: notify, - SecretKey: secretKey, + DbConn: db, + ResetUrl: resetUrl, + Notify: notify, + secretKey: secretKey, } } diff --git a/internal/user/user.go b/internal/user/user.go index 81d036e..58d005c 100644 --- a/internal/user/user.go +++ b/internal/user/user.go @@ -6,6 +6,7 @@ import ( "time" "geeks-accelerator/oss/saas-starter-kit/internal/platform/auth" + "geeks-accelerator/oss/saas-starter-kit/internal/platform/notify" "geeks-accelerator/oss/saas-starter-kit/internal/platform/web/webcontext" "github.com/huandu/go-sqlbuilder" "github.com/jmoiron/sqlx" @@ -200,11 +201,11 @@ func findRequestQuery(req UserFindRequest) (*sqlbuilder.SelectBuilder, []interfa // Find gets all the users from the database based on the request params. func (repo *Repository) Find(ctx context.Context, claims auth.Claims, req UserFindRequest) (Users, error) { query, args := findRequestQuery(req) - return repo.find(ctx, claims, query, args, req.IncludeArchived) + return find(ctx, claims, repo.DbConn, query, args, req.IncludeArchived) } // find internal method for getting all the users from the database using a select query. -func (repo *Repository) find(ctx context.Context, claims auth.Claims, query *sqlbuilder.SelectBuilder, args []interface{}, includedArchived bool) (Users, error) { +func find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, query *sqlbuilder.SelectBuilder, args []interface{}, includedArchived bool) (Users, error) { span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.Find") defer span.Finish() @@ -221,11 +222,11 @@ func (repo *Repository) find(ctx context.Context, claims auth.Claims, query *sql return nil, err } queryStr, queryArgs := query.Build() - queryStr = repo.DbConn.Rebind(queryStr) + queryStr = dbConn.Rebind(queryStr) args = append(args, queryArgs...) // fetch all places from the db - rows, err := repo.DbConn.QueryContext(ctx, queryStr, args...) + rows, err := dbConn.QueryContext(ctx, queryStr, args...) if err != nil { err = errors.Wrapf(err, "query - %s", query.String()) err = errors.WithMessage(err, "find users failed") @@ -247,17 +248,17 @@ func (repo *Repository) find(ctx context.Context, claims auth.Claims, query *sql } // Validation an email address is unique excluding the current user ID. -func (repo *Repository) UniqueEmail(ctx context.Context, email, userId string) (bool, error) { +func UniqueEmail(ctx context.Context, dbConn *sqlx.DB, email, userId string) (bool, error) { query := sqlbuilder.NewSelectBuilder().Select("id").From(userTableName) query.Where(query.And( query.Equal("email", email), query.NotEqual("id", userId), )) queryStr, args := query.Build() - queryStr = repo.DbConn.Rebind(queryStr) + queryStr = dbConn.Rebind(queryStr) var existingId string - err := repo.DbConn.QueryRowContext(ctx, queryStr, args...).Scan(&existingId) + err := dbConn.QueryRowContext(ctx, queryStr, args...).Scan(&existingId) if err != nil && err != sql.ErrNoRows { err = errors.Wrapf(err, "query - %s", query.String()) return false, err @@ -283,7 +284,7 @@ func (repo *Repository) Create(ctx context.Context, claims auth.Claims, req User v := webcontext.Validator() // Validation email address is unique in the database. - uniq, err := repo.UniqueEmail(ctx, req.Email, "") + uniq, err := UniqueEmail(ctx, repo.DbConn, req.Email, "") if err != nil { return nil, err } @@ -364,7 +365,7 @@ func (repo *Repository) CreateInvite(ctx context.Context, claims auth.Claims, re v := webcontext.Validator() // Validation email address is unique in the database. - uniq, err := repo.UniqueEmail(ctx, req.Email, "") + uniq, err := UniqueEmail(ctx, repo.DbConn, req.Email, "") if err != nil { return nil, err } @@ -448,7 +449,7 @@ func (repo *Repository) Read(ctx context.Context, claims auth.Claims, req UserRe query := selectQuery() query.Where(query.Equal("id", req.ID)) - res, err := repo.find(ctx, claims, query, []interface{}{}, req.IncludeArchived) + res, err := find(ctx, claims, repo.DbConn, query, []interface{}{}, req.IncludeArchived) if err != nil { return nil, err } else if res == nil || len(res) == 0 { @@ -469,7 +470,7 @@ func (repo *Repository) ReadByEmail(ctx context.Context, claims auth.Claims, ema query := selectQuery() query.Where(query.Equal("email", email)) - res, err := repo.find(ctx, claims, query, []interface{}{}, includedArchived) + res, err := find(ctx, claims, repo.DbConn, query, []interface{}{}, includedArchived) if err != nil { return nil, err } else if res == nil || len(res) == 0 { @@ -489,7 +490,7 @@ func (repo *Repository) Update(ctx context.Context, claims auth.Claims, req User // Validation email address is unique in the database. if req.Email != nil { // Validation email address is unique in the database. - uniq, err := repo.UniqueEmail(ctx, *req.Email, req.ID) + uniq, err := UniqueEmail(ctx, repo.DbConn, *req.Email, req.ID) if err != nil { return err } @@ -844,7 +845,7 @@ func (repo *Repository) ResetPassword(ctx context.Context, req UserResetPassword query := selectQuery() query.Where(query.Equal("email", req.Email)) - res, err := repo.find(ctx, auth.Claims{}, query, []interface{}{}, false) + res, err := find(ctx, auth.Claims{}, repo.DbConn, query, []interface{}{}, false) if err != nil { return "", err } else if res == nil || len(res) == 0 { @@ -894,7 +895,7 @@ func (repo *Repository) ResetPassword(ctx context.Context, req UserResetPassword requestIp = vals.RequestIP } - encrypted, err := NewResetHash(ctx, repo.SecretKey, resetId, requestIp, req.TTL, now) + encrypted, err := NewResetHash(ctx, repo.secretKey, resetId, requestIp, req.TTL, now) if err != nil { return "", err } @@ -927,7 +928,7 @@ func (repo *Repository) ResetConfirm(ctx context.Context, req UserResetConfirmRe return nil, err } - hash, err := ParseResetHash(ctx, repo.SecretKey, req.ResetHash, now) + hash, err := ParseResetHash(ctx, repo.secretKey, req.ResetHash, now) if err != nil { return nil, err } @@ -938,7 +939,7 @@ func (repo *Repository) ResetConfirm(ctx context.Context, req UserResetConfirmRe query := selectQuery() query.Where(query.Equal("password_reset", hash.ResetID)) - res, err := repo.find(ctx, auth.Claims{}, query, []interface{}{}, false) + res, err := find(ctx, auth.Claims{}, repo.DbConn, query, []interface{}{}, false) if err != nil { return nil, err } else if res == nil || len(res) == 0 { @@ -1020,3 +1021,14 @@ func MockUser(ctx context.Context, dbConn *sqlx.DB, now time.Time) (*MockUserRes Password: pass, }, nil } + +func MockRepository(dbConn *sqlx.DB) *Repository { + // Mock the methods needed to make a password reset. + resetUrl := func(string) string { + return "" + } + notify := ¬ify.MockEmail{} + secretKey := "6368616e676520746869732070617373" + + return NewRepository(dbConn, resetUrl, notify, secretKey) +} diff --git a/internal/user/user_test.go b/internal/user/user_test.go index edfaf49..fd40aa6 100644 --- a/internal/user/user_test.go +++ b/internal/user/user_test.go @@ -8,7 +8,6 @@ import ( "time" "geeks-accelerator/oss/saas-starter-kit/internal/platform/auth" - "geeks-accelerator/oss/saas-starter-kit/internal/platform/notify" "geeks-accelerator/oss/saas-starter-kit/internal/platform/tests" "github.com/dgrijalva/jwt-go" "github.com/google/go-cmp/cmp" @@ -32,14 +31,7 @@ func testMain(m *testing.M) int { test = tests.New() defer test.TearDown() - // Mock the methods needed to make a password reset. - resetUrl := func(string) string { - return "" - } - notify := ¬ify.MockEmail{} - secretKey := "6368616e676520746869732070617373" - - repo = NewRepository(test.MasterDB, resetUrl, notify, secretKey) + repo = MockRepository(test.MasterDB) return m.Run() } @@ -930,7 +922,7 @@ func TestFind(t *testing.T) { var users []*User for i := 0; i <= 4; i++ { - user, err := repo.Create(tests.Context(), auth.Claims{}, UserCreateRequest{ + user, err := repo.Create(tests.Context(), auth.Claims{}, UserCreateRequest{ FirstName: "Lee", LastName: "Brown", Email: uuid.NewRandom().String() + "@geeksinthewoods.com", @@ -1042,7 +1034,7 @@ func TestFind(t *testing.T) { { ctx := tests.Context() - res, err := repo.Find(ctx, auth.Claims{}, tt.req) + res, err := repo.Find(ctx, auth.Claims{}, tt.req) if errors.Cause(err) != tt.error { t.Logf("\t\tGot : %+v", err) t.Logf("\t\tWant: %+v", tt.error) diff --git a/internal/user_account/invite/invite.go b/internal/user_account/invite/invite.go index a1fe4b1..0ee8262 100644 --- a/internal/user_account/invite/invite.go +++ b/internal/user_account/invite/invite.go @@ -8,11 +8,9 @@ import ( "geeks-accelerator/oss/saas-starter-kit/internal/account" "geeks-accelerator/oss/saas-starter-kit/internal/platform/auth" - "geeks-accelerator/oss/saas-starter-kit/internal/platform/notify" "geeks-accelerator/oss/saas-starter-kit/internal/platform/web/webcontext" "geeks-accelerator/oss/saas-starter-kit/internal/user" "geeks-accelerator/oss/saas-starter-kit/internal/user_account" - "github.com/jmoiron/sqlx" "github.com/pkg/errors" "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer" ) @@ -29,7 +27,7 @@ var ( ) // SendUserInvites sends emails to the users inviting them to join an account. -func SendUserInvites(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, resetUrl func(string) string, notify notify.Email, req SendUserInvitesRequest, secretKey string, now time.Time) ([]string, error) { +func (repo *Repository) SendUserInvites(ctx context.Context, claims auth.Claims, req SendUserInvitesRequest, now time.Time) ([]string, error) { span, ctx := tracer.StartSpanFromContext(ctx, "internal.user_account.invite.SendUserInvites") defer span.Finish() @@ -42,7 +40,7 @@ func SendUserInvites(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, r } // Ensure the claims can modify the account specified in the request. - err = user_account.CanModifyAccount(ctx, claims, dbConn, req.AccountID) + err = account.CanModifyAccount(ctx, claims, repo.DbConn, req.AccountID) if err != nil { return nil, err } @@ -51,7 +49,7 @@ func SendUserInvites(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, r emailUserIDs := make(map[string]string) { // Find all users without passing in claims to search all users. - users, err := user.Find(ctx, auth.Claims{}, dbConn, user.UserFindRequest{ + users, err := repo.User.Find(ctx, auth.Claims{}, user.UserFindRequest{ Where: fmt.Sprintf("email in ('%s')", strings.Join(req.Emails, "','")), }) @@ -72,7 +70,7 @@ func SendUserInvites(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, r args = append(args, userID) } - userAccs, err := user_account.Find(ctx, claims, dbConn, user_account.UserAccountFindRequest{ + userAccs, err := repo.UserAccount.Find(ctx, claims, user_account.UserAccountFindRequest{ Where: fmt.Sprintf("user_id in ('%s') and status = '%s'", strings.Join(args, "','"), user_account.UserAccountStatus_Active.String()), @@ -99,7 +97,7 @@ func SendUserInvites(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, r continue } - u, err := user.CreateInvite(ctx, claims, dbConn, user.UserCreateInviteRequest{ + u, err := repo.User.CreateInvite(ctx, claims, user.UserCreateInviteRequest{ Email: email, }, now) if err != nil { @@ -118,7 +116,7 @@ func SendUserInvites(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, r } status := user_account.UserAccountStatus_Invited - _, err = user_account.Create(ctx, claims, dbConn, user_account.UserAccountCreateRequest{ + _, err = repo.UserAccount.Create(ctx, claims, user_account.UserAccountCreateRequest{ UserID: userID, AccountID: req.AccountID, Roles: req.Roles, @@ -133,12 +131,12 @@ func SendUserInvites(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, r req.TTL = time.Minute * 90 } - fromUser, err := user.ReadByID(ctx, claims, dbConn, req.UserID) + fromUser, err := repo.User.ReadByID(ctx, claims, req.UserID) if err != nil { return nil, err } - account, err := account.ReadByID(ctx, claims, dbConn, req.AccountID) + account, err := repo.Account.ReadByID(ctx, claims, req.AccountID) if err != nil { return nil, err } @@ -151,7 +149,7 @@ func SendUserInvites(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, r var inviteHashes []string for email, userID := range emailUserIDs { - hash, err := NewInviteHash(ctx, secretKey, userID, req.AccountID, requestIp, req.TTL, now) + hash, err := NewInviteHash(ctx, repo.secretKey, userID, req.AccountID, requestIp, req.TTL, now) if err != nil { return nil, err } @@ -159,13 +157,13 @@ func SendUserInvites(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, r data := map[string]interface{}{ "FromUser": fromUser.Response(ctx), "Account": account.Response(ctx), - "Url": resetUrl(hash), + "Url": repo.ResetUrl(hash), "Minutes": req.TTL.Minutes(), } subject := fmt.Sprintf("%s %s has invited you to %s", fromUser.FirstName, fromUser.LastName, account.Name) - err = notify.Send(ctx, email, subject, "user_invite", data) + err = repo.Notify.Send(ctx, email, subject, "user_invite", data) if err != nil { err = errors.WithMessagef(err, "Send invite to %s failed.", email) return nil, err @@ -178,7 +176,7 @@ func SendUserInvites(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, r } // AcceptInvite updates the user using the provided invite hash. -func AcceptInvite(ctx context.Context, dbConn *sqlx.DB, req AcceptInviteRequest, secretKey string, now time.Time) (*user_account.UserAccount, error) { +func (repo *Repository) AcceptInvite(ctx context.Context, req AcceptInviteRequest, now time.Time) (*user_account.UserAccount, error) { span, ctx := tracer.StartSpanFromContext(ctx, "internal.user_account.invite.AcceptInvite") defer span.Finish() @@ -190,25 +188,25 @@ func AcceptInvite(ctx context.Context, dbConn *sqlx.DB, req AcceptInviteRequest, return nil, err } - hash, err := ParseInviteHash(ctx, req.InviteHash, secretKey, now) + hash, err := ParseInviteHash(ctx, req.InviteHash, repo.secretKey, now) if err != nil { return nil, err } - u, err := user.Read(ctx, auth.Claims{}, dbConn, + u, err := repo.User.Read(ctx, auth.Claims{}, user.UserReadRequest{ID: hash.UserID, IncludeArchived: true}) if err != nil { return nil, err } if u.ArchivedAt != nil && !u.ArchivedAt.Time.IsZero() { - err = user.Restore(ctx, auth.Claims{}, dbConn, user.UserRestoreRequest{ID: hash.UserID}, now) + err = repo.User.Restore(ctx, auth.Claims{}, user.UserRestoreRequest{ID: hash.UserID}, now) if err != nil { return nil, err } } - usrAcc, err := user_account.Read(ctx, auth.Claims{}, dbConn, user_account.UserAccountReadRequest{ + usrAcc, err := repo.UserAccount.Read(ctx, auth.Claims{}, user_account.UserAccountReadRequest{ UserID: hash.UserID, AccountID: hash.AccountID, }) @@ -230,7 +228,7 @@ func AcceptInvite(ctx context.Context, dbConn *sqlx.DB, req AcceptInviteRequest, if len(u.PasswordHash) > 0 { usrAcc.Status = user_account.UserAccountStatus_Active - err = user_account.Update(ctx, auth.Claims{}, dbConn, user_account.UserAccountUpdateRequest{ + err = repo.UserAccount.Update(ctx, auth.Claims{}, user_account.UserAccountUpdateRequest{ UserID: usrAcc.UserID, AccountID: usrAcc.AccountID, Status: &usrAcc.Status, @@ -244,7 +242,7 @@ func AcceptInvite(ctx context.Context, dbConn *sqlx.DB, req AcceptInviteRequest, } // AcceptInviteUser updates the user using the provided invite hash. -func AcceptInviteUser(ctx context.Context, dbConn *sqlx.DB, req AcceptInviteUserRequest, secretKey string, now time.Time) (*user_account.UserAccount, error) { +func (repo *Repository) AcceptInviteUser(ctx context.Context, req AcceptInviteUserRequest, now time.Time) (*user_account.UserAccount, error) { span, ctx := tracer.StartSpanFromContext(ctx, "internal.user_account.invite.AcceptInviteUser") defer span.Finish() @@ -256,25 +254,25 @@ func AcceptInviteUser(ctx context.Context, dbConn *sqlx.DB, req AcceptInviteUser return nil, err } - hash, err := ParseInviteHash(ctx, req.InviteHash, secretKey, now) + hash, err := ParseInviteHash(ctx, req.InviteHash, repo.secretKey, now) if err != nil { return nil, err } - u, err := user.Read(ctx, auth.Claims{}, dbConn, + u, err := repo.User.Read(ctx, auth.Claims{}, user.UserReadRequest{ID: hash.UserID, IncludeArchived: true}) if err != nil { return nil, err } if u.ArchivedAt != nil && !u.ArchivedAt.Time.IsZero() { - err = user.Restore(ctx, auth.Claims{}, dbConn, user.UserRestoreRequest{ID: hash.UserID}, now) + err = repo.User.Restore(ctx, auth.Claims{}, user.UserRestoreRequest{ID: hash.UserID}, now) if err != nil { return nil, err } } - usrAcc, err := user_account.Read(ctx, auth.Claims{}, dbConn, user_account.UserAccountReadRequest{ + usrAcc, err := repo.UserAccount.Read(ctx, auth.Claims{}, user_account.UserAccountReadRequest{ UserID: hash.UserID, AccountID: hash.AccountID, }) @@ -293,7 +291,7 @@ func AcceptInviteUser(ctx context.Context, dbConn *sqlx.DB, req AcceptInviteUser // These three calls, user.Update, user.UpdatePassword, and user_account.Update // should probably be in a transaction! - err = user.Update(ctx, auth.Claims{}, dbConn, user.UserUpdateRequest{ + err = repo.User.Update(ctx, auth.Claims{}, user.UserUpdateRequest{ ID: hash.UserID, Email: &req.Email, FirstName: &req.FirstName, @@ -304,7 +302,7 @@ func AcceptInviteUser(ctx context.Context, dbConn *sqlx.DB, req AcceptInviteUser return nil, err } - err = user.UpdatePassword(ctx, auth.Claims{}, dbConn, user.UserUpdatePasswordRequest{ + err = repo.User.UpdatePassword(ctx, auth.Claims{}, user.UserUpdatePasswordRequest{ ID: hash.UserID, Password: req.Password, PasswordConfirm: req.PasswordConfirm, @@ -314,7 +312,7 @@ func AcceptInviteUser(ctx context.Context, dbConn *sqlx.DB, req AcceptInviteUser } usrAcc.Status = user_account.UserAccountStatus_Active - err = user_account.Update(ctx, auth.Claims{}, dbConn, user_account.UserAccountUpdateRequest{ + err = repo.UserAccount.Update(ctx, auth.Claims{}, user_account.UserAccountUpdateRequest{ UserID: usrAcc.UserID, AccountID: usrAcc.AccountID, Status: &usrAcc.Status, diff --git a/internal/user_account/invite/invite_test.go b/internal/user_account/invite/invite_test.go index 032c6e7..3018733 100644 --- a/internal/user_account/invite/invite_test.go +++ b/internal/user_account/invite/invite_test.go @@ -1,7 +1,6 @@ package invite import ( - "geeks-accelerator/oss/saas-starter-kit/internal/platform/web/webcontext" "os" "strings" "testing" @@ -11,6 +10,7 @@ import ( "geeks-accelerator/oss/saas-starter-kit/internal/platform/auth" "geeks-accelerator/oss/saas-starter-kit/internal/platform/notify" "geeks-accelerator/oss/saas-starter-kit/internal/platform/tests" + "geeks-accelerator/oss/saas-starter-kit/internal/platform/web/webcontext" "geeks-accelerator/oss/saas-starter-kit/internal/user" "geeks-accelerator/oss/saas-starter-kit/internal/user_account" "github.com/dgrijalva/jwt-go" @@ -18,7 +18,10 @@ import ( "github.com/pkg/errors" ) -var test *tests.Test +var ( + test *tests.Test + repo *Repository +) // TestMain is the entry point for testing. func TestMain(m *testing.M) { @@ -28,6 +31,20 @@ func TestMain(m *testing.M) { func testMain(m *testing.M) int { test = tests.New() defer test.TearDown() + + userRepo := user.MockRepository(test.MasterDB) + userAccRepo := user_account.NewRepository(test.MasterDB) + accRepo := account.NewRepository(test.MasterDB) + + // Mock the methods needed to make an invite. + resetUrl := func(string) string { + return "" + } + notify := ¬ify.MockEmail{} + secretKey := "6368616e676520746869732070613434" + + repo = NewRepository(test.MasterDB, userRepo, userAccRepo, accRepo, resetUrl, notify, secretKey) + return m.Run() } @@ -42,7 +59,7 @@ func TestSendUserInvites(t *testing.T) { // Create a new user for testing. initPass := uuid.NewRandom().String() - u, err := user.Create(ctx, auth.Claims{}, test.MasterDB, user.UserCreateRequest{ + u, err := repo.User.Create(ctx, auth.Claims{}, user.UserCreateRequest{ FirstName: "Lee", LastName: "Brown", Email: uuid.NewRandom().String() + "@geeksinthewoods.com", @@ -54,7 +71,7 @@ func TestSendUserInvites(t *testing.T) { t.Fatalf("\t%s\tCreate user failed.", tests.Failed) } - a, err := account.Create(ctx, auth.Claims{}, test.MasterDB, account.AccountCreateRequest{ + a, err := repo.Account.Create(ctx, auth.Claims{}, account.AccountCreateRequest{ Name: uuid.NewRandom().String(), Address1: "101 E Main", City: "Valdez", @@ -68,7 +85,7 @@ func TestSendUserInvites(t *testing.T) { } uRoles := []user_account.UserAccountRole{user_account.UserAccountRole_Admin} - _, err = user_account.Create(ctx, auth.Claims{}, test.MasterDB, user_account.UserAccountCreateRequest{ + _, err = repo.UserAccount.Create(ctx, auth.Claims{}, user_account.UserAccountCreateRequest{ UserID: u.ID, AccountID: a.ID, Roles: uRoles, @@ -91,21 +108,13 @@ func TestSendUserInvites(t *testing.T) { claims.Roles = append(claims.Roles, r.String()) } - // Mock the methods needed to make a password reset. - resetUrl := func(string) string { - return "" - } - notify := ¬ify.MockEmail{} - - secretKey := "6368616e676520746869732070617373" - // Ensure validation is working by trying ResetPassword with an empty request. { expectedErr := errors.New("Key: 'SendUserInvitesRequest.account_id' Error:Field validation for 'account_id' failed on the 'required' tag\n" + "Key: 'SendUserInvitesRequest.user_id' Error:Field validation for 'user_id' failed on the 'required' tag\n" + "Key: 'SendUserInvitesRequest.emails' Error:Field validation for 'emails' failed on the 'required' tag\n" + "Key: 'SendUserInvitesRequest.roles' Error:Field validation for 'roles' failed on the 'required' tag") - _, err = SendUserInvites(ctx, claims, test.MasterDB, resetUrl, notify, SendUserInvitesRequest{}, secretKey, now) + _, err = repo.SendUserInvites(ctx, claims, SendUserInvitesRequest{}, now) if err == nil { t.Logf("\t\tWant: %+v", expectedErr) t.Fatalf("\t%s\tInviteUsers failed.", tests.Failed) @@ -129,13 +138,13 @@ func TestSendUserInvites(t *testing.T) { } // Make the reset password request. - inviteHashes, err := SendUserInvites(ctx, claims, test.MasterDB, resetUrl, notify, SendUserInvitesRequest{ + inviteHashes, err := repo.SendUserInvites(ctx, claims, SendUserInvitesRequest{ UserID: u.ID, AccountID: a.ID, Emails: inviteEmails, Roles: []user_account.UserAccountRole{user_account.UserAccountRole_User}, TTL: ttl, - }, secretKey, now) + }, now) if err != nil { t.Log("\t\tGot :", err) t.Fatalf("\t%s\tInviteUsers failed.", tests.Failed) @@ -154,7 +163,7 @@ func TestSendUserInvites(t *testing.T) { "Key: 'AcceptInviteUserRequest.last_name' Error:Field validation for 'last_name' failed on the 'required' tag\n" + "Key: 'AcceptInviteUserRequest.password' Error:Field validation for 'password' failed on the 'required' tag\n" + "Key: 'AcceptInviteUserRequest.password_confirm' Error:Field validation for 'password_confirm' failed on the 'required' tag") - _, err = AcceptInviteUser(ctx, test.MasterDB, AcceptInviteUserRequest{}, secretKey, now) + _, err = repo.AcceptInviteUser(ctx, AcceptInviteUserRequest{}, now) if err == nil { t.Logf("\t\tWant: %+v", expectedErr) t.Fatalf("\t%s\tResetConfirm failed.", tests.Failed) @@ -174,14 +183,14 @@ func TestSendUserInvites(t *testing.T) { // Ensure the TTL is enforced. { newPass := uuid.NewRandom().String() - _, err = AcceptInviteUser(ctx, test.MasterDB, AcceptInviteUserRequest{ + _, err = repo.AcceptInviteUser(ctx, AcceptInviteUserRequest{ InviteHash: inviteHashes[0], Email: inviteEmails[0], FirstName: "Foo", LastName: "Bar", Password: newPass, PasswordConfirm: newPass, - }, secretKey, now.UTC().Add(ttl*2)) + }, now.UTC().Add(ttl*2)) if errors.Cause(err) != ErrInviteExpired { t.Logf("\t\tGot : %+v", errors.Cause(err)) t.Logf("\t\tWant: %+v", ErrInviteExpired) @@ -194,14 +203,14 @@ func TestSendUserInvites(t *testing.T) { for idx, inviteHash := range inviteHashes { newPass := uuid.NewRandom().String() - hash, err := AcceptInviteUser(ctx, test.MasterDB, AcceptInviteUserRequest{ + hash, err := repo.AcceptInviteUser(ctx, AcceptInviteUserRequest{ InviteHash: inviteHash, Email: inviteEmails[idx], FirstName: "Foo", LastName: "Bar", Password: newPass, PasswordConfirm: newPass, - }, secretKey, now) + }, now) if err != nil { t.Log("\t\tGot :", err) t.Fatalf("\t%s\tInviteAccept failed.", tests.Failed) @@ -227,14 +236,14 @@ func TestSendUserInvites(t *testing.T) { // Ensure the reset hash does not work after its used. { newPass := uuid.NewRandom().String() - _, err = AcceptInviteUser(ctx, test.MasterDB, AcceptInviteUserRequest{ + _, err = repo.AcceptInviteUser(ctx, AcceptInviteUserRequest{ InviteHash: inviteHashes[0], Email: inviteEmails[0], FirstName: "Foo", LastName: "Bar", Password: newPass, PasswordConfirm: newPass, - }, secretKey, now) + }, now) if errors.Cause(err) != ErrUserAccountActive { t.Logf("\t\tGot : %+v", errors.Cause(err)) t.Logf("\t\tWant: %+v", ErrUserAccountActive) diff --git a/internal/user_account/invite/models.go b/internal/user_account/invite/models.go index ca87007..5c7231c 100644 --- a/internal/user_account/invite/models.go +++ b/internal/user_account/invite/models.go @@ -6,12 +6,41 @@ import ( "strings" "time" + "geeks-accelerator/oss/saas-starter-kit/internal/account" + "geeks-accelerator/oss/saas-starter-kit/internal/platform/notify" "geeks-accelerator/oss/saas-starter-kit/internal/platform/web/webcontext" + "geeks-accelerator/oss/saas-starter-kit/internal/user" "geeks-accelerator/oss/saas-starter-kit/internal/user_account" + "github.com/jmoiron/sqlx" "github.com/pkg/errors" "github.com/sudo-suhas/symcrypto" ) +// Repository defines the required dependencies for User Invite. +type Repository struct { + DbConn *sqlx.DB + User *user.Repository + UserAccount *user_account.Repository + Account *account.Repository + ResetUrl func(string) string + Notify notify.Email + secretKey string +} + +// NewRepository creates a new Repository that defines dependencies for User Invite. +func NewRepository(db *sqlx.DB, user *user.Repository, userAccount *user_account.Repository, account *account.Repository, + resetUrl func(string) string, notify notify.Email, secretKey string) *Repository { + return &Repository{ + DbConn: db, + User: user, + UserAccount: userAccount, + Account: account, + ResetUrl: resetUrl, + Notify: notify, + secretKey: secretKey, + } +} + // SendUserInvitesRequest defines the data needed to make an invite request. type SendUserInvitesRequest struct { AccountID string `json:"account_id" validate:"required,uuid" example:"c4653bf9-5978-48b7-89c5-95704aebb7e2"` diff --git a/internal/user_account/models.go b/internal/user_account/models.go index e2131a4..df16106 100644 --- a/internal/user_account/models.go +++ b/internal/user_account/models.go @@ -2,13 +2,13 @@ package user_account import ( "context" - "database/sql/driver" - "github.com/jmoiron/sqlx" "strings" "time" + "database/sql/driver" "geeks-accelerator/oss/saas-starter-kit/internal/platform/auth" "geeks-accelerator/oss/saas-starter-kit/internal/platform/web" + "github.com/jmoiron/sqlx" "github.com/lib/pq" "github.com/pkg/errors" "gopkg.in/go-playground/validator.v9" diff --git a/internal/user_account/user_account.go b/internal/user_account/user_account.go index 4fde70f..e436f6e 100644 --- a/internal/user_account/user_account.go +++ b/internal/user_account/user_account.go @@ -3,12 +3,12 @@ package user_account import ( "context" "database/sql" - "geeks-accelerator/oss/saas-starter-kit/internal/user" "time" "geeks-accelerator/oss/saas-starter-kit/internal/account" "geeks-accelerator/oss/saas-starter-kit/internal/platform/auth" "geeks-accelerator/oss/saas-starter-kit/internal/platform/web/webcontext" + "geeks-accelerator/oss/saas-starter-kit/internal/user" "github.com/huandu/go-sqlbuilder" "github.com/jmoiron/sqlx" "github.com/pborman/uuid" @@ -50,13 +50,13 @@ func mapRowsToUserAccount(rows *sql.Rows) (*UserAccount, error) { // CanReadAccount determines if claims has the authority to access the specified user account by user ID. func (repo *Repository) CanReadAccount(ctx context.Context, claims auth.Claims, accountID string) error { - err := account.CanReadAccount(ctx, claims, accountID) + err := account.CanReadAccount(ctx, claims, repo.DbConn, accountID) return mapAccountError(err) } // CanModifyAccount determines if claims has the authority to modify the specified user ID. func (repo *Repository) CanModifyAccount(ctx context.Context, claims auth.Claims, accountID string) error { - err := account.CanModifyAccount(ctx, claims, accountID) + err := account.CanModifyAccount(ctx, claims, repo.DbConn, accountID) return mapAccountError(err) } @@ -131,9 +131,9 @@ func findRequestQuery(req UserAccountFindRequest) (*sqlbuilder.SelectBuilder, [] } // Find gets all the user accounts from the database based on the request params. -func Find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserAccountFindRequest) (UserAccounts, error) { +func (repo *Repository) Find(ctx context.Context, claims auth.Claims, req UserAccountFindRequest) (UserAccounts, error) { query, args := findRequestQuery(req) - return find(ctx, claims, dbConn, query, args, req.IncludeArchived) + return find(ctx, claims, repo.DbConn, query, args, req.IncludeArchived) } // Find gets all the user accounts from the database based on the select query @@ -180,7 +180,7 @@ func find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, query *sqlbu } // Retrieve gets the specified user from the database. -func FindByUserID(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, userID string, includedArchived bool) (UserAccounts, error) { +func (repo *Repository) FindByUserID(ctx context.Context, claims auth.Claims, userID string, includedArchived bool) (UserAccounts, error) { span, ctx := tracer.StartSpanFromContext(ctx, "internal.user_account.FindByUserID") defer span.Finish() @@ -190,7 +190,7 @@ func FindByUserID(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, user query.OrderBy("created_at") // Execute the find accounts method. - res, err := find(ctx, claims, dbConn, query, []interface{}{}, includedArchived) + res, err := find(ctx, claims, repo.DbConn, query, []interface{}{}, includedArchived) if err != nil { return nil, err } else if res == nil || len(res) == 0 { @@ -202,7 +202,7 @@ func FindByUserID(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, user } // Create a user account for a given user with specified roles. -func Create(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserAccountCreateRequest, now time.Time) (*UserAccount, error) { +func (repo *Repository) Create(ctx context.Context, claims auth.Claims, req UserAccountCreateRequest, now time.Time) (*UserAccount, error) { span, ctx := tracer.StartSpanFromContext(ctx, "internal.user_account.Create") defer span.Finish() @@ -214,7 +214,7 @@ func Create(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserAc } // Ensure the claims can modify the account specified in the request. - err = CanModifyAccount(ctx, claims, dbConn, req.AccountID) + err = repo.CanModifyAccount(ctx, claims, req.AccountID) if err != nil { return nil, err } @@ -237,7 +237,7 @@ func Create(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserAc existQuery.Equal("account_id", req.AccountID), existQuery.Equal("user_id", req.UserID), )) - existing, err := find(ctx, claims, dbConn, existQuery, []interface{}{}, true) + existing, err := find(ctx, claims, repo.DbConn, existQuery, []interface{}{}, true) if err != nil { return nil, err } @@ -251,7 +251,7 @@ func Create(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserAc Roles: &req.Roles, unArchive: true, } - err = Update(ctx, claims, dbConn, upReq, now) + err = repo.Update(ctx, claims, upReq, now) if err != nil { return nil, err } @@ -285,8 +285,8 @@ func Create(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserAc // Execute the query with the provided context. sql, args := query.Build() - sql = dbConn.Rebind(sql) - _, err = dbConn.ExecContext(ctx, sql, args...) + sql = repo.DbConn.Rebind(sql) + _, err = repo.DbConn.ExecContext(ctx, sql, args...) if err != nil { err = errors.Wrapf(err, "query - %s", query.String()) err = errors.WithMessagef(err, "add account %s to user %s failed", req.AccountID, req.UserID) @@ -298,7 +298,7 @@ func Create(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserAc } // Read gets the specified user account from the database. -func Read(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserAccountReadRequest) (*UserAccount, error) { +func (repo *Repository) Read(ctx context.Context, claims auth.Claims, req UserAccountReadRequest) (*UserAccount, error) { span, ctx := tracer.StartSpanFromContext(ctx, "internal.user_account.Read") defer span.Finish() @@ -315,7 +315,7 @@ func Read(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserAcco query.Equal("user_id", req.UserID), query.Equal("account_id", req.AccountID))) - res, err := find(ctx, claims, dbConn, query, []interface{}{}, req.IncludeArchived) + res, err := find(ctx, claims, repo.DbConn, query, []interface{}{}, req.IncludeArchived) if err != nil { return nil, err } else if res == nil || len(res) == 0 { @@ -328,7 +328,7 @@ func Read(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserAcco } // Update replaces a user account in the database. -func Update(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserAccountUpdateRequest, now time.Time) error { +func (repo *Repository) Update(ctx context.Context, claims auth.Claims, req UserAccountUpdateRequest, now time.Time) error { span, ctx := tracer.StartSpanFromContext(ctx, "internal.user_account.Update") defer span.Finish() @@ -340,7 +340,7 @@ func Update(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserAc } // Ensure the claims can modify the user specified in the request. - err = CanModifyAccount(ctx, claims, dbConn, req.AccountID) + err = repo.CanModifyAccount(ctx, claims, req.AccountID) if err != nil { return err } @@ -389,8 +389,8 @@ func Update(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserAc // Execute the query with the provided context. sql, args := query.Build() - sql = dbConn.Rebind(sql) - _, err = dbConn.ExecContext(ctx, sql, args...) + sql = repo.DbConn.Rebind(sql) + _, err = repo.DbConn.ExecContext(ctx, sql, args...) if err != nil { err = errors.Wrapf(err, "query - %s", query.String()) err = errors.WithMessagef(err, "update account %s for user %s failed", req.AccountID, req.UserID) @@ -401,7 +401,7 @@ func Update(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserAc } // Archive soft deleted the user account from the database. -func Archive(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserAccountArchiveRequest, now time.Time) error { +func (repo *Repository) Archive(ctx context.Context, claims auth.Claims, req UserAccountArchiveRequest, now time.Time) error { span, ctx := tracer.StartSpanFromContext(ctx, "internal.user_account.Archive") defer span.Finish() @@ -413,7 +413,7 @@ func Archive(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserA } // Ensure the claims can modify the user specified in the request. - err = CanModifyAccount(ctx, claims, dbConn, req.AccountID) + err = repo.CanModifyAccount(ctx, claims, req.AccountID) if err != nil { return err } @@ -441,8 +441,8 @@ func Archive(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserA // Execute the query with the provided context. sql, args := query.Build() - sql = dbConn.Rebind(sql) - _, err = dbConn.ExecContext(ctx, sql, args...) + sql = repo.DbConn.Rebind(sql) + _, err = repo.DbConn.ExecContext(ctx, sql, args...) if err != nil { err = errors.Wrapf(err, "query - %s", query.String()) err = errors.WithMessagef(err, "archive account %s from user %s failed", req.AccountID, req.UserID) @@ -453,7 +453,7 @@ func Archive(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserA } // Delete removes a user account from the database. -func Delete(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserAccountDeleteRequest) error { +func (repo *Repository) Delete(ctx context.Context, claims auth.Claims, req UserAccountDeleteRequest) error { span, ctx := tracer.StartSpanFromContext(ctx, "internal.user_account.Delete") defer span.Finish() @@ -465,7 +465,7 @@ func Delete(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserAc } // Ensure the claims can modify the user specified in the request. - err = CanModifyAccount(ctx, claims, dbConn, req.AccountID) + err = repo.CanModifyAccount(ctx, claims, req.AccountID) if err != nil { return err } @@ -480,8 +480,8 @@ func Delete(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserAc // Execute the query with the provided context. sql, args := query.Build() - sql = dbConn.Rebind(sql) - _, err = dbConn.ExecContext(ctx, sql, args...) + sql = repo.DbConn.Rebind(sql) + _, err = repo.DbConn.ExecContext(ctx, sql, args...) if err != nil { err = errors.Wrapf(err, "query - %s", query.String()) err = errors.WithMessagef(err, "delete account %s for user %s failed", req.AccountID, req.UserID) @@ -509,6 +509,10 @@ func MockUserAccount(ctx context.Context, dbConn *sqlx.DB, now time.Time, roles return nil, err } + repo := &Repository{ + DbConn: dbConn, + } + status := UserAccountStatus_Active req := UserAccountCreateRequest{ @@ -517,7 +521,7 @@ func MockUserAccount(ctx context.Context, dbConn *sqlx.DB, now time.Time, roles Status: &status, Roles: roles, } - ua, err := Create(ctx, auth.Claims{}, dbConn, req, now) + ua, err := repo.Create(ctx, auth.Claims{}, req, now) if err != nil { return nil, err } diff --git a/internal/user_account/user_account_test.go b/internal/user_account/user_account_test.go index 2728273..eb88466 100644 --- a/internal/user_account/user_account_test.go +++ b/internal/user_account/user_account_test.go @@ -1,7 +1,6 @@ package user_account import ( - "github.com/lib/pq" "math/rand" "os" "strings" @@ -13,6 +12,7 @@ import ( "github.com/dgrijalva/jwt-go" "github.com/google/go-cmp/cmp" "github.com/huandu/go-sqlbuilder" + "github.com/lib/pq" "github.com/pborman/uuid" "github.com/pkg/errors" ) @@ -232,7 +232,7 @@ func TestCreateValidation(t *testing.T) { t.Fatalf("\t%s\tMock account failed.", tests.Failed) } - res, err := Create(ctx, auth.Claims{}, test.MasterDB, tt.req, now) + res, err := repo.Create(ctx, auth.Claims{}, tt.req, now) if err != tt.error { // TODO: need a better way to handle validation errors as they are // of type interface validator.ValidationErrorsTranslations @@ -300,7 +300,7 @@ func TestCreateExistingEntry(t *testing.T) { AccountID: accountID, Roles: []UserAccountRole{UserAccountRole_User}, } - ua1, err := Create(ctx, auth.Claims{}, test.MasterDB, req1, now) + ua1, err := repo.Create(ctx, auth.Claims{}, req1, now) if err != nil { t.Log("\t\tGot :", err) t.Fatalf("\t%s\tCreate user account failed.", tests.Failed) @@ -313,7 +313,7 @@ func TestCreateExistingEntry(t *testing.T) { AccountID: req1.AccountID, Roles: []UserAccountRole{UserAccountRole_Admin}, } - ua2, err := Create(ctx, auth.Claims{}, test.MasterDB, req2, now) + ua2, err := repo.Create(ctx, auth.Claims{}, req2, now) if err != nil { t.Log("\t\tGot :", err) t.Fatalf("\t%s\tCreate user account failed.", tests.Failed) @@ -322,7 +322,7 @@ func TestCreateExistingEntry(t *testing.T) { } // Now archive the user account to test trying to create a new entry for an archived entry - err = Archive(tests.Context(), auth.Claims{}, test.MasterDB, UserAccountArchiveRequest{ + err = repo.Archive(tests.Context(), auth.Claims{}, UserAccountArchiveRequest{ UserID: req1.UserID, AccountID: req1.AccountID, }, now) @@ -332,7 +332,7 @@ func TestCreateExistingEntry(t *testing.T) { } // Find the archived user account - arcRes, err := Read(tests.Context(), auth.Claims{}, test.MasterDB, + arcRes, err := repo.Read(tests.Context(), auth.Claims{}, UserAccountReadRequest{UserID: req1.UserID, AccountID: req1.AccountID, IncludeArchived: true}) if err != nil || arcRes == nil { t.Log("\t\tGot :", err) @@ -347,7 +347,7 @@ func TestCreateExistingEntry(t *testing.T) { AccountID: req1.AccountID, Roles: []UserAccountRole{UserAccountRole_User}, } - ua3, err := Create(ctx, auth.Claims{}, test.MasterDB, req3, now) + ua3, err := repo.Create(ctx, auth.Claims{}, req3, now) if err != nil { t.Log("\t\tGot :", err) t.Fatalf("\t%s\tCreate user account failed.", tests.Failed) @@ -356,7 +356,7 @@ func TestCreateExistingEntry(t *testing.T) { } // Ensure the user account has archived_at empty - findRes, err := Read(tests.Context(), auth.Claims{}, test.MasterDB, + findRes, err := repo.Read(tests.Context(), auth.Claims{}, UserAccountReadRequest{UserID: req1.UserID, AccountID: req1.AccountID}) if err != nil || arcRes == nil { t.Log("\t\tGot :", err) @@ -414,7 +414,7 @@ func TestUpdateValidation(t *testing.T) { { ctx := tests.Context() - err := Update(ctx, auth.Claims{}, test.MasterDB, tt.req, now) + err := repo.Update(ctx, auth.Claims{}, tt.req, now) if err != tt.error { // TODO: need a better way to handle validation errors as they are // of type interface validator.ValidationErrorsTranslations @@ -564,7 +564,7 @@ func TestCrud(t *testing.T) { AccountID: accountID, Roles: []UserAccountRole{UserAccountRole_User}, } - ua, err := Create(tests.Context(), tt.claims(userID, accountID), test.MasterDB, createReq, now) + ua, err := repo.Create(tests.Context(), tt.claims(userID, accountID), createReq, now) if err != nil && errors.Cause(err) != tt.createErr { t.Logf("\t\tGot : %+v", err) t.Logf("\t\tWant: %+v", tt.createErr) @@ -577,7 +577,7 @@ func TestCrud(t *testing.T) { } if tt.createErr == ErrForbidden { - ua, err = Create(tests.Context(), auth.Claims{}, test.MasterDB, createReq, now) + ua, err = repo.Create(tests.Context(), auth.Claims{}, createReq, now) if err != nil && errors.Cause(err) != tt.createErr { t.Logf("\t\tGot : %+v", err) t.Fatalf("\t%s\tCreate user account failed.", tests.Failed) @@ -590,7 +590,7 @@ func TestCrud(t *testing.T) { AccountID: accountID, Roles: &UserAccountRoles{UserAccountRole_Admin}, } - err = Update(tests.Context(), tt.claims(userID, accountID), test.MasterDB, updateReq, now) + err = repo.Update(tests.Context(), tt.claims(userID, accountID), updateReq, now) if err != nil { if errors.Cause(err) != tt.updateErr { t.Logf("\t\tGot : %+v", err) @@ -604,7 +604,7 @@ func TestCrud(t *testing.T) { // Find the account for the user to verify the updates where made. There should only // be one account associated with the user for this test. - findRes, err := Find(tests.Context(), tt.claims(userID, accountID), test.MasterDB, UserAccountFindRequest{ + findRes, err := repo.Find(tests.Context(), tt.claims(userID, accountID), UserAccountFindRequest{ Where: "user_id = ? or account_id = ?", Args: []interface{}{userID, accountID}, Order: []string{"created_at"}, @@ -632,7 +632,7 @@ func TestCrud(t *testing.T) { } // Archive (soft-delete) the user account. - err = Archive(tests.Context(), tt.claims(userID, accountID), test.MasterDB, UserAccountArchiveRequest{ + err = repo.Archive(tests.Context(), tt.claims(userID, accountID), UserAccountArchiveRequest{ UserID: userID, AccountID: accountID, }, now) @@ -642,7 +642,7 @@ func TestCrud(t *testing.T) { t.Fatalf("\t%s\tArchive user account failed.", tests.Failed) } else if tt.updateErr == nil { // Trying to find the archived user with the includeArchived false should result in not found. - _, err = FindByUserID(tests.Context(), tt.claims(userID, accountID), test.MasterDB, userID, false) + _, err = repo.FindByUserID(tests.Context(), tt.claims(userID, accountID), userID, false) if errors.Cause(err) != ErrNotFound { t.Logf("\t\tGot : %+v", err) t.Logf("\t\tWant: %+v", ErrNotFound) @@ -650,7 +650,7 @@ func TestCrud(t *testing.T) { } // Trying to find the archived user with the includeArchived true should result no error. - findRes, err = FindByUserID(tests.Context(), tt.claims(userID, accountID), test.MasterDB, userID, true) + findRes, err = repo.FindByUserID(tests.Context(), tt.claims(userID, accountID), userID, true) if err != nil { t.Logf("\t\tGot : %+v", err) t.Fatalf("\t%s\tVerify archive user account failed when including archived.", tests.Failed) @@ -675,7 +675,7 @@ func TestCrud(t *testing.T) { t.Logf("\t%s\tArchive user account ok.", tests.Success) // Delete (hard-delete) the user account. - err = Delete(tests.Context(), tt.claims(userID, accountID), test.MasterDB, UserAccountDeleteRequest{ + err = repo.Delete(tests.Context(), tt.claims(userID, accountID), UserAccountDeleteRequest{ UserID: userID, AccountID: accountID, }) @@ -685,7 +685,7 @@ func TestCrud(t *testing.T) { t.Fatalf("\t%s\tDelete user account failed.", tests.Failed) } else if tt.updateErr == nil { // Trying to find the deleted user with the includeArchived true should result in not found. - _, err = FindByUserID(tests.Context(), tt.claims(userID, accountID), test.MasterDB, userID, true) + _, err = repo.FindByUserID(tests.Context(), tt.claims(userID, accountID), userID, true) if errors.Cause(err) != ErrNotFound { t.Logf("\t\tGot : %+v", err) t.Logf("\t\tWant: %+v", ErrNotFound) @@ -725,7 +725,7 @@ func TestFind(t *testing.T) { } // Execute Create that will associate the user with the account. - ua, err := Create(tests.Context(), auth.Claims{}, test.MasterDB, UserAccountCreateRequest{ + ua, err := repo.Create(tests.Context(), auth.Claims{}, UserAccountCreateRequest{ UserID: userID, AccountID: accountID, Roles: []UserAccountRole{UserAccountRole_User}, @@ -836,7 +836,7 @@ func TestFind(t *testing.T) { { ctx := tests.Context() - res, err := Find(ctx, auth.Claims{}, test.MasterDB, tt.req) + res, err := repo.Find(ctx, auth.Claims{}, tt.req) if errors.Cause(err) != tt.error { t.Logf("\t\tGot : %+v", err) t.Logf("\t\tWant: %+v", tt.error) diff --git a/internal/user_auth/auth.go b/internal/user_auth/auth.go index 34bf1ca..4e90b10 100644 --- a/internal/user_auth/auth.go +++ b/internal/user_auth/auth.go @@ -3,7 +3,6 @@ package user_auth import ( "context" "database/sql" - "geeks-accelerator/oss/saas-starter-kit/internal/user_account" "strings" "time" @@ -11,8 +10,8 @@ import ( "geeks-accelerator/oss/saas-starter-kit/internal/platform/auth" "geeks-accelerator/oss/saas-starter-kit/internal/platform/web/webcontext" "geeks-accelerator/oss/saas-starter-kit/internal/user" + "geeks-accelerator/oss/saas-starter-kit/internal/user_account" "github.com/huandu/go-sqlbuilder" - "github.com/jmoiron/sqlx" "github.com/lib/pq" "github.com/pkg/errors" "golang.org/x/crypto/bcrypt" @@ -40,7 +39,7 @@ const ( // Authenticate finds a user by their email and verifies their password. On success // it returns a Token that can be used to authenticate access to the application in // the future. -func Authenticate(ctx context.Context, dbConn *sqlx.DB, tknGen TokenGenerator, req AuthenticateRequest, expires time.Duration, now time.Time, scopes ...string) (Token, error) { +func (repo *Repository) Authenticate(ctx context.Context, req AuthenticateRequest, expires time.Duration, now time.Time, scopes ...string) (Token, error) { span, ctx := tracer.StartSpanFromContext(ctx, "internal.user_auth.Authenticate") defer span.Finish() @@ -51,7 +50,7 @@ func Authenticate(ctx context.Context, dbConn *sqlx.DB, tknGen TokenGenerator, r return Token{}, err } - u, err := user.ReadByEmail(ctx, auth.Claims{}, dbConn, req.Email, false) + u, err := repo.User.ReadByEmail(ctx, auth.Claims{}, req.Email, false) if err != nil { if errors.Cause(err) == user.ErrNotFound { err = errors.WithStack(ErrAuthenticationFailure) @@ -73,11 +72,11 @@ func Authenticate(ctx context.Context, dbConn *sqlx.DB, tknGen TokenGenerator, r } // The user is successfully authenticated with the supplied email and password. - return generateToken(ctx, dbConn, tknGen, auth.Claims{}, u.ID, req.AccountID, expires, now, scopes...) + return repo.generateToken(ctx, auth.Claims{}, u.ID, req.AccountID, expires, now, scopes...) } // SwitchAccount allows users to switch between multiple accounts, this changes the claim audience. -func SwitchAccount(ctx context.Context, dbConn *sqlx.DB, tknGen TokenGenerator, claims auth.Claims, req SwitchAccountRequest, expires time.Duration, now time.Time, scopes ...string) (Token, error) { +func (repo *Repository) SwitchAccount(ctx context.Context, claims auth.Claims, req SwitchAccountRequest, expires time.Duration, now time.Time, scopes ...string) (Token, error) { span, ctx := tracer.StartSpanFromContext(ctx, "internal.user_auth.SwitchAccount") defer span.Finish() @@ -97,11 +96,11 @@ func SwitchAccount(ctx context.Context, dbConn *sqlx.DB, tknGen TokenGenerator, // Generate a token for the user ID in supplied in claims as the Subject. Pass // in the supplied claims as well to enforce ACLs when finding the current // list of accounts for the user. - return generateToken(ctx, dbConn, tknGen, claims, claims.Subject, req.AccountID, expires, now, scopes...) + return repo.generateToken(ctx, claims, claims.Subject, req.AccountID, expires, now, scopes...) } // VirtualLogin allows users to mock being logged in as other users. -func VirtualLogin(ctx context.Context, dbConn *sqlx.DB, tknGen TokenGenerator, claims auth.Claims, req VirtualLoginRequest, expires time.Duration, now time.Time, scopes ...string) (Token, error) { +func (repo *Repository) VirtualLogin(ctx context.Context, claims auth.Claims, req VirtualLoginRequest, expires time.Duration, now time.Time, scopes ...string) (Token, error) { span, ctx := tracer.StartSpanFromContext(ctx, "internal.user_auth.VirtualLogin") defer span.Finish() @@ -113,7 +112,7 @@ func VirtualLogin(ctx context.Context, dbConn *sqlx.DB, tknGen TokenGenerator, c } // Find all the accounts that the current user has access to. - usrAccs, err := user_account.FindByUserID(ctx, claims, dbConn, claims.Subject, false) + usrAccs, err := repo.UserAccount.FindByUserID(ctx, claims, claims.Subject, false) if err != nil { return Token{}, err } @@ -142,23 +141,23 @@ func VirtualLogin(ctx context.Context, dbConn *sqlx.DB, tknGen TokenGenerator, c // Generate a token for the user ID in supplied in claims as the Subject. Pass // in the supplied claims as well to enforce ACLs when finding the current // list of accounts for the user. - return generateToken(ctx, dbConn, tknGen, claims, req.UserID, req.AccountID, expires, now, scopes...) + return repo.generateToken(ctx, claims, req.UserID, req.AccountID, expires, now, scopes...) } // VirtualLogout allows switch back to their root user/account. -func VirtualLogout(ctx context.Context, dbConn *sqlx.DB, tknGen TokenGenerator, claims auth.Claims, expires time.Duration, now time.Time, scopes ...string) (Token, error) { +func (repo *Repository) VirtualLogout(ctx context.Context, claims auth.Claims, expires time.Duration, now time.Time, scopes ...string) (Token, error) { span, ctx := tracer.StartSpanFromContext(ctx, "internal.user_auth.VirtualLogout") defer span.Finish() // Generate a token for the user ID in supplied in claims as the Subject. Pass // in the supplied claims as well to enforce ACLs when finding the current // list of accounts for the user. - return generateToken(ctx, dbConn, tknGen, claims, claims.RootUserID, claims.RootAccountID, expires, now, scopes...) + return repo.generateToken(ctx, claims, claims.RootUserID, claims.RootAccountID, expires, now, scopes...) } // generateToken generates claims for the supplied user ID and account ID and then // returns the token for the generated claims used for authentication. -func generateToken(ctx context.Context, dbConn *sqlx.DB, tknGen TokenGenerator, claims auth.Claims, userID, accountID string, expires time.Duration, now time.Time, scopes ...string) (Token, error) { +func (repo *Repository) generateToken(ctx context.Context, claims auth.Claims, userID, accountID string, expires time.Duration, now time.Time, scopes ...string) (Token, error) { type userAccount struct { AccountID string @@ -184,8 +183,8 @@ func generateToken(ctx context.Context, dbConn *sqlx.DB, tknGen TokenGenerator, // fetch all places from the db queryStr, queryArgs := query.Build() - queryStr = dbConn.Rebind(queryStr) - rows, err := dbConn.QueryContext(ctx, queryStr, queryArgs...) + queryStr = repo.DbConn.Rebind(queryStr) + rows, err := repo.DbConn.QueryContext(ctx, queryStr, queryArgs...) if err != nil { err = errors.Wrapf(err, "query - %s", query.String()) return nil, err @@ -339,7 +338,7 @@ func generateToken(ctx context.Context, dbConn *sqlx.DB, tknGen TokenGenerator, tz, _ = time.LoadLocation(account.AccountTimezone.String) } - prefs, err := account_preference.FindByAccountID(ctx, auth.Claims{}, dbConn, account_preference.AccountPreferenceFindByAccountIDRequest{ + prefs, err := repo.AccountPreference.FindByAccountID(ctx, auth.Claims{}, account_preference.AccountPreferenceFindByAccountIDRequest{ AccountID: accountID, }) if err != nil { @@ -393,7 +392,7 @@ func generateToken(ctx context.Context, dbConn *sqlx.DB, tknGen TokenGenerator, newClaims.RootUserID = claims.RootUserID // Generate a token for the user with the defined claims. - tknStr, err := tknGen.GenerateToken(newClaims) + tknStr, err := repo.TknGen.GenerateToken(newClaims) if err != nil { return Token{}, errors.Wrap(err, "generating token") } diff --git a/internal/user_auth/auth_test.go b/internal/user_auth/auth_test.go index db837b6..e7fb61a 100644 --- a/internal/user_auth/auth_test.go +++ b/internal/user_auth/auth_test.go @@ -8,8 +8,8 @@ import ( "time" "geeks-accelerator/oss/saas-starter-kit/internal/account" + "geeks-accelerator/oss/saas-starter-kit/internal/account/account_preference" "geeks-accelerator/oss/saas-starter-kit/internal/platform/auth" - "geeks-accelerator/oss/saas-starter-kit/internal/platform/notify" "geeks-accelerator/oss/saas-starter-kit/internal/platform/tests" "geeks-accelerator/oss/saas-starter-kit/internal/user" "geeks-accelerator/oss/saas-starter-kit/internal/user_account" @@ -18,7 +18,10 @@ import ( "github.com/pkg/errors" ) -var test *tests.Test +var ( + test *tests.Test + repo *Repository +) // TestMain is the entry point for testing. func TestMain(m *testing.M) { @@ -28,6 +31,15 @@ func TestMain(m *testing.M) { func testMain(m *testing.M) int { test = tests.New() defer test.TearDown() + + tknGen := &auth.MockTokenGenerator{} + + userRepo := user.MockRepository(test.MasterDB) + userAccRepo := user_account.NewRepository(test.MasterDB) + accPrefRepo := account_preference.NewRepository(test.MasterDB) + + repo = NewRepository(test.MasterDB, tknGen, userRepo, userAccRepo, accPrefRepo) + return m.Run() } @@ -41,14 +53,12 @@ func TestAuthenticate(t *testing.T) { { ctx := tests.Context() - tknGen := &auth.MockTokenGenerator{} - // Auth tokens are valid for an our and is verified against current time. // Issue the token one hour ago. now := time.Now().Add(time.Hour * -1) // Try to authenticate an invalid user. - _, err := Authenticate(ctx, test.MasterDB, tknGen, + _, err := repo.Authenticate(ctx, AuthenticateRequest{ Email: "doesnotexist@gmail.com", Password: "xy7", @@ -82,7 +92,7 @@ func TestAuthenticate(t *testing.T) { // is always greater than the first user_account entry created so it will // be returned consistently back in the same order, last. account2Role := auth.RoleUser - _, err = user_account.Create(ctx, auth.Claims{}, test.MasterDB, user_account.UserAccountCreateRequest{ + _, err = repo.UserAccount.Create(ctx, auth.Claims{}, user_account.UserAccountCreateRequest{ UserID: usrAcc.UserID, AccountID: acc2.ID, Roles: []user_account.UserAccountRole{user_account.UserAccountRole(account2Role)}, @@ -92,7 +102,7 @@ func TestAuthenticate(t *testing.T) { now = now.Add(time.Minute * 5) // Try to authenticate valid user with invalid password. - _, err = Authenticate(ctx, test.MasterDB, tknGen, + _, err = repo.Authenticate(ctx, AuthenticateRequest{ Email: usrAcc.User.Email, Password: "xy7", @@ -106,7 +116,7 @@ func TestAuthenticate(t *testing.T) { t.Logf("\t%s\tAuthenticate user w/invalid password ok.", tests.Success) // Verify that the user can be authenticated with the created user. - tkn1, err := Authenticate(ctx, test.MasterDB, tknGen, + tkn1, err := repo.Authenticate(ctx, AuthenticateRequest{ Email: usrAcc.User.Email, Password: usrAcc.User.Password, @@ -118,7 +128,7 @@ func TestAuthenticate(t *testing.T) { t.Logf("\t%s\tAuthenticate user ok.", tests.Success) // Ensure the token string was correctly generated. - claims1, err := tknGen.ParseClaims(tkn1.AccessToken) + claims1, err := repo.TknGen.ParseClaims(tkn1.AccessToken) if err != nil { t.Log("\t\tGot :", err) t.Fatalf("\t%s\tParse claims from token failed.", tests.Failed) @@ -135,7 +145,7 @@ func TestAuthenticate(t *testing.T) { t.Logf("\t%s\tAuthenticate parse claims from token ok.", tests.Success) // Try switching to a second account using the first set of claims. - tkn2, err := SwitchAccount(ctx, test.MasterDB, tknGen, claims1, + tkn2, err := repo.SwitchAccount(ctx, claims1, SwitchAccountRequest{AccountID: acc2.ID}, time.Hour, now) if err != nil { t.Log("\t\tGot :", err) @@ -144,7 +154,7 @@ func TestAuthenticate(t *testing.T) { t.Logf("\t%s\tSwitchAccount user ok.", tests.Success) // Ensure the token string was correctly generated. - claims2, err := tknGen.ParseClaims(tkn2.AccessToken) + claims2, err := repo.TknGen.ParseClaims(tkn2.AccessToken) if err != nil { t.Log("\t\tGot :", err) t.Fatalf("\t%s\tParse claims from token failed.", tests.Failed) @@ -172,8 +182,6 @@ func TestUserUpdatePassword(t *testing.T) { now := time.Date(2018, time.October, 1, 0, 0, 0, 0, time.UTC) - tknGen := &auth.MockTokenGenerator{} - // Create a new user for testing. usrAcc, err := user_account.MockUserAccount(ctx, test.MasterDB, now, user_account.UserAccountRole_User) if err != nil { @@ -183,7 +191,7 @@ func TestUserUpdatePassword(t *testing.T) { t.Logf("\t%s\tCreate user account ok.", tests.Success) // Verify that the user can be authenticated with the created user. - _, err = Authenticate(ctx, test.MasterDB, tknGen, + _, err = repo.Authenticate(ctx, AuthenticateRequest{ Email: usrAcc.User.Email, Password: usrAcc.User.Password, @@ -195,7 +203,7 @@ func TestUserUpdatePassword(t *testing.T) { // Update the users password. newPass := uuid.NewRandom().String() - err = user.UpdatePassword(ctx, auth.Claims{}, test.MasterDB, user.UserUpdatePasswordRequest{ + err = repo.User.UpdatePassword(ctx, auth.Claims{}, user.UserUpdatePasswordRequest{ ID: usrAcc.UserID, Password: newPass, PasswordConfirm: newPass, @@ -207,7 +215,7 @@ func TestUserUpdatePassword(t *testing.T) { t.Logf("\t%s\tUpdatePassword ok.", tests.Success) // Verify that the user can be authenticated with the updated password. - _, err = Authenticate(ctx, test.MasterDB, tknGen, + _, err = repo.Authenticate(ctx, AuthenticateRequest{ Email: usrAcc.User.Email, Password: newPass, @@ -229,8 +237,6 @@ func TestUserResetPassword(t *testing.T) { now := time.Date(2018, time.October, 1, 0, 0, 0, 0, time.UTC) - tknGen := &auth.MockTokenGenerator{} - // Create a new user for testing. usrAcc, err := user_account.MockUserAccount(ctx, test.MasterDB, now, user_account.UserAccountRole_User) if err != nil { @@ -239,21 +245,13 @@ func TestUserResetPassword(t *testing.T) { } t.Logf("\t%s\tCreate user account ok.", tests.Success) - // Mock the methods needed to make a password reset. - resetUrl := func(string) string { - return "" - } - notify := ¬ify.MockEmail{} - - secretKey := "6368616e676520746869732070617373" - ttl := time.Hour // Make the reset password request. - resetHash, err := user.ResetPassword(ctx, test.MasterDB, resetUrl, notify, user.UserResetPasswordRequest{ + resetHash, err := repo.User.ResetPassword(ctx, user.UserResetPasswordRequest{ Email: usrAcc.User.Email, TTL: ttl, - }, secretKey, now) + }, now) if err != nil { t.Log("\t\tGot :", err) t.Fatalf("\t%s\tResetPassword failed.", tests.Failed) @@ -262,11 +260,11 @@ func TestUserResetPassword(t *testing.T) { // Assuming we have received the email and clicked the link, we now can ensure confirm works. newPass := uuid.NewRandom().String() - reset, err := user.ResetConfirm(ctx, test.MasterDB, user.UserResetConfirmRequest{ + reset, err := repo.User.ResetConfirm(ctx, user.UserResetConfirmRequest{ ResetHash: resetHash, Password: newPass, PasswordConfirm: newPass, - }, secretKey, now) + }, now) if err != nil { t.Log("\t\tGot :", err) t.Fatalf("\t%s\tResetConfirm failed.", tests.Failed) @@ -278,7 +276,7 @@ func TestUserResetPassword(t *testing.T) { t.Logf("\t%s\tResetConfirm ok.", tests.Success) // Verify that the user can be authenticated with the updated password. - _, err = Authenticate(ctx, test.MasterDB, tknGen, + _, err = repo.Authenticate(ctx, AuthenticateRequest{ Email: usrAcc.User.Email, Password: newPass, @@ -340,7 +338,7 @@ func TestSwitchAccount(t *testing.T) { } // Associate the second account with root user. - usrAcc2, err := user_account.Create(ctx, auth.Claims{}, test.MasterDB, user_account.UserAccountCreateRequest{ + usrAcc2, err := repo.UserAccount.Create(ctx, auth.Claims{}, user_account.UserAccountCreateRequest{ UserID: usrAcc.UserID, AccountID: acc2.ID, Roles: []user_account.UserAccountRole{user_account.UserAccountRole(roles[1])}, @@ -359,7 +357,7 @@ func TestSwitchAccount(t *testing.T) { } // Associate the third account with root user. - usrAcc3, err := user_account.Create(ctx, auth.Claims{}, test.MasterDB, user_account.UserAccountCreateRequest{ + usrAcc3, err := repo.UserAccount.Create(ctx, auth.Claims{}, user_account.UserAccountCreateRequest{ UserID: usrAcc.UserID, AccountID: acc3.ID, Roles: []user_account.UserAccountRole{user_account.UserAccountRole(roles[2])}, @@ -426,7 +424,7 @@ func TestSwitchAccount(t *testing.T) { } // Associate the second account with root user. - usrAcc2, err := user_account.Create(ctx, auth.Claims{}, test.MasterDB, user_account.UserAccountCreateRequest{ + usrAcc2, err := repo.UserAccount.Create(ctx, auth.Claims{}, user_account.UserAccountCreateRequest{ UserID: usrAcc.UserID, AccountID: acc2.ID, Roles: []user_account.UserAccountRole{user_account.UserAccountRole_Admin}, @@ -445,7 +443,7 @@ func TestSwitchAccount(t *testing.T) { } // Associate the third account with root user. - usrAcc3, err := user_account.Create(ctx, auth.Claims{}, test.MasterDB, user_account.UserAccountCreateRequest{ + usrAcc3, err := repo.UserAccount.Create(ctx, auth.Claims{}, user_account.UserAccountCreateRequest{ UserID: usrAcc.UserID, AccountID: acc3.ID, Roles: []user_account.UserAccountRole{user_account.UserAccountRole_User}, @@ -472,8 +470,6 @@ func TestSwitchAccount(t *testing.T) { // Add 30 minutes to now to simulate time passing. now = now.Add(time.Minute * 5) - tknGen := &auth.MockTokenGenerator{} - t.Log("Given the need to switch accounts.") { for i, authTest := range authTests { @@ -481,7 +477,7 @@ func TestSwitchAccount(t *testing.T) { { // Verify that the user can be authenticated with the created user. var claims1 auth.Claims - tkn1, err := Authenticate(ctx, test.MasterDB, tknGen, + tkn1, err := repo.Authenticate(ctx, AuthenticateRequest{ Email: authTest.root.User.Email, Password: authTest.root.User.Password, @@ -491,7 +487,7 @@ func TestSwitchAccount(t *testing.T) { t.Fatalf("\t%s\tAuthenticate user failed.", tests.Failed) } else { // Ensure the token string was correctly generated. - claims1, err = tknGen.ParseClaims(tkn1.AccessToken) + claims1, err = repo.TknGen.ParseClaims(tkn1.AccessToken) if err != nil { t.Log("\t\tGot :", err) t.Fatalf("\t%s\tParse claims from token failed.", tests.Failed) @@ -511,7 +507,7 @@ func TestSwitchAccount(t *testing.T) { // Try to switch to account 2. var claims2 auth.Claims - tkn2, err := SwitchAccount(ctx, test.MasterDB, tknGen, claims1, authTest.switch1Req, time.Hour, now, authTest.switch1Scopes...) + tkn2, err := repo.SwitchAccount(ctx, claims1, authTest.switch1Req, time.Hour, now, authTest.switch1Scopes...) if err != authTest.switch1Err { if errors.Cause(err) != authTest.switch1Err { t.Log("\t\tExpected :", authTest.switch1Err) @@ -520,7 +516,7 @@ func TestSwitchAccount(t *testing.T) { } } else { // Ensure the token string was correctly generated. - claims2, err = tknGen.ParseClaims(tkn2.AccessToken) + claims2, err = repo.TknGen.ParseClaims(tkn2.AccessToken) if err != nil { t.Log("\t\tGot :", err) t.Fatalf("\t%s\tParse claims from token failed.", tests.Failed) @@ -549,7 +545,7 @@ func TestSwitchAccount(t *testing.T) { } // Try to switch to account 3. - tkn3, err := SwitchAccount(ctx, test.MasterDB, tknGen, claims2, authTest.switch2Req, time.Hour, now, authTest.switch2Scopes...) + tkn3, err := repo.SwitchAccount(ctx, claims2, authTest.switch2Req, time.Hour, now, authTest.switch2Scopes...) if err != authTest.switch2Err { if errors.Cause(err) != authTest.switch2Err { t.Log("\t\tExpected :", authTest.switch2Err) @@ -558,7 +554,7 @@ func TestSwitchAccount(t *testing.T) { } } else { // Ensure the token string was correctly generated. - claims3, err := tknGen.ParseClaims(tkn3.AccessToken) + claims3, err := repo.TknGen.ParseClaims(tkn3.AccessToken) if err != nil { t.Log("\t\tGot :", err) t.Fatalf("\t%s\tParse claims from token failed.", tests.Failed) @@ -610,7 +606,7 @@ func TestVirtualLogin(t *testing.T) { var authTests []authTest // Root admin -> role admin -> role admin - if true { + { // Create a new user for testing. usrAcc, err := user_account.MockUserAccount(ctx, test.MasterDB, now, user_account.UserAccountRole_Admin) if err != nil { @@ -625,7 +621,7 @@ func TestVirtualLogin(t *testing.T) { } // Associate second user with basic role associated with the same account. - usrAcc2, err := user_account.Create(ctx, auth.Claims{}, test.MasterDB, user_account.UserAccountCreateRequest{ + usrAcc2, err := repo.UserAccount.Create(ctx, auth.Claims{}, user_account.UserAccountCreateRequest{ UserID: usr2.ID, AccountID: usrAcc.AccountID, Roles: []user_account.UserAccountRole{user_account.UserAccountRole(user_account.UserAccountRole_Admin)}, @@ -642,7 +638,7 @@ func TestVirtualLogin(t *testing.T) { } // Associate second user with basic role associated with the same account. - usrAcc3, err := user_account.Create(ctx, auth.Claims{}, test.MasterDB, user_account.UserAccountCreateRequest{ + usrAcc3, err := repo.UserAccount.Create(ctx, auth.Claims{}, user_account.UserAccountCreateRequest{ UserID: usr3.ID, AccountID: usrAcc.AccountID, Roles: []user_account.UserAccountRole{user_account.UserAccountRole(user_account.UserAccountRole_Admin)}, @@ -687,7 +683,7 @@ func TestVirtualLogin(t *testing.T) { } // Associate second user with basic role associated with the same account. - usrAcc2, err := user_account.Create(ctx, auth.Claims{}, test.MasterDB, user_account.UserAccountCreateRequest{ + usrAcc2, err := repo.UserAccount.Create(ctx, auth.Claims{}, user_account.UserAccountCreateRequest{ UserID: usr2.ID, AccountID: usrAcc.AccountID, Roles: []user_account.UserAccountRole{user_account.UserAccountRole(user_account.UserAccountRole_Admin)}, @@ -704,7 +700,7 @@ func TestVirtualLogin(t *testing.T) { } // Associate second user with basic role associated with the same account. - usrAcc3, err := user_account.Create(ctx, auth.Claims{}, test.MasterDB, user_account.UserAccountCreateRequest{ + usrAcc3, err := repo.UserAccount.Create(ctx, auth.Claims{}, user_account.UserAccountCreateRequest{ UserID: usr3.ID, AccountID: usrAcc.AccountID, Roles: []user_account.UserAccountRole{user_account.UserAccountRole(user_account.UserAccountRole_User)}, @@ -749,7 +745,7 @@ func TestVirtualLogin(t *testing.T) { } // Associate second user with basic role associated with the same account. - usrAcc2, err := user_account.Create(ctx, auth.Claims{}, test.MasterDB, user_account.UserAccountCreateRequest{ + usrAcc2, err := repo.UserAccount.Create(ctx, auth.Claims{}, user_account.UserAccountCreateRequest{ UserID: usr2.ID, AccountID: usrAcc.AccountID, Roles: []user_account.UserAccountRole{user_account.UserAccountRole(user_account.UserAccountRole_User)}, @@ -766,7 +762,7 @@ func TestVirtualLogin(t *testing.T) { } // Associate second user with basic role associated with the same account. - usrAcc3, err := user_account.Create(ctx, auth.Claims{}, test.MasterDB, user_account.UserAccountCreateRequest{ + usrAcc3, err := repo.UserAccount.Create(ctx, auth.Claims{}, user_account.UserAccountCreateRequest{ UserID: usr3.ID, AccountID: usrAcc.AccountID, Roles: []user_account.UserAccountRole{user_account.UserAccountRole(user_account.UserAccountRole_Admin)}, @@ -811,7 +807,7 @@ func TestVirtualLogin(t *testing.T) { } // Associate second user with basic role associated with the same account. - usrAcc2, err := user_account.Create(ctx, auth.Claims{}, test.MasterDB, user_account.UserAccountCreateRequest{ + usrAcc2, err := repo.UserAccount.Create(ctx, auth.Claims{}, user_account.UserAccountCreateRequest{ UserID: usr2.ID, AccountID: usrAcc.AccountID, Roles: []user_account.UserAccountRole{user_account.UserAccountRole(user_account.UserAccountRole_Admin)}, @@ -850,7 +846,7 @@ func TestVirtualLogin(t *testing.T) { } // Associate second user with basic role associated with the same account. - usrAcc2, err := user_account.Create(ctx, auth.Claims{}, test.MasterDB, user_account.UserAccountCreateRequest{ + usrAcc2, err := repo.UserAccount.Create(ctx, auth.Claims{}, user_account.UserAccountCreateRequest{ UserID: usr2.ID, AccountID: usrAcc.AccountID, Roles: []user_account.UserAccountRole{user_account.UserAccountRole(user_account.UserAccountRole_User)}, @@ -876,8 +872,6 @@ func TestVirtualLogin(t *testing.T) { // Add 30 minutes to now to simulate time passing. now = now.Add(time.Minute * 5) - tknGen := &auth.MockTokenGenerator{} - t.Log("Given the need to virtual login.") { for i, authTest := range authTests { @@ -885,7 +879,7 @@ func TestVirtualLogin(t *testing.T) { { // Verify that the user can be authenticated with the created user. var claims1 auth.Claims - tkn1, err := Authenticate(ctx, test.MasterDB, tknGen, + tkn1, err := repo.Authenticate(ctx, AuthenticateRequest{ Email: authTest.root.User.Email, Password: authTest.root.User.Password, @@ -895,7 +889,7 @@ func TestVirtualLogin(t *testing.T) { t.Fatalf("\t%s\tAuthenticate user failed.", tests.Failed) } else { // Ensure the token string was correctly generated. - claims1, err = tknGen.ParseClaims(tkn1.AccessToken) + claims1, err = repo.TknGen.ParseClaims(tkn1.AccessToken) if err != nil { t.Log("\t\tGot :", err) t.Fatalf("\t%s\tParse claims from token failed.", tests.Failed) @@ -915,7 +909,7 @@ func TestVirtualLogin(t *testing.T) { // Try virtual login to user 2. var claims2 auth.Claims - tkn2, err := VirtualLogin(ctx, test.MasterDB, tknGen, claims1, authTest.login1Req, time.Hour, now) + tkn2, err := repo.VirtualLogin(ctx, claims1, authTest.login1Req, time.Hour, now) if err != authTest.login1Err { if errors.Cause(err) != authTest.login1Err { t.Log("\t\tExpected :", authTest.login1Err) @@ -924,7 +918,7 @@ func TestVirtualLogin(t *testing.T) { } } else { // Ensure the token string was correctly generated. - claims2, err = tknGen.ParseClaims(tkn2.AccessToken) + claims2, err = repo.TknGen.ParseClaims(tkn2.AccessToken) if err != nil { t.Log("\t\tGot :", err) t.Fatalf("\t%s\tParse claims from token failed.", tests.Failed) @@ -948,7 +942,7 @@ func TestVirtualLogin(t *testing.T) { } // Try virtual login to user 3. - tkn3, err := VirtualLogin(ctx, test.MasterDB, tknGen, claims2, authTest.login2Req, time.Hour, now) + tkn3, err := repo.VirtualLogin(ctx, claims2, authTest.login2Req, time.Hour, now) if err != authTest.login2Err { if errors.Cause(err) != authTest.login2Err { t.Log("\t\tExpected :", authTest.login2Err) @@ -957,7 +951,7 @@ func TestVirtualLogin(t *testing.T) { } } else { // Ensure the token string was correctly generated. - claims3, err := tknGen.ParseClaims(tkn3.AccessToken) + claims3, err := repo.TknGen.ParseClaims(tkn3.AccessToken) if err != nil { t.Log("\t\tGot :", err) t.Fatalf("\t%s\tParse claims from token failed.", tests.Failed) @@ -976,14 +970,14 @@ func TestVirtualLogin(t *testing.T) { t.Logf("\t%s\tVirtualLogin user 2 with role %s ok.", tests.Success, authTest.login2Role) if authTest.login2Logout { - tknOut, err := VirtualLogout(ctx, test.MasterDB, tknGen, claims2, time.Hour, now) + tknOut, err := repo.VirtualLogout(ctx, claims2, time.Hour, now) if err != nil { t.Log("\t\tGot :", err) t.Fatalf("\t%s\tVirtualLogout user 2 failed.", tests.Failed) } // Ensure the token string was correctly generated. - claimsOut, err := tknGen.ParseClaims(tknOut.AccessToken) + claimsOut, err := repo.TknGen.ParseClaims(tknOut.AccessToken) if err != nil { t.Log("\t\tGot :", err) t.Fatalf("\t%s\tParse claims from token failed.", tests.Failed) diff --git a/internal/user_auth/models.go b/internal/user_auth/models.go index 5990253..12e41d2 100644 --- a/internal/user_auth/models.go +++ b/internal/user_auth/models.go @@ -3,9 +3,33 @@ package user_auth import ( "time" + "geeks-accelerator/oss/saas-starter-kit/internal/account/account_preference" "geeks-accelerator/oss/saas-starter-kit/internal/platform/auth" + "geeks-accelerator/oss/saas-starter-kit/internal/user" + "geeks-accelerator/oss/saas-starter-kit/internal/user_account" + "github.com/jmoiron/sqlx" ) +// Repository defines the required dependencies for User Auth. +type Repository struct { + DbConn *sqlx.DB + TknGen TokenGenerator + User *user.Repository + UserAccount *user_account.Repository + AccountPreference *account_preference.Repository +} + +// NewRepository creates a new Repository that defines dependencies for User Auth. +func NewRepository(db *sqlx.DB, tknGen TokenGenerator, user *user.Repository, usrAcc *user_account.Repository, accPref *account_preference.Repository) *Repository { + return &Repository{ + DbConn: db, + TknGen: tknGen, + User: user, + UserAccount: usrAcc, + AccountPreference: accPref, + } +} + // AuthenticateRequest defines what information is required to authenticate a user. type AuthenticateRequest struct { Email string `json:"email" validate:"required,email" example:"gabi.may@geeksinthewoods.com"`