From 3bc814a01ef6d85799224b7ecdc155972c788550 Mon Sep 17 00:00:00 2001 From: Lee Brown Date: Tue, 13 Aug 2019 23:41:06 -0800 Subject: [PATCH] WIP: not sure how to solve user_account calling account.CanModifyAccount --- cmd/web-api/handlers/routes.go | 44 +++++-- internal/user/models.go | 20 +++ internal/user/user.go | 137 +++++++++++---------- internal/user/user_test.go | 93 +++++++------- internal/user_account/models.go | 13 ++ internal/user_account/user.go | 7 +- internal/user_account/user_account.go | 8 +- internal/user_account/user_account_test.go | 8 +- 8 files changed, 198 insertions(+), 132 deletions(-) diff --git a/cmd/web-api/handlers/routes.go b/cmd/web-api/handlers/routes.go index c056074..b68460f 100644 --- a/cmd/web-api/handlers/routes.go +++ b/cmd/web-api/handlers/routes.go @@ -2,6 +2,7 @@ package handlers import ( "context" + "geeks-accelerator/oss/saas-starter-kit/internal/user" "log" "net/http" "os" @@ -20,33 +21,52 @@ import ( "gopkg.in/DataDog/dd-trace-go.v1/contrib/go-redis/redis" ) + +type AppContext struct { + Log *log.Logger + Env webcontext.Env + Repo *user.Repository + MasterDB *sqlx.DB + Redis *redis.Client + Authenticator *auth.Authenticator + PreAppMiddleware []web.Middleware + PostAppMiddleware []web.Middleware +} + + // API returns a handler for a set of routes. -func API(shutdown chan os.Signal, log *log.Logger, env webcontext.Env, masterDB *sqlx.DB, redis *redis.Client, authenticator *auth.Authenticator, globalMids ...web.Middleware) http.Handler { +func API(shutdown chan os.Signal, appContext *AppContext ) http.Handler { - // Define base middlewares applied to all requests. - middlewares := []web.Middleware{ - mid.Trace(), mid.Logger(log), mid.Errors(log, nil), mid.Metrics(), mid.Panics(), - } + // Include the pre middlewares first. + middlewares := appContext.PreAppMiddleware - // Append any global middlewares if they were included. - if len(globalMids) > 0 { - middlewares = append(middlewares, globalMids...) + // Define app middlewares applied to all requests. + middlewares = append(middlewares, + mid.Trace(), + mid.Logger(appContext.Log), + mid.Errors(appContext.Log, nil), + mid.Metrics(), + mid.Panics()) + + // Append any global middlewares that should be included after the app middlewares. + if len(appContext.PostAppMiddleware) > 0 { + middlewares = append(middlewares, appContext.PostAppMiddleware...) } // Construct the web.App which holds all routes as well as common Middleware. - app := web.NewApp(shutdown, log, env, middlewares...) + app := web.NewApp(shutdown, appContext.Log, appContext.Env, middlewares...) // Register health check endpoint. This route is not authenticated. check := Check{ - MasterDB: masterDB, - Redis: redis, + MasterDB: appContext.MasterDB, + Redis: appContext.Redis, } app.Handle("GET", "/v1/health", check.Health) app.Handle("GET", "/ping", check.Ping) // Register user management and authentication endpoints. u := User{ - MasterDB: masterDB, + MasterDB: appContext.MasterDB, TokenGenerator: authenticator, } app.Handle("GET", "/v1/users", u.Find, mid.AuthenticateHeader(authenticator)) diff --git a/internal/user/models.go b/internal/user/models.go index 2aca963..b4c65c5 100644 --- a/internal/user/models.go +++ b/internal/user/models.go @@ -4,8 +4,10 @@ import ( "context" "database/sql" "encoding/json" + "geeks-accelerator/oss/saas-starter-kit/internal/platform/notify" "geeks-accelerator/oss/saas-starter-kit/internal/platform/web" "geeks-accelerator/oss/saas-starter-kit/internal/platform/web/webcontext" + "github.com/jmoiron/sqlx" "github.com/pkg/errors" "github.com/sudo-suhas/symcrypto" "strconv" @@ -15,6 +17,24 @@ import ( "github.com/lib/pq" ) +// Repository defines the required dependencies for User. +type Repository struct { + DbConn *sqlx.DB + ResetUrl func(string) string + Notify notify.Email + SecretKey string +} + +// NewRepository creates a new Repository that defines dependencies for User. +func NewRepository(db *sqlx.DB, resetUrl func(string) string, notify notify.Email, secretKey string) *Repository { + return &Repository{ + DbConn: db, + ResetUrl: resetUrl, + Notify: notify, + SecretKey: secretKey, + } +} + // User represents someone with access to our system. type User struct { ID string `json:"id" validate:"required,uuid" example:"d69bdef7-173f-4d29-b52c-3edc60baf6a2"` diff --git a/internal/user/user.go b/internal/user/user.go index 6354c19..81d036e 100644 --- a/internal/user/user.go +++ b/internal/user/user.go @@ -6,7 +6,6 @@ import ( "time" "geeks-accelerator/oss/saas-starter-kit/internal/platform/auth" - "geeks-accelerator/oss/saas-starter-kit/internal/platform/notify" "geeks-accelerator/oss/saas-starter-kit/internal/platform/web/webcontext" "github.com/huandu/go-sqlbuilder" "github.com/jmoiron/sqlx" @@ -55,7 +54,7 @@ func mapRowsToUser(rows *sql.Rows) (*User, error) { } // CanReadUser determines if claims has the authority to access the specified user ID. -func CanReadUser(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, userID string) error { +func (repo *Repository) CanReadUser(ctx context.Context, claims auth.Claims, userID string) error { // If the request has claims from a specific user, ensure that the user // has the correct access to the user. if claims.Subject != "" && claims.Subject != userID { @@ -68,10 +67,10 @@ func CanReadUser(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, userI query.Equal("user_id", userID), )) queryStr, args := query.Build() - queryStr = dbConn.Rebind(queryStr) + queryStr = repo.DbConn.Rebind(queryStr) var userAccountId string - err := dbConn.QueryRowContext(ctx, queryStr, args...).Scan(&userAccountId) + err := repo.DbConn.QueryRowContext(ctx, queryStr, args...).Scan(&userAccountId) if err != nil && err != sql.ErrNoRows { err = errors.Wrapf(err, "query - %s", query.String()) return err @@ -88,7 +87,7 @@ func CanReadUser(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, userI } // CanModifyUser determines if claims has the authority to modify the specified user ID. -func CanModifyUser(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, userID string) error { +func (repo *Repository) CanModifyUser(ctx context.Context, claims auth.Claims, userID string) error { // If the request has claims from a specific user, ensure that the user // has the correct role for creating a new user. if claims.Subject != "" && claims.Subject != userID { @@ -99,7 +98,7 @@ func CanModifyUser(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, use } } - if err := CanReadUser(ctx, claims, dbConn, userID); err != nil { + if err := repo.CanReadUser(ctx, claims, userID); err != nil { return err } @@ -118,10 +117,10 @@ func CanModifyUser(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, use "'"+auth.RoleAdmin+"' = ANY (roles)", )) queryStr, args := query.Build() - queryStr = dbConn.Rebind(queryStr) + queryStr = repo.DbConn.Rebind(queryStr) var userAccountId string - err := dbConn.QueryRowContext(ctx, queryStr, args...).Scan(&userAccountId) + err := repo.DbConn.QueryRowContext(ctx, queryStr, args...).Scan(&userAccountId) if err != nil && err != sql.ErrNoRows { err = errors.Wrapf(err, "query - %s", query.String()) return err @@ -199,13 +198,13 @@ func findRequestQuery(req UserFindRequest) (*sqlbuilder.SelectBuilder, []interfa } // Find gets all the users from the database based on the request params. -func Find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserFindRequest) (Users, error) { +func (repo *Repository) Find(ctx context.Context, claims auth.Claims, req UserFindRequest) (Users, error) { query, args := findRequestQuery(req) - return find(ctx, claims, dbConn, query, args, req.IncludeArchived) + return repo.find(ctx, claims, query, args, req.IncludeArchived) } // find internal method for getting all the users from the database using a select query. -func find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, query *sqlbuilder.SelectBuilder, args []interface{}, includedArchived bool) (Users, error) { +func (repo *Repository) find(ctx context.Context, claims auth.Claims, query *sqlbuilder.SelectBuilder, args []interface{}, includedArchived bool) (Users, error) { span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.Find") defer span.Finish() @@ -222,11 +221,11 @@ func find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, query *sqlbu return nil, err } queryStr, queryArgs := query.Build() - queryStr = dbConn.Rebind(queryStr) + queryStr = repo.DbConn.Rebind(queryStr) args = append(args, queryArgs...) // fetch all places from the db - rows, err := dbConn.QueryContext(ctx, queryStr, args...) + rows, err := repo.DbConn.QueryContext(ctx, queryStr, args...) if err != nil { err = errors.Wrapf(err, "query - %s", query.String()) err = errors.WithMessage(err, "find users failed") @@ -248,17 +247,17 @@ func find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, query *sqlbu } // Validation an email address is unique excluding the current user ID. -func UniqueEmail(ctx context.Context, dbConn *sqlx.DB, email, userId string) (bool, error) { +func (repo *Repository) UniqueEmail(ctx context.Context, email, userId string) (bool, error) { query := sqlbuilder.NewSelectBuilder().Select("id").From(userTableName) query.Where(query.And( query.Equal("email", email), query.NotEqual("id", userId), )) queryStr, args := query.Build() - queryStr = dbConn.Rebind(queryStr) + queryStr = repo.DbConn.Rebind(queryStr) var existingId string - err := dbConn.QueryRowContext(ctx, queryStr, args...).Scan(&existingId) + err := repo.DbConn.QueryRowContext(ctx, queryStr, args...).Scan(&existingId) if err != nil && err != sql.ErrNoRows { err = errors.Wrapf(err, "query - %s", query.String()) return false, err @@ -273,7 +272,7 @@ func UniqueEmail(ctx context.Context, dbConn *sqlx.DB, email, userId string) (bo } // Create inserts a new user into the database. -func Create(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserCreateRequest, now time.Time) (*User, error) { +func (repo *Repository) Create(ctx context.Context, claims auth.Claims, req UserCreateRequest, now time.Time) (*User, error) { span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.Create") defer span.Finish() @@ -284,7 +283,7 @@ func Create(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserCr v := webcontext.Validator() // Validation email address is unique in the database. - uniq, err := UniqueEmail(ctx, dbConn, req.Email, "") + uniq, err := repo.UniqueEmail(ctx, req.Email, "") if err != nil { return nil, err } @@ -346,8 +345,8 @@ func Create(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserCr // Execute the query with the provided context. sql, args := query.Build() - sql = dbConn.Rebind(sql) - _, err = dbConn.ExecContext(ctx, sql, args...) + sql = repo.DbConn.Rebind(sql) + _, err = repo.DbConn.ExecContext(ctx, sql, args...) if err != nil { err = errors.Wrapf(err, "query - %s", query.String()) err = errors.WithMessage(err, "create user failed") @@ -358,14 +357,14 @@ func Create(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserCr } // Create invite inserts a new user into the database. -func CreateInvite(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserCreateInviteRequest, now time.Time) (*User, error) { +func (repo *Repository) CreateInvite(ctx context.Context, claims auth.Claims, req UserCreateInviteRequest, now time.Time) (*User, error) { span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.CreateInvite") defer span.Finish() v := webcontext.Validator() // Validation email address is unique in the database. - uniq, err := UniqueEmail(ctx, dbConn, req.Email, "") + uniq, err := repo.UniqueEmail(ctx, req.Email, "") if err != nil { return nil, err } @@ -414,8 +413,8 @@ func CreateInvite(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req // Execute the query with the provided context. sql, args := query.Build() - sql = dbConn.Rebind(sql) - _, err = dbConn.ExecContext(ctx, sql, args...) + sql = repo.DbConn.Rebind(sql) + _, err = repo.DbConn.ExecContext(ctx, sql, args...) if err != nil { err = errors.Wrapf(err, "query - %s", query.String()) err = errors.WithMessage(err, "create user failed") @@ -426,15 +425,15 @@ func CreateInvite(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req } // ReadByID gets the specified user by ID from the database. -func ReadByID(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, id string) (*User, error) { - return Read(ctx, claims, dbConn, UserReadRequest{ +func (repo *Repository) ReadByID(ctx context.Context, claims auth.Claims, id string) (*User, error) { + return repo.Read(ctx, claims, UserReadRequest{ ID: id, IncludeArchived: false, }) } // Read gets the specified user from the database. -func Read(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserReadRequest) (*User, error) { +func (repo *Repository) Read(ctx context.Context, claims auth.Claims, req UserReadRequest) (*User, error) { span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.Read") defer span.Finish() @@ -449,7 +448,7 @@ func Read(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserRead query := selectQuery() query.Where(query.Equal("id", req.ID)) - res, err := find(ctx, claims, dbConn, query, []interface{}{}, req.IncludeArchived) + res, err := repo.find(ctx, claims, query, []interface{}{}, req.IncludeArchived) if err != nil { return nil, err } else if res == nil || len(res) == 0 { @@ -462,7 +461,7 @@ func Read(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserRead } // ReadByEmail gets the specified user from the database. -func ReadByEmail(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, email string, includedArchived bool) (*User, error) { +func (repo *Repository) ReadByEmail(ctx context.Context, claims auth.Claims, email string, includedArchived bool) (*User, error) { span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.ReadByEmail") defer span.Finish() @@ -470,7 +469,7 @@ func ReadByEmail(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, email query := selectQuery() query.Where(query.Equal("email", email)) - res, err := find(ctx, claims, dbConn, query, []interface{}{}, includedArchived) + res, err := repo.find(ctx, claims, query, []interface{}{}, includedArchived) if err != nil { return nil, err } else if res == nil || len(res) == 0 { @@ -483,14 +482,14 @@ func ReadByEmail(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, email } // Update replaces a user in the database. -func Update(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserUpdateRequest, now time.Time) error { +func (repo *Repository) Update(ctx context.Context, claims auth.Claims, req UserUpdateRequest, now time.Time) error { span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.Update") defer span.Finish() // Validation email address is unique in the database. if req.Email != nil { // Validation email address is unique in the database. - uniq, err := UniqueEmail(ctx, dbConn, *req.Email, req.ID) + uniq, err := repo.UniqueEmail(ctx, *req.Email, req.ID) if err != nil { return err } @@ -507,7 +506,7 @@ func Update(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserUp } // Ensure the claims can modify the user specified in the request. - err = CanModifyUser(ctx, claims, dbConn, req.ID) + err = repo.CanModifyUser(ctx, claims, req.ID) if err != nil { return err } @@ -555,8 +554,8 @@ func Update(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserUp // Execute the query with the provided context. sql, args := query.Build() - sql = dbConn.Rebind(sql) - _, err = dbConn.ExecContext(ctx, sql, args...) + sql = repo.DbConn.Rebind(sql) + _, err = repo.DbConn.ExecContext(ctx, sql, args...) if err != nil { err = errors.Wrapf(err, "query - %s", query.String()) err = errors.WithMessagef(err, "update user %s failed", req.ID) @@ -567,7 +566,7 @@ func Update(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserUp } // Update changes the password for a user in the database. -func UpdatePassword(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserUpdatePasswordRequest, now time.Time) error { +func (repo *Repository) UpdatePassword(ctx context.Context, claims auth.Claims, req UserUpdatePasswordRequest, now time.Time) error { span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.UpdatePassword") defer span.Finish() @@ -579,7 +578,7 @@ func UpdatePassword(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, re } // Ensure the claims can modify the user specified in the request. - err = CanModifyUser(ctx, claims, dbConn, req.ID) + err = repo.CanModifyUser(ctx, claims, req.ID) if err != nil { return err } @@ -616,8 +615,8 @@ func UpdatePassword(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, re // Execute the query with the provided context. sql, args := query.Build() - sql = dbConn.Rebind(sql) - _, err = dbConn.ExecContext(ctx, sql, args...) + sql = repo.DbConn.Rebind(sql) + _, err = repo.DbConn.ExecContext(ctx, sql, args...) if err != nil { err = errors.Wrapf(err, "query - %s", query.String()) err = errors.WithMessagef(err, "update password for user %s failed", req.ID) @@ -628,7 +627,7 @@ func UpdatePassword(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, re } // Archive soft deleted the user from the database. -func Archive(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserArchiveRequest, now time.Time) error { +func (repo *Repository) Archive(ctx context.Context, claims auth.Claims, req UserArchiveRequest, now time.Time) error { span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.Archive") defer span.Finish() @@ -640,7 +639,7 @@ func Archive(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserA } // Ensure the claims can modify the user specified in the request. - err = CanModifyUser(ctx, claims, dbConn, req.ID) + err = repo.CanModifyUser(ctx, claims, req.ID) if err != nil { return err } else if claims.Subject != "" && claims.Subject == req.ID && !req.force { @@ -669,8 +668,8 @@ func Archive(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserA // Execute the query with the provided context. sql, args := query.Build() - sql = dbConn.Rebind(sql) - _, err = dbConn.ExecContext(ctx, sql, args...) + sql = repo.DbConn.Rebind(sql) + _, err = repo.DbConn.ExecContext(ctx, sql, args...) if err != nil { err = errors.Wrapf(err, "query - %s", query.String()) err = errors.WithMessagef(err, "archive user %s failed", req.ID) @@ -689,8 +688,8 @@ func Archive(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserA // Execute the query with the provided context. sql, args := query.Build() - sql = dbConn.Rebind(sql) - _, err = dbConn.ExecContext(ctx, sql, args...) + sql = repo.DbConn.Rebind(sql) + _, err = repo.DbConn.ExecContext(ctx, sql, args...) if err != nil { err = errors.Wrapf(err, "query - %s", query.String()) err = errors.WithMessagef(err, "archive accounts for user %s failed", req.ID) @@ -702,7 +701,7 @@ func Archive(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserA } // Restore undeletes the user from the database. -func Restore(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserRestoreRequest, now time.Time) error { +func (repo *Repository) Restore(ctx context.Context, claims auth.Claims, req UserRestoreRequest, now time.Time) error { span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.Restore") defer span.Finish() @@ -714,7 +713,7 @@ func Restore(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserR } // Ensure the claims can modify the user specified in the request. - err = CanModifyUser(ctx, claims, dbConn, req.ID) + err = repo.CanModifyUser(ctx, claims, req.ID) if err != nil { return err } @@ -741,8 +740,8 @@ func Restore(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserR // Execute the query with the provided context. sql, args := query.Build() - sql = dbConn.Rebind(sql) - _, err = dbConn.ExecContext(ctx, sql, args...) + sql = repo.DbConn.Rebind(sql) + _, err = repo.DbConn.ExecContext(ctx, sql, args...) if err != nil { err = errors.Wrapf(err, "query - %s", query.String()) err = errors.WithMessagef(err, "unarchive user %s failed", req.ID) @@ -753,7 +752,7 @@ func Restore(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserR } // Delete removes a user from the database. -func Delete(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserDeleteRequest) error { +func (repo *Repository) Delete(ctx context.Context, claims auth.Claims, req UserDeleteRequest) error { span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.Delete") defer span.Finish() @@ -765,7 +764,7 @@ func Delete(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserDe } // Ensure the claims can modify the user specified in the request. - err = CanModifyUser(ctx, claims, dbConn, req.ID) + err = repo.CanModifyUser(ctx, claims, req.ID) if err != nil { return err } else if claims.Subject != "" && claims.Subject == req.ID && !req.force { @@ -773,7 +772,7 @@ func Delete(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserDe } // Start a new transaction to handle rollbacks on error. - tx, err := dbConn.Begin() + tx, err := repo.DbConn.Begin() if err != nil { return errors.WithStack(err) } @@ -790,7 +789,7 @@ func Delete(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserDe // Execute the query with the provided context. sql, args := query.Build() - sql = dbConn.Rebind(sql) + sql = repo.DbConn.Rebind(sql) _, err = tx.ExecContext(ctx, sql, args...) if err != nil { tx.Rollback() @@ -808,7 +807,7 @@ func Delete(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserDe // Execute the query with the provided context. sql, args := query.Build() - sql = dbConn.Rebind(sql) + sql = repo.DbConn.Rebind(sql) _, err = tx.ExecContext(ctx, sql, args...) if err != nil { tx.Rollback() @@ -827,7 +826,7 @@ func Delete(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserDe } // ResetPassword sends en email to the user to allow them to reset their password. -func ResetPassword(ctx context.Context, dbConn *sqlx.DB, resetUrl func(string) string, notify notify.Email, req UserResetPasswordRequest, secretKey string, now time.Time) (string, error) { +func (repo *Repository) ResetPassword(ctx context.Context, req UserResetPasswordRequest, now time.Time) (string, error) { span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.ResetPassword") defer span.Finish() @@ -845,7 +844,7 @@ func ResetPassword(ctx context.Context, dbConn *sqlx.DB, resetUrl func(string) s query := selectQuery() query.Where(query.Equal("email", req.Email)) - res, err := find(ctx, auth.Claims{}, dbConn, query, []interface{}{}, false) + res, err := repo.find(ctx, auth.Claims{}, query, []interface{}{}, false) if err != nil { return "", err } else if res == nil || len(res) == 0 { @@ -876,8 +875,8 @@ func ResetPassword(ctx context.Context, dbConn *sqlx.DB, resetUrl func(string) s // Execute the query with the provided context. sql, args := query.Build() - sql = dbConn.Rebind(sql) - _, err = dbConn.ExecContext(ctx, sql, args...) + sql = repo.DbConn.Rebind(sql) + _, err = repo.DbConn.ExecContext(ctx, sql, args...) if err != nil { err = errors.Wrapf(err, "query - %s", query.String()) err = errors.WithMessagef(err, "Update user %s failed.", u.ID) @@ -895,18 +894,18 @@ func ResetPassword(ctx context.Context, dbConn *sqlx.DB, resetUrl func(string) s requestIp = vals.RequestIP } - encrypted, err := NewResetHash(ctx, secretKey, resetId, requestIp, req.TTL, now) + encrypted, err := NewResetHash(ctx, repo.SecretKey, resetId, requestIp, req.TTL, now) if err != nil { return "", err } data := map[string]interface{}{ "Name": u.FirstName, - "Url": resetUrl(encrypted), + "Url": repo.ResetUrl(encrypted), "Minutes": req.TTL.Minutes(), } - err = notify.Send(ctx, u.Email, "Reset your Password", "user_reset_password", data) + err = repo.Notify.Send(ctx, u.Email, "Reset your Password", "user_reset_password", data) if err != nil { err = errors.WithMessagef(err, "Send password reset email to %s failed.", u.Email) return "", err @@ -916,7 +915,7 @@ func ResetPassword(ctx context.Context, dbConn *sqlx.DB, resetUrl func(string) s } // ResetConfirm updates the password for a user using the provided reset password ID. -func ResetConfirm(ctx context.Context, dbConn *sqlx.DB, req UserResetConfirmRequest, secretKey string, now time.Time) (*User, error) { +func (repo *Repository) ResetConfirm(ctx context.Context, req UserResetConfirmRequest, now time.Time) (*User, error) { span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.ResetConfirm") defer span.Finish() @@ -928,7 +927,7 @@ func ResetConfirm(ctx context.Context, dbConn *sqlx.DB, req UserResetConfirmRequ return nil, err } - hash, err := ParseResetHash(ctx, secretKey, req.ResetHash, now) + hash, err := ParseResetHash(ctx, repo.SecretKey, req.ResetHash, now) if err != nil { return nil, err } @@ -939,7 +938,7 @@ func ResetConfirm(ctx context.Context, dbConn *sqlx.DB, req UserResetConfirmRequ query := selectQuery() query.Where(query.Equal("password_reset", hash.ResetID)) - res, err := find(ctx, auth.Claims{}, dbConn, query, []interface{}{}, false) + res, err := repo.find(ctx, auth.Claims{}, query, []interface{}{}, false) if err != nil { return nil, err } else if res == nil || len(res) == 0 { @@ -979,8 +978,8 @@ func ResetConfirm(ctx context.Context, dbConn *sqlx.DB, req UserResetConfirmRequ // Execute the query with the provided context. sql, args := query.Build() - sql = dbConn.Rebind(sql) - _, err = dbConn.ExecContext(ctx, sql, args...) + sql = repo.DbConn.Rebind(sql) + _, err = repo.DbConn.ExecContext(ctx, sql, args...) if err != nil { err = errors.Wrapf(err, "query - %s", query.String()) err = errors.WithMessagef(err, "update password for user %s failed", u.ID) @@ -1000,6 +999,10 @@ type MockUserResponse struct { func MockUser(ctx context.Context, dbConn *sqlx.DB, now time.Time) (*MockUserResponse, error) { pass := uuid.NewRandom().String() + repo := &Repository{ + DbConn: dbConn, + } + req := UserCreateRequest{ FirstName: "Lee", LastName: "Brown", @@ -1007,7 +1010,7 @@ func MockUser(ctx context.Context, dbConn *sqlx.DB, now time.Time) (*MockUserRes Password: pass, PasswordConfirm: pass, } - u, err := Create(ctx, auth.Claims{}, dbConn, req, now) + u, err := repo.Create(ctx, auth.Claims{}, req, now) if err != nil { return nil, err } diff --git a/internal/user/user_test.go b/internal/user/user_test.go index 2dc15fd..edfaf49 100644 --- a/internal/user/user_test.go +++ b/internal/user/user_test.go @@ -18,7 +18,10 @@ import ( "github.com/pkg/errors" ) -var test *tests.Test +var ( + test *tests.Test + repo *Repository +) // TestMain is the entry point for testing. func TestMain(m *testing.M) { @@ -28,6 +31,16 @@ func TestMain(m *testing.M) { func testMain(m *testing.M) int { test = tests.New() defer test.TearDown() + + // Mock the methods needed to make a password reset. + resetUrl := func(string) string { + return "" + } + notify := ¬ify.MockEmail{} + secretKey := "6368616e676520746869732070617373" + + repo = NewRepository(test.MasterDB, resetUrl, notify, secretKey) + return m.Run() } @@ -219,7 +232,7 @@ func TestCreateValidation(t *testing.T) { { ctx := tests.Context() - res, err := Create(ctx, auth.Claims{}, test.MasterDB, tt.req, now) + res, err := repo.Create(ctx, auth.Claims{}, tt.req, now) if err != tt.error { // TODO: need a better way to handle validation errors as they are // of type interface validator.ValidationErrorsTranslations @@ -272,7 +285,7 @@ func TestCreateValidationEmailUnique(t *testing.T) { Password: "akTechFr0n!ier", PasswordConfirm: "akTechFr0n!ier", } - user1, err := Create(ctx, auth.Claims{}, test.MasterDB, req1, now) + user1, err := repo.Create(ctx, auth.Claims{}, req1, now) if err != nil { t.Log("\t\tGot :", err) t.Fatalf("\t%s\tCreate failed.", tests.Failed) @@ -286,7 +299,7 @@ func TestCreateValidationEmailUnique(t *testing.T) { PasswordConfirm: "W0rkL1fe#", } expectedErr := errors.New("Key: 'UserCreateRequest.email' Error:Field validation for 'email' failed on the 'unique' tag") - _, err = Create(ctx, auth.Claims{}, test.MasterDB, req2, now) + _, err = repo.Create(ctx, auth.Claims{}, req2, now) if err == nil { t.Logf("\t\tWant: %+v", expectedErr) t.Fatalf("\t%s\tCreate failed.", tests.Failed) @@ -374,7 +387,7 @@ func TestCreateClaims(t *testing.T) { { ctx := tests.Context() - _, err := Create(ctx, tt.claims, test.MasterDB, tt.req, now) + _, err := repo.Create(ctx, tt.claims, tt.req, now) if errors.Cause(err) != tt.error { t.Logf("\t\tGot : %+v", err) t.Logf("\t\tWant: %+v", tt.error) @@ -421,7 +434,7 @@ func TestUpdateValidation(t *testing.T) { { ctx := tests.Context() - err := Update(ctx, auth.Claims{}, test.MasterDB, tt.req, now) + err := repo.Update(ctx, auth.Claims{}, tt.req, now) if err != tt.error { // TODO: need a better way to handle validation errors as they are // of type interface validator.ValidationErrorsTranslations @@ -463,7 +476,7 @@ func TestUpdateValidationEmailUnique(t *testing.T) { Password: "akTechFr0n!ier", PasswordConfirm: "akTechFr0n!ier", } - user1, err := Create(ctx, auth.Claims{}, test.MasterDB, req1, now) + user1, err := repo.Create(ctx, auth.Claims{}, req1, now) if err != nil { t.Log("\t\tGot :", err) t.Fatalf("\t%s\tCreate failed.", tests.Failed) @@ -476,7 +489,7 @@ func TestUpdateValidationEmailUnique(t *testing.T) { Password: "W0rkL1fe#", PasswordConfirm: "W0rkL1fe#", } - user2, err := Create(ctx, auth.Claims{}, test.MasterDB, req2, now) + user2, err := repo.Create(ctx, auth.Claims{}, req2, now) if err != nil { t.Log("\t\tGot :", err) t.Fatalf("\t%s\tCreate failed.", tests.Failed) @@ -488,7 +501,7 @@ func TestUpdateValidationEmailUnique(t *testing.T) { Email: &user1.Email, } expectedErr := errors.New("Key: 'UserUpdateRequest.email' Error:Field validation for 'email' failed on the 'unique' tag") - err = Update(ctx, auth.Claims{}, test.MasterDB, updateReq, now) + err = repo.Update(ctx, auth.Claims{}, updateReq, now) if err == nil { t.Logf("\t\tWant: %+v", expectedErr) t.Fatalf("\t%s\tUpdate failed.", tests.Failed) @@ -518,7 +531,7 @@ func TestUpdatePassword(t *testing.T) { // Create a new user for testing. initPass := uuid.NewRandom().String() - user, err := Create(ctx, auth.Claims{}, test.MasterDB, UserCreateRequest{ + user, err := repo.Create(ctx, auth.Claims{}, UserCreateRequest{ FirstName: "Lee", LastName: "Brown", Email: uuid.NewRandom().String() + "@geeksinthewoods.com", @@ -549,7 +562,7 @@ func TestUpdatePassword(t *testing.T) { expectedErr := errors.New("Key: 'UserUpdatePasswordRequest.id' Error:Field validation for 'id' failed on the 'required' tag\n" + "Key: 'UserUpdatePasswordRequest.password' Error:Field validation for 'password' failed on the 'required' tag\n" + "Key: 'UserUpdatePasswordRequest.password_confirm' Error:Field validation for 'password_confirm' failed on the 'required' tag") - err = UpdatePassword(ctx, auth.Claims{}, test.MasterDB, UserUpdatePasswordRequest{}, now) + err = repo.UpdatePassword(ctx, auth.Claims{}, UserUpdatePasswordRequest{}, now) if err == nil { t.Logf("\t\tWant: %+v", expectedErr) t.Fatalf("\t%s\tUpdate failed.", tests.Failed) @@ -567,7 +580,7 @@ func TestUpdatePassword(t *testing.T) { // Update the users password. newPass := uuid.NewRandom().String() - err = UpdatePassword(ctx, auth.Claims{}, test.MasterDB, UserUpdatePasswordRequest{ + err = repo.UpdatePassword(ctx, auth.Claims{}, UserUpdatePasswordRequest{ ID: user.ID, Password: newPass, PasswordConfirm: newPass, @@ -800,7 +813,7 @@ func TestCrud(t *testing.T) { // Always create the new user with empty claims, testing claims for create user // will be handled separately. - user, err := Create(tests.Context(), auth.Claims{}, test.MasterDB, tt.create, now) + user, err := repo.Create(tests.Context(), auth.Claims{}, tt.create, now) if err != nil { t.Log("\t\tGot :", err) t.Fatalf("\t%s\tCreate user failed.", tests.Failed) @@ -823,7 +836,7 @@ func TestCrud(t *testing.T) { // Update the user. updateReq := tt.update(user) - err = Update(ctx, tt.claims(user, accountId), test.MasterDB, updateReq, now) + err = repo.Update(ctx, tt.claims(user, accountId), updateReq, now) if err != nil && errors.Cause(err) != tt.updateErr { t.Logf("\t\tGot : %+v", err) t.Logf("\t\tWant: %+v", tt.updateErr) @@ -832,7 +845,7 @@ func TestCrud(t *testing.T) { t.Logf("\t%s\tUpdate ok.", tests.Success) // Find the user and make sure the updates where made. - findRes, err := ReadByID(ctx, tt.claims(user, accountId), test.MasterDB, user.ID) + findRes, err := repo.ReadByID(ctx, tt.claims(user, accountId), user.ID) if err != nil && errors.Cause(err) != tt.findErr { t.Logf("\t\tGot : %+v", err) t.Logf("\t\tWant: %+v", tt.findErr) @@ -846,14 +859,14 @@ func TestCrud(t *testing.T) { } // Archive (soft-delete) the user. - err = Archive(ctx, tt.claims(user, accountId), test.MasterDB, UserArchiveRequest{ID: user.ID, force: true}, now) + err = repo.Archive(ctx, tt.claims(user, accountId), UserArchiveRequest{ID: user.ID, force: true}, now) if err != nil && errors.Cause(err) != tt.updateErr { t.Logf("\t\tGot : %+v", err) t.Logf("\t\tWant: %+v", tt.updateErr) t.Fatalf("\t%s\tArchive failed.", tests.Failed) } else if tt.updateErr == nil { // Trying to find the archived user with the includeArchived false should result in not found. - _, err = ReadByID(ctx, tt.claims(user, accountId), test.MasterDB, user.ID) + _, err = repo.ReadByID(ctx, tt.claims(user, accountId), user.ID) if err != nil && errors.Cause(err) != ErrNotFound { t.Logf("\t\tGot : %+v", err) t.Logf("\t\tWant: %+v", ErrNotFound) @@ -861,7 +874,7 @@ func TestCrud(t *testing.T) { } // Trying to find the archived user with the includeArchived true should result no error. - _, err = Read(ctx, tt.claims(user, accountId), test.MasterDB, + _, err = repo.Read(ctx, tt.claims(user, accountId), UserReadRequest{ID: user.ID, IncludeArchived: true}) if err != nil { t.Log("\t\tGot :", err) @@ -871,14 +884,14 @@ func TestCrud(t *testing.T) { t.Logf("\t%s\tArchive ok.", tests.Success) // Restore (un-delete) the user. - err = Restore(ctx, tt.claims(user, accountId), test.MasterDB, UserRestoreRequest{ID: user.ID}, now) + err = repo.Restore(ctx, tt.claims(user, accountId), UserRestoreRequest{ID: user.ID}, now) if err != nil && errors.Cause(err) != tt.updateErr { t.Logf("\t\tGot : %+v", err) t.Logf("\t\tWant: %+v", tt.updateErr) t.Fatalf("\t%s\tUnarchive failed.", tests.Failed) } else if tt.updateErr == nil { // Trying to find the archived user with the includeArchived false should result no error. - _, err = ReadByID(ctx, tt.claims(user, accountId), test.MasterDB, user.ID) + _, err = repo.ReadByID(ctx, tt.claims(user, accountId), user.ID) if err != nil { t.Log("\t\tGot :", err) t.Fatalf("\t%s\tUnarchive Read failed.", tests.Failed) @@ -887,14 +900,14 @@ func TestCrud(t *testing.T) { t.Logf("\t%s\tUnarchive ok.", tests.Success) // Delete (hard-delete) the user. - err = Delete(ctx, tt.claims(user, accountId), test.MasterDB, UserDeleteRequest{ID: user.ID, force: true}) + err = repo.Delete(ctx, tt.claims(user, accountId), UserDeleteRequest{ID: user.ID, force: true}) if err != nil && errors.Cause(err) != tt.updateErr { t.Logf("\t\tGot : %+v", err) t.Logf("\t\tWant: %+v", tt.updateErr) t.Fatalf("\t%s\tUpdate failed.", tests.Failed) } else if tt.updateErr == nil { // Trying to find the deleted user with the includeArchived true should result in not found. - _, err = ReadByID(ctx, tt.claims(user, accountId), test.MasterDB, user.ID) + _, err = repo.ReadByID(ctx, tt.claims(user, accountId), user.ID) if errors.Cause(err) != ErrNotFound { t.Logf("\t\tGot : %+v", err) t.Logf("\t\tWant: %+v", ErrNotFound) @@ -917,7 +930,7 @@ func TestFind(t *testing.T) { var users []*User for i := 0; i <= 4; i++ { - user, err := Create(tests.Context(), auth.Claims{}, test.MasterDB, UserCreateRequest{ + user, err := repo.Create(tests.Context(), auth.Claims{}, UserCreateRequest{ FirstName: "Lee", LastName: "Brown", Email: uuid.NewRandom().String() + "@geeksinthewoods.com", @@ -1029,7 +1042,7 @@ func TestFind(t *testing.T) { { ctx := tests.Context() - res, err := Find(ctx, auth.Claims{}, test.MasterDB, tt.req) + res, err := repo.Find(ctx, auth.Claims{}, tt.req) if errors.Cause(err) != tt.error { t.Logf("\t\tGot : %+v", err) t.Logf("\t\tWant: %+v", tt.error) @@ -1064,7 +1077,7 @@ func TestResetPassword(t *testing.T) { // Create a new user for testing. initPass := uuid.NewRandom().String() - user, err := Create(ctx, auth.Claims{}, test.MasterDB, UserCreateRequest{ + user, err := repo.Create(ctx, auth.Claims{}, UserCreateRequest{ FirstName: "Lee", LastName: "Brown", Email: uuid.NewRandom().String() + "@geeksinthewoods.com", @@ -1091,18 +1104,10 @@ func TestResetPassword(t *testing.T) { t.Fatalf("\t%s\tCreate user account failed.", tests.Failed) } - // Mock the methods needed to make a password reset. - resetUrl := func(string) string { - return "" - } - notify := ¬ify.MockEmail{} - - secretKey := "6368616e676520746869732070617373" - // Ensure validation is working by trying ResetPassword with an empty request. { expectedErr := errors.New("Key: 'UserResetPasswordRequest.email' Error:Field validation for 'email' failed on the 'required' tag") - _, err = ResetPassword(ctx, test.MasterDB, resetUrl, notify, UserResetPasswordRequest{}, secretKey, now) + _, err = repo.ResetPassword(ctx, UserResetPasswordRequest{}, now) if err == nil { t.Logf("\t\tWant: %+v", expectedErr) t.Fatalf("\t%s\tResetPassword failed.", tests.Failed) @@ -1122,10 +1127,10 @@ func TestResetPassword(t *testing.T) { ttl := time.Hour // Make the reset password request. - resetHash, err := ResetPassword(ctx, test.MasterDB, resetUrl, notify, UserResetPasswordRequest{ + resetHash, err := repo.ResetPassword(ctx, UserResetPasswordRequest{ Email: user.Email, TTL: ttl, - }, secretKey, now) + }, now) if err != nil { t.Log("\t\tGot :", err) t.Fatalf("\t%s\tResetPassword failed.", tests.Failed) @@ -1133,7 +1138,7 @@ func TestResetPassword(t *testing.T) { t.Logf("\t%s\tResetPassword ok.", tests.Success) // Read the user to ensure the password_reset field was set. - user, err = ReadByID(ctx, auth.Claims{}, test.MasterDB, user.ID) + user, err = repo.ReadByID(ctx, auth.Claims{}, user.ID) if err != nil { t.Log("\t\tGot :", err) t.Fatalf("\t%s\tRead failed.", tests.Failed) @@ -1146,7 +1151,7 @@ func TestResetPassword(t *testing.T) { expectedErr := errors.New("Key: 'UserResetConfirmRequest.reset_hash' Error:Field validation for 'reset_hash' failed on the 'required' tag\n" + "Key: 'UserResetConfirmRequest.password' Error:Field validation for 'password' failed on the 'required' tag\n" + "Key: 'UserResetConfirmRequest.password_confirm' Error:Field validation for 'password_confirm' failed on the 'required' tag") - _, err = ResetConfirm(ctx, test.MasterDB, UserResetConfirmRequest{}, secretKey, now) + _, err = repo.ResetConfirm(ctx, UserResetConfirmRequest{}, now) if err == nil { t.Logf("\t\tWant: %+v", expectedErr) t.Fatalf("\t%s\tResetConfirm failed.", tests.Failed) @@ -1166,11 +1171,11 @@ func TestResetPassword(t *testing.T) { // Ensure the TTL is enforced. { newPass := uuid.NewRandom().String() - _, err = ResetConfirm(ctx, test.MasterDB, UserResetConfirmRequest{ + _, err = repo.ResetConfirm(ctx, UserResetConfirmRequest{ ResetHash: resetHash, Password: newPass, PasswordConfirm: newPass, - }, secretKey, now.UTC().Add(ttl*2)) + }, now.UTC().Add(ttl*2)) if errors.Cause(err) != ErrResetExpired { t.Logf("\t\tGot : %+v", errors.Cause(err)) t.Logf("\t\tWant: %+v", ErrResetExpired) @@ -1181,11 +1186,11 @@ func TestResetPassword(t *testing.T) { // Assuming we have received the email and clicked the link, we now can ensure confirm works. newPass := uuid.NewRandom().String() - reset, err := ResetConfirm(ctx, test.MasterDB, UserResetConfirmRequest{ + reset, err := repo.ResetConfirm(ctx, UserResetConfirmRequest{ ResetHash: resetHash, Password: newPass, PasswordConfirm: newPass, - }, secretKey, now) + }, now) if err != nil { t.Log("\t\tGot :", err) t.Fatalf("\t%s\tResetConfirm failed.", tests.Failed) @@ -1199,11 +1204,11 @@ func TestResetPassword(t *testing.T) { // Ensure the reset hash does not work after its used. { newPass := uuid.NewRandom().String() - _, err = ResetConfirm(ctx, test.MasterDB, UserResetConfirmRequest{ + _, err = repo.ResetConfirm(ctx, UserResetConfirmRequest{ ResetHash: resetHash, Password: newPass, PasswordConfirm: newPass, - }, secretKey, now) + }, now) if errors.Cause(err) != ErrNotFound { t.Logf("\t\tGot : %+v", errors.Cause(err)) t.Logf("\t\tWant: %+v", ErrNotFound) diff --git a/internal/user_account/models.go b/internal/user_account/models.go index 2b63ef0..e2131a4 100644 --- a/internal/user_account/models.go +++ b/internal/user_account/models.go @@ -3,6 +3,7 @@ package user_account import ( "context" "database/sql/driver" + "github.com/jmoiron/sqlx" "strings" "time" @@ -13,6 +14,18 @@ import ( "gopkg.in/go-playground/validator.v9" ) +// Repository defines the required dependencies for UserAccount. +type Repository struct { + DbConn *sqlx.DB +} + +// NewRepository creates a new Repository that defines dependencies for UserAccount. +func NewRepository(db *sqlx.DB) *Repository { + return &Repository{ + DbConn: db, + } +} + // UserAccount defines the one to many relationship of an user to an account. This // will enable a single user access to multiple accounts without having duplicate // users. Each association of a user to an account has a set of roles and a status diff --git a/internal/user_account/user.go b/internal/user_account/user.go index 9c63821..152bff2 100644 --- a/internal/user_account/user.go +++ b/internal/user_account/user.go @@ -6,13 +6,12 @@ import ( "geeks-accelerator/oss/saas-starter-kit/internal/platform/auth" "geeks-accelerator/oss/saas-starter-kit/internal/platform/web/webcontext" "github.com/huandu/go-sqlbuilder" - "github.com/jmoiron/sqlx" "github.com/pkg/errors" "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer" ) // UserFindByAccount lists all the users for a given account ID. -func UserFindByAccount(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserFindByAccountRequest) (Users, error) { +func (repo *Repository) UserFindByAccount(ctx context.Context, claims auth.Claims, req UserFindByAccountRequest) (Users, error) { span, ctx := tracer.StartSpanFromContext(ctx, "internal.user_account.UserFindByAccount") defer span.Finish() @@ -113,12 +112,12 @@ func UserFindByAccount(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, } queryStr, moreQueryArgs := query.Build() - queryStr = dbConn.Rebind(queryStr) + queryStr = repo.DbConn.Rebind(queryStr) queryArgs = append(queryArgs, moreQueryArgs...) // fetch all places from the db - rows, err := dbConn.QueryContext(ctx, queryStr, queryArgs...) + rows, err := repo.DbConn.QueryContext(ctx, queryStr, queryArgs...) if err != nil { err = errors.Wrapf(err, "query - %s", query.String()) err = errors.WithMessage(err, "find users failed") diff --git a/internal/user_account/user_account.go b/internal/user_account/user_account.go index 98de009..4fde70f 100644 --- a/internal/user_account/user_account.go +++ b/internal/user_account/user_account.go @@ -49,14 +49,14 @@ func mapRowsToUserAccount(rows *sql.Rows) (*UserAccount, error) { } // CanReadAccount determines if claims has the authority to access the specified user account by user ID. -func CanReadAccount(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, accountID string) error { - err := account.CanReadAccount(ctx, claims, dbConn, accountID) +func (repo *Repository) CanReadAccount(ctx context.Context, claims auth.Claims, accountID string) error { + err := account.CanReadAccount(ctx, claims, accountID) return mapAccountError(err) } // CanModifyAccount determines if claims has the authority to modify the specified user ID. -func CanModifyAccount(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, accountID string) error { - err := account.CanModifyAccount(ctx, claims, dbConn, accountID) +func (repo *Repository) CanModifyAccount(ctx context.Context, claims auth.Claims, accountID string) error { + err := account.CanModifyAccount(ctx, claims, accountID) return mapAccountError(err) } diff --git a/internal/user_account/user_account_test.go b/internal/user_account/user_account_test.go index 7ff38c7..2728273 100644 --- a/internal/user_account/user_account_test.go +++ b/internal/user_account/user_account_test.go @@ -17,7 +17,10 @@ import ( "github.com/pkg/errors" ) -var test *tests.Test +var ( + test *tests.Test + repo *Repository +) // TestMain is the entry point for testing. func TestMain(m *testing.M) { @@ -27,6 +30,9 @@ func TestMain(m *testing.M) { func testMain(m *testing.M) int { test = tests.New() defer test.TearDown() + + repo = NewRepository(test.MasterDB) + return m.Run() }