1
0
mirror of https://github.com/raseels-repos/golang-saas-starter-kit.git synced 2025-06-17 00:17:59 +02:00

Completed updating biz logic packages to use repository pattern

This commit is contained in:
Lee Brown
2019-08-14 11:40:26 -08:00
parent 3bc814a01e
commit e45dd56149
25 changed files with 530 additions and 353 deletions

View File

@ -64,6 +64,11 @@ func CanReadAccount(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, ac
return nil return nil
} }
// CanReadAccount determines if claims has the authority to access the specified account ID.
func (repo *Repository) CanReadAccount(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, accountID string) error {
return repo.CanReadAccount(ctx, claims, repo.DbConn, accountID)
}
// CanModifyAccount determines if claims has the authority to modify the specified account ID. // CanModifyAccount determines if claims has the authority to modify the specified account ID.
func CanModifyAccount(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, accountID string) error { func CanModifyAccount(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, accountID string) error {
// If the request has claims from a specific account, ensure that the claims // If the request has claims from a specific account, ensure that the claims
@ -105,6 +110,11 @@ func CanModifyAccount(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB,
return nil return nil
} }
// CanModifyAccount determines if claims has the authority to modify the specified account ID.
func (repo *Repository) CanModifyAccount(ctx context.Context, claims auth.Claims, accountID string) error {
return CanModifyAccount(ctx, claims, repo.DbConn, accountID)
}
// applyClaimsSelect applies a sub-query to the provided query to enforce ACL based on // applyClaimsSelect applies a sub-query to the provided query to enforce ACL based on
// the claims provided. // the claims provided.
// 1. All role types can access their user ID // 1. All role types can access their user ID
@ -150,7 +160,7 @@ func selectQuery() *sqlbuilder.SelectBuilder {
// Find gets all the accounts from the database based on the request params. // Find gets all the accounts from the database based on the request params.
// TODO: Need to figure out why can't parse the args when appending the where // TODO: Need to figure out why can't parse the args when appending the where
// to the query. // to the query.
func Find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req AccountFindRequest) (Accounts, error) { func (repo *Repository) Find(ctx context.Context, claims auth.Claims, req AccountFindRequest) (Accounts, error) {
query := selectQuery() query := selectQuery()
if req.Where != "" { if req.Where != "" {
@ -166,7 +176,7 @@ func Find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req AccountF
query.Offset(int(*req.Offset)) query.Offset(int(*req.Offset))
} }
return find(ctx, claims, dbConn, query, req.Args, req.IncludeArchived) return find(ctx, claims, repo.DbConn, query, req.Args, req.IncludeArchived)
} }
// find internal method for getting all the accounts from the database using a select query. // find internal method for getting all the accounts from the database using a select query.
@ -242,14 +252,14 @@ func UniqueName(ctx context.Context, dbConn *sqlx.DB, name, accountId string) (b
} }
// Create inserts a new account into the database. // Create inserts a new account into the database.
func Create(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req AccountCreateRequest, now time.Time) (*Account, error) { func (repo *Repository) Create(ctx context.Context, claims auth.Claims, req AccountCreateRequest, now time.Time) (*Account, error) {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.account.Create") span, ctx := tracer.StartSpanFromContext(ctx, "internal.account.Create")
defer span.Finish() defer span.Finish()
v := webcontext.Validator() v := webcontext.Validator()
// Validation account name is unique in the database. // Validation account name is unique in the database.
uniq, err := UniqueName(ctx, dbConn, req.Name, "") uniq, err := UniqueName(ctx, repo.DbConn, req.Name, "")
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -310,8 +320,8 @@ func Create(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req Accoun
// Execute the query with the provided context. // Execute the query with the provided context.
sql, args := query.Build() sql, args := query.Build()
sql = dbConn.Rebind(sql) sql = repo.DbConn.Rebind(sql)
_, err = dbConn.ExecContext(ctx, sql, args...) _, err = repo.DbConn.ExecContext(ctx, sql, args...)
if err != nil { if err != nil {
err = errors.Wrapf(err, "query - %s", query.String()) err = errors.Wrapf(err, "query - %s", query.String())
err = errors.WithMessage(err, "create account failed") err = errors.WithMessage(err, "create account failed")
@ -322,15 +332,15 @@ func Create(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req Accoun
} }
// ReadByID gets the specified user by ID from the database. // ReadByID gets the specified user by ID from the database.
func ReadByID(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, id string) (*Account, error) { func (repo *Repository) ReadByID(ctx context.Context, claims auth.Claims, id string) (*Account, error) {
return Read(ctx, claims, dbConn, AccountReadRequest{ return repo.Read(ctx, claims, AccountReadRequest{
ID: id, ID: id,
IncludeArchived: false, IncludeArchived: false,
}) })
} }
// Read gets the specified account from the database. // Read gets the specified account from the database.
func Read(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req AccountReadRequest) (*Account, error) { func (repo *Repository) Read(ctx context.Context, claims auth.Claims, req AccountReadRequest) (*Account, error) {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.account.Read") span, ctx := tracer.StartSpanFromContext(ctx, "internal.account.Read")
defer span.Finish() defer span.Finish()
@ -345,7 +355,7 @@ func Read(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req AccountR
query := sqlbuilder.NewSelectBuilder() query := sqlbuilder.NewSelectBuilder()
query.Where(query.Equal("id", req.ID)) query.Where(query.Equal("id", req.ID))
res, err := find(ctx, claims, dbConn, query, []interface{}{}, req.IncludeArchived) res, err := find(ctx, claims, repo.DbConn, query, []interface{}{}, req.IncludeArchived)
if err != nil { if err != nil {
return nil, err return nil, err
} else if res == nil || len(res) == 0 { } else if res == nil || len(res) == 0 {
@ -358,7 +368,7 @@ func Read(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req AccountR
} }
// Update replaces an account in the database. // Update replaces an account in the database.
func Update(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req AccountUpdateRequest, now time.Time) error { func (repo *Repository) Update(ctx context.Context, claims auth.Claims, req AccountUpdateRequest, now time.Time) error {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.account.Update") span, ctx := tracer.StartSpanFromContext(ctx, "internal.account.Update")
defer span.Finish() defer span.Finish()
@ -366,7 +376,7 @@ func Update(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req Accoun
if req.Name != nil { if req.Name != nil {
// Validation account name is unique in the database. // Validation account name is unique in the database.
uniq, err := UniqueName(ctx, dbConn, *req.Name, req.ID) uniq, err := UniqueName(ctx, repo.DbConn, *req.Name, req.ID)
if err != nil { if err != nil {
return err return err
} }
@ -382,7 +392,7 @@ func Update(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req Accoun
} }
// Ensure the claims can modify the account specified in the request. // Ensure the claims can modify the account specified in the request.
err = CanModifyAccount(ctx, claims, dbConn, req.ID) err = CanModifyAccount(ctx, claims, repo.DbConn, req.ID)
if err != nil { if err != nil {
return err return err
} }
@ -460,8 +470,8 @@ func Update(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req Accoun
// Execute the query with the provided context. // Execute the query with the provided context.
sql, args := query.Build() sql, args := query.Build()
sql = dbConn.Rebind(sql) sql = repo.DbConn.Rebind(sql)
_, err = dbConn.ExecContext(ctx, sql, args...) _, err = repo.DbConn.ExecContext(ctx, sql, args...)
if err != nil { if err != nil {
err = errors.Wrapf(err, "query - %s", query.String()) err = errors.Wrapf(err, "query - %s", query.String())
err = errors.WithMessagef(err, "update account %s failed", req.ID) err = errors.WithMessagef(err, "update account %s failed", req.ID)
@ -472,7 +482,7 @@ func Update(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req Accoun
} }
// Archive soft deleted the account from the database. // Archive soft deleted the account from the database.
func Archive(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req AccountArchiveRequest, now time.Time) error { func (repo *Repository) Archive(ctx context.Context, claims auth.Claims, req AccountArchiveRequest, now time.Time) error {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.account.Archive") span, ctx := tracer.StartSpanFromContext(ctx, "internal.account.Archive")
defer span.Finish() defer span.Finish()
@ -484,7 +494,7 @@ func Archive(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req Accou
} }
// Ensure the claims can modify the account specified in the request. // Ensure the claims can modify the account specified in the request.
err = CanModifyAccount(ctx, claims, dbConn, req.ID) err = CanModifyAccount(ctx, claims, repo.DbConn, req.ID)
if err != nil { if err != nil {
return err return err
} }
@ -511,8 +521,8 @@ func Archive(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req Accou
// Execute the query with the provided context. // Execute the query with the provided context.
sql, args := query.Build() sql, args := query.Build()
sql = dbConn.Rebind(sql) sql = repo.DbConn.Rebind(sql)
_, err = dbConn.ExecContext(ctx, sql, args...) _, err = repo.DbConn.ExecContext(ctx, sql, args...)
if err != nil { if err != nil {
err = errors.Wrapf(err, "query - %s", query.String()) err = errors.Wrapf(err, "query - %s", query.String())
err = errors.WithMessagef(err, "archive account %s failed", req.ID) err = errors.WithMessagef(err, "archive account %s failed", req.ID)
@ -531,8 +541,8 @@ func Archive(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req Accou
// Execute the query with the provided context. // Execute the query with the provided context.
sql, args := query.Build() sql, args := query.Build()
sql = dbConn.Rebind(sql) sql = repo.DbConn.Rebind(sql)
_, err = dbConn.ExecContext(ctx, sql, args...) _, err = repo.DbConn.ExecContext(ctx, sql, args...)
if err != nil { if err != nil {
err = errors.Wrapf(err, "query - %s", query.String()) err = errors.Wrapf(err, "query - %s", query.String())
err = errors.WithMessagef(err, "archive users for account %s failed", req.ID) err = errors.WithMessagef(err, "archive users for account %s failed", req.ID)
@ -544,7 +554,7 @@ func Archive(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req Accou
} }
// Delete removes an account from the database. // Delete removes an account from the database.
func Delete(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req AccountDeleteRequest) error { func (repo *Repository) Delete(ctx context.Context, claims auth.Claims, req AccountDeleteRequest) error {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.account.Delete") span, ctx := tracer.StartSpanFromContext(ctx, "internal.account.Delete")
defer span.Finish() defer span.Finish()
@ -556,13 +566,13 @@ func Delete(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req Accoun
} }
// Ensure the claims can modify the account specified in the request. // Ensure the claims can modify the account specified in the request.
err = CanModifyAccount(ctx, claims, dbConn, req.ID) err = CanModifyAccount(ctx, claims, repo.DbConn, req.ID)
if err != nil { if err != nil {
return err return err
} }
// Start a new transaction to handle rollbacks on error. // Start a new transaction to handle rollbacks on error.
tx, err := dbConn.Begin() tx, err := repo.DbConn.Begin()
if err != nil { if err != nil {
return errors.WithStack(err) return errors.WithStack(err)
} }
@ -579,7 +589,7 @@ func Delete(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req Accoun
// Execute the query with the provided context. // Execute the query with the provided context.
sql, args := query.Build() sql, args := query.Build()
sql = dbConn.Rebind(sql) sql = repo.DbConn.Rebind(sql)
_, err = tx.ExecContext(ctx, sql, args...) _, err = tx.ExecContext(ctx, sql, args...)
if err != nil { if err != nil {
tx.Rollback() tx.Rollback()
@ -602,7 +612,7 @@ func Delete(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req Accoun
// Execute the query with the provided context. // Execute the query with the provided context.
sql, args := query.Build() sql, args := query.Build()
sql = dbConn.Rebind(sql) sql = repo.DbConn.Rebind(sql)
_, err = tx.ExecContext(ctx, sql, args...) _, err = tx.ExecContext(ctx, sql, args...)
if err != nil { if err != nil {
tx.Rollback() tx.Rollback()
@ -620,7 +630,7 @@ func Delete(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req Accoun
// Execute the query with the provided context. // Execute the query with the provided context.
sql, args := query.Build() sql, args := query.Build()
sql = dbConn.Rebind(sql) sql = repo.DbConn.Rebind(sql)
_, err = tx.ExecContext(ctx, sql, args...) _, err = tx.ExecContext(ctx, sql, args...)
if err != nil { if err != nil {
tx.Rollback() tx.Rollback()
@ -642,6 +652,10 @@ func Delete(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req Accoun
func MockAccount(ctx context.Context, dbConn *sqlx.DB, now time.Time) (*Account, error) { func MockAccount(ctx context.Context, dbConn *sqlx.DB, now time.Time) (*Account, error) {
s := AccountStatus_Active s := AccountStatus_Active
repo := &Repository{
DbConn: dbConn,
}
req := AccountCreateRequest{ req := AccountCreateRequest{
Name: uuid.NewRandom().String(), Name: uuid.NewRandom().String(),
Address1: "103 East Main St", Address1: "103 East Main St",
@ -652,5 +666,5 @@ func MockAccount(ctx context.Context, dbConn *sqlx.DB, now time.Time) (*Account,
Zipcode: "99686", Zipcode: "99686",
Status: &s, Status: &s,
} }
return Create(ctx, auth.Claims{}, dbConn, req, now) return repo.Create(ctx, auth.Claims{}, req, now)
} }

View File

@ -63,7 +63,7 @@ func applyClaimsSelect(ctx context.Context, claims auth.Claims, query *sqlbuilde
// Find gets all the account preferences from the database based on the request params. // Find gets all the account preferences from the database based on the request params.
// TODO: Need to figure out why can't parse the args when appending the where to the query. // TODO: Need to figure out why can't parse the args when appending the where to the query.
func Find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req AccountPreferenceFindRequest) ([]*AccountPreference, error) { func (repo *Repository) Find(ctx context.Context, claims auth.Claims, req AccountPreferenceFindRequest) ([]*AccountPreference, error) {
query := sqlbuilder.NewSelectBuilder() query := sqlbuilder.NewSelectBuilder()
if req.Where != "" { if req.Where != "" {
query.Where(query.And(req.Where)) query.Where(query.And(req.Where))
@ -78,11 +78,11 @@ func Find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req AccountP
query.Offset(int(*req.Offset)) query.Offset(int(*req.Offset))
} }
return find(ctx, claims, dbConn, query, req.Args, req.IncludeArchived) return find(ctx, claims, repo.DbConn, query, req.Args, req.IncludeArchived)
} }
// FindByAccountID gets the specified account preferences for an account from the database. // FindByAccountID gets the specified account preferences for an account from the database.
func FindByAccountID(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req AccountPreferenceFindByAccountIDRequest) ([]*AccountPreference, error) { func (repo *Repository) FindByAccountID(ctx context.Context, claims auth.Claims, req AccountPreferenceFindByAccountIDRequest) ([]*AccountPreference, error) {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.account_preference.FindByAccountID") span, ctx := tracer.StartSpanFromContext(ctx, "internal.account_preference.FindByAccountID")
defer span.Finish() defer span.Finish()
@ -106,7 +106,7 @@ func FindByAccountID(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, r
query.Offset(int(*req.Offset)) query.Offset(int(*req.Offset))
} }
return find(ctx, claims, dbConn, query, []interface{}{}, req.IncludeArchived) return find(ctx, claims, repo.DbConn, query, []interface{}{}, req.IncludeArchived)
} }
// find internal method for getting all the account preferences from the database using a select query. // find internal method for getting all the account preferences from the database using a select query.
@ -157,7 +157,7 @@ func find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, query *sqlbu
} }
// Read gets the specified account preference from the database. // Read gets the specified account preference from the database.
func Read(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req AccountPreferenceReadRequest) (*AccountPreference, error) { func (repo *Repository) Read(ctx context.Context, claims auth.Claims, req AccountPreferenceReadRequest) (*AccountPreference, error) {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.account_preference.Read") span, ctx := tracer.StartSpanFromContext(ctx, "internal.account_preference.Read")
defer span.Finish() defer span.Finish()
@ -173,7 +173,7 @@ func Read(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req AccountP
query.Equal("account_id", req.AccountID)), query.Equal("account_id", req.AccountID)),
query.Equal("name", req.Name)) query.Equal("name", req.Name))
res, err := find(ctx, claims, dbConn, query, []interface{}{}, req.IncludeArchived) res, err := find(ctx, claims, repo.DbConn, query, []interface{}{}, req.IncludeArchived)
if err != nil { if err != nil {
return nil, err return nil, err
} else if res == nil || len(res) == 0 { } else if res == nil || len(res) == 0 {
@ -263,7 +263,7 @@ func Validator() *validator.Validate {
} }
// Set inserts a new account preference or updates an existing on. // Set inserts a new account preference or updates an existing on.
func Set(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req AccountPreferenceSetRequest, now time.Time) error { func (repo *Repository) Set(ctx context.Context, claims auth.Claims, req AccountPreferenceSetRequest, now time.Time) error {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.account_preference.Set") span, ctx := tracer.StartSpanFromContext(ctx, "internal.account_preference.Set")
defer span.Finish() defer span.Finish()
@ -276,7 +276,7 @@ func Set(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req AccountPr
} }
// Ensure the claims can modify the account specified in the request. // Ensure the claims can modify the account specified in the request.
err = account.CanModifyAccount(ctx, claims, dbConn, req.AccountID) err = account.CanModifyAccount(ctx, claims, repo.DbConn, req.AccountID)
if err != nil { if err != nil {
return err return err
} }
@ -301,11 +301,11 @@ func Set(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req AccountPr
// Execute the query with the provided context. // Execute the query with the provided context.
sql, args := query.Build() sql, args := query.Build()
sql = dbConn.Rebind(sql) sql = repo.DbConn.Rebind(sql)
sql = sql + " ON CONFLICT ON CONSTRAINT account_preferences_pkey DO UPDATE set value = EXCLUDED.value " sql = sql + " ON CONFLICT ON CONSTRAINT account_preferences_pkey DO UPDATE set value = EXCLUDED.value "
_, err = dbConn.ExecContext(ctx, sql, args...) _, err = repo.DbConn.ExecContext(ctx, sql, args...)
if err != nil { if err != nil {
err = errors.Wrapf(err, "query - %s", query.String()) err = errors.Wrapf(err, "query - %s", query.String())
err = errors.WithMessage(err, "set account preference failed") err = errors.WithMessage(err, "set account preference failed")
@ -316,7 +316,7 @@ func Set(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req AccountPr
} }
// Archive soft deleted the account preference from the database. // Archive soft deleted the account preference from the database.
func Archive(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req AccountPreferenceArchiveRequest, now time.Time) error { func (repo *Repository) Archive(ctx context.Context, claims auth.Claims, req AccountPreferenceArchiveRequest, now time.Time) error {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.account_preference.Archive") span, ctx := tracer.StartSpanFromContext(ctx, "internal.account_preference.Archive")
defer span.Finish() defer span.Finish()
@ -328,7 +328,7 @@ func Archive(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req Accou
} }
// Ensure the claims can modify the account specified in the request. // Ensure the claims can modify the account specified in the request.
err = account.CanModifyAccount(ctx, claims, dbConn, req.AccountID) err = account.CanModifyAccount(ctx, claims, repo.DbConn, req.AccountID)
if err != nil { if err != nil {
return err return err
} }
@ -355,8 +355,8 @@ func Archive(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req Accou
// Execute the query with the provided context. // Execute the query with the provided context.
sql, args := query.Build() sql, args := query.Build()
sql = dbConn.Rebind(sql) sql = repo.DbConn.Rebind(sql)
_, err = dbConn.ExecContext(ctx, sql, args...) _, err = repo.DbConn.ExecContext(ctx, sql, args...)
if err != nil { if err != nil {
err = errors.Wrapf(err, "query - %s", query.String()) err = errors.Wrapf(err, "query - %s", query.String())
err = errors.WithMessagef(err, "archive account preference %s for account %s failed", req.Name, req.AccountID) err = errors.WithMessagef(err, "archive account preference %s for account %s failed", req.Name, req.AccountID)
@ -367,7 +367,7 @@ func Archive(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req Accou
} }
// Delete removes an account preference from the database. // Delete removes an account preference from the database.
func Delete(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req AccountPreferenceDeleteRequest) error { func (repo *Repository) Delete(ctx context.Context, claims auth.Claims, req AccountPreferenceDeleteRequest) error {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.account_preference.Delete") span, ctx := tracer.StartSpanFromContext(ctx, "internal.account_preference.Delete")
defer span.Finish() defer span.Finish()
@ -379,13 +379,13 @@ func Delete(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req Accoun
} }
// Ensure the claims can modify the account specified in the request. // Ensure the claims can modify the account specified in the request.
err = account.CanModifyAccount(ctx, claims, dbConn, req.AccountID) err = account.CanModifyAccount(ctx, claims, repo.DbConn, req.AccountID)
if err != nil { if err != nil {
return err return err
} }
// Start a new transaction to handle rollbacks on error. // Start a new transaction to handle rollbacks on error.
tx, err := dbConn.Begin() tx, err := repo.DbConn.Begin()
if err != nil { if err != nil {
return errors.WithStack(err) return errors.WithStack(err)
} }
@ -397,7 +397,7 @@ func Delete(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req Accoun
// Execute the query with the provided context. // Execute the query with the provided context.
sql, args := query.Build() sql, args := query.Build()
sql = dbConn.Rebind(sql) sql = repo.DbConn.Rebind(sql)
_, err = tx.ExecContext(ctx, sql, args...) _, err = tx.ExecContext(ctx, sql, args...)
if err != nil { if err != nil {
tx.Rollback() tx.Rollback()
@ -417,10 +417,15 @@ func Delete(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req Accoun
// MockAccountPreference returns a fake AccountPreference for testing. // MockAccountPreference returns a fake AccountPreference for testing.
func MockAccountPreference(ctx context.Context, dbConn *sqlx.DB, now time.Time) error { func MockAccountPreference(ctx context.Context, dbConn *sqlx.DB, now time.Time) error {
repo := &Repository{
DbConn: dbConn,
}
req := AccountPreferenceSetRequest{ req := AccountPreferenceSetRequest{
AccountID: uuid.NewRandom().String(), AccountID: uuid.NewRandom().String(),
Name: AccountPreference_Datetime_Format, Name: AccountPreference_Datetime_Format,
Value: AccountPreference_Datetime_Format_Default, Value: AccountPreference_Datetime_Format_Default,
} }
return Set(ctx, auth.Claims{}, dbConn, req, now) return repo.Set(ctx, auth.Claims{}, req, now)
} }

View File

@ -1,13 +1,13 @@
package account_preference package account_preference
import ( import (
"geeks-accelerator/oss/saas-starter-kit/internal/account"
"math/rand" "math/rand"
"os" "os"
"strings" "strings"
"testing" "testing"
"time" "time"
"geeks-accelerator/oss/saas-starter-kit/internal/account"
"geeks-accelerator/oss/saas-starter-kit/internal/platform/auth" "geeks-accelerator/oss/saas-starter-kit/internal/platform/auth"
"geeks-accelerator/oss/saas-starter-kit/internal/platform/tests" "geeks-accelerator/oss/saas-starter-kit/internal/platform/tests"
"geeks-accelerator/oss/saas-starter-kit/internal/user_account" "geeks-accelerator/oss/saas-starter-kit/internal/user_account"
@ -17,7 +17,10 @@ import (
"github.com/pkg/errors" "github.com/pkg/errors"
) )
var test *tests.Test var (
test *tests.Test
repo *Repository
)
// TestMain is the entry point for testing. // TestMain is the entry point for testing.
func TestMain(m *testing.M) { func TestMain(m *testing.M) {
@ -27,6 +30,9 @@ func TestMain(m *testing.M) {
func testMain(m *testing.M) int { func testMain(m *testing.M) int {
test = tests.New() test = tests.New()
defer test.TearDown() defer test.TearDown()
repo = NewRepository(test.MasterDB)
return m.Run() return m.Run()
} }
@ -66,7 +72,7 @@ func TestSetValidation(t *testing.T) {
{ {
ctx := tests.Context() ctx := tests.Context()
err := Set(ctx, auth.Claims{}, test.MasterDB, tt.req, now) err := repo.Set(ctx, auth.Claims{}, tt.req, now)
if err != tt.error { if err != tt.error {
// TODO: need a better way to handle validation errors as they are // TODO: need a better way to handle validation errors as they are
// of type interface validator.ValidationErrorsTranslations // of type interface validator.ValidationErrorsTranslations
@ -225,7 +231,7 @@ func TestCrud(t *testing.T) {
{ {
ctx := tests.Context() ctx := tests.Context()
err := Set(ctx, tt.claims(usrAcc.AccountID, usrAcc.UserID), test.MasterDB, tt.set, now) err := repo.Set(ctx, tt.claims(usrAcc.AccountID, usrAcc.UserID), tt.set, now)
if err != nil && errors.Cause(err) != tt.writeErr { if err != nil && errors.Cause(err) != tt.writeErr {
t.Logf("\t\tGot : %+v", err) t.Logf("\t\tGot : %+v", err)
t.Logf("\t\tWant: %+v", tt.writeErr) t.Logf("\t\tWant: %+v", tt.writeErr)
@ -234,7 +240,7 @@ func TestCrud(t *testing.T) {
// If user doesn't have access to set, create one anyways to test the other endpoints. // If user doesn't have access to set, create one anyways to test the other endpoints.
if tt.writeErr != nil { if tt.writeErr != nil {
err := Set(ctx, auth.Claims{}, test.MasterDB, tt.set, now) err := repo.Set(ctx, auth.Claims{}, tt.set, now)
if err != nil { if err != nil {
t.Log("\t\tGot :", err) t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tCreate failed.", tests.Failed) t.Fatalf("\t%s\tCreate failed.", tests.Failed)
@ -242,7 +248,7 @@ func TestCrud(t *testing.T) {
} }
// Find the account and make sure the set where made. // Find the account and make sure the set where made.
readRes, err := Read(ctx, tt.claims(usrAcc.AccountID, usrAcc.UserID), test.MasterDB, AccountPreferenceReadRequest{ readRes, err := repo.Read(ctx, tt.claims(usrAcc.AccountID, usrAcc.UserID), AccountPreferenceReadRequest{
AccountID: tt.set.AccountID, AccountID: tt.set.AccountID,
Name: tt.set.Name, Name: tt.set.Name,
}) })
@ -266,7 +272,7 @@ func TestCrud(t *testing.T) {
} }
// Archive (soft-delete) the account. // Archive (soft-delete) the account.
err = Archive(ctx, tt.claims(usrAcc.AccountID, usrAcc.UserID), test.MasterDB, AccountPreferenceArchiveRequest{ err = repo.Archive(ctx, tt.claims(usrAcc.AccountID, usrAcc.UserID), AccountPreferenceArchiveRequest{
AccountID: tt.set.AccountID, AccountID: tt.set.AccountID,
Name: tt.set.Name, Name: tt.set.Name,
}, now) }, now)
@ -276,7 +282,7 @@ func TestCrud(t *testing.T) {
t.Fatalf("\t%s\tArchive failed.", tests.Failed) t.Fatalf("\t%s\tArchive failed.", tests.Failed)
} else if tt.findErr == nil { } else if tt.findErr == nil {
// Trying to find the archived account with the includeArchived false should result in not found. // Trying to find the archived account with the includeArchived false should result in not found.
_, err = Read(ctx, tt.claims(usrAcc.AccountID, usrAcc.UserID), test.MasterDB, AccountPreferenceReadRequest{ _, err = repo.Read(ctx, tt.claims(usrAcc.AccountID, usrAcc.UserID), AccountPreferenceReadRequest{
AccountID: tt.set.AccountID, AccountID: tt.set.AccountID,
Name: tt.set.Name, Name: tt.set.Name,
}) })
@ -287,7 +293,7 @@ func TestCrud(t *testing.T) {
} }
// Trying to find the archived account with the includeArchived true should result no error. // Trying to find the archived account with the includeArchived true should result no error.
_, err = Read(ctx, tt.claims(usrAcc.AccountID, usrAcc.UserID), test.MasterDB, AccountPreferenceReadRequest{ _, err = repo.Read(ctx, tt.claims(usrAcc.AccountID, usrAcc.UserID), AccountPreferenceReadRequest{
AccountID: tt.set.AccountID, AccountID: tt.set.AccountID,
Name: tt.set.Name, Name: tt.set.Name,
IncludeArchived: true, IncludeArchived: true,
@ -300,7 +306,7 @@ func TestCrud(t *testing.T) {
t.Logf("\t%s\tArchive ok.", tests.Success) t.Logf("\t%s\tArchive ok.", tests.Success)
// Delete (hard-delete) the account. // Delete (hard-delete) the account.
err = Delete(ctx, tt.claims(usrAcc.AccountID, usrAcc.UserID), test.MasterDB, AccountPreferenceDeleteRequest{ err = repo.Delete(ctx, tt.claims(usrAcc.AccountID, usrAcc.UserID), AccountPreferenceDeleteRequest{
AccountID: tt.set.AccountID, AccountID: tt.set.AccountID,
Name: tt.set.Name, Name: tt.set.Name,
}) })
@ -310,7 +316,7 @@ func TestCrud(t *testing.T) {
t.Fatalf("\t%s\tDelete failed.", tests.Failed) t.Fatalf("\t%s\tDelete failed.", tests.Failed)
} else if tt.writeErr == nil { } else if tt.writeErr == nil {
// Trying to find the deleted account with the includeArchived true should result in not found. // Trying to find the deleted account with the includeArchived true should result in not found.
_, err = Read(ctx, tt.claims(usrAcc.AccountID, usrAcc.UserID), test.MasterDB, AccountPreferenceReadRequest{ _, err = repo.Read(ctx, tt.claims(usrAcc.AccountID, usrAcc.UserID), AccountPreferenceReadRequest{
AccountID: tt.set.AccountID, AccountID: tt.set.AccountID,
Name: tt.set.Name, Name: tt.set.Name,
IncludeArchived: true, IncludeArchived: true,
@ -362,14 +368,14 @@ func TestFind(t *testing.T) {
var prefs []*AccountPreference var prefs []*AccountPreference
for idx, req := range reqs { for idx, req := range reqs {
err = Set(tests.Context(), auth.Claims{}, test.MasterDB, req, now.Add(time.Second*time.Duration(idx))) err = repo.Set(tests.Context(), auth.Claims{}, req, now.Add(time.Second*time.Duration(idx)))
if err != nil { if err != nil {
t.Logf("\t\tGot : %+v", err) t.Logf("\t\tGot : %+v", err)
t.Logf("\t\tRequest : %+v", req) t.Logf("\t\tRequest : %+v", req)
t.Fatalf("\t%s\tSet failed.", tests.Failed) t.Fatalf("\t%s\tSet failed.", tests.Failed)
} }
pref, err := Read(tests.Context(), auth.Claims{}, test.MasterDB, AccountPreferenceReadRequest{ pref, err := repo.Read(tests.Context(), auth.Claims{}, AccountPreferenceReadRequest{
AccountID: req.AccountID, AccountID: req.AccountID,
Name: req.Name, Name: req.Name,
}) })
@ -479,7 +485,7 @@ func TestFind(t *testing.T) {
{ {
ctx := tests.Context() ctx := tests.Context()
res, err := Find(ctx, auth.Claims{}, test.MasterDB, tt.req) res, err := repo.Find(ctx, auth.Claims{}, tt.req)
if errors.Cause(err) != tt.error { if errors.Cause(err) != tt.error {
t.Logf("\t\tGot : %+v", err) t.Logf("\t\tGot : %+v", err)
t.Logf("\t\tWant: %+v", tt.error) t.Logf("\t\tWant: %+v", tt.error)

View File

@ -2,15 +2,28 @@ package account_preference
import ( import (
"context" "context"
"github.com/pkg/errors"
"time" "time"
"database/sql/driver" "database/sql/driver"
"geeks-accelerator/oss/saas-starter-kit/internal/platform/web" "geeks-accelerator/oss/saas-starter-kit/internal/platform/web"
"github.com/jmoiron/sqlx"
"github.com/lib/pq" "github.com/lib/pq"
"github.com/pkg/errors"
"gopkg.in/go-playground/validator.v9" "gopkg.in/go-playground/validator.v9"
) )
// Repository defines the required dependencies for AccountPreference.
type Repository struct {
DbConn *sqlx.DB
}
// NewRepository creates a new Repository that defines dependencies for AccountPreference.
func NewRepository(db *sqlx.DB) *Repository {
return &Repository{
DbConn: db,
}
}
// AccountPreference represents an account setting. // AccountPreference represents an account setting.
type AccountPreference struct { type AccountPreference struct {
AccountID string `json:"account_id" validate:"required,uuid" example:"c4653bf9-5978-48b7-89c5-95704aebb7e2"` AccountID string `json:"account_id" validate:"required,uuid" example:"c4653bf9-5978-48b7-89c5-95704aebb7e2"`

View File

@ -17,7 +17,10 @@ import (
"github.com/pkg/errors" "github.com/pkg/errors"
) )
var test *tests.Test var (
test *tests.Test
repo *Repository
)
// TestMain is the entry point for testing. // TestMain is the entry point for testing.
func TestMain(m *testing.M) { func TestMain(m *testing.M) {
@ -27,6 +30,9 @@ func TestMain(m *testing.M) {
func testMain(m *testing.M) int { func testMain(m *testing.M) int {
test = tests.New() test = tests.New()
defer test.TearDown() defer test.TearDown()
repo = NewRepository(test.MasterDB)
return m.Run() return m.Run()
} }
@ -184,7 +190,7 @@ func TestCreateValidation(t *testing.T) {
{ {
ctx := tests.Context() ctx := tests.Context()
res, err := Create(ctx, auth.Claims{}, test.MasterDB, tt.req, now) res, err := repo.Create(ctx, auth.Claims{}, tt.req, now)
if err != tt.error { if err != tt.error {
// TODO: need a better way to handle validation errors as they are // TODO: need a better way to handle validation errors as they are
// of type interface validator.ValidationErrorsTranslations // of type interface validator.ValidationErrorsTranslations
@ -239,7 +245,7 @@ func TestCreateValidationNameUnique(t *testing.T) {
Country: "USA", Country: "USA",
Zipcode: "99686", Zipcode: "99686",
} }
account1, err := Create(ctx, auth.Claims{}, test.MasterDB, req1, now) account1, err := repo.Create(ctx, auth.Claims{}, req1, now)
if err != nil { if err != nil {
t.Log("\t\tGot :", err) t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tCreate failed.", tests.Failed) t.Fatalf("\t%s\tCreate failed.", tests.Failed)
@ -255,7 +261,7 @@ func TestCreateValidationNameUnique(t *testing.T) {
Zipcode: "99686", Zipcode: "99686",
} }
expectedErr := errors.New("Key: 'AccountCreateRequest.name' Error:Field validation for 'name' failed on the 'unique' tag") expectedErr := errors.New("Key: 'AccountCreateRequest.name' Error:Field validation for 'name' failed on the 'unique' tag")
_, err = Create(ctx, auth.Claims{}, test.MasterDB, req2, now) _, err = repo.Create(ctx, auth.Claims{}, req2, now)
if err == nil { if err == nil {
t.Logf("\t\tWant: %+v", expectedErr) t.Logf("\t\tWant: %+v", expectedErr)
t.Fatalf("\t%s\tCreate failed.", tests.Failed) t.Fatalf("\t%s\tCreate failed.", tests.Failed)
@ -349,7 +355,7 @@ func TestCreateClaims(t *testing.T) {
{ {
ctx := tests.Context() ctx := tests.Context()
_, err := Create(ctx, auth.Claims{}, test.MasterDB, tt.req, now) _, err := repo.Create(ctx, auth.Claims{}, tt.req, now)
if errors.Cause(err) != tt.error { if errors.Cause(err) != tt.error {
t.Logf("\t\tGot : %+v", err) t.Logf("\t\tGot : %+v", err)
t.Logf("\t\tWant: %+v", tt.error) t.Logf("\t\tWant: %+v", tt.error)
@ -396,7 +402,7 @@ func TestUpdateValidation(t *testing.T) {
{ {
ctx := tests.Context() ctx := tests.Context()
err := Update(ctx, auth.Claims{}, test.MasterDB, tt.req, now) err := repo.Update(ctx, auth.Claims{}, tt.req, now)
if err != tt.error { if err != tt.error {
// TODO: need a better way to handle validation errors as they are // TODO: need a better way to handle validation errors as they are
// of type interface validator.ValidationErrorsTranslations // of type interface validator.ValidationErrorsTranslations
@ -440,7 +446,7 @@ func TestUpdateValidationNameUnique(t *testing.T) {
Country: "USA", Country: "USA",
Zipcode: "99686", Zipcode: "99686",
} }
account1, err := Create(ctx, auth.Claims{}, test.MasterDB, req1, now) account1, err := repo.Create(ctx, auth.Claims{}, req1, now)
if err != nil { if err != nil {
t.Log("\t\tGot :", err) t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tCreate failed.", tests.Failed) t.Fatalf("\t%s\tCreate failed.", tests.Failed)
@ -455,7 +461,7 @@ func TestUpdateValidationNameUnique(t *testing.T) {
Country: "USA", Country: "USA",
Zipcode: "99686", Zipcode: "99686",
} }
account2, err := Create(ctx, auth.Claims{}, test.MasterDB, req2, now) account2, err := repo.Create(ctx, auth.Claims{}, req2, now)
if err != nil { if err != nil {
t.Log("\t\tGot :", err) t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tCreate failed.", tests.Failed) t.Fatalf("\t%s\tCreate failed.", tests.Failed)
@ -467,7 +473,7 @@ func TestUpdateValidationNameUnique(t *testing.T) {
Name: &account1.Name, Name: &account1.Name,
} }
expectedErr := errors.New("Key: 'AccountUpdateRequest.name' Error:Field validation for 'name' failed on the 'unique' tag") expectedErr := errors.New("Key: 'AccountUpdateRequest.name' Error:Field validation for 'name' failed on the 'unique' tag")
err = Update(ctx, auth.Claims{}, test.MasterDB, updateReq, now) err = repo.Update(ctx, auth.Claims{}, updateReq, now)
if err == nil { if err == nil {
t.Logf("\t\tWant: %+v", expectedErr) t.Logf("\t\tWant: %+v", expectedErr)
t.Fatalf("\t%s\tUpdate failed.", tests.Failed) t.Fatalf("\t%s\tUpdate failed.", tests.Failed)
@ -728,7 +734,7 @@ func TestCrud(t *testing.T) {
// Always create the new account with empty claims, testing claims for create account // Always create the new account with empty claims, testing claims for create account
// will be handled separately. // will be handled separately.
account, err := Create(ctx, auth.Claims{}, test.MasterDB, tt.create, now) account, err := repo.Create(ctx, auth.Claims{}, tt.create, now)
if err != nil { if err != nil {
t.Log("\t\tGot :", err) t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tCreate failed.", tests.Failed) t.Fatalf("\t%s\tCreate failed.", tests.Failed)
@ -744,7 +750,7 @@ func TestCrud(t *testing.T) {
// Update the account. // Update the account.
updateReq := tt.update(account) updateReq := tt.update(account)
err = Update(ctx, tt.claims(account, userId), test.MasterDB, updateReq, now) err = repo.Update(ctx, tt.claims(account, userId), updateReq, now)
if err != nil && errors.Cause(err) != tt.updateErr { if err != nil && errors.Cause(err) != tt.updateErr {
t.Logf("\t\tGot : %+v", err) t.Logf("\t\tGot : %+v", err)
t.Logf("\t\tWant: %+v", tt.updateErr) t.Logf("\t\tWant: %+v", tt.updateErr)
@ -753,7 +759,7 @@ func TestCrud(t *testing.T) {
t.Logf("\t%s\tUpdate ok.", tests.Success) t.Logf("\t%s\tUpdate ok.", tests.Success)
// Find the account and make sure the updates where made. // Find the account and make sure the updates where made.
findRes, err := ReadByID(ctx, tt.claims(account, userId), test.MasterDB, account.ID) findRes, err := repo.ReadByID(ctx, tt.claims(account, userId), account.ID)
if err != nil && errors.Cause(err) != tt.findErr { if err != nil && errors.Cause(err) != tt.findErr {
t.Logf("\t\tGot : %+v", err) t.Logf("\t\tGot : %+v", err)
t.Logf("\t\tWant: %+v", tt.findErr) t.Logf("\t\tWant: %+v", tt.findErr)
@ -767,14 +773,14 @@ func TestCrud(t *testing.T) {
} }
// Archive (soft-delete) the account. // Archive (soft-delete) the account.
err = Archive(ctx, tt.claims(account, userId), test.MasterDB, AccountArchiveRequest{ID: account.ID}, now) err = repo.Archive(ctx, tt.claims(account, userId), AccountArchiveRequest{ID: account.ID}, now)
if err != nil && errors.Cause(err) != tt.updateErr { if err != nil && errors.Cause(err) != tt.updateErr {
t.Logf("\t\tGot : %+v", err) t.Logf("\t\tGot : %+v", err)
t.Logf("\t\tWant: %+v", tt.updateErr) t.Logf("\t\tWant: %+v", tt.updateErr)
t.Fatalf("\t%s\tArchive failed.", tests.Failed) t.Fatalf("\t%s\tArchive failed.", tests.Failed)
} else if tt.updateErr == nil { } else if tt.updateErr == nil {
// Trying to find the archived account with the includeArchived false should result in not found. // Trying to find the archived account with the includeArchived false should result in not found.
_, err = ReadByID(ctx, tt.claims(account, userId), test.MasterDB, account.ID) _, err = repo.ReadByID(ctx, tt.claims(account, userId), account.ID)
if err != nil && errors.Cause(err) != ErrNotFound { if err != nil && errors.Cause(err) != ErrNotFound {
t.Logf("\t\tGot : %+v", err) t.Logf("\t\tGot : %+v", err)
t.Logf("\t\tWant: %+v", ErrNotFound) t.Logf("\t\tWant: %+v", ErrNotFound)
@ -782,7 +788,7 @@ func TestCrud(t *testing.T) {
} }
// Trying to find the archived account with the includeArchived true should result no error. // Trying to find the archived account with the includeArchived true should result no error.
_, err = Read(ctx, tt.claims(account, userId), test.MasterDB, _, err = repo.Read(ctx, tt.claims(account, userId),
AccountReadRequest{ID: account.ID, IncludeArchived: true}) AccountReadRequest{ID: account.ID, IncludeArchived: true})
if err != nil { if err != nil {
t.Log("\t\tGot :", err) t.Log("\t\tGot :", err)
@ -792,14 +798,14 @@ func TestCrud(t *testing.T) {
t.Logf("\t%s\tArchive ok.", tests.Success) t.Logf("\t%s\tArchive ok.", tests.Success)
// Delete (hard-delete) the account. // Delete (hard-delete) the account.
err = Delete(ctx, tt.claims(account, userId), test.MasterDB, AccountDeleteRequest{ID: account.ID}) err = repo.Delete(ctx, tt.claims(account, userId), AccountDeleteRequest{ID: account.ID})
if err != nil && errors.Cause(err) != tt.updateErr { if err != nil && errors.Cause(err) != tt.updateErr {
t.Logf("\t\tGot : %+v", err) t.Logf("\t\tGot : %+v", err)
t.Logf("\t\tWant: %+v", tt.updateErr) t.Logf("\t\tWant: %+v", tt.updateErr)
t.Fatalf("\t%s\tUpdate failed.", tests.Failed) t.Fatalf("\t%s\tUpdate failed.", tests.Failed)
} else if tt.updateErr == nil { } else if tt.updateErr == nil {
// Trying to find the deleted account with the includeArchived true should result in not found. // Trying to find the deleted account with the includeArchived true should result in not found.
_, err = ReadByID(ctx, tt.claims(account, userId), test.MasterDB, account.ID) _, err = repo.ReadByID(ctx, tt.claims(account, userId), account.ID)
if errors.Cause(err) != ErrNotFound { if errors.Cause(err) != ErrNotFound {
t.Logf("\t\tGot : %+v", err) t.Logf("\t\tGot : %+v", err)
t.Logf("\t\tWant: %+v", ErrNotFound) t.Logf("\t\tWant: %+v", ErrNotFound)
@ -822,7 +828,7 @@ func TestFind(t *testing.T) {
var accounts []*Account var accounts []*Account
for i := 0; i <= 4; i++ { for i := 0; i <= 4; i++ {
account, err := Create(tests.Context(), auth.Claims{}, test.MasterDB, AccountCreateRequest{ account, err := repo.Create(tests.Context(), auth.Claims{}, AccountCreateRequest{
Name: uuid.NewRandom().String(), Name: uuid.NewRandom().String(),
Address1: "103 East Main St", Address1: "103 East Main St",
Address2: "Unit 546", Address2: "Unit 546",
@ -935,7 +941,7 @@ func TestFind(t *testing.T) {
{ {
ctx := tests.Context() ctx := tests.Context()
res, err := Find(ctx, auth.Claims{}, test.MasterDB, tt.req) res, err := repo.Find(ctx, auth.Claims{}, tt.req)
if errors.Cause(err) != tt.error { if errors.Cause(err) != tt.error {
t.Logf("\t\tGot : %+v", err) t.Logf("\t\tGot : %+v", err)
t.Logf("\t\tWant: %+v", tt.error) t.Logf("\t\tWant: %+v", tt.error)

View File

@ -5,14 +5,27 @@ import (
"database/sql" "database/sql"
"database/sql/driver" "database/sql/driver"
"encoding/json" "encoding/json"
"geeks-accelerator/oss/saas-starter-kit/internal/platform/web"
"time" "time"
"geeks-accelerator/oss/saas-starter-kit/internal/platform/web"
"github.com/jmoiron/sqlx"
"github.com/lib/pq" "github.com/lib/pq"
"github.com/pkg/errors" "github.com/pkg/errors"
"gopkg.in/go-playground/validator.v9" "gopkg.in/go-playground/validator.v9"
) )
// Repository defines the required dependencies for Account.
type Repository struct {
DbConn *sqlx.DB
}
// NewRepository creates a new Repository that defines dependencies for Account.
func NewRepository(db *sqlx.DB) *Repository {
return &Repository{
DbConn: db,
}
}
// Account represents someone with access to our system. // Account represents someone with access to our system.
type Account struct { type Account struct {
ID string `json:"id" validate:"required,uuid" example:"c4653bf9-5978-48b7-89c5-95704aebb7e2"` ID string `json:"id" validate:"required,uuid" example:"c4653bf9-5978-48b7-89c5-95704aebb7e2"`

View File

@ -2,14 +2,28 @@ package project
import ( import (
"context" "context"
"time"
"database/sql/driver" "database/sql/driver"
"geeks-accelerator/oss/saas-starter-kit/internal/platform/web" "geeks-accelerator/oss/saas-starter-kit/internal/platform/web"
"github.com/jmoiron/sqlx"
"github.com/lib/pq" "github.com/lib/pq"
"github.com/pkg/errors" "github.com/pkg/errors"
"gopkg.in/go-playground/validator.v9" "gopkg.in/go-playground/validator.v9"
"time"
) )
// Repository defines the required dependencies for Project.
type Repository struct {
DbConn *sqlx.DB
}
// NewRepository creates a new Repository that defines dependencies for Project.
func NewRepository(db *sqlx.DB) *Repository {
return &Repository{
DbConn: db,
}
}
// Project represents a workflow. // Project represents a workflow.
type Project struct { type Project struct {
ID string `json:"id" validate:"required,uuid" example:"985f1746-1d9f-459f-a2d9-fc53ece5ae86"` ID string `json:"id" validate:"required,uuid" example:"985f1746-1d9f-459f-a2d9-fc53ece5ae86"`

View File

@ -3,6 +3,8 @@ package project
import ( import (
"context" "context"
"database/sql" "database/sql"
"time"
"geeks-accelerator/oss/saas-starter-kit/internal/platform/auth" "geeks-accelerator/oss/saas-starter-kit/internal/platform/auth"
"geeks-accelerator/oss/saas-starter-kit/internal/platform/web/webcontext" "geeks-accelerator/oss/saas-starter-kit/internal/platform/web/webcontext"
"github.com/huandu/go-sqlbuilder" "github.com/huandu/go-sqlbuilder"
@ -10,7 +12,6 @@ import (
"github.com/pborman/uuid" "github.com/pborman/uuid"
"github.com/pkg/errors" "github.com/pkg/errors"
"gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer" "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer"
"time"
) )
const ( const (
@ -27,7 +28,7 @@ var (
) )
// CanReadProject determines if claims has the authority to access the specified project by id. // CanReadProject determines if claims has the authority to access the specified project by id.
func CanReadProject(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, id string) error { func (repo *Repository) CanReadProject(ctx context.Context, claims auth.Claims, id string) error {
// If the request has claims from a specific project, ensure that the claims // If the request has claims from a specific project, ensure that the claims
// has the correct access to the project. // has the correct access to the project.
@ -40,9 +41,9 @@ func CanReadProject(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, id
)) ))
queryStr, args := query.Build() queryStr, args := query.Build()
queryStr = dbConn.Rebind(queryStr) queryStr = repo.DbConn.Rebind(queryStr)
var id string var id string
err := dbConn.QueryRowContext(ctx, queryStr, args...).Scan(&id) err := repo.DbConn.QueryRowContext(ctx, queryStr, args...).Scan(&id)
if err != nil && err != sql.ErrNoRows { if err != nil && err != sql.ErrNoRows {
err = errors.Wrapf(err, "query - %s", query.String()) err = errors.Wrapf(err, "query - %s", query.String())
return err return err
@ -60,8 +61,8 @@ func CanReadProject(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, id
} }
// CanModifyProject determines if claims has the authority to modify the specified project by id. // CanModifyProject determines if claims has the authority to modify the specified project by id.
func CanModifyProject(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, id string) error { func (repo *Repository) CanModifyProject(ctx context.Context, claims auth.Claims, id string) error {
err := CanReadProject(ctx, claims, dbConn, id) err := repo.CanReadProject(ctx, claims, id)
if err != nil { if err != nil {
return err return err
} }
@ -124,9 +125,9 @@ func findRequestQuery(req ProjectFindRequest) (*sqlbuilder.SelectBuilder, []inte
} }
// Find gets all the projects from the database based on the request params. // Find gets all the projects from the database based on the request params.
func Find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req ProjectFindRequest) (Projects, error) { func (repo *Repository) Find(ctx context.Context, claims auth.Claims, req ProjectFindRequest) (Projects, error) {
query, args := findRequestQuery(req) query, args := findRequestQuery(req)
return find(ctx, claims, dbConn, query, args, req.IncludeArchived) return find(ctx, claims, repo.DbConn, query, args, req.IncludeArchived)
} }
// find internal method for getting all the projects from the database using a select query. // find internal method for getting all the projects from the database using a select query.
@ -177,15 +178,15 @@ func find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, query *sqlbu
} }
// ReadByID gets the specified project by ID from the database. // ReadByID gets the specified project by ID from the database.
func ReadByID(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, id string) (*Project, error) { func (repo *Repository) ReadByID(ctx context.Context, claims auth.Claims, id string) (*Project, error) {
return Read(ctx, claims, dbConn, ProjectReadRequest{ return repo.Read(ctx, claims, ProjectReadRequest{
ID: id, ID: id,
IncludeArchived: false, IncludeArchived: false,
}) })
} }
// Read gets the specified project from the database. // Read gets the specified project from the database.
func Read(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req ProjectReadRequest) (*Project, error) { func (repo *Repository) Read(ctx context.Context, claims auth.Claims, req ProjectReadRequest) (*Project, error) {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.project.Read") span, ctx := tracer.StartSpanFromContext(ctx, "internal.project.Read")
defer span.Finish() defer span.Finish()
@ -200,7 +201,7 @@ func Read(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req ProjectR
query := sqlbuilder.NewSelectBuilder() query := sqlbuilder.NewSelectBuilder()
query.Where(query.Equal("id", req.ID)) query.Where(query.Equal("id", req.ID))
res, err := find(ctx, claims, dbConn, query, []interface{}{}, req.IncludeArchived) res, err := find(ctx, claims, repo.DbConn, query, []interface{}{}, req.IncludeArchived)
if err != nil { if err != nil {
return nil, err return nil, err
} else if res == nil || len(res) == 0 { } else if res == nil || len(res) == 0 {
@ -213,7 +214,7 @@ func Read(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req ProjectR
} }
// Create inserts a new project into the database. // Create inserts a new project into the database.
func Create(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req ProjectCreateRequest, now time.Time) (*Project, error) { func (repo *Repository) Create(ctx context.Context, claims auth.Claims, req ProjectCreateRequest, now time.Time) (*Project, error) {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.project.Create") span, ctx := tracer.StartSpanFromContext(ctx, "internal.project.Create")
defer span.Finish() defer span.Finish()
if claims.Audience != "" { if claims.Audience != "" {
@ -290,8 +291,8 @@ func Create(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req Projec
// Execute the query with the provided context. // Execute the query with the provided context.
sql, args := query.Build() sql, args := query.Build()
sql = dbConn.Rebind(sql) sql = repo.DbConn.Rebind(sql)
_, err = dbConn.ExecContext(ctx, sql, args...) _, err = repo.DbConn.ExecContext(ctx, sql, args...)
if err != nil { if err != nil {
err = errors.Wrapf(err, "query - %s", query.String()) err = errors.Wrapf(err, "query - %s", query.String())
err = errors.WithMessage(err, "create project failed") err = errors.WithMessage(err, "create project failed")
@ -302,7 +303,7 @@ func Create(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req Projec
} }
// Update replaces an project in the database. // Update replaces an project in the database.
func Update(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req ProjectUpdateRequest, now time.Time) error { func (repo *Repository) Update(ctx context.Context, claims auth.Claims, req ProjectUpdateRequest, now time.Time) error {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.project.Update") span, ctx := tracer.StartSpanFromContext(ctx, "internal.project.Update")
defer span.Finish() defer span.Finish()
@ -314,7 +315,7 @@ func Update(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req Projec
} }
// Ensure the claims can modify the project specified in the request. // Ensure the claims can modify the project specified in the request.
err = CanModifyProject(ctx, claims, dbConn, req.ID) err = repo.CanModifyProject(ctx, claims, req.ID)
if err != nil { if err != nil {
return err return err
} }
@ -352,8 +353,8 @@ func Update(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req Projec
query.Where(query.Equal("id", req.ID)) query.Where(query.Equal("id", req.ID))
// Execute the query with the provided context. // Execute the query with the provided context.
sql, args := query.Build() sql, args := query.Build()
sql = dbConn.Rebind(sql) sql = repo.DbConn.Rebind(sql)
_, err = dbConn.ExecContext(ctx, sql, args...) _, err = repo.DbConn.ExecContext(ctx, sql, args...)
if err != nil { if err != nil {
err = errors.Wrapf(err, "query - %s", query.String()) err = errors.Wrapf(err, "query - %s", query.String())
err = errors.WithMessagef(err, "update project %s failed", req.ID) err = errors.WithMessagef(err, "update project %s failed", req.ID)
@ -364,7 +365,7 @@ func Update(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req Projec
} }
// Archive soft deleted the project from the database. // Archive soft deleted the project from the database.
func Archive(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req ProjectArchiveRequest, now time.Time) error { func (repo *Repository) Archive(ctx context.Context, claims auth.Claims, req ProjectArchiveRequest, now time.Time) error {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.project.Archive") span, ctx := tracer.StartSpanFromContext(ctx, "internal.project.Archive")
defer span.Finish() defer span.Finish()
@ -376,7 +377,7 @@ func Archive(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req Proje
} }
// Ensure the claims can modify the project specified in the request. // Ensure the claims can modify the project specified in the request.
err = CanModifyProject(ctx, claims, dbConn, req.ID) err = repo.CanModifyProject(ctx, claims, req.ID)
if err != nil { if err != nil {
return err return err
} }
@ -401,8 +402,8 @@ func Archive(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req Proje
query.Where(query.Equal("id", req.ID)) query.Where(query.Equal("id", req.ID))
// Execute the query with the provided context. // Execute the query with the provided context.
sql, args := query.Build() sql, args := query.Build()
sql = dbConn.Rebind(sql) sql = repo.DbConn.Rebind(sql)
_, err = dbConn.ExecContext(ctx, sql, args...) _, err = repo.DbConn.ExecContext(ctx, sql, args...)
if err != nil { if err != nil {
err = errors.Wrapf(err, "query - %s", query.String()) err = errors.Wrapf(err, "query - %s", query.String())
err = errors.WithMessagef(err, "archive project %s failed", req.ID) err = errors.WithMessagef(err, "archive project %s failed", req.ID)
@ -413,7 +414,7 @@ func Archive(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req Proje
} }
// Delete removes an project from the database. // Delete removes an project from the database.
func Delete(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req ProjectDeleteRequest) error { func (repo *Repository) Delete(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req ProjectDeleteRequest) error {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.project.Delete") span, ctx := tracer.StartSpanFromContext(ctx, "internal.project.Delete")
defer span.Finish() defer span.Finish()
@ -425,7 +426,7 @@ func Delete(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req Projec
} }
// Ensure the claims can modify the project specified in the request. // Ensure the claims can modify the project specified in the request.
err = CanModifyProject(ctx, claims, dbConn, req.ID) err = repo.CanModifyProject(ctx, claims, req.ID)
if err != nil { if err != nil {
return err return err
} }

View File

@ -1,15 +1,19 @@
package project package project
import ( import (
"os"
"testing"
"geeks-accelerator/oss/saas-starter-kit/internal/platform/auth" "geeks-accelerator/oss/saas-starter-kit/internal/platform/auth"
"geeks-accelerator/oss/saas-starter-kit/internal/platform/tests" "geeks-accelerator/oss/saas-starter-kit/internal/platform/tests"
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
"github.com/huandu/go-sqlbuilder" "github.com/huandu/go-sqlbuilder"
"os"
"testing"
) )
var test *tests.Test var (
test *tests.Test
repo *Repository
)
// TestMain is the entry point for testing. // TestMain is the entry point for testing.
func TestMain(m *testing.M) { func TestMain(m *testing.M) {
@ -19,6 +23,9 @@ func TestMain(m *testing.M) {
func testMain(m *testing.M) int { func testMain(m *testing.M) int {
test = tests.New() test = tests.New()
defer test.TearDown() defer test.TearDown()
repo = NewRepository(test.MasterDB)
return m.Run() return m.Run()
} }

View File

@ -1,17 +1,17 @@
package project_routes package project_route
import ( import (
"github.com/pkg/errors" "github.com/pkg/errors"
"net/url" "net/url"
) )
type ProjectRoutes struct { type ProjectRoute struct {
webAppUrl url.URL webAppUrl url.URL
webApiUrl url.URL webApiUrl url.URL
} }
func New(apiBaseUrl, appBaseUrl string) (ProjectRoutes, error) { func New(apiBaseUrl, appBaseUrl string) (ProjectRoute, error) {
var r ProjectRoutes var r ProjectRoute
apiUrl, err := url.Parse(apiBaseUrl) apiUrl, err := url.Parse(apiBaseUrl)
if err != nil { if err != nil {
@ -28,37 +28,37 @@ func New(apiBaseUrl, appBaseUrl string) (ProjectRoutes, error) {
return r, nil return r, nil
} }
func (r ProjectRoutes) WebAppUrl(urlPath string) string { func (r ProjectRoute) WebAppUrl(urlPath string) string {
u := r.webAppUrl u := r.webAppUrl
u.Path = urlPath u.Path = urlPath
return u.String() return u.String()
} }
func (r ProjectRoutes) WebApiUrl(urlPath string) string { func (r ProjectRoute) WebApiUrl(urlPath string) string {
u := r.webApiUrl u := r.webApiUrl
u.Path = urlPath u.Path = urlPath
return u.String() return u.String()
} }
func (r ProjectRoutes) UserResetPassword(resetHash string) string { func (r ProjectRoute) UserResetPassword(resetHash string) string {
u := r.webAppUrl u := r.webAppUrl
u.Path = "/user/reset-password/" + resetHash u.Path = "/user/reset-password/" + resetHash
return u.String() return u.String()
} }
func (r ProjectRoutes) UserInviteAccept(inviteHash string) string { func (r ProjectRoute) UserInviteAccept(inviteHash string) string {
u := r.webAppUrl u := r.webAppUrl
u.Path = "/users/invite/" + inviteHash u.Path = "/users/invite/" + inviteHash
return u.String() return u.String()
} }
func (r ProjectRoutes) ApiDocs() string { func (r ProjectRoute) ApiDocs() string {
u := r.webApiUrl u := r.webApiUrl
u.Path = "/docs" u.Path = "/docs"
return u.String() return u.String()
} }
func (r ProjectRoutes) ApiDocsJson() string { func (r ProjectRoute) ApiDocsJson() string {
u := r.webApiUrl u := r.webApiUrl
u.Path = "/docs/doc.json" u.Path = "/docs/doc.json"
return u.String() return u.String()

View File

@ -2,10 +2,31 @@ package signup
import ( import (
"context" "context"
"geeks-accelerator/oss/saas-starter-kit/internal/account" "geeks-accelerator/oss/saas-starter-kit/internal/account"
"geeks-accelerator/oss/saas-starter-kit/internal/user" "geeks-accelerator/oss/saas-starter-kit/internal/user"
"geeks-accelerator/oss/saas-starter-kit/internal/user_account"
"github.com/jmoiron/sqlx"
) )
// Repository defines the required dependencies for Signup.
type Repository struct {
DbConn *sqlx.DB
User *user.Repository
UserAccount *user_account.Repository
Account *account.Repository
}
// NewRepository creates a new Repository that defines dependencies for Signup.
func NewRepository(db *sqlx.DB, user *user.Repository, userAccount *user_account.Repository, account *account.Repository) *Repository {
return &Repository{
DbConn: db,
User: user,
UserAccount: userAccount,
Account: account,
}
}
// SignupRequest contains information needed perform signup. // SignupRequest contains information needed perform signup.
type SignupRequest struct { type SignupRequest struct {
Account SignupAccount `json:"account" validate:"required"` // Account details. Account SignupAccount `json:"account" validate:"required"` // Account details.

View File

@ -9,25 +9,24 @@ import (
"geeks-accelerator/oss/saas-starter-kit/internal/platform/web/webcontext" "geeks-accelerator/oss/saas-starter-kit/internal/platform/web/webcontext"
"geeks-accelerator/oss/saas-starter-kit/internal/user" "geeks-accelerator/oss/saas-starter-kit/internal/user"
"geeks-accelerator/oss/saas-starter-kit/internal/user_account" "geeks-accelerator/oss/saas-starter-kit/internal/user_account"
"github.com/jmoiron/sqlx"
"gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer" "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer"
) )
// Signup performs the steps needed to create a new account, new user and then associate // Signup performs the steps needed to create a new account, new user and then associate
// both records with a new user_account entry. // both records with a new user_account entry.
func Signup(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req SignupRequest, now time.Time) (*SignupResult, error) { func (repo *Repository) Signup(ctx context.Context, claims auth.Claims, req SignupRequest, now time.Time) (*SignupResult, error) {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.signup.Signup") span, ctx := tracer.StartSpanFromContext(ctx, "internal.signup.Signup")
defer span.Finish() defer span.Finish()
// Validate the user email address is unique in the database. // Validate the user email address is unique in the database.
uniqEmail, err := user.UniqueEmail(ctx, dbConn, req.User.Email, "") uniqEmail, err := user.UniqueEmail(ctx, repo.DbConn, req.User.Email, "")
if err != nil { if err != nil {
return nil, err return nil, err
} }
ctx = webcontext.ContextAddUniqueValue(ctx, req.User, "Email", uniqEmail) ctx = webcontext.ContextAddUniqueValue(ctx, req.User, "Email", uniqEmail)
// Validate the account name is unique in the database. // Validate the account name is unique in the database.
uniqName, err := account.UniqueName(ctx, dbConn, req.Account.Name, "") uniqName, err := account.UniqueName(ctx, repo.DbConn, req.Account.Name, "")
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -52,7 +51,7 @@ func Signup(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req Signup
} }
// Execute user creation. // Execute user creation.
resp.User, err = user.Create(ctx, claims, dbConn, userReq, now) resp.User, err = repo.User.Create(ctx, claims, userReq, now)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -73,7 +72,7 @@ func Signup(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req Signup
} }
// Execute account creation. // Execute account creation.
resp.Account, err = account.Create(ctx, claims, dbConn, accountReq, now) resp.Account, err = repo.Account.Create(ctx, claims, accountReq, now)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -87,7 +86,7 @@ func Signup(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req Signup
//Status: Use default value //Status: Use default value
} }
_, err = user_account.Create(ctx, claims, dbConn, ua, now) _, err = repo.UserAccount.Create(ctx, claims, ua, now)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -1,19 +1,26 @@
package signup package signup
import ( import (
"geeks-accelerator/oss/saas-starter-kit/internal/user_auth"
"os" "os"
"testing" "testing"
"time" "time"
"geeks-accelerator/oss/saas-starter-kit/internal/account"
"geeks-accelerator/oss/saas-starter-kit/internal/account/account_preference"
"geeks-accelerator/oss/saas-starter-kit/internal/platform/auth" "geeks-accelerator/oss/saas-starter-kit/internal/platform/auth"
"geeks-accelerator/oss/saas-starter-kit/internal/platform/tests" "geeks-accelerator/oss/saas-starter-kit/internal/platform/tests"
"geeks-accelerator/oss/saas-starter-kit/internal/user"
"geeks-accelerator/oss/saas-starter-kit/internal/user_account"
"geeks-accelerator/oss/saas-starter-kit/internal/user_auth"
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
"github.com/pborman/uuid" "github.com/pborman/uuid"
"github.com/pkg/errors" "github.com/pkg/errors"
) )
var test *tests.Test var (
test *tests.Test
repo *Repository
)
// TestMain is the entry point for testing. // TestMain is the entry point for testing.
func TestMain(m *testing.M) { func TestMain(m *testing.M) {
@ -23,6 +30,13 @@ func TestMain(m *testing.M) {
func testMain(m *testing.M) int { func testMain(m *testing.M) int {
test = tests.New() test = tests.New()
defer test.TearDown() defer test.TearDown()
userRepo := user.MockRepository(test.MasterDB)
userAccRepo := user_account.NewRepository(test.MasterDB)
accRepo := account.NewRepository(test.MasterDB)
repo = NewRepository(test.MasterDB, userRepo, userAccRepo, accRepo)
return m.Run() return m.Run()
} }
@ -63,7 +77,7 @@ func TestSignupValidation(t *testing.T) {
{ {
ctx := tests.Context() ctx := tests.Context()
res, err := Signup(ctx, auth.Claims{}, test.MasterDB, tt.req, now) res, err := repo.Signup(ctx, auth.Claims{}, tt.req, now)
if err != tt.error { if err != tt.error {
// TODO: need a better way to handle validation errors as they are // TODO: need a better way to handle validation errors as they are
// of type interface validator.ValidationErrorsTranslations // of type interface validator.ValidationErrorsTranslations
@ -127,9 +141,12 @@ func TestSignupFull(t *testing.T) {
tknGen := &auth.MockTokenGenerator{} tknGen := &auth.MockTokenGenerator{}
accPrefRepo := account_preference.NewRepository(test.MasterDB)
authRepo := user_auth.NewRepository(test.MasterDB, tknGen, repo.User, repo.UserAccount, accPrefRepo)
t.Log("Given the need to ensure signup works.") t.Log("Given the need to ensure signup works.")
{ {
res, err := Signup(ctx, auth.Claims{}, test.MasterDB, req, now) res, err := repo.Signup(ctx, auth.Claims{}, req, now)
if err != nil { if err != nil {
t.Logf("\t\tGot error : %+v", err) t.Logf("\t\tGot error : %+v", err)
t.Fatalf("\t%s\tSignup failed.", tests.Failed) t.Fatalf("\t%s\tSignup failed.", tests.Failed)
@ -162,7 +179,7 @@ func TestSignupFull(t *testing.T) {
t.Logf("\t%s\tSignup ok.", tests.Success) t.Logf("\t%s\tSignup ok.", tests.Success)
// Verify that the user can be authenticated with the updated password. // Verify that the user can be authenticated with the updated password.
_, err = user_auth.Authenticate(ctx, test.MasterDB, tknGen, user_auth.AuthenticateRequest{ _, err = authRepo.Authenticate(ctx, user_auth.AuthenticateRequest{
Email: res.User.Email, Email: res.User.Email,
Password: req.User.Password, Password: req.User.Password,
}, time.Hour, now) }, time.Hour, now)

View File

@ -4,17 +4,17 @@ import (
"context" "context"
"database/sql" "database/sql"
"encoding/json" "encoding/json"
"geeks-accelerator/oss/saas-starter-kit/internal/platform/notify"
"geeks-accelerator/oss/saas-starter-kit/internal/platform/web"
"geeks-accelerator/oss/saas-starter-kit/internal/platform/web/webcontext"
"github.com/jmoiron/sqlx"
"github.com/pkg/errors"
"github.com/sudo-suhas/symcrypto"
"strconv" "strconv"
"strings" "strings"
"time" "time"
"geeks-accelerator/oss/saas-starter-kit/internal/platform/notify"
"geeks-accelerator/oss/saas-starter-kit/internal/platform/web"
"geeks-accelerator/oss/saas-starter-kit/internal/platform/web/webcontext"
"github.com/jmoiron/sqlx"
"github.com/lib/pq" "github.com/lib/pq"
"github.com/pkg/errors"
"github.com/sudo-suhas/symcrypto"
) )
// Repository defines the required dependencies for User. // Repository defines the required dependencies for User.
@ -22,7 +22,7 @@ type Repository struct {
DbConn *sqlx.DB DbConn *sqlx.DB
ResetUrl func(string) string ResetUrl func(string) string
Notify notify.Email Notify notify.Email
SecretKey string secretKey string
} }
// NewRepository creates a new Repository that defines dependencies for User. // NewRepository creates a new Repository that defines dependencies for User.
@ -31,7 +31,7 @@ func NewRepository(db *sqlx.DB, resetUrl func(string) string, notify notify.Emai
DbConn: db, DbConn: db,
ResetUrl: resetUrl, ResetUrl: resetUrl,
Notify: notify, Notify: notify,
SecretKey: secretKey, secretKey: secretKey,
} }
} }

View File

@ -6,6 +6,7 @@ import (
"time" "time"
"geeks-accelerator/oss/saas-starter-kit/internal/platform/auth" "geeks-accelerator/oss/saas-starter-kit/internal/platform/auth"
"geeks-accelerator/oss/saas-starter-kit/internal/platform/notify"
"geeks-accelerator/oss/saas-starter-kit/internal/platform/web/webcontext" "geeks-accelerator/oss/saas-starter-kit/internal/platform/web/webcontext"
"github.com/huandu/go-sqlbuilder" "github.com/huandu/go-sqlbuilder"
"github.com/jmoiron/sqlx" "github.com/jmoiron/sqlx"
@ -200,11 +201,11 @@ func findRequestQuery(req UserFindRequest) (*sqlbuilder.SelectBuilder, []interfa
// Find gets all the users from the database based on the request params. // Find gets all the users from the database based on the request params.
func (repo *Repository) Find(ctx context.Context, claims auth.Claims, req UserFindRequest) (Users, error) { func (repo *Repository) Find(ctx context.Context, claims auth.Claims, req UserFindRequest) (Users, error) {
query, args := findRequestQuery(req) query, args := findRequestQuery(req)
return repo.find(ctx, claims, query, args, req.IncludeArchived) return find(ctx, claims, repo.DbConn, query, args, req.IncludeArchived)
} }
// find internal method for getting all the users from the database using a select query. // find internal method for getting all the users from the database using a select query.
func (repo *Repository) find(ctx context.Context, claims auth.Claims, query *sqlbuilder.SelectBuilder, args []interface{}, includedArchived bool) (Users, error) { func find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, query *sqlbuilder.SelectBuilder, args []interface{}, includedArchived bool) (Users, error) {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.Find") span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.Find")
defer span.Finish() defer span.Finish()
@ -221,11 +222,11 @@ func (repo *Repository) find(ctx context.Context, claims auth.Claims, query *sql
return nil, err return nil, err
} }
queryStr, queryArgs := query.Build() queryStr, queryArgs := query.Build()
queryStr = repo.DbConn.Rebind(queryStr) queryStr = dbConn.Rebind(queryStr)
args = append(args, queryArgs...) args = append(args, queryArgs...)
// fetch all places from the db // fetch all places from the db
rows, err := repo.DbConn.QueryContext(ctx, queryStr, args...) rows, err := dbConn.QueryContext(ctx, queryStr, args...)
if err != nil { if err != nil {
err = errors.Wrapf(err, "query - %s", query.String()) err = errors.Wrapf(err, "query - %s", query.String())
err = errors.WithMessage(err, "find users failed") err = errors.WithMessage(err, "find users failed")
@ -247,17 +248,17 @@ func (repo *Repository) find(ctx context.Context, claims auth.Claims, query *sql
} }
// Validation an email address is unique excluding the current user ID. // Validation an email address is unique excluding the current user ID.
func (repo *Repository) UniqueEmail(ctx context.Context, email, userId string) (bool, error) { func UniqueEmail(ctx context.Context, dbConn *sqlx.DB, email, userId string) (bool, error) {
query := sqlbuilder.NewSelectBuilder().Select("id").From(userTableName) query := sqlbuilder.NewSelectBuilder().Select("id").From(userTableName)
query.Where(query.And( query.Where(query.And(
query.Equal("email", email), query.Equal("email", email),
query.NotEqual("id", userId), query.NotEqual("id", userId),
)) ))
queryStr, args := query.Build() queryStr, args := query.Build()
queryStr = repo.DbConn.Rebind(queryStr) queryStr = dbConn.Rebind(queryStr)
var existingId string var existingId string
err := repo.DbConn.QueryRowContext(ctx, queryStr, args...).Scan(&existingId) err := dbConn.QueryRowContext(ctx, queryStr, args...).Scan(&existingId)
if err != nil && err != sql.ErrNoRows { if err != nil && err != sql.ErrNoRows {
err = errors.Wrapf(err, "query - %s", query.String()) err = errors.Wrapf(err, "query - %s", query.String())
return false, err return false, err
@ -283,7 +284,7 @@ func (repo *Repository) Create(ctx context.Context, claims auth.Claims, req User
v := webcontext.Validator() v := webcontext.Validator()
// Validation email address is unique in the database. // Validation email address is unique in the database.
uniq, err := repo.UniqueEmail(ctx, req.Email, "") uniq, err := UniqueEmail(ctx, repo.DbConn, req.Email, "")
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -364,7 +365,7 @@ func (repo *Repository) CreateInvite(ctx context.Context, claims auth.Claims, re
v := webcontext.Validator() v := webcontext.Validator()
// Validation email address is unique in the database. // Validation email address is unique in the database.
uniq, err := repo.UniqueEmail(ctx, req.Email, "") uniq, err := UniqueEmail(ctx, repo.DbConn, req.Email, "")
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -448,7 +449,7 @@ func (repo *Repository) Read(ctx context.Context, claims auth.Claims, req UserRe
query := selectQuery() query := selectQuery()
query.Where(query.Equal("id", req.ID)) query.Where(query.Equal("id", req.ID))
res, err := repo.find(ctx, claims, query, []interface{}{}, req.IncludeArchived) res, err := find(ctx, claims, repo.DbConn, query, []interface{}{}, req.IncludeArchived)
if err != nil { if err != nil {
return nil, err return nil, err
} else if res == nil || len(res) == 0 { } else if res == nil || len(res) == 0 {
@ -469,7 +470,7 @@ func (repo *Repository) ReadByEmail(ctx context.Context, claims auth.Claims, ema
query := selectQuery() query := selectQuery()
query.Where(query.Equal("email", email)) query.Where(query.Equal("email", email))
res, err := repo.find(ctx, claims, query, []interface{}{}, includedArchived) res, err := find(ctx, claims, repo.DbConn, query, []interface{}{}, includedArchived)
if err != nil { if err != nil {
return nil, err return nil, err
} else if res == nil || len(res) == 0 { } else if res == nil || len(res) == 0 {
@ -489,7 +490,7 @@ func (repo *Repository) Update(ctx context.Context, claims auth.Claims, req User
// Validation email address is unique in the database. // Validation email address is unique in the database.
if req.Email != nil { if req.Email != nil {
// Validation email address is unique in the database. // Validation email address is unique in the database.
uniq, err := repo.UniqueEmail(ctx, *req.Email, req.ID) uniq, err := UniqueEmail(ctx, repo.DbConn, *req.Email, req.ID)
if err != nil { if err != nil {
return err return err
} }
@ -844,7 +845,7 @@ func (repo *Repository) ResetPassword(ctx context.Context, req UserResetPassword
query := selectQuery() query := selectQuery()
query.Where(query.Equal("email", req.Email)) query.Where(query.Equal("email", req.Email))
res, err := repo.find(ctx, auth.Claims{}, query, []interface{}{}, false) res, err := find(ctx, auth.Claims{}, repo.DbConn, query, []interface{}{}, false)
if err != nil { if err != nil {
return "", err return "", err
} else if res == nil || len(res) == 0 { } else if res == nil || len(res) == 0 {
@ -894,7 +895,7 @@ func (repo *Repository) ResetPassword(ctx context.Context, req UserResetPassword
requestIp = vals.RequestIP requestIp = vals.RequestIP
} }
encrypted, err := NewResetHash(ctx, repo.SecretKey, resetId, requestIp, req.TTL, now) encrypted, err := NewResetHash(ctx, repo.secretKey, resetId, requestIp, req.TTL, now)
if err != nil { if err != nil {
return "", err return "", err
} }
@ -927,7 +928,7 @@ func (repo *Repository) ResetConfirm(ctx context.Context, req UserResetConfirmRe
return nil, err return nil, err
} }
hash, err := ParseResetHash(ctx, repo.SecretKey, req.ResetHash, now) hash, err := ParseResetHash(ctx, repo.secretKey, req.ResetHash, now)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -938,7 +939,7 @@ func (repo *Repository) ResetConfirm(ctx context.Context, req UserResetConfirmRe
query := selectQuery() query := selectQuery()
query.Where(query.Equal("password_reset", hash.ResetID)) query.Where(query.Equal("password_reset", hash.ResetID))
res, err := repo.find(ctx, auth.Claims{}, query, []interface{}{}, false) res, err := find(ctx, auth.Claims{}, repo.DbConn, query, []interface{}{}, false)
if err != nil { if err != nil {
return nil, err return nil, err
} else if res == nil || len(res) == 0 { } else if res == nil || len(res) == 0 {
@ -1020,3 +1021,14 @@ func MockUser(ctx context.Context, dbConn *sqlx.DB, now time.Time) (*MockUserRes
Password: pass, Password: pass,
}, nil }, nil
} }
func MockRepository(dbConn *sqlx.DB) *Repository {
// Mock the methods needed to make a password reset.
resetUrl := func(string) string {
return ""
}
notify := &notify.MockEmail{}
secretKey := "6368616e676520746869732070617373"
return NewRepository(dbConn, resetUrl, notify, secretKey)
}

View File

@ -8,7 +8,6 @@ import (
"time" "time"
"geeks-accelerator/oss/saas-starter-kit/internal/platform/auth" "geeks-accelerator/oss/saas-starter-kit/internal/platform/auth"
"geeks-accelerator/oss/saas-starter-kit/internal/platform/notify"
"geeks-accelerator/oss/saas-starter-kit/internal/platform/tests" "geeks-accelerator/oss/saas-starter-kit/internal/platform/tests"
"github.com/dgrijalva/jwt-go" "github.com/dgrijalva/jwt-go"
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
@ -32,14 +31,7 @@ func testMain(m *testing.M) int {
test = tests.New() test = tests.New()
defer test.TearDown() defer test.TearDown()
// Mock the methods needed to make a password reset. repo = MockRepository(test.MasterDB)
resetUrl := func(string) string {
return ""
}
notify := &notify.MockEmail{}
secretKey := "6368616e676520746869732070617373"
repo = NewRepository(test.MasterDB, resetUrl, notify, secretKey)
return m.Run() return m.Run()
} }

View File

@ -8,11 +8,9 @@ import (
"geeks-accelerator/oss/saas-starter-kit/internal/account" "geeks-accelerator/oss/saas-starter-kit/internal/account"
"geeks-accelerator/oss/saas-starter-kit/internal/platform/auth" "geeks-accelerator/oss/saas-starter-kit/internal/platform/auth"
"geeks-accelerator/oss/saas-starter-kit/internal/platform/notify"
"geeks-accelerator/oss/saas-starter-kit/internal/platform/web/webcontext" "geeks-accelerator/oss/saas-starter-kit/internal/platform/web/webcontext"
"geeks-accelerator/oss/saas-starter-kit/internal/user" "geeks-accelerator/oss/saas-starter-kit/internal/user"
"geeks-accelerator/oss/saas-starter-kit/internal/user_account" "geeks-accelerator/oss/saas-starter-kit/internal/user_account"
"github.com/jmoiron/sqlx"
"github.com/pkg/errors" "github.com/pkg/errors"
"gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer" "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer"
) )
@ -29,7 +27,7 @@ var (
) )
// SendUserInvites sends emails to the users inviting them to join an account. // SendUserInvites sends emails to the users inviting them to join an account.
func SendUserInvites(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, resetUrl func(string) string, notify notify.Email, req SendUserInvitesRequest, secretKey string, now time.Time) ([]string, error) { func (repo *Repository) SendUserInvites(ctx context.Context, claims auth.Claims, req SendUserInvitesRequest, now time.Time) ([]string, error) {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user_account.invite.SendUserInvites") span, ctx := tracer.StartSpanFromContext(ctx, "internal.user_account.invite.SendUserInvites")
defer span.Finish() defer span.Finish()
@ -42,7 +40,7 @@ func SendUserInvites(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, r
} }
// Ensure the claims can modify the account specified in the request. // Ensure the claims can modify the account specified in the request.
err = user_account.CanModifyAccount(ctx, claims, dbConn, req.AccountID) err = account.CanModifyAccount(ctx, claims, repo.DbConn, req.AccountID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -51,7 +49,7 @@ func SendUserInvites(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, r
emailUserIDs := make(map[string]string) emailUserIDs := make(map[string]string)
{ {
// Find all users without passing in claims to search all users. // Find all users without passing in claims to search all users.
users, err := user.Find(ctx, auth.Claims{}, dbConn, user.UserFindRequest{ users, err := repo.User.Find(ctx, auth.Claims{}, user.UserFindRequest{
Where: fmt.Sprintf("email in ('%s')", Where: fmt.Sprintf("email in ('%s')",
strings.Join(req.Emails, "','")), strings.Join(req.Emails, "','")),
}) })
@ -72,7 +70,7 @@ func SendUserInvites(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, r
args = append(args, userID) args = append(args, userID)
} }
userAccs, err := user_account.Find(ctx, claims, dbConn, user_account.UserAccountFindRequest{ userAccs, err := repo.UserAccount.Find(ctx, claims, user_account.UserAccountFindRequest{
Where: fmt.Sprintf("user_id in ('%s') and status = '%s'", Where: fmt.Sprintf("user_id in ('%s') and status = '%s'",
strings.Join(args, "','"), strings.Join(args, "','"),
user_account.UserAccountStatus_Active.String()), user_account.UserAccountStatus_Active.String()),
@ -99,7 +97,7 @@ func SendUserInvites(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, r
continue continue
} }
u, err := user.CreateInvite(ctx, claims, dbConn, user.UserCreateInviteRequest{ u, err := repo.User.CreateInvite(ctx, claims, user.UserCreateInviteRequest{
Email: email, Email: email,
}, now) }, now)
if err != nil { if err != nil {
@ -118,7 +116,7 @@ func SendUserInvites(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, r
} }
status := user_account.UserAccountStatus_Invited status := user_account.UserAccountStatus_Invited
_, err = user_account.Create(ctx, claims, dbConn, user_account.UserAccountCreateRequest{ _, err = repo.UserAccount.Create(ctx, claims, user_account.UserAccountCreateRequest{
UserID: userID, UserID: userID,
AccountID: req.AccountID, AccountID: req.AccountID,
Roles: req.Roles, Roles: req.Roles,
@ -133,12 +131,12 @@ func SendUserInvites(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, r
req.TTL = time.Minute * 90 req.TTL = time.Minute * 90
} }
fromUser, err := user.ReadByID(ctx, claims, dbConn, req.UserID) fromUser, err := repo.User.ReadByID(ctx, claims, req.UserID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
account, err := account.ReadByID(ctx, claims, dbConn, req.AccountID) account, err := repo.Account.ReadByID(ctx, claims, req.AccountID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -151,7 +149,7 @@ func SendUserInvites(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, r
var inviteHashes []string var inviteHashes []string
for email, userID := range emailUserIDs { for email, userID := range emailUserIDs {
hash, err := NewInviteHash(ctx, secretKey, userID, req.AccountID, requestIp, req.TTL, now) hash, err := NewInviteHash(ctx, repo.secretKey, userID, req.AccountID, requestIp, req.TTL, now)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -159,13 +157,13 @@ func SendUserInvites(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, r
data := map[string]interface{}{ data := map[string]interface{}{
"FromUser": fromUser.Response(ctx), "FromUser": fromUser.Response(ctx),
"Account": account.Response(ctx), "Account": account.Response(ctx),
"Url": resetUrl(hash), "Url": repo.ResetUrl(hash),
"Minutes": req.TTL.Minutes(), "Minutes": req.TTL.Minutes(),
} }
subject := fmt.Sprintf("%s %s has invited you to %s", fromUser.FirstName, fromUser.LastName, account.Name) subject := fmt.Sprintf("%s %s has invited you to %s", fromUser.FirstName, fromUser.LastName, account.Name)
err = notify.Send(ctx, email, subject, "user_invite", data) err = repo.Notify.Send(ctx, email, subject, "user_invite", data)
if err != nil { if err != nil {
err = errors.WithMessagef(err, "Send invite to %s failed.", email) err = errors.WithMessagef(err, "Send invite to %s failed.", email)
return nil, err return nil, err
@ -178,7 +176,7 @@ func SendUserInvites(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, r
} }
// AcceptInvite updates the user using the provided invite hash. // AcceptInvite updates the user using the provided invite hash.
func AcceptInvite(ctx context.Context, dbConn *sqlx.DB, req AcceptInviteRequest, secretKey string, now time.Time) (*user_account.UserAccount, error) { func (repo *Repository) AcceptInvite(ctx context.Context, req AcceptInviteRequest, now time.Time) (*user_account.UserAccount, error) {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user_account.invite.AcceptInvite") span, ctx := tracer.StartSpanFromContext(ctx, "internal.user_account.invite.AcceptInvite")
defer span.Finish() defer span.Finish()
@ -190,25 +188,25 @@ func AcceptInvite(ctx context.Context, dbConn *sqlx.DB, req AcceptInviteRequest,
return nil, err return nil, err
} }
hash, err := ParseInviteHash(ctx, req.InviteHash, secretKey, now) hash, err := ParseInviteHash(ctx, req.InviteHash, repo.secretKey, now)
if err != nil { if err != nil {
return nil, err return nil, err
} }
u, err := user.Read(ctx, auth.Claims{}, dbConn, u, err := repo.User.Read(ctx, auth.Claims{},
user.UserReadRequest{ID: hash.UserID, IncludeArchived: true}) user.UserReadRequest{ID: hash.UserID, IncludeArchived: true})
if err != nil { if err != nil {
return nil, err return nil, err
} }
if u.ArchivedAt != nil && !u.ArchivedAt.Time.IsZero() { if u.ArchivedAt != nil && !u.ArchivedAt.Time.IsZero() {
err = user.Restore(ctx, auth.Claims{}, dbConn, user.UserRestoreRequest{ID: hash.UserID}, now) err = repo.User.Restore(ctx, auth.Claims{}, user.UserRestoreRequest{ID: hash.UserID}, now)
if err != nil { if err != nil {
return nil, err return nil, err
} }
} }
usrAcc, err := user_account.Read(ctx, auth.Claims{}, dbConn, user_account.UserAccountReadRequest{ usrAcc, err := repo.UserAccount.Read(ctx, auth.Claims{}, user_account.UserAccountReadRequest{
UserID: hash.UserID, UserID: hash.UserID,
AccountID: hash.AccountID, AccountID: hash.AccountID,
}) })
@ -230,7 +228,7 @@ func AcceptInvite(ctx context.Context, dbConn *sqlx.DB, req AcceptInviteRequest,
if len(u.PasswordHash) > 0 { if len(u.PasswordHash) > 0 {
usrAcc.Status = user_account.UserAccountStatus_Active usrAcc.Status = user_account.UserAccountStatus_Active
err = user_account.Update(ctx, auth.Claims{}, dbConn, user_account.UserAccountUpdateRequest{ err = repo.UserAccount.Update(ctx, auth.Claims{}, user_account.UserAccountUpdateRequest{
UserID: usrAcc.UserID, UserID: usrAcc.UserID,
AccountID: usrAcc.AccountID, AccountID: usrAcc.AccountID,
Status: &usrAcc.Status, Status: &usrAcc.Status,
@ -244,7 +242,7 @@ func AcceptInvite(ctx context.Context, dbConn *sqlx.DB, req AcceptInviteRequest,
} }
// AcceptInviteUser updates the user using the provided invite hash. // AcceptInviteUser updates the user using the provided invite hash.
func AcceptInviteUser(ctx context.Context, dbConn *sqlx.DB, req AcceptInviteUserRequest, secretKey string, now time.Time) (*user_account.UserAccount, error) { func (repo *Repository) AcceptInviteUser(ctx context.Context, req AcceptInviteUserRequest, now time.Time) (*user_account.UserAccount, error) {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user_account.invite.AcceptInviteUser") span, ctx := tracer.StartSpanFromContext(ctx, "internal.user_account.invite.AcceptInviteUser")
defer span.Finish() defer span.Finish()
@ -256,25 +254,25 @@ func AcceptInviteUser(ctx context.Context, dbConn *sqlx.DB, req AcceptInviteUser
return nil, err return nil, err
} }
hash, err := ParseInviteHash(ctx, req.InviteHash, secretKey, now) hash, err := ParseInviteHash(ctx, req.InviteHash, repo.secretKey, now)
if err != nil { if err != nil {
return nil, err return nil, err
} }
u, err := user.Read(ctx, auth.Claims{}, dbConn, u, err := repo.User.Read(ctx, auth.Claims{},
user.UserReadRequest{ID: hash.UserID, IncludeArchived: true}) user.UserReadRequest{ID: hash.UserID, IncludeArchived: true})
if err != nil { if err != nil {
return nil, err return nil, err
} }
if u.ArchivedAt != nil && !u.ArchivedAt.Time.IsZero() { if u.ArchivedAt != nil && !u.ArchivedAt.Time.IsZero() {
err = user.Restore(ctx, auth.Claims{}, dbConn, user.UserRestoreRequest{ID: hash.UserID}, now) err = repo.User.Restore(ctx, auth.Claims{}, user.UserRestoreRequest{ID: hash.UserID}, now)
if err != nil { if err != nil {
return nil, err return nil, err
} }
} }
usrAcc, err := user_account.Read(ctx, auth.Claims{}, dbConn, user_account.UserAccountReadRequest{ usrAcc, err := repo.UserAccount.Read(ctx, auth.Claims{}, user_account.UserAccountReadRequest{
UserID: hash.UserID, UserID: hash.UserID,
AccountID: hash.AccountID, AccountID: hash.AccountID,
}) })
@ -293,7 +291,7 @@ func AcceptInviteUser(ctx context.Context, dbConn *sqlx.DB, req AcceptInviteUser
// These three calls, user.Update, user.UpdatePassword, and user_account.Update // These three calls, user.Update, user.UpdatePassword, and user_account.Update
// should probably be in a transaction! // should probably be in a transaction!
err = user.Update(ctx, auth.Claims{}, dbConn, user.UserUpdateRequest{ err = repo.User.Update(ctx, auth.Claims{}, user.UserUpdateRequest{
ID: hash.UserID, ID: hash.UserID,
Email: &req.Email, Email: &req.Email,
FirstName: &req.FirstName, FirstName: &req.FirstName,
@ -304,7 +302,7 @@ func AcceptInviteUser(ctx context.Context, dbConn *sqlx.DB, req AcceptInviteUser
return nil, err return nil, err
} }
err = user.UpdatePassword(ctx, auth.Claims{}, dbConn, user.UserUpdatePasswordRequest{ err = repo.User.UpdatePassword(ctx, auth.Claims{}, user.UserUpdatePasswordRequest{
ID: hash.UserID, ID: hash.UserID,
Password: req.Password, Password: req.Password,
PasswordConfirm: req.PasswordConfirm, PasswordConfirm: req.PasswordConfirm,
@ -314,7 +312,7 @@ func AcceptInviteUser(ctx context.Context, dbConn *sqlx.DB, req AcceptInviteUser
} }
usrAcc.Status = user_account.UserAccountStatus_Active usrAcc.Status = user_account.UserAccountStatus_Active
err = user_account.Update(ctx, auth.Claims{}, dbConn, user_account.UserAccountUpdateRequest{ err = repo.UserAccount.Update(ctx, auth.Claims{}, user_account.UserAccountUpdateRequest{
UserID: usrAcc.UserID, UserID: usrAcc.UserID,
AccountID: usrAcc.AccountID, AccountID: usrAcc.AccountID,
Status: &usrAcc.Status, Status: &usrAcc.Status,

View File

@ -1,7 +1,6 @@
package invite package invite
import ( import (
"geeks-accelerator/oss/saas-starter-kit/internal/platform/web/webcontext"
"os" "os"
"strings" "strings"
"testing" "testing"
@ -11,6 +10,7 @@ import (
"geeks-accelerator/oss/saas-starter-kit/internal/platform/auth" "geeks-accelerator/oss/saas-starter-kit/internal/platform/auth"
"geeks-accelerator/oss/saas-starter-kit/internal/platform/notify" "geeks-accelerator/oss/saas-starter-kit/internal/platform/notify"
"geeks-accelerator/oss/saas-starter-kit/internal/platform/tests" "geeks-accelerator/oss/saas-starter-kit/internal/platform/tests"
"geeks-accelerator/oss/saas-starter-kit/internal/platform/web/webcontext"
"geeks-accelerator/oss/saas-starter-kit/internal/user" "geeks-accelerator/oss/saas-starter-kit/internal/user"
"geeks-accelerator/oss/saas-starter-kit/internal/user_account" "geeks-accelerator/oss/saas-starter-kit/internal/user_account"
"github.com/dgrijalva/jwt-go" "github.com/dgrijalva/jwt-go"
@ -18,7 +18,10 @@ import (
"github.com/pkg/errors" "github.com/pkg/errors"
) )
var test *tests.Test var (
test *tests.Test
repo *Repository
)
// TestMain is the entry point for testing. // TestMain is the entry point for testing.
func TestMain(m *testing.M) { func TestMain(m *testing.M) {
@ -28,6 +31,20 @@ func TestMain(m *testing.M) {
func testMain(m *testing.M) int { func testMain(m *testing.M) int {
test = tests.New() test = tests.New()
defer test.TearDown() defer test.TearDown()
userRepo := user.MockRepository(test.MasterDB)
userAccRepo := user_account.NewRepository(test.MasterDB)
accRepo := account.NewRepository(test.MasterDB)
// Mock the methods needed to make an invite.
resetUrl := func(string) string {
return ""
}
notify := &notify.MockEmail{}
secretKey := "6368616e676520746869732070613434"
repo = NewRepository(test.MasterDB, userRepo, userAccRepo, accRepo, resetUrl, notify, secretKey)
return m.Run() return m.Run()
} }
@ -42,7 +59,7 @@ func TestSendUserInvites(t *testing.T) {
// Create a new user for testing. // Create a new user for testing.
initPass := uuid.NewRandom().String() initPass := uuid.NewRandom().String()
u, err := user.Create(ctx, auth.Claims{}, test.MasterDB, user.UserCreateRequest{ u, err := repo.User.Create(ctx, auth.Claims{}, user.UserCreateRequest{
FirstName: "Lee", FirstName: "Lee",
LastName: "Brown", LastName: "Brown",
Email: uuid.NewRandom().String() + "@geeksinthewoods.com", Email: uuid.NewRandom().String() + "@geeksinthewoods.com",
@ -54,7 +71,7 @@ func TestSendUserInvites(t *testing.T) {
t.Fatalf("\t%s\tCreate user failed.", tests.Failed) t.Fatalf("\t%s\tCreate user failed.", tests.Failed)
} }
a, err := account.Create(ctx, auth.Claims{}, test.MasterDB, account.AccountCreateRequest{ a, err := repo.Account.Create(ctx, auth.Claims{}, account.AccountCreateRequest{
Name: uuid.NewRandom().String(), Name: uuid.NewRandom().String(),
Address1: "101 E Main", Address1: "101 E Main",
City: "Valdez", City: "Valdez",
@ -68,7 +85,7 @@ func TestSendUserInvites(t *testing.T) {
} }
uRoles := []user_account.UserAccountRole{user_account.UserAccountRole_Admin} uRoles := []user_account.UserAccountRole{user_account.UserAccountRole_Admin}
_, err = user_account.Create(ctx, auth.Claims{}, test.MasterDB, user_account.UserAccountCreateRequest{ _, err = repo.UserAccount.Create(ctx, auth.Claims{}, user_account.UserAccountCreateRequest{
UserID: u.ID, UserID: u.ID,
AccountID: a.ID, AccountID: a.ID,
Roles: uRoles, Roles: uRoles,
@ -91,21 +108,13 @@ func TestSendUserInvites(t *testing.T) {
claims.Roles = append(claims.Roles, r.String()) claims.Roles = append(claims.Roles, r.String())
} }
// Mock the methods needed to make a password reset.
resetUrl := func(string) string {
return ""
}
notify := &notify.MockEmail{}
secretKey := "6368616e676520746869732070617373"
// Ensure validation is working by trying ResetPassword with an empty request. // Ensure validation is working by trying ResetPassword with an empty request.
{ {
expectedErr := errors.New("Key: 'SendUserInvitesRequest.account_id' Error:Field validation for 'account_id' failed on the 'required' tag\n" + expectedErr := errors.New("Key: 'SendUserInvitesRequest.account_id' Error:Field validation for 'account_id' failed on the 'required' tag\n" +
"Key: 'SendUserInvitesRequest.user_id' Error:Field validation for 'user_id' failed on the 'required' tag\n" + "Key: 'SendUserInvitesRequest.user_id' Error:Field validation for 'user_id' failed on the 'required' tag\n" +
"Key: 'SendUserInvitesRequest.emails' Error:Field validation for 'emails' failed on the 'required' tag\n" + "Key: 'SendUserInvitesRequest.emails' Error:Field validation for 'emails' failed on the 'required' tag\n" +
"Key: 'SendUserInvitesRequest.roles' Error:Field validation for 'roles' failed on the 'required' tag") "Key: 'SendUserInvitesRequest.roles' Error:Field validation for 'roles' failed on the 'required' tag")
_, err = SendUserInvites(ctx, claims, test.MasterDB, resetUrl, notify, SendUserInvitesRequest{}, secretKey, now) _, err = repo.SendUserInvites(ctx, claims, SendUserInvitesRequest{}, now)
if err == nil { if err == nil {
t.Logf("\t\tWant: %+v", expectedErr) t.Logf("\t\tWant: %+v", expectedErr)
t.Fatalf("\t%s\tInviteUsers failed.", tests.Failed) t.Fatalf("\t%s\tInviteUsers failed.", tests.Failed)
@ -129,13 +138,13 @@ func TestSendUserInvites(t *testing.T) {
} }
// Make the reset password request. // Make the reset password request.
inviteHashes, err := SendUserInvites(ctx, claims, test.MasterDB, resetUrl, notify, SendUserInvitesRequest{ inviteHashes, err := repo.SendUserInvites(ctx, claims, SendUserInvitesRequest{
UserID: u.ID, UserID: u.ID,
AccountID: a.ID, AccountID: a.ID,
Emails: inviteEmails, Emails: inviteEmails,
Roles: []user_account.UserAccountRole{user_account.UserAccountRole_User}, Roles: []user_account.UserAccountRole{user_account.UserAccountRole_User},
TTL: ttl, TTL: ttl,
}, secretKey, now) }, now)
if err != nil { if err != nil {
t.Log("\t\tGot :", err) t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tInviteUsers failed.", tests.Failed) t.Fatalf("\t%s\tInviteUsers failed.", tests.Failed)
@ -154,7 +163,7 @@ func TestSendUserInvites(t *testing.T) {
"Key: 'AcceptInviteUserRequest.last_name' Error:Field validation for 'last_name' failed on the 'required' tag\n" + "Key: 'AcceptInviteUserRequest.last_name' Error:Field validation for 'last_name' failed on the 'required' tag\n" +
"Key: 'AcceptInviteUserRequest.password' Error:Field validation for 'password' failed on the 'required' tag\n" + "Key: 'AcceptInviteUserRequest.password' Error:Field validation for 'password' failed on the 'required' tag\n" +
"Key: 'AcceptInviteUserRequest.password_confirm' Error:Field validation for 'password_confirm' failed on the 'required' tag") "Key: 'AcceptInviteUserRequest.password_confirm' Error:Field validation for 'password_confirm' failed on the 'required' tag")
_, err = AcceptInviteUser(ctx, test.MasterDB, AcceptInviteUserRequest{}, secretKey, now) _, err = repo.AcceptInviteUser(ctx, AcceptInviteUserRequest{}, now)
if err == nil { if err == nil {
t.Logf("\t\tWant: %+v", expectedErr) t.Logf("\t\tWant: %+v", expectedErr)
t.Fatalf("\t%s\tResetConfirm failed.", tests.Failed) t.Fatalf("\t%s\tResetConfirm failed.", tests.Failed)
@ -174,14 +183,14 @@ func TestSendUserInvites(t *testing.T) {
// Ensure the TTL is enforced. // Ensure the TTL is enforced.
{ {
newPass := uuid.NewRandom().String() newPass := uuid.NewRandom().String()
_, err = AcceptInviteUser(ctx, test.MasterDB, AcceptInviteUserRequest{ _, err = repo.AcceptInviteUser(ctx, AcceptInviteUserRequest{
InviteHash: inviteHashes[0], InviteHash: inviteHashes[0],
Email: inviteEmails[0], Email: inviteEmails[0],
FirstName: "Foo", FirstName: "Foo",
LastName: "Bar", LastName: "Bar",
Password: newPass, Password: newPass,
PasswordConfirm: newPass, PasswordConfirm: newPass,
}, secretKey, now.UTC().Add(ttl*2)) }, now.UTC().Add(ttl*2))
if errors.Cause(err) != ErrInviteExpired { if errors.Cause(err) != ErrInviteExpired {
t.Logf("\t\tGot : %+v", errors.Cause(err)) t.Logf("\t\tGot : %+v", errors.Cause(err))
t.Logf("\t\tWant: %+v", ErrInviteExpired) t.Logf("\t\tWant: %+v", ErrInviteExpired)
@ -194,14 +203,14 @@ func TestSendUserInvites(t *testing.T) {
for idx, inviteHash := range inviteHashes { for idx, inviteHash := range inviteHashes {
newPass := uuid.NewRandom().String() newPass := uuid.NewRandom().String()
hash, err := AcceptInviteUser(ctx, test.MasterDB, AcceptInviteUserRequest{ hash, err := repo.AcceptInviteUser(ctx, AcceptInviteUserRequest{
InviteHash: inviteHash, InviteHash: inviteHash,
Email: inviteEmails[idx], Email: inviteEmails[idx],
FirstName: "Foo", FirstName: "Foo",
LastName: "Bar", LastName: "Bar",
Password: newPass, Password: newPass,
PasswordConfirm: newPass, PasswordConfirm: newPass,
}, secretKey, now) }, now)
if err != nil { if err != nil {
t.Log("\t\tGot :", err) t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tInviteAccept failed.", tests.Failed) t.Fatalf("\t%s\tInviteAccept failed.", tests.Failed)
@ -227,14 +236,14 @@ func TestSendUserInvites(t *testing.T) {
// Ensure the reset hash does not work after its used. // Ensure the reset hash does not work after its used.
{ {
newPass := uuid.NewRandom().String() newPass := uuid.NewRandom().String()
_, err = AcceptInviteUser(ctx, test.MasterDB, AcceptInviteUserRequest{ _, err = repo.AcceptInviteUser(ctx, AcceptInviteUserRequest{
InviteHash: inviteHashes[0], InviteHash: inviteHashes[0],
Email: inviteEmails[0], Email: inviteEmails[0],
FirstName: "Foo", FirstName: "Foo",
LastName: "Bar", LastName: "Bar",
Password: newPass, Password: newPass,
PasswordConfirm: newPass, PasswordConfirm: newPass,
}, secretKey, now) }, now)
if errors.Cause(err) != ErrUserAccountActive { if errors.Cause(err) != ErrUserAccountActive {
t.Logf("\t\tGot : %+v", errors.Cause(err)) t.Logf("\t\tGot : %+v", errors.Cause(err))
t.Logf("\t\tWant: %+v", ErrUserAccountActive) t.Logf("\t\tWant: %+v", ErrUserAccountActive)

View File

@ -6,12 +6,41 @@ import (
"strings" "strings"
"time" "time"
"geeks-accelerator/oss/saas-starter-kit/internal/account"
"geeks-accelerator/oss/saas-starter-kit/internal/platform/notify"
"geeks-accelerator/oss/saas-starter-kit/internal/platform/web/webcontext" "geeks-accelerator/oss/saas-starter-kit/internal/platform/web/webcontext"
"geeks-accelerator/oss/saas-starter-kit/internal/user"
"geeks-accelerator/oss/saas-starter-kit/internal/user_account" "geeks-accelerator/oss/saas-starter-kit/internal/user_account"
"github.com/jmoiron/sqlx"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/sudo-suhas/symcrypto" "github.com/sudo-suhas/symcrypto"
) )
// Repository defines the required dependencies for User Invite.
type Repository struct {
DbConn *sqlx.DB
User *user.Repository
UserAccount *user_account.Repository
Account *account.Repository
ResetUrl func(string) string
Notify notify.Email
secretKey string
}
// NewRepository creates a new Repository that defines dependencies for User Invite.
func NewRepository(db *sqlx.DB, user *user.Repository, userAccount *user_account.Repository, account *account.Repository,
resetUrl func(string) string, notify notify.Email, secretKey string) *Repository {
return &Repository{
DbConn: db,
User: user,
UserAccount: userAccount,
Account: account,
ResetUrl: resetUrl,
Notify: notify,
secretKey: secretKey,
}
}
// SendUserInvitesRequest defines the data needed to make an invite request. // SendUserInvitesRequest defines the data needed to make an invite request.
type SendUserInvitesRequest struct { type SendUserInvitesRequest struct {
AccountID string `json:"account_id" validate:"required,uuid" example:"c4653bf9-5978-48b7-89c5-95704aebb7e2"` AccountID string `json:"account_id" validate:"required,uuid" example:"c4653bf9-5978-48b7-89c5-95704aebb7e2"`

View File

@ -2,13 +2,13 @@ package user_account
import ( import (
"context" "context"
"database/sql/driver"
"github.com/jmoiron/sqlx"
"strings" "strings"
"time" "time"
"database/sql/driver"
"geeks-accelerator/oss/saas-starter-kit/internal/platform/auth" "geeks-accelerator/oss/saas-starter-kit/internal/platform/auth"
"geeks-accelerator/oss/saas-starter-kit/internal/platform/web" "geeks-accelerator/oss/saas-starter-kit/internal/platform/web"
"github.com/jmoiron/sqlx"
"github.com/lib/pq" "github.com/lib/pq"
"github.com/pkg/errors" "github.com/pkg/errors"
"gopkg.in/go-playground/validator.v9" "gopkg.in/go-playground/validator.v9"

View File

@ -3,12 +3,12 @@ package user_account
import ( import (
"context" "context"
"database/sql" "database/sql"
"geeks-accelerator/oss/saas-starter-kit/internal/user"
"time" "time"
"geeks-accelerator/oss/saas-starter-kit/internal/account" "geeks-accelerator/oss/saas-starter-kit/internal/account"
"geeks-accelerator/oss/saas-starter-kit/internal/platform/auth" "geeks-accelerator/oss/saas-starter-kit/internal/platform/auth"
"geeks-accelerator/oss/saas-starter-kit/internal/platform/web/webcontext" "geeks-accelerator/oss/saas-starter-kit/internal/platform/web/webcontext"
"geeks-accelerator/oss/saas-starter-kit/internal/user"
"github.com/huandu/go-sqlbuilder" "github.com/huandu/go-sqlbuilder"
"github.com/jmoiron/sqlx" "github.com/jmoiron/sqlx"
"github.com/pborman/uuid" "github.com/pborman/uuid"
@ -50,13 +50,13 @@ func mapRowsToUserAccount(rows *sql.Rows) (*UserAccount, error) {
// CanReadAccount determines if claims has the authority to access the specified user account by user ID. // CanReadAccount determines if claims has the authority to access the specified user account by user ID.
func (repo *Repository) CanReadAccount(ctx context.Context, claims auth.Claims, accountID string) error { func (repo *Repository) CanReadAccount(ctx context.Context, claims auth.Claims, accountID string) error {
err := account.CanReadAccount(ctx, claims, accountID) err := account.CanReadAccount(ctx, claims, repo.DbConn, accountID)
return mapAccountError(err) return mapAccountError(err)
} }
// CanModifyAccount determines if claims has the authority to modify the specified user ID. // CanModifyAccount determines if claims has the authority to modify the specified user ID.
func (repo *Repository) CanModifyAccount(ctx context.Context, claims auth.Claims, accountID string) error { func (repo *Repository) CanModifyAccount(ctx context.Context, claims auth.Claims, accountID string) error {
err := account.CanModifyAccount(ctx, claims, accountID) err := account.CanModifyAccount(ctx, claims, repo.DbConn, accountID)
return mapAccountError(err) return mapAccountError(err)
} }
@ -131,9 +131,9 @@ func findRequestQuery(req UserAccountFindRequest) (*sqlbuilder.SelectBuilder, []
} }
// Find gets all the user accounts from the database based on the request params. // Find gets all the user accounts from the database based on the request params.
func Find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserAccountFindRequest) (UserAccounts, error) { func (repo *Repository) Find(ctx context.Context, claims auth.Claims, req UserAccountFindRequest) (UserAccounts, error) {
query, args := findRequestQuery(req) query, args := findRequestQuery(req)
return find(ctx, claims, dbConn, query, args, req.IncludeArchived) return find(ctx, claims, repo.DbConn, query, args, req.IncludeArchived)
} }
// Find gets all the user accounts from the database based on the select query // Find gets all the user accounts from the database based on the select query
@ -180,7 +180,7 @@ func find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, query *sqlbu
} }
// Retrieve gets the specified user from the database. // Retrieve gets the specified user from the database.
func FindByUserID(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, userID string, includedArchived bool) (UserAccounts, error) { func (repo *Repository) FindByUserID(ctx context.Context, claims auth.Claims, userID string, includedArchived bool) (UserAccounts, error) {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user_account.FindByUserID") span, ctx := tracer.StartSpanFromContext(ctx, "internal.user_account.FindByUserID")
defer span.Finish() defer span.Finish()
@ -190,7 +190,7 @@ func FindByUserID(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, user
query.OrderBy("created_at") query.OrderBy("created_at")
// Execute the find accounts method. // Execute the find accounts method.
res, err := find(ctx, claims, dbConn, query, []interface{}{}, includedArchived) res, err := find(ctx, claims, repo.DbConn, query, []interface{}{}, includedArchived)
if err != nil { if err != nil {
return nil, err return nil, err
} else if res == nil || len(res) == 0 { } else if res == nil || len(res) == 0 {
@ -202,7 +202,7 @@ func FindByUserID(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, user
} }
// Create a user account for a given user with specified roles. // Create a user account for a given user with specified roles.
func Create(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserAccountCreateRequest, now time.Time) (*UserAccount, error) { func (repo *Repository) Create(ctx context.Context, claims auth.Claims, req UserAccountCreateRequest, now time.Time) (*UserAccount, error) {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user_account.Create") span, ctx := tracer.StartSpanFromContext(ctx, "internal.user_account.Create")
defer span.Finish() defer span.Finish()
@ -214,7 +214,7 @@ func Create(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserAc
} }
// Ensure the claims can modify the account specified in the request. // Ensure the claims can modify the account specified in the request.
err = CanModifyAccount(ctx, claims, dbConn, req.AccountID) err = repo.CanModifyAccount(ctx, claims, req.AccountID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -237,7 +237,7 @@ func Create(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserAc
existQuery.Equal("account_id", req.AccountID), existQuery.Equal("account_id", req.AccountID),
existQuery.Equal("user_id", req.UserID), existQuery.Equal("user_id", req.UserID),
)) ))
existing, err := find(ctx, claims, dbConn, existQuery, []interface{}{}, true) existing, err := find(ctx, claims, repo.DbConn, existQuery, []interface{}{}, true)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -251,7 +251,7 @@ func Create(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserAc
Roles: &req.Roles, Roles: &req.Roles,
unArchive: true, unArchive: true,
} }
err = Update(ctx, claims, dbConn, upReq, now) err = repo.Update(ctx, claims, upReq, now)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -285,8 +285,8 @@ func Create(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserAc
// Execute the query with the provided context. // Execute the query with the provided context.
sql, args := query.Build() sql, args := query.Build()
sql = dbConn.Rebind(sql) sql = repo.DbConn.Rebind(sql)
_, err = dbConn.ExecContext(ctx, sql, args...) _, err = repo.DbConn.ExecContext(ctx, sql, args...)
if err != nil { if err != nil {
err = errors.Wrapf(err, "query - %s", query.String()) err = errors.Wrapf(err, "query - %s", query.String())
err = errors.WithMessagef(err, "add account %s to user %s failed", req.AccountID, req.UserID) err = errors.WithMessagef(err, "add account %s to user %s failed", req.AccountID, req.UserID)
@ -298,7 +298,7 @@ func Create(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserAc
} }
// Read gets the specified user account from the database. // Read gets the specified user account from the database.
func Read(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserAccountReadRequest) (*UserAccount, error) { func (repo *Repository) Read(ctx context.Context, claims auth.Claims, req UserAccountReadRequest) (*UserAccount, error) {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user_account.Read") span, ctx := tracer.StartSpanFromContext(ctx, "internal.user_account.Read")
defer span.Finish() defer span.Finish()
@ -315,7 +315,7 @@ func Read(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserAcco
query.Equal("user_id", req.UserID), query.Equal("user_id", req.UserID),
query.Equal("account_id", req.AccountID))) query.Equal("account_id", req.AccountID)))
res, err := find(ctx, claims, dbConn, query, []interface{}{}, req.IncludeArchived) res, err := find(ctx, claims, repo.DbConn, query, []interface{}{}, req.IncludeArchived)
if err != nil { if err != nil {
return nil, err return nil, err
} else if res == nil || len(res) == 0 { } else if res == nil || len(res) == 0 {
@ -328,7 +328,7 @@ func Read(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserAcco
} }
// Update replaces a user account in the database. // Update replaces a user account in the database.
func Update(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserAccountUpdateRequest, now time.Time) error { func (repo *Repository) Update(ctx context.Context, claims auth.Claims, req UserAccountUpdateRequest, now time.Time) error {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user_account.Update") span, ctx := tracer.StartSpanFromContext(ctx, "internal.user_account.Update")
defer span.Finish() defer span.Finish()
@ -340,7 +340,7 @@ func Update(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserAc
} }
// Ensure the claims can modify the user specified in the request. // Ensure the claims can modify the user specified in the request.
err = CanModifyAccount(ctx, claims, dbConn, req.AccountID) err = repo.CanModifyAccount(ctx, claims, req.AccountID)
if err != nil { if err != nil {
return err return err
} }
@ -389,8 +389,8 @@ func Update(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserAc
// Execute the query with the provided context. // Execute the query with the provided context.
sql, args := query.Build() sql, args := query.Build()
sql = dbConn.Rebind(sql) sql = repo.DbConn.Rebind(sql)
_, err = dbConn.ExecContext(ctx, sql, args...) _, err = repo.DbConn.ExecContext(ctx, sql, args...)
if err != nil { if err != nil {
err = errors.Wrapf(err, "query - %s", query.String()) err = errors.Wrapf(err, "query - %s", query.String())
err = errors.WithMessagef(err, "update account %s for user %s failed", req.AccountID, req.UserID) err = errors.WithMessagef(err, "update account %s for user %s failed", req.AccountID, req.UserID)
@ -401,7 +401,7 @@ func Update(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserAc
} }
// Archive soft deleted the user account from the database. // Archive soft deleted the user account from the database.
func Archive(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserAccountArchiveRequest, now time.Time) error { func (repo *Repository) Archive(ctx context.Context, claims auth.Claims, req UserAccountArchiveRequest, now time.Time) error {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user_account.Archive") span, ctx := tracer.StartSpanFromContext(ctx, "internal.user_account.Archive")
defer span.Finish() defer span.Finish()
@ -413,7 +413,7 @@ func Archive(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserA
} }
// Ensure the claims can modify the user specified in the request. // Ensure the claims can modify the user specified in the request.
err = CanModifyAccount(ctx, claims, dbConn, req.AccountID) err = repo.CanModifyAccount(ctx, claims, req.AccountID)
if err != nil { if err != nil {
return err return err
} }
@ -441,8 +441,8 @@ func Archive(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserA
// Execute the query with the provided context. // Execute the query with the provided context.
sql, args := query.Build() sql, args := query.Build()
sql = dbConn.Rebind(sql) sql = repo.DbConn.Rebind(sql)
_, err = dbConn.ExecContext(ctx, sql, args...) _, err = repo.DbConn.ExecContext(ctx, sql, args...)
if err != nil { if err != nil {
err = errors.Wrapf(err, "query - %s", query.String()) err = errors.Wrapf(err, "query - %s", query.String())
err = errors.WithMessagef(err, "archive account %s from user %s failed", req.AccountID, req.UserID) err = errors.WithMessagef(err, "archive account %s from user %s failed", req.AccountID, req.UserID)
@ -453,7 +453,7 @@ func Archive(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserA
} }
// Delete removes a user account from the database. // Delete removes a user account from the database.
func Delete(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserAccountDeleteRequest) error { func (repo *Repository) Delete(ctx context.Context, claims auth.Claims, req UserAccountDeleteRequest) error {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user_account.Delete") span, ctx := tracer.StartSpanFromContext(ctx, "internal.user_account.Delete")
defer span.Finish() defer span.Finish()
@ -465,7 +465,7 @@ func Delete(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserAc
} }
// Ensure the claims can modify the user specified in the request. // Ensure the claims can modify the user specified in the request.
err = CanModifyAccount(ctx, claims, dbConn, req.AccountID) err = repo.CanModifyAccount(ctx, claims, req.AccountID)
if err != nil { if err != nil {
return err return err
} }
@ -480,8 +480,8 @@ func Delete(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserAc
// Execute the query with the provided context. // Execute the query with the provided context.
sql, args := query.Build() sql, args := query.Build()
sql = dbConn.Rebind(sql) sql = repo.DbConn.Rebind(sql)
_, err = dbConn.ExecContext(ctx, sql, args...) _, err = repo.DbConn.ExecContext(ctx, sql, args...)
if err != nil { if err != nil {
err = errors.Wrapf(err, "query - %s", query.String()) err = errors.Wrapf(err, "query - %s", query.String())
err = errors.WithMessagef(err, "delete account %s for user %s failed", req.AccountID, req.UserID) err = errors.WithMessagef(err, "delete account %s for user %s failed", req.AccountID, req.UserID)
@ -509,6 +509,10 @@ func MockUserAccount(ctx context.Context, dbConn *sqlx.DB, now time.Time, roles
return nil, err return nil, err
} }
repo := &Repository{
DbConn: dbConn,
}
status := UserAccountStatus_Active status := UserAccountStatus_Active
req := UserAccountCreateRequest{ req := UserAccountCreateRequest{
@ -517,7 +521,7 @@ func MockUserAccount(ctx context.Context, dbConn *sqlx.DB, now time.Time, roles
Status: &status, Status: &status,
Roles: roles, Roles: roles,
} }
ua, err := Create(ctx, auth.Claims{}, dbConn, req, now) ua, err := repo.Create(ctx, auth.Claims{}, req, now)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -1,7 +1,6 @@
package user_account package user_account
import ( import (
"github.com/lib/pq"
"math/rand" "math/rand"
"os" "os"
"strings" "strings"
@ -13,6 +12,7 @@ import (
"github.com/dgrijalva/jwt-go" "github.com/dgrijalva/jwt-go"
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
"github.com/huandu/go-sqlbuilder" "github.com/huandu/go-sqlbuilder"
"github.com/lib/pq"
"github.com/pborman/uuid" "github.com/pborman/uuid"
"github.com/pkg/errors" "github.com/pkg/errors"
) )
@ -232,7 +232,7 @@ func TestCreateValidation(t *testing.T) {
t.Fatalf("\t%s\tMock account failed.", tests.Failed) t.Fatalf("\t%s\tMock account failed.", tests.Failed)
} }
res, err := Create(ctx, auth.Claims{}, test.MasterDB, tt.req, now) res, err := repo.Create(ctx, auth.Claims{}, tt.req, now)
if err != tt.error { if err != tt.error {
// TODO: need a better way to handle validation errors as they are // TODO: need a better way to handle validation errors as they are
// of type interface validator.ValidationErrorsTranslations // of type interface validator.ValidationErrorsTranslations
@ -300,7 +300,7 @@ func TestCreateExistingEntry(t *testing.T) {
AccountID: accountID, AccountID: accountID,
Roles: []UserAccountRole{UserAccountRole_User}, Roles: []UserAccountRole{UserAccountRole_User},
} }
ua1, err := Create(ctx, auth.Claims{}, test.MasterDB, req1, now) ua1, err := repo.Create(ctx, auth.Claims{}, req1, now)
if err != nil { if err != nil {
t.Log("\t\tGot :", err) t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tCreate user account failed.", tests.Failed) t.Fatalf("\t%s\tCreate user account failed.", tests.Failed)
@ -313,7 +313,7 @@ func TestCreateExistingEntry(t *testing.T) {
AccountID: req1.AccountID, AccountID: req1.AccountID,
Roles: []UserAccountRole{UserAccountRole_Admin}, Roles: []UserAccountRole{UserAccountRole_Admin},
} }
ua2, err := Create(ctx, auth.Claims{}, test.MasterDB, req2, now) ua2, err := repo.Create(ctx, auth.Claims{}, req2, now)
if err != nil { if err != nil {
t.Log("\t\tGot :", err) t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tCreate user account failed.", tests.Failed) t.Fatalf("\t%s\tCreate user account failed.", tests.Failed)
@ -322,7 +322,7 @@ func TestCreateExistingEntry(t *testing.T) {
} }
// Now archive the user account to test trying to create a new entry for an archived entry // Now archive the user account to test trying to create a new entry for an archived entry
err = Archive(tests.Context(), auth.Claims{}, test.MasterDB, UserAccountArchiveRequest{ err = repo.Archive(tests.Context(), auth.Claims{}, UserAccountArchiveRequest{
UserID: req1.UserID, UserID: req1.UserID,
AccountID: req1.AccountID, AccountID: req1.AccountID,
}, now) }, now)
@ -332,7 +332,7 @@ func TestCreateExistingEntry(t *testing.T) {
} }
// Find the archived user account // Find the archived user account
arcRes, err := Read(tests.Context(), auth.Claims{}, test.MasterDB, arcRes, err := repo.Read(tests.Context(), auth.Claims{},
UserAccountReadRequest{UserID: req1.UserID, AccountID: req1.AccountID, IncludeArchived: true}) UserAccountReadRequest{UserID: req1.UserID, AccountID: req1.AccountID, IncludeArchived: true})
if err != nil || arcRes == nil { if err != nil || arcRes == nil {
t.Log("\t\tGot :", err) t.Log("\t\tGot :", err)
@ -347,7 +347,7 @@ func TestCreateExistingEntry(t *testing.T) {
AccountID: req1.AccountID, AccountID: req1.AccountID,
Roles: []UserAccountRole{UserAccountRole_User}, Roles: []UserAccountRole{UserAccountRole_User},
} }
ua3, err := Create(ctx, auth.Claims{}, test.MasterDB, req3, now) ua3, err := repo.Create(ctx, auth.Claims{}, req3, now)
if err != nil { if err != nil {
t.Log("\t\tGot :", err) t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tCreate user account failed.", tests.Failed) t.Fatalf("\t%s\tCreate user account failed.", tests.Failed)
@ -356,7 +356,7 @@ func TestCreateExistingEntry(t *testing.T) {
} }
// Ensure the user account has archived_at empty // Ensure the user account has archived_at empty
findRes, err := Read(tests.Context(), auth.Claims{}, test.MasterDB, findRes, err := repo.Read(tests.Context(), auth.Claims{},
UserAccountReadRequest{UserID: req1.UserID, AccountID: req1.AccountID}) UserAccountReadRequest{UserID: req1.UserID, AccountID: req1.AccountID})
if err != nil || arcRes == nil { if err != nil || arcRes == nil {
t.Log("\t\tGot :", err) t.Log("\t\tGot :", err)
@ -414,7 +414,7 @@ func TestUpdateValidation(t *testing.T) {
{ {
ctx := tests.Context() ctx := tests.Context()
err := Update(ctx, auth.Claims{}, test.MasterDB, tt.req, now) err := repo.Update(ctx, auth.Claims{}, tt.req, now)
if err != tt.error { if err != tt.error {
// TODO: need a better way to handle validation errors as they are // TODO: need a better way to handle validation errors as they are
// of type interface validator.ValidationErrorsTranslations // of type interface validator.ValidationErrorsTranslations
@ -564,7 +564,7 @@ func TestCrud(t *testing.T) {
AccountID: accountID, AccountID: accountID,
Roles: []UserAccountRole{UserAccountRole_User}, Roles: []UserAccountRole{UserAccountRole_User},
} }
ua, err := Create(tests.Context(), tt.claims(userID, accountID), test.MasterDB, createReq, now) ua, err := repo.Create(tests.Context(), tt.claims(userID, accountID), createReq, now)
if err != nil && errors.Cause(err) != tt.createErr { if err != nil && errors.Cause(err) != tt.createErr {
t.Logf("\t\tGot : %+v", err) t.Logf("\t\tGot : %+v", err)
t.Logf("\t\tWant: %+v", tt.createErr) t.Logf("\t\tWant: %+v", tt.createErr)
@ -577,7 +577,7 @@ func TestCrud(t *testing.T) {
} }
if tt.createErr == ErrForbidden { if tt.createErr == ErrForbidden {
ua, err = Create(tests.Context(), auth.Claims{}, test.MasterDB, createReq, now) ua, err = repo.Create(tests.Context(), auth.Claims{}, createReq, now)
if err != nil && errors.Cause(err) != tt.createErr { if err != nil && errors.Cause(err) != tt.createErr {
t.Logf("\t\tGot : %+v", err) t.Logf("\t\tGot : %+v", err)
t.Fatalf("\t%s\tCreate user account failed.", tests.Failed) t.Fatalf("\t%s\tCreate user account failed.", tests.Failed)
@ -590,7 +590,7 @@ func TestCrud(t *testing.T) {
AccountID: accountID, AccountID: accountID,
Roles: &UserAccountRoles{UserAccountRole_Admin}, Roles: &UserAccountRoles{UserAccountRole_Admin},
} }
err = Update(tests.Context(), tt.claims(userID, accountID), test.MasterDB, updateReq, now) err = repo.Update(tests.Context(), tt.claims(userID, accountID), updateReq, now)
if err != nil { if err != nil {
if errors.Cause(err) != tt.updateErr { if errors.Cause(err) != tt.updateErr {
t.Logf("\t\tGot : %+v", err) t.Logf("\t\tGot : %+v", err)
@ -604,7 +604,7 @@ func TestCrud(t *testing.T) {
// Find the account for the user to verify the updates where made. There should only // Find the account for the user to verify the updates where made. There should only
// be one account associated with the user for this test. // be one account associated with the user for this test.
findRes, err := Find(tests.Context(), tt.claims(userID, accountID), test.MasterDB, UserAccountFindRequest{ findRes, err := repo.Find(tests.Context(), tt.claims(userID, accountID), UserAccountFindRequest{
Where: "user_id = ? or account_id = ?", Where: "user_id = ? or account_id = ?",
Args: []interface{}{userID, accountID}, Args: []interface{}{userID, accountID},
Order: []string{"created_at"}, Order: []string{"created_at"},
@ -632,7 +632,7 @@ func TestCrud(t *testing.T) {
} }
// Archive (soft-delete) the user account. // Archive (soft-delete) the user account.
err = Archive(tests.Context(), tt.claims(userID, accountID), test.MasterDB, UserAccountArchiveRequest{ err = repo.Archive(tests.Context(), tt.claims(userID, accountID), UserAccountArchiveRequest{
UserID: userID, UserID: userID,
AccountID: accountID, AccountID: accountID,
}, now) }, now)
@ -642,7 +642,7 @@ func TestCrud(t *testing.T) {
t.Fatalf("\t%s\tArchive user account failed.", tests.Failed) t.Fatalf("\t%s\tArchive user account failed.", tests.Failed)
} else if tt.updateErr == nil { } else if tt.updateErr == nil {
// Trying to find the archived user with the includeArchived false should result in not found. // Trying to find the archived user with the includeArchived false should result in not found.
_, err = FindByUserID(tests.Context(), tt.claims(userID, accountID), test.MasterDB, userID, false) _, err = repo.FindByUserID(tests.Context(), tt.claims(userID, accountID), userID, false)
if errors.Cause(err) != ErrNotFound { if errors.Cause(err) != ErrNotFound {
t.Logf("\t\tGot : %+v", err) t.Logf("\t\tGot : %+v", err)
t.Logf("\t\tWant: %+v", ErrNotFound) t.Logf("\t\tWant: %+v", ErrNotFound)
@ -650,7 +650,7 @@ func TestCrud(t *testing.T) {
} }
// Trying to find the archived user with the includeArchived true should result no error. // Trying to find the archived user with the includeArchived true should result no error.
findRes, err = FindByUserID(tests.Context(), tt.claims(userID, accountID), test.MasterDB, userID, true) findRes, err = repo.FindByUserID(tests.Context(), tt.claims(userID, accountID), userID, true)
if err != nil { if err != nil {
t.Logf("\t\tGot : %+v", err) t.Logf("\t\tGot : %+v", err)
t.Fatalf("\t%s\tVerify archive user account failed when including archived.", tests.Failed) t.Fatalf("\t%s\tVerify archive user account failed when including archived.", tests.Failed)
@ -675,7 +675,7 @@ func TestCrud(t *testing.T) {
t.Logf("\t%s\tArchive user account ok.", tests.Success) t.Logf("\t%s\tArchive user account ok.", tests.Success)
// Delete (hard-delete) the user account. // Delete (hard-delete) the user account.
err = Delete(tests.Context(), tt.claims(userID, accountID), test.MasterDB, UserAccountDeleteRequest{ err = repo.Delete(tests.Context(), tt.claims(userID, accountID), UserAccountDeleteRequest{
UserID: userID, UserID: userID,
AccountID: accountID, AccountID: accountID,
}) })
@ -685,7 +685,7 @@ func TestCrud(t *testing.T) {
t.Fatalf("\t%s\tDelete user account failed.", tests.Failed) t.Fatalf("\t%s\tDelete user account failed.", tests.Failed)
} else if tt.updateErr == nil { } else if tt.updateErr == nil {
// Trying to find the deleted user with the includeArchived true should result in not found. // Trying to find the deleted user with the includeArchived true should result in not found.
_, err = FindByUserID(tests.Context(), tt.claims(userID, accountID), test.MasterDB, userID, true) _, err = repo.FindByUserID(tests.Context(), tt.claims(userID, accountID), userID, true)
if errors.Cause(err) != ErrNotFound { if errors.Cause(err) != ErrNotFound {
t.Logf("\t\tGot : %+v", err) t.Logf("\t\tGot : %+v", err)
t.Logf("\t\tWant: %+v", ErrNotFound) t.Logf("\t\tWant: %+v", ErrNotFound)
@ -725,7 +725,7 @@ func TestFind(t *testing.T) {
} }
// Execute Create that will associate the user with the account. // Execute Create that will associate the user with the account.
ua, err := Create(tests.Context(), auth.Claims{}, test.MasterDB, UserAccountCreateRequest{ ua, err := repo.Create(tests.Context(), auth.Claims{}, UserAccountCreateRequest{
UserID: userID, UserID: userID,
AccountID: accountID, AccountID: accountID,
Roles: []UserAccountRole{UserAccountRole_User}, Roles: []UserAccountRole{UserAccountRole_User},
@ -836,7 +836,7 @@ func TestFind(t *testing.T) {
{ {
ctx := tests.Context() ctx := tests.Context()
res, err := Find(ctx, auth.Claims{}, test.MasterDB, tt.req) res, err := repo.Find(ctx, auth.Claims{}, tt.req)
if errors.Cause(err) != tt.error { if errors.Cause(err) != tt.error {
t.Logf("\t\tGot : %+v", err) t.Logf("\t\tGot : %+v", err)
t.Logf("\t\tWant: %+v", tt.error) t.Logf("\t\tWant: %+v", tt.error)

View File

@ -3,7 +3,6 @@ package user_auth
import ( import (
"context" "context"
"database/sql" "database/sql"
"geeks-accelerator/oss/saas-starter-kit/internal/user_account"
"strings" "strings"
"time" "time"
@ -11,8 +10,8 @@ import (
"geeks-accelerator/oss/saas-starter-kit/internal/platform/auth" "geeks-accelerator/oss/saas-starter-kit/internal/platform/auth"
"geeks-accelerator/oss/saas-starter-kit/internal/platform/web/webcontext" "geeks-accelerator/oss/saas-starter-kit/internal/platform/web/webcontext"
"geeks-accelerator/oss/saas-starter-kit/internal/user" "geeks-accelerator/oss/saas-starter-kit/internal/user"
"geeks-accelerator/oss/saas-starter-kit/internal/user_account"
"github.com/huandu/go-sqlbuilder" "github.com/huandu/go-sqlbuilder"
"github.com/jmoiron/sqlx"
"github.com/lib/pq" "github.com/lib/pq"
"github.com/pkg/errors" "github.com/pkg/errors"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
@ -40,7 +39,7 @@ const (
// Authenticate finds a user by their email and verifies their password. On success // Authenticate finds a user by their email and verifies their password. On success
// it returns a Token that can be used to authenticate access to the application in // it returns a Token that can be used to authenticate access to the application in
// the future. // the future.
func Authenticate(ctx context.Context, dbConn *sqlx.DB, tknGen TokenGenerator, req AuthenticateRequest, expires time.Duration, now time.Time, scopes ...string) (Token, error) { func (repo *Repository) Authenticate(ctx context.Context, req AuthenticateRequest, expires time.Duration, now time.Time, scopes ...string) (Token, error) {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user_auth.Authenticate") span, ctx := tracer.StartSpanFromContext(ctx, "internal.user_auth.Authenticate")
defer span.Finish() defer span.Finish()
@ -51,7 +50,7 @@ func Authenticate(ctx context.Context, dbConn *sqlx.DB, tknGen TokenGenerator, r
return Token{}, err return Token{}, err
} }
u, err := user.ReadByEmail(ctx, auth.Claims{}, dbConn, req.Email, false) u, err := repo.User.ReadByEmail(ctx, auth.Claims{}, req.Email, false)
if err != nil { if err != nil {
if errors.Cause(err) == user.ErrNotFound { if errors.Cause(err) == user.ErrNotFound {
err = errors.WithStack(ErrAuthenticationFailure) err = errors.WithStack(ErrAuthenticationFailure)
@ -73,11 +72,11 @@ func Authenticate(ctx context.Context, dbConn *sqlx.DB, tknGen TokenGenerator, r
} }
// The user is successfully authenticated with the supplied email and password. // The user is successfully authenticated with the supplied email and password.
return generateToken(ctx, dbConn, tknGen, auth.Claims{}, u.ID, req.AccountID, expires, now, scopes...) return repo.generateToken(ctx, auth.Claims{}, u.ID, req.AccountID, expires, now, scopes...)
} }
// SwitchAccount allows users to switch between multiple accounts, this changes the claim audience. // SwitchAccount allows users to switch between multiple accounts, this changes the claim audience.
func SwitchAccount(ctx context.Context, dbConn *sqlx.DB, tknGen TokenGenerator, claims auth.Claims, req SwitchAccountRequest, expires time.Duration, now time.Time, scopes ...string) (Token, error) { func (repo *Repository) SwitchAccount(ctx context.Context, claims auth.Claims, req SwitchAccountRequest, expires time.Duration, now time.Time, scopes ...string) (Token, error) {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user_auth.SwitchAccount") span, ctx := tracer.StartSpanFromContext(ctx, "internal.user_auth.SwitchAccount")
defer span.Finish() defer span.Finish()
@ -97,11 +96,11 @@ func SwitchAccount(ctx context.Context, dbConn *sqlx.DB, tknGen TokenGenerator,
// Generate a token for the user ID in supplied in claims as the Subject. Pass // Generate a token for the user ID in supplied in claims as the Subject. Pass
// in the supplied claims as well to enforce ACLs when finding the current // in the supplied claims as well to enforce ACLs when finding the current
// list of accounts for the user. // list of accounts for the user.
return generateToken(ctx, dbConn, tknGen, claims, claims.Subject, req.AccountID, expires, now, scopes...) return repo.generateToken(ctx, claims, claims.Subject, req.AccountID, expires, now, scopes...)
} }
// VirtualLogin allows users to mock being logged in as other users. // VirtualLogin allows users to mock being logged in as other users.
func VirtualLogin(ctx context.Context, dbConn *sqlx.DB, tknGen TokenGenerator, claims auth.Claims, req VirtualLoginRequest, expires time.Duration, now time.Time, scopes ...string) (Token, error) { func (repo *Repository) VirtualLogin(ctx context.Context, claims auth.Claims, req VirtualLoginRequest, expires time.Duration, now time.Time, scopes ...string) (Token, error) {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user_auth.VirtualLogin") span, ctx := tracer.StartSpanFromContext(ctx, "internal.user_auth.VirtualLogin")
defer span.Finish() defer span.Finish()
@ -113,7 +112,7 @@ func VirtualLogin(ctx context.Context, dbConn *sqlx.DB, tknGen TokenGenerator, c
} }
// Find all the accounts that the current user has access to. // Find all the accounts that the current user has access to.
usrAccs, err := user_account.FindByUserID(ctx, claims, dbConn, claims.Subject, false) usrAccs, err := repo.UserAccount.FindByUserID(ctx, claims, claims.Subject, false)
if err != nil { if err != nil {
return Token{}, err return Token{}, err
} }
@ -142,23 +141,23 @@ func VirtualLogin(ctx context.Context, dbConn *sqlx.DB, tknGen TokenGenerator, c
// Generate a token for the user ID in supplied in claims as the Subject. Pass // Generate a token for the user ID in supplied in claims as the Subject. Pass
// in the supplied claims as well to enforce ACLs when finding the current // in the supplied claims as well to enforce ACLs when finding the current
// list of accounts for the user. // list of accounts for the user.
return generateToken(ctx, dbConn, tknGen, claims, req.UserID, req.AccountID, expires, now, scopes...) return repo.generateToken(ctx, claims, req.UserID, req.AccountID, expires, now, scopes...)
} }
// VirtualLogout allows switch back to their root user/account. // VirtualLogout allows switch back to their root user/account.
func VirtualLogout(ctx context.Context, dbConn *sqlx.DB, tknGen TokenGenerator, claims auth.Claims, expires time.Duration, now time.Time, scopes ...string) (Token, error) { func (repo *Repository) VirtualLogout(ctx context.Context, claims auth.Claims, expires time.Duration, now time.Time, scopes ...string) (Token, error) {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user_auth.VirtualLogout") span, ctx := tracer.StartSpanFromContext(ctx, "internal.user_auth.VirtualLogout")
defer span.Finish() defer span.Finish()
// Generate a token for the user ID in supplied in claims as the Subject. Pass // Generate a token for the user ID in supplied in claims as the Subject. Pass
// in the supplied claims as well to enforce ACLs when finding the current // in the supplied claims as well to enforce ACLs when finding the current
// list of accounts for the user. // list of accounts for the user.
return generateToken(ctx, dbConn, tknGen, claims, claims.RootUserID, claims.RootAccountID, expires, now, scopes...) return repo.generateToken(ctx, claims, claims.RootUserID, claims.RootAccountID, expires, now, scopes...)
} }
// generateToken generates claims for the supplied user ID and account ID and then // generateToken generates claims for the supplied user ID and account ID and then
// returns the token for the generated claims used for authentication. // returns the token for the generated claims used for authentication.
func generateToken(ctx context.Context, dbConn *sqlx.DB, tknGen TokenGenerator, claims auth.Claims, userID, accountID string, expires time.Duration, now time.Time, scopes ...string) (Token, error) { func (repo *Repository) generateToken(ctx context.Context, claims auth.Claims, userID, accountID string, expires time.Duration, now time.Time, scopes ...string) (Token, error) {
type userAccount struct { type userAccount struct {
AccountID string AccountID string
@ -184,8 +183,8 @@ func generateToken(ctx context.Context, dbConn *sqlx.DB, tknGen TokenGenerator,
// fetch all places from the db // fetch all places from the db
queryStr, queryArgs := query.Build() queryStr, queryArgs := query.Build()
queryStr = dbConn.Rebind(queryStr) queryStr = repo.DbConn.Rebind(queryStr)
rows, err := dbConn.QueryContext(ctx, queryStr, queryArgs...) rows, err := repo.DbConn.QueryContext(ctx, queryStr, queryArgs...)
if err != nil { if err != nil {
err = errors.Wrapf(err, "query - %s", query.String()) err = errors.Wrapf(err, "query - %s", query.String())
return nil, err return nil, err
@ -339,7 +338,7 @@ func generateToken(ctx context.Context, dbConn *sqlx.DB, tknGen TokenGenerator,
tz, _ = time.LoadLocation(account.AccountTimezone.String) tz, _ = time.LoadLocation(account.AccountTimezone.String)
} }
prefs, err := account_preference.FindByAccountID(ctx, auth.Claims{}, dbConn, account_preference.AccountPreferenceFindByAccountIDRequest{ prefs, err := repo.AccountPreference.FindByAccountID(ctx, auth.Claims{}, account_preference.AccountPreferenceFindByAccountIDRequest{
AccountID: accountID, AccountID: accountID,
}) })
if err != nil { if err != nil {
@ -393,7 +392,7 @@ func generateToken(ctx context.Context, dbConn *sqlx.DB, tknGen TokenGenerator,
newClaims.RootUserID = claims.RootUserID newClaims.RootUserID = claims.RootUserID
// Generate a token for the user with the defined claims. // Generate a token for the user with the defined claims.
tknStr, err := tknGen.GenerateToken(newClaims) tknStr, err := repo.TknGen.GenerateToken(newClaims)
if err != nil { if err != nil {
return Token{}, errors.Wrap(err, "generating token") return Token{}, errors.Wrap(err, "generating token")
} }

View File

@ -8,8 +8,8 @@ import (
"time" "time"
"geeks-accelerator/oss/saas-starter-kit/internal/account" "geeks-accelerator/oss/saas-starter-kit/internal/account"
"geeks-accelerator/oss/saas-starter-kit/internal/account/account_preference"
"geeks-accelerator/oss/saas-starter-kit/internal/platform/auth" "geeks-accelerator/oss/saas-starter-kit/internal/platform/auth"
"geeks-accelerator/oss/saas-starter-kit/internal/platform/notify"
"geeks-accelerator/oss/saas-starter-kit/internal/platform/tests" "geeks-accelerator/oss/saas-starter-kit/internal/platform/tests"
"geeks-accelerator/oss/saas-starter-kit/internal/user" "geeks-accelerator/oss/saas-starter-kit/internal/user"
"geeks-accelerator/oss/saas-starter-kit/internal/user_account" "geeks-accelerator/oss/saas-starter-kit/internal/user_account"
@ -18,7 +18,10 @@ import (
"github.com/pkg/errors" "github.com/pkg/errors"
) )
var test *tests.Test var (
test *tests.Test
repo *Repository
)
// TestMain is the entry point for testing. // TestMain is the entry point for testing.
func TestMain(m *testing.M) { func TestMain(m *testing.M) {
@ -28,6 +31,15 @@ func TestMain(m *testing.M) {
func testMain(m *testing.M) int { func testMain(m *testing.M) int {
test = tests.New() test = tests.New()
defer test.TearDown() defer test.TearDown()
tknGen := &auth.MockTokenGenerator{}
userRepo := user.MockRepository(test.MasterDB)
userAccRepo := user_account.NewRepository(test.MasterDB)
accPrefRepo := account_preference.NewRepository(test.MasterDB)
repo = NewRepository(test.MasterDB, tknGen, userRepo, userAccRepo, accPrefRepo)
return m.Run() return m.Run()
} }
@ -41,14 +53,12 @@ func TestAuthenticate(t *testing.T) {
{ {
ctx := tests.Context() ctx := tests.Context()
tknGen := &auth.MockTokenGenerator{}
// Auth tokens are valid for an our and is verified against current time. // Auth tokens are valid for an our and is verified against current time.
// Issue the token one hour ago. // Issue the token one hour ago.
now := time.Now().Add(time.Hour * -1) now := time.Now().Add(time.Hour * -1)
// Try to authenticate an invalid user. // Try to authenticate an invalid user.
_, err := Authenticate(ctx, test.MasterDB, tknGen, _, err := repo.Authenticate(ctx,
AuthenticateRequest{ AuthenticateRequest{
Email: "doesnotexist@gmail.com", Email: "doesnotexist@gmail.com",
Password: "xy7", Password: "xy7",
@ -82,7 +92,7 @@ func TestAuthenticate(t *testing.T) {
// is always greater than the first user_account entry created so it will // is always greater than the first user_account entry created so it will
// be returned consistently back in the same order, last. // be returned consistently back in the same order, last.
account2Role := auth.RoleUser account2Role := auth.RoleUser
_, err = user_account.Create(ctx, auth.Claims{}, test.MasterDB, user_account.UserAccountCreateRequest{ _, err = repo.UserAccount.Create(ctx, auth.Claims{}, user_account.UserAccountCreateRequest{
UserID: usrAcc.UserID, UserID: usrAcc.UserID,
AccountID: acc2.ID, AccountID: acc2.ID,
Roles: []user_account.UserAccountRole{user_account.UserAccountRole(account2Role)}, Roles: []user_account.UserAccountRole{user_account.UserAccountRole(account2Role)},
@ -92,7 +102,7 @@ func TestAuthenticate(t *testing.T) {
now = now.Add(time.Minute * 5) now = now.Add(time.Minute * 5)
// Try to authenticate valid user with invalid password. // Try to authenticate valid user with invalid password.
_, err = Authenticate(ctx, test.MasterDB, tknGen, _, err = repo.Authenticate(ctx,
AuthenticateRequest{ AuthenticateRequest{
Email: usrAcc.User.Email, Email: usrAcc.User.Email,
Password: "xy7", Password: "xy7",
@ -106,7 +116,7 @@ func TestAuthenticate(t *testing.T) {
t.Logf("\t%s\tAuthenticate user w/invalid password ok.", tests.Success) t.Logf("\t%s\tAuthenticate user w/invalid password ok.", tests.Success)
// Verify that the user can be authenticated with the created user. // Verify that the user can be authenticated with the created user.
tkn1, err := Authenticate(ctx, test.MasterDB, tknGen, tkn1, err := repo.Authenticate(ctx,
AuthenticateRequest{ AuthenticateRequest{
Email: usrAcc.User.Email, Email: usrAcc.User.Email,
Password: usrAcc.User.Password, Password: usrAcc.User.Password,
@ -118,7 +128,7 @@ func TestAuthenticate(t *testing.T) {
t.Logf("\t%s\tAuthenticate user ok.", tests.Success) t.Logf("\t%s\tAuthenticate user ok.", tests.Success)
// Ensure the token string was correctly generated. // Ensure the token string was correctly generated.
claims1, err := tknGen.ParseClaims(tkn1.AccessToken) claims1, err := repo.TknGen.ParseClaims(tkn1.AccessToken)
if err != nil { if err != nil {
t.Log("\t\tGot :", err) t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tParse claims from token failed.", tests.Failed) t.Fatalf("\t%s\tParse claims from token failed.", tests.Failed)
@ -135,7 +145,7 @@ func TestAuthenticate(t *testing.T) {
t.Logf("\t%s\tAuthenticate parse claims from token ok.", tests.Success) t.Logf("\t%s\tAuthenticate parse claims from token ok.", tests.Success)
// Try switching to a second account using the first set of claims. // Try switching to a second account using the first set of claims.
tkn2, err := SwitchAccount(ctx, test.MasterDB, tknGen, claims1, tkn2, err := repo.SwitchAccount(ctx, claims1,
SwitchAccountRequest{AccountID: acc2.ID}, time.Hour, now) SwitchAccountRequest{AccountID: acc2.ID}, time.Hour, now)
if err != nil { if err != nil {
t.Log("\t\tGot :", err) t.Log("\t\tGot :", err)
@ -144,7 +154,7 @@ func TestAuthenticate(t *testing.T) {
t.Logf("\t%s\tSwitchAccount user ok.", tests.Success) t.Logf("\t%s\tSwitchAccount user ok.", tests.Success)
// Ensure the token string was correctly generated. // Ensure the token string was correctly generated.
claims2, err := tknGen.ParseClaims(tkn2.AccessToken) claims2, err := repo.TknGen.ParseClaims(tkn2.AccessToken)
if err != nil { if err != nil {
t.Log("\t\tGot :", err) t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tParse claims from token failed.", tests.Failed) t.Fatalf("\t%s\tParse claims from token failed.", tests.Failed)
@ -172,8 +182,6 @@ func TestUserUpdatePassword(t *testing.T) {
now := time.Date(2018, time.October, 1, 0, 0, 0, 0, time.UTC) now := time.Date(2018, time.October, 1, 0, 0, 0, 0, time.UTC)
tknGen := &auth.MockTokenGenerator{}
// Create a new user for testing. // Create a new user for testing.
usrAcc, err := user_account.MockUserAccount(ctx, test.MasterDB, now, user_account.UserAccountRole_User) usrAcc, err := user_account.MockUserAccount(ctx, test.MasterDB, now, user_account.UserAccountRole_User)
if err != nil { if err != nil {
@ -183,7 +191,7 @@ func TestUserUpdatePassword(t *testing.T) {
t.Logf("\t%s\tCreate user account ok.", tests.Success) t.Logf("\t%s\tCreate user account ok.", tests.Success)
// Verify that the user can be authenticated with the created user. // Verify that the user can be authenticated with the created user.
_, err = Authenticate(ctx, test.MasterDB, tknGen, _, err = repo.Authenticate(ctx,
AuthenticateRequest{ AuthenticateRequest{
Email: usrAcc.User.Email, Email: usrAcc.User.Email,
Password: usrAcc.User.Password, Password: usrAcc.User.Password,
@ -195,7 +203,7 @@ func TestUserUpdatePassword(t *testing.T) {
// Update the users password. // Update the users password.
newPass := uuid.NewRandom().String() newPass := uuid.NewRandom().String()
err = user.UpdatePassword(ctx, auth.Claims{}, test.MasterDB, user.UserUpdatePasswordRequest{ err = repo.User.UpdatePassword(ctx, auth.Claims{}, user.UserUpdatePasswordRequest{
ID: usrAcc.UserID, ID: usrAcc.UserID,
Password: newPass, Password: newPass,
PasswordConfirm: newPass, PasswordConfirm: newPass,
@ -207,7 +215,7 @@ func TestUserUpdatePassword(t *testing.T) {
t.Logf("\t%s\tUpdatePassword ok.", tests.Success) t.Logf("\t%s\tUpdatePassword ok.", tests.Success)
// Verify that the user can be authenticated with the updated password. // Verify that the user can be authenticated with the updated password.
_, err = Authenticate(ctx, test.MasterDB, tknGen, _, err = repo.Authenticate(ctx,
AuthenticateRequest{ AuthenticateRequest{
Email: usrAcc.User.Email, Email: usrAcc.User.Email,
Password: newPass, Password: newPass,
@ -229,8 +237,6 @@ func TestUserResetPassword(t *testing.T) {
now := time.Date(2018, time.October, 1, 0, 0, 0, 0, time.UTC) now := time.Date(2018, time.October, 1, 0, 0, 0, 0, time.UTC)
tknGen := &auth.MockTokenGenerator{}
// Create a new user for testing. // Create a new user for testing.
usrAcc, err := user_account.MockUserAccount(ctx, test.MasterDB, now, user_account.UserAccountRole_User) usrAcc, err := user_account.MockUserAccount(ctx, test.MasterDB, now, user_account.UserAccountRole_User)
if err != nil { if err != nil {
@ -239,21 +245,13 @@ func TestUserResetPassword(t *testing.T) {
} }
t.Logf("\t%s\tCreate user account ok.", tests.Success) t.Logf("\t%s\tCreate user account ok.", tests.Success)
// Mock the methods needed to make a password reset.
resetUrl := func(string) string {
return ""
}
notify := &notify.MockEmail{}
secretKey := "6368616e676520746869732070617373"
ttl := time.Hour ttl := time.Hour
// Make the reset password request. // Make the reset password request.
resetHash, err := user.ResetPassword(ctx, test.MasterDB, resetUrl, notify, user.UserResetPasswordRequest{ resetHash, err := repo.User.ResetPassword(ctx, user.UserResetPasswordRequest{
Email: usrAcc.User.Email, Email: usrAcc.User.Email,
TTL: ttl, TTL: ttl,
}, secretKey, now) }, now)
if err != nil { if err != nil {
t.Log("\t\tGot :", err) t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tResetPassword failed.", tests.Failed) t.Fatalf("\t%s\tResetPassword failed.", tests.Failed)
@ -262,11 +260,11 @@ func TestUserResetPassword(t *testing.T) {
// Assuming we have received the email and clicked the link, we now can ensure confirm works. // Assuming we have received the email and clicked the link, we now can ensure confirm works.
newPass := uuid.NewRandom().String() newPass := uuid.NewRandom().String()
reset, err := user.ResetConfirm(ctx, test.MasterDB, user.UserResetConfirmRequest{ reset, err := repo.User.ResetConfirm(ctx, user.UserResetConfirmRequest{
ResetHash: resetHash, ResetHash: resetHash,
Password: newPass, Password: newPass,
PasswordConfirm: newPass, PasswordConfirm: newPass,
}, secretKey, now) }, now)
if err != nil { if err != nil {
t.Log("\t\tGot :", err) t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tResetConfirm failed.", tests.Failed) t.Fatalf("\t%s\tResetConfirm failed.", tests.Failed)
@ -278,7 +276,7 @@ func TestUserResetPassword(t *testing.T) {
t.Logf("\t%s\tResetConfirm ok.", tests.Success) t.Logf("\t%s\tResetConfirm ok.", tests.Success)
// Verify that the user can be authenticated with the updated password. // Verify that the user can be authenticated with the updated password.
_, err = Authenticate(ctx, test.MasterDB, tknGen, _, err = repo.Authenticate(ctx,
AuthenticateRequest{ AuthenticateRequest{
Email: usrAcc.User.Email, Email: usrAcc.User.Email,
Password: newPass, Password: newPass,
@ -340,7 +338,7 @@ func TestSwitchAccount(t *testing.T) {
} }
// Associate the second account with root user. // Associate the second account with root user.
usrAcc2, err := user_account.Create(ctx, auth.Claims{}, test.MasterDB, user_account.UserAccountCreateRequest{ usrAcc2, err := repo.UserAccount.Create(ctx, auth.Claims{}, user_account.UserAccountCreateRequest{
UserID: usrAcc.UserID, UserID: usrAcc.UserID,
AccountID: acc2.ID, AccountID: acc2.ID,
Roles: []user_account.UserAccountRole{user_account.UserAccountRole(roles[1])}, Roles: []user_account.UserAccountRole{user_account.UserAccountRole(roles[1])},
@ -359,7 +357,7 @@ func TestSwitchAccount(t *testing.T) {
} }
// Associate the third account with root user. // Associate the third account with root user.
usrAcc3, err := user_account.Create(ctx, auth.Claims{}, test.MasterDB, user_account.UserAccountCreateRequest{ usrAcc3, err := repo.UserAccount.Create(ctx, auth.Claims{}, user_account.UserAccountCreateRequest{
UserID: usrAcc.UserID, UserID: usrAcc.UserID,
AccountID: acc3.ID, AccountID: acc3.ID,
Roles: []user_account.UserAccountRole{user_account.UserAccountRole(roles[2])}, Roles: []user_account.UserAccountRole{user_account.UserAccountRole(roles[2])},
@ -426,7 +424,7 @@ func TestSwitchAccount(t *testing.T) {
} }
// Associate the second account with root user. // Associate the second account with root user.
usrAcc2, err := user_account.Create(ctx, auth.Claims{}, test.MasterDB, user_account.UserAccountCreateRequest{ usrAcc2, err := repo.UserAccount.Create(ctx, auth.Claims{}, user_account.UserAccountCreateRequest{
UserID: usrAcc.UserID, UserID: usrAcc.UserID,
AccountID: acc2.ID, AccountID: acc2.ID,
Roles: []user_account.UserAccountRole{user_account.UserAccountRole_Admin}, Roles: []user_account.UserAccountRole{user_account.UserAccountRole_Admin},
@ -445,7 +443,7 @@ func TestSwitchAccount(t *testing.T) {
} }
// Associate the third account with root user. // Associate the third account with root user.
usrAcc3, err := user_account.Create(ctx, auth.Claims{}, test.MasterDB, user_account.UserAccountCreateRequest{ usrAcc3, err := repo.UserAccount.Create(ctx, auth.Claims{}, user_account.UserAccountCreateRequest{
UserID: usrAcc.UserID, UserID: usrAcc.UserID,
AccountID: acc3.ID, AccountID: acc3.ID,
Roles: []user_account.UserAccountRole{user_account.UserAccountRole_User}, Roles: []user_account.UserAccountRole{user_account.UserAccountRole_User},
@ -472,8 +470,6 @@ func TestSwitchAccount(t *testing.T) {
// Add 30 minutes to now to simulate time passing. // Add 30 minutes to now to simulate time passing.
now = now.Add(time.Minute * 5) now = now.Add(time.Minute * 5)
tknGen := &auth.MockTokenGenerator{}
t.Log("Given the need to switch accounts.") t.Log("Given the need to switch accounts.")
{ {
for i, authTest := range authTests { for i, authTest := range authTests {
@ -481,7 +477,7 @@ func TestSwitchAccount(t *testing.T) {
{ {
// Verify that the user can be authenticated with the created user. // Verify that the user can be authenticated with the created user.
var claims1 auth.Claims var claims1 auth.Claims
tkn1, err := Authenticate(ctx, test.MasterDB, tknGen, tkn1, err := repo.Authenticate(ctx,
AuthenticateRequest{ AuthenticateRequest{
Email: authTest.root.User.Email, Email: authTest.root.User.Email,
Password: authTest.root.User.Password, Password: authTest.root.User.Password,
@ -491,7 +487,7 @@ func TestSwitchAccount(t *testing.T) {
t.Fatalf("\t%s\tAuthenticate user failed.", tests.Failed) t.Fatalf("\t%s\tAuthenticate user failed.", tests.Failed)
} else { } else {
// Ensure the token string was correctly generated. // Ensure the token string was correctly generated.
claims1, err = tknGen.ParseClaims(tkn1.AccessToken) claims1, err = repo.TknGen.ParseClaims(tkn1.AccessToken)
if err != nil { if err != nil {
t.Log("\t\tGot :", err) t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tParse claims from token failed.", tests.Failed) t.Fatalf("\t%s\tParse claims from token failed.", tests.Failed)
@ -511,7 +507,7 @@ func TestSwitchAccount(t *testing.T) {
// Try to switch to account 2. // Try to switch to account 2.
var claims2 auth.Claims var claims2 auth.Claims
tkn2, err := SwitchAccount(ctx, test.MasterDB, tknGen, claims1, authTest.switch1Req, time.Hour, now, authTest.switch1Scopes...) tkn2, err := repo.SwitchAccount(ctx, claims1, authTest.switch1Req, time.Hour, now, authTest.switch1Scopes...)
if err != authTest.switch1Err { if err != authTest.switch1Err {
if errors.Cause(err) != authTest.switch1Err { if errors.Cause(err) != authTest.switch1Err {
t.Log("\t\tExpected :", authTest.switch1Err) t.Log("\t\tExpected :", authTest.switch1Err)
@ -520,7 +516,7 @@ func TestSwitchAccount(t *testing.T) {
} }
} else { } else {
// Ensure the token string was correctly generated. // Ensure the token string was correctly generated.
claims2, err = tknGen.ParseClaims(tkn2.AccessToken) claims2, err = repo.TknGen.ParseClaims(tkn2.AccessToken)
if err != nil { if err != nil {
t.Log("\t\tGot :", err) t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tParse claims from token failed.", tests.Failed) t.Fatalf("\t%s\tParse claims from token failed.", tests.Failed)
@ -549,7 +545,7 @@ func TestSwitchAccount(t *testing.T) {
} }
// Try to switch to account 3. // Try to switch to account 3.
tkn3, err := SwitchAccount(ctx, test.MasterDB, tknGen, claims2, authTest.switch2Req, time.Hour, now, authTest.switch2Scopes...) tkn3, err := repo.SwitchAccount(ctx, claims2, authTest.switch2Req, time.Hour, now, authTest.switch2Scopes...)
if err != authTest.switch2Err { if err != authTest.switch2Err {
if errors.Cause(err) != authTest.switch2Err { if errors.Cause(err) != authTest.switch2Err {
t.Log("\t\tExpected :", authTest.switch2Err) t.Log("\t\tExpected :", authTest.switch2Err)
@ -558,7 +554,7 @@ func TestSwitchAccount(t *testing.T) {
} }
} else { } else {
// Ensure the token string was correctly generated. // Ensure the token string was correctly generated.
claims3, err := tknGen.ParseClaims(tkn3.AccessToken) claims3, err := repo.TknGen.ParseClaims(tkn3.AccessToken)
if err != nil { if err != nil {
t.Log("\t\tGot :", err) t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tParse claims from token failed.", tests.Failed) t.Fatalf("\t%s\tParse claims from token failed.", tests.Failed)
@ -610,7 +606,7 @@ func TestVirtualLogin(t *testing.T) {
var authTests []authTest var authTests []authTest
// Root admin -> role admin -> role admin // Root admin -> role admin -> role admin
if true { {
// Create a new user for testing. // Create a new user for testing.
usrAcc, err := user_account.MockUserAccount(ctx, test.MasterDB, now, user_account.UserAccountRole_Admin) usrAcc, err := user_account.MockUserAccount(ctx, test.MasterDB, now, user_account.UserAccountRole_Admin)
if err != nil { if err != nil {
@ -625,7 +621,7 @@ func TestVirtualLogin(t *testing.T) {
} }
// Associate second user with basic role associated with the same account. // Associate second user with basic role associated with the same account.
usrAcc2, err := user_account.Create(ctx, auth.Claims{}, test.MasterDB, user_account.UserAccountCreateRequest{ usrAcc2, err := repo.UserAccount.Create(ctx, auth.Claims{}, user_account.UserAccountCreateRequest{
UserID: usr2.ID, UserID: usr2.ID,
AccountID: usrAcc.AccountID, AccountID: usrAcc.AccountID,
Roles: []user_account.UserAccountRole{user_account.UserAccountRole(user_account.UserAccountRole_Admin)}, Roles: []user_account.UserAccountRole{user_account.UserAccountRole(user_account.UserAccountRole_Admin)},
@ -642,7 +638,7 @@ func TestVirtualLogin(t *testing.T) {
} }
// Associate second user with basic role associated with the same account. // Associate second user with basic role associated with the same account.
usrAcc3, err := user_account.Create(ctx, auth.Claims{}, test.MasterDB, user_account.UserAccountCreateRequest{ usrAcc3, err := repo.UserAccount.Create(ctx, auth.Claims{}, user_account.UserAccountCreateRequest{
UserID: usr3.ID, UserID: usr3.ID,
AccountID: usrAcc.AccountID, AccountID: usrAcc.AccountID,
Roles: []user_account.UserAccountRole{user_account.UserAccountRole(user_account.UserAccountRole_Admin)}, Roles: []user_account.UserAccountRole{user_account.UserAccountRole(user_account.UserAccountRole_Admin)},
@ -687,7 +683,7 @@ func TestVirtualLogin(t *testing.T) {
} }
// Associate second user with basic role associated with the same account. // Associate second user with basic role associated with the same account.
usrAcc2, err := user_account.Create(ctx, auth.Claims{}, test.MasterDB, user_account.UserAccountCreateRequest{ usrAcc2, err := repo.UserAccount.Create(ctx, auth.Claims{}, user_account.UserAccountCreateRequest{
UserID: usr2.ID, UserID: usr2.ID,
AccountID: usrAcc.AccountID, AccountID: usrAcc.AccountID,
Roles: []user_account.UserAccountRole{user_account.UserAccountRole(user_account.UserAccountRole_Admin)}, Roles: []user_account.UserAccountRole{user_account.UserAccountRole(user_account.UserAccountRole_Admin)},
@ -704,7 +700,7 @@ func TestVirtualLogin(t *testing.T) {
} }
// Associate second user with basic role associated with the same account. // Associate second user with basic role associated with the same account.
usrAcc3, err := user_account.Create(ctx, auth.Claims{}, test.MasterDB, user_account.UserAccountCreateRequest{ usrAcc3, err := repo.UserAccount.Create(ctx, auth.Claims{}, user_account.UserAccountCreateRequest{
UserID: usr3.ID, UserID: usr3.ID,
AccountID: usrAcc.AccountID, AccountID: usrAcc.AccountID,
Roles: []user_account.UserAccountRole{user_account.UserAccountRole(user_account.UserAccountRole_User)}, Roles: []user_account.UserAccountRole{user_account.UserAccountRole(user_account.UserAccountRole_User)},
@ -749,7 +745,7 @@ func TestVirtualLogin(t *testing.T) {
} }
// Associate second user with basic role associated with the same account. // Associate second user with basic role associated with the same account.
usrAcc2, err := user_account.Create(ctx, auth.Claims{}, test.MasterDB, user_account.UserAccountCreateRequest{ usrAcc2, err := repo.UserAccount.Create(ctx, auth.Claims{}, user_account.UserAccountCreateRequest{
UserID: usr2.ID, UserID: usr2.ID,
AccountID: usrAcc.AccountID, AccountID: usrAcc.AccountID,
Roles: []user_account.UserAccountRole{user_account.UserAccountRole(user_account.UserAccountRole_User)}, Roles: []user_account.UserAccountRole{user_account.UserAccountRole(user_account.UserAccountRole_User)},
@ -766,7 +762,7 @@ func TestVirtualLogin(t *testing.T) {
} }
// Associate second user with basic role associated with the same account. // Associate second user with basic role associated with the same account.
usrAcc3, err := user_account.Create(ctx, auth.Claims{}, test.MasterDB, user_account.UserAccountCreateRequest{ usrAcc3, err := repo.UserAccount.Create(ctx, auth.Claims{}, user_account.UserAccountCreateRequest{
UserID: usr3.ID, UserID: usr3.ID,
AccountID: usrAcc.AccountID, AccountID: usrAcc.AccountID,
Roles: []user_account.UserAccountRole{user_account.UserAccountRole(user_account.UserAccountRole_Admin)}, Roles: []user_account.UserAccountRole{user_account.UserAccountRole(user_account.UserAccountRole_Admin)},
@ -811,7 +807,7 @@ func TestVirtualLogin(t *testing.T) {
} }
// Associate second user with basic role associated with the same account. // Associate second user with basic role associated with the same account.
usrAcc2, err := user_account.Create(ctx, auth.Claims{}, test.MasterDB, user_account.UserAccountCreateRequest{ usrAcc2, err := repo.UserAccount.Create(ctx, auth.Claims{}, user_account.UserAccountCreateRequest{
UserID: usr2.ID, UserID: usr2.ID,
AccountID: usrAcc.AccountID, AccountID: usrAcc.AccountID,
Roles: []user_account.UserAccountRole{user_account.UserAccountRole(user_account.UserAccountRole_Admin)}, Roles: []user_account.UserAccountRole{user_account.UserAccountRole(user_account.UserAccountRole_Admin)},
@ -850,7 +846,7 @@ func TestVirtualLogin(t *testing.T) {
} }
// Associate second user with basic role associated with the same account. // Associate second user with basic role associated with the same account.
usrAcc2, err := user_account.Create(ctx, auth.Claims{}, test.MasterDB, user_account.UserAccountCreateRequest{ usrAcc2, err := repo.UserAccount.Create(ctx, auth.Claims{}, user_account.UserAccountCreateRequest{
UserID: usr2.ID, UserID: usr2.ID,
AccountID: usrAcc.AccountID, AccountID: usrAcc.AccountID,
Roles: []user_account.UserAccountRole{user_account.UserAccountRole(user_account.UserAccountRole_User)}, Roles: []user_account.UserAccountRole{user_account.UserAccountRole(user_account.UserAccountRole_User)},
@ -876,8 +872,6 @@ func TestVirtualLogin(t *testing.T) {
// Add 30 minutes to now to simulate time passing. // Add 30 minutes to now to simulate time passing.
now = now.Add(time.Minute * 5) now = now.Add(time.Minute * 5)
tknGen := &auth.MockTokenGenerator{}
t.Log("Given the need to virtual login.") t.Log("Given the need to virtual login.")
{ {
for i, authTest := range authTests { for i, authTest := range authTests {
@ -885,7 +879,7 @@ func TestVirtualLogin(t *testing.T) {
{ {
// Verify that the user can be authenticated with the created user. // Verify that the user can be authenticated with the created user.
var claims1 auth.Claims var claims1 auth.Claims
tkn1, err := Authenticate(ctx, test.MasterDB, tknGen, tkn1, err := repo.Authenticate(ctx,
AuthenticateRequest{ AuthenticateRequest{
Email: authTest.root.User.Email, Email: authTest.root.User.Email,
Password: authTest.root.User.Password, Password: authTest.root.User.Password,
@ -895,7 +889,7 @@ func TestVirtualLogin(t *testing.T) {
t.Fatalf("\t%s\tAuthenticate user failed.", tests.Failed) t.Fatalf("\t%s\tAuthenticate user failed.", tests.Failed)
} else { } else {
// Ensure the token string was correctly generated. // Ensure the token string was correctly generated.
claims1, err = tknGen.ParseClaims(tkn1.AccessToken) claims1, err = repo.TknGen.ParseClaims(tkn1.AccessToken)
if err != nil { if err != nil {
t.Log("\t\tGot :", err) t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tParse claims from token failed.", tests.Failed) t.Fatalf("\t%s\tParse claims from token failed.", tests.Failed)
@ -915,7 +909,7 @@ func TestVirtualLogin(t *testing.T) {
// Try virtual login to user 2. // Try virtual login to user 2.
var claims2 auth.Claims var claims2 auth.Claims
tkn2, err := VirtualLogin(ctx, test.MasterDB, tknGen, claims1, authTest.login1Req, time.Hour, now) tkn2, err := repo.VirtualLogin(ctx, claims1, authTest.login1Req, time.Hour, now)
if err != authTest.login1Err { if err != authTest.login1Err {
if errors.Cause(err) != authTest.login1Err { if errors.Cause(err) != authTest.login1Err {
t.Log("\t\tExpected :", authTest.login1Err) t.Log("\t\tExpected :", authTest.login1Err)
@ -924,7 +918,7 @@ func TestVirtualLogin(t *testing.T) {
} }
} else { } else {
// Ensure the token string was correctly generated. // Ensure the token string was correctly generated.
claims2, err = tknGen.ParseClaims(tkn2.AccessToken) claims2, err = repo.TknGen.ParseClaims(tkn2.AccessToken)
if err != nil { if err != nil {
t.Log("\t\tGot :", err) t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tParse claims from token failed.", tests.Failed) t.Fatalf("\t%s\tParse claims from token failed.", tests.Failed)
@ -948,7 +942,7 @@ func TestVirtualLogin(t *testing.T) {
} }
// Try virtual login to user 3. // Try virtual login to user 3.
tkn3, err := VirtualLogin(ctx, test.MasterDB, tknGen, claims2, authTest.login2Req, time.Hour, now) tkn3, err := repo.VirtualLogin(ctx, claims2, authTest.login2Req, time.Hour, now)
if err != authTest.login2Err { if err != authTest.login2Err {
if errors.Cause(err) != authTest.login2Err { if errors.Cause(err) != authTest.login2Err {
t.Log("\t\tExpected :", authTest.login2Err) t.Log("\t\tExpected :", authTest.login2Err)
@ -957,7 +951,7 @@ func TestVirtualLogin(t *testing.T) {
} }
} else { } else {
// Ensure the token string was correctly generated. // Ensure the token string was correctly generated.
claims3, err := tknGen.ParseClaims(tkn3.AccessToken) claims3, err := repo.TknGen.ParseClaims(tkn3.AccessToken)
if err != nil { if err != nil {
t.Log("\t\tGot :", err) t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tParse claims from token failed.", tests.Failed) t.Fatalf("\t%s\tParse claims from token failed.", tests.Failed)
@ -976,14 +970,14 @@ func TestVirtualLogin(t *testing.T) {
t.Logf("\t%s\tVirtualLogin user 2 with role %s ok.", tests.Success, authTest.login2Role) t.Logf("\t%s\tVirtualLogin user 2 with role %s ok.", tests.Success, authTest.login2Role)
if authTest.login2Logout { if authTest.login2Logout {
tknOut, err := VirtualLogout(ctx, test.MasterDB, tknGen, claims2, time.Hour, now) tknOut, err := repo.VirtualLogout(ctx, claims2, time.Hour, now)
if err != nil { if err != nil {
t.Log("\t\tGot :", err) t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tVirtualLogout user 2 failed.", tests.Failed) t.Fatalf("\t%s\tVirtualLogout user 2 failed.", tests.Failed)
} }
// Ensure the token string was correctly generated. // Ensure the token string was correctly generated.
claimsOut, err := tknGen.ParseClaims(tknOut.AccessToken) claimsOut, err := repo.TknGen.ParseClaims(tknOut.AccessToken)
if err != nil { if err != nil {
t.Log("\t\tGot :", err) t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tParse claims from token failed.", tests.Failed) t.Fatalf("\t%s\tParse claims from token failed.", tests.Failed)

View File

@ -3,9 +3,33 @@ package user_auth
import ( import (
"time" "time"
"geeks-accelerator/oss/saas-starter-kit/internal/account/account_preference"
"geeks-accelerator/oss/saas-starter-kit/internal/platform/auth" "geeks-accelerator/oss/saas-starter-kit/internal/platform/auth"
"geeks-accelerator/oss/saas-starter-kit/internal/user"
"geeks-accelerator/oss/saas-starter-kit/internal/user_account"
"github.com/jmoiron/sqlx"
) )
// Repository defines the required dependencies for User Auth.
type Repository struct {
DbConn *sqlx.DB
TknGen TokenGenerator
User *user.Repository
UserAccount *user_account.Repository
AccountPreference *account_preference.Repository
}
// NewRepository creates a new Repository that defines dependencies for User Auth.
func NewRepository(db *sqlx.DB, tknGen TokenGenerator, user *user.Repository, usrAcc *user_account.Repository, accPref *account_preference.Repository) *Repository {
return &Repository{
DbConn: db,
TknGen: tknGen,
User: user,
UserAccount: usrAcc,
AccountPreference: accPref,
}
}
// AuthenticateRequest defines what information is required to authenticate a user. // AuthenticateRequest defines what information is required to authenticate a user.
type AuthenticateRequest struct { type AuthenticateRequest struct {
Email string `json:"email" validate:"required,email" example:"gabi.may@geeksinthewoods.com"` Email string `json:"email" validate:"required,email" example:"gabi.may@geeksinthewoods.com"`