From 4be04544216500f67eb3108f8ef4541ab88a1b07 Mon Sep 17 00:00:00 2001 From: Lee Brown Date: Tue, 13 Aug 2019 22:26:25 -0800 Subject: [PATCH 01/13] update gitignore for .devops --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index 6566d85..d1a5f45 100644 --- a/.gitignore +++ b/.gitignore @@ -5,3 +5,4 @@ aws.* local.env .DS_Store tmp +.devops.json From 3bc814a01ef6d85799224b7ecdc155972c788550 Mon Sep 17 00:00:00 2001 From: Lee Brown Date: Tue, 13 Aug 2019 23:41:06 -0800 Subject: [PATCH 02/13] WIP: not sure how to solve user_account calling account.CanModifyAccount --- cmd/web-api/handlers/routes.go | 44 +++++-- internal/user/models.go | 20 +++ internal/user/user.go | 137 +++++++++++---------- internal/user/user_test.go | 93 +++++++------- internal/user_account/models.go | 13 ++ internal/user_account/user.go | 7 +- internal/user_account/user_account.go | 8 +- internal/user_account/user_account_test.go | 8 +- 8 files changed, 198 insertions(+), 132 deletions(-) diff --git a/cmd/web-api/handlers/routes.go b/cmd/web-api/handlers/routes.go index c056074..b68460f 100644 --- a/cmd/web-api/handlers/routes.go +++ b/cmd/web-api/handlers/routes.go @@ -2,6 +2,7 @@ package handlers import ( "context" + "geeks-accelerator/oss/saas-starter-kit/internal/user" "log" "net/http" "os" @@ -20,33 +21,52 @@ import ( "gopkg.in/DataDog/dd-trace-go.v1/contrib/go-redis/redis" ) + +type AppContext struct { + Log *log.Logger + Env webcontext.Env + Repo *user.Repository + MasterDB *sqlx.DB + Redis *redis.Client + Authenticator *auth.Authenticator + PreAppMiddleware []web.Middleware + PostAppMiddleware []web.Middleware +} + + // API returns a handler for a set of routes. -func API(shutdown chan os.Signal, log *log.Logger, env webcontext.Env, masterDB *sqlx.DB, redis *redis.Client, authenticator *auth.Authenticator, globalMids ...web.Middleware) http.Handler { +func API(shutdown chan os.Signal, appContext *AppContext ) http.Handler { - // Define base middlewares applied to all requests. - middlewares := []web.Middleware{ - mid.Trace(), mid.Logger(log), mid.Errors(log, nil), mid.Metrics(), mid.Panics(), - } + // Include the pre middlewares first. + middlewares := appContext.PreAppMiddleware - // Append any global middlewares if they were included. - if len(globalMids) > 0 { - middlewares = append(middlewares, globalMids...) + // Define app middlewares applied to all requests. + middlewares = append(middlewares, + mid.Trace(), + mid.Logger(appContext.Log), + mid.Errors(appContext.Log, nil), + mid.Metrics(), + mid.Panics()) + + // Append any global middlewares that should be included after the app middlewares. + if len(appContext.PostAppMiddleware) > 0 { + middlewares = append(middlewares, appContext.PostAppMiddleware...) } // Construct the web.App which holds all routes as well as common Middleware. - app := web.NewApp(shutdown, log, env, middlewares...) + app := web.NewApp(shutdown, appContext.Log, appContext.Env, middlewares...) // Register health check endpoint. This route is not authenticated. check := Check{ - MasterDB: masterDB, - Redis: redis, + MasterDB: appContext.MasterDB, + Redis: appContext.Redis, } app.Handle("GET", "/v1/health", check.Health) app.Handle("GET", "/ping", check.Ping) // Register user management and authentication endpoints. u := User{ - MasterDB: masterDB, + MasterDB: appContext.MasterDB, TokenGenerator: authenticator, } app.Handle("GET", "/v1/users", u.Find, mid.AuthenticateHeader(authenticator)) diff --git a/internal/user/models.go b/internal/user/models.go index 2aca963..b4c65c5 100644 --- a/internal/user/models.go +++ b/internal/user/models.go @@ -4,8 +4,10 @@ import ( "context" "database/sql" "encoding/json" + "geeks-accelerator/oss/saas-starter-kit/internal/platform/notify" "geeks-accelerator/oss/saas-starter-kit/internal/platform/web" "geeks-accelerator/oss/saas-starter-kit/internal/platform/web/webcontext" + "github.com/jmoiron/sqlx" "github.com/pkg/errors" "github.com/sudo-suhas/symcrypto" "strconv" @@ -15,6 +17,24 @@ import ( "github.com/lib/pq" ) +// Repository defines the required dependencies for User. +type Repository struct { + DbConn *sqlx.DB + ResetUrl func(string) string + Notify notify.Email + SecretKey string +} + +// NewRepository creates a new Repository that defines dependencies for User. +func NewRepository(db *sqlx.DB, resetUrl func(string) string, notify notify.Email, secretKey string) *Repository { + return &Repository{ + DbConn: db, + ResetUrl: resetUrl, + Notify: notify, + SecretKey: secretKey, + } +} + // User represents someone with access to our system. type User struct { ID string `json:"id" validate:"required,uuid" example:"d69bdef7-173f-4d29-b52c-3edc60baf6a2"` diff --git a/internal/user/user.go b/internal/user/user.go index 6354c19..81d036e 100644 --- a/internal/user/user.go +++ b/internal/user/user.go @@ -6,7 +6,6 @@ import ( "time" "geeks-accelerator/oss/saas-starter-kit/internal/platform/auth" - "geeks-accelerator/oss/saas-starter-kit/internal/platform/notify" "geeks-accelerator/oss/saas-starter-kit/internal/platform/web/webcontext" "github.com/huandu/go-sqlbuilder" "github.com/jmoiron/sqlx" @@ -55,7 +54,7 @@ func mapRowsToUser(rows *sql.Rows) (*User, error) { } // CanReadUser determines if claims has the authority to access the specified user ID. -func CanReadUser(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, userID string) error { +func (repo *Repository) CanReadUser(ctx context.Context, claims auth.Claims, userID string) error { // If the request has claims from a specific user, ensure that the user // has the correct access to the user. if claims.Subject != "" && claims.Subject != userID { @@ -68,10 +67,10 @@ func CanReadUser(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, userI query.Equal("user_id", userID), )) queryStr, args := query.Build() - queryStr = dbConn.Rebind(queryStr) + queryStr = repo.DbConn.Rebind(queryStr) var userAccountId string - err := dbConn.QueryRowContext(ctx, queryStr, args...).Scan(&userAccountId) + err := repo.DbConn.QueryRowContext(ctx, queryStr, args...).Scan(&userAccountId) if err != nil && err != sql.ErrNoRows { err = errors.Wrapf(err, "query - %s", query.String()) return err @@ -88,7 +87,7 @@ func CanReadUser(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, userI } // CanModifyUser determines if claims has the authority to modify the specified user ID. -func CanModifyUser(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, userID string) error { +func (repo *Repository) CanModifyUser(ctx context.Context, claims auth.Claims, userID string) error { // If the request has claims from a specific user, ensure that the user // has the correct role for creating a new user. if claims.Subject != "" && claims.Subject != userID { @@ -99,7 +98,7 @@ func CanModifyUser(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, use } } - if err := CanReadUser(ctx, claims, dbConn, userID); err != nil { + if err := repo.CanReadUser(ctx, claims, userID); err != nil { return err } @@ -118,10 +117,10 @@ func CanModifyUser(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, use "'"+auth.RoleAdmin+"' = ANY (roles)", )) queryStr, args := query.Build() - queryStr = dbConn.Rebind(queryStr) + queryStr = repo.DbConn.Rebind(queryStr) var userAccountId string - err := dbConn.QueryRowContext(ctx, queryStr, args...).Scan(&userAccountId) + err := repo.DbConn.QueryRowContext(ctx, queryStr, args...).Scan(&userAccountId) if err != nil && err != sql.ErrNoRows { err = errors.Wrapf(err, "query - %s", query.String()) return err @@ -199,13 +198,13 @@ func findRequestQuery(req UserFindRequest) (*sqlbuilder.SelectBuilder, []interfa } // Find gets all the users from the database based on the request params. -func Find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserFindRequest) (Users, error) { +func (repo *Repository) Find(ctx context.Context, claims auth.Claims, req UserFindRequest) (Users, error) { query, args := findRequestQuery(req) - return find(ctx, claims, dbConn, query, args, req.IncludeArchived) + return repo.find(ctx, claims, query, args, req.IncludeArchived) } // find internal method for getting all the users from the database using a select query. -func find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, query *sqlbuilder.SelectBuilder, args []interface{}, includedArchived bool) (Users, error) { +func (repo *Repository) find(ctx context.Context, claims auth.Claims, query *sqlbuilder.SelectBuilder, args []interface{}, includedArchived bool) (Users, error) { span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.Find") defer span.Finish() @@ -222,11 +221,11 @@ func find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, query *sqlbu return nil, err } queryStr, queryArgs := query.Build() - queryStr = dbConn.Rebind(queryStr) + queryStr = repo.DbConn.Rebind(queryStr) args = append(args, queryArgs...) // fetch all places from the db - rows, err := dbConn.QueryContext(ctx, queryStr, args...) + rows, err := repo.DbConn.QueryContext(ctx, queryStr, args...) if err != nil { err = errors.Wrapf(err, "query - %s", query.String()) err = errors.WithMessage(err, "find users failed") @@ -248,17 +247,17 @@ func find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, query *sqlbu } // Validation an email address is unique excluding the current user ID. -func UniqueEmail(ctx context.Context, dbConn *sqlx.DB, email, userId string) (bool, error) { +func (repo *Repository) UniqueEmail(ctx context.Context, email, userId string) (bool, error) { query := sqlbuilder.NewSelectBuilder().Select("id").From(userTableName) query.Where(query.And( query.Equal("email", email), query.NotEqual("id", userId), )) queryStr, args := query.Build() - queryStr = dbConn.Rebind(queryStr) + queryStr = repo.DbConn.Rebind(queryStr) var existingId string - err := dbConn.QueryRowContext(ctx, queryStr, args...).Scan(&existingId) + err := repo.DbConn.QueryRowContext(ctx, queryStr, args...).Scan(&existingId) if err != nil && err != sql.ErrNoRows { err = errors.Wrapf(err, "query - %s", query.String()) return false, err @@ -273,7 +272,7 @@ func UniqueEmail(ctx context.Context, dbConn *sqlx.DB, email, userId string) (bo } // Create inserts a new user into the database. -func Create(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserCreateRequest, now time.Time) (*User, error) { +func (repo *Repository) Create(ctx context.Context, claims auth.Claims, req UserCreateRequest, now time.Time) (*User, error) { span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.Create") defer span.Finish() @@ -284,7 +283,7 @@ func Create(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserCr v := webcontext.Validator() // Validation email address is unique in the database. - uniq, err := UniqueEmail(ctx, dbConn, req.Email, "") + uniq, err := repo.UniqueEmail(ctx, req.Email, "") if err != nil { return nil, err } @@ -346,8 +345,8 @@ func Create(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserCr // Execute the query with the provided context. sql, args := query.Build() - sql = dbConn.Rebind(sql) - _, err = dbConn.ExecContext(ctx, sql, args...) + sql = repo.DbConn.Rebind(sql) + _, err = repo.DbConn.ExecContext(ctx, sql, args...) if err != nil { err = errors.Wrapf(err, "query - %s", query.String()) err = errors.WithMessage(err, "create user failed") @@ -358,14 +357,14 @@ func Create(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserCr } // Create invite inserts a new user into the database. -func CreateInvite(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserCreateInviteRequest, now time.Time) (*User, error) { +func (repo *Repository) CreateInvite(ctx context.Context, claims auth.Claims, req UserCreateInviteRequest, now time.Time) (*User, error) { span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.CreateInvite") defer span.Finish() v := webcontext.Validator() // Validation email address is unique in the database. - uniq, err := UniqueEmail(ctx, dbConn, req.Email, "") + uniq, err := repo.UniqueEmail(ctx, req.Email, "") if err != nil { return nil, err } @@ -414,8 +413,8 @@ func CreateInvite(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req // Execute the query with the provided context. sql, args := query.Build() - sql = dbConn.Rebind(sql) - _, err = dbConn.ExecContext(ctx, sql, args...) + sql = repo.DbConn.Rebind(sql) + _, err = repo.DbConn.ExecContext(ctx, sql, args...) if err != nil { err = errors.Wrapf(err, "query - %s", query.String()) err = errors.WithMessage(err, "create user failed") @@ -426,15 +425,15 @@ func CreateInvite(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req } // ReadByID gets the specified user by ID from the database. -func ReadByID(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, id string) (*User, error) { - return Read(ctx, claims, dbConn, UserReadRequest{ +func (repo *Repository) ReadByID(ctx context.Context, claims auth.Claims, id string) (*User, error) { + return repo.Read(ctx, claims, UserReadRequest{ ID: id, IncludeArchived: false, }) } // Read gets the specified user from the database. -func Read(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserReadRequest) (*User, error) { +func (repo *Repository) Read(ctx context.Context, claims auth.Claims, req UserReadRequest) (*User, error) { span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.Read") defer span.Finish() @@ -449,7 +448,7 @@ func Read(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserRead query := selectQuery() query.Where(query.Equal("id", req.ID)) - res, err := find(ctx, claims, dbConn, query, []interface{}{}, req.IncludeArchived) + res, err := repo.find(ctx, claims, query, []interface{}{}, req.IncludeArchived) if err != nil { return nil, err } else if res == nil || len(res) == 0 { @@ -462,7 +461,7 @@ func Read(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserRead } // ReadByEmail gets the specified user from the database. -func ReadByEmail(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, email string, includedArchived bool) (*User, error) { +func (repo *Repository) ReadByEmail(ctx context.Context, claims auth.Claims, email string, includedArchived bool) (*User, error) { span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.ReadByEmail") defer span.Finish() @@ -470,7 +469,7 @@ func ReadByEmail(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, email query := selectQuery() query.Where(query.Equal("email", email)) - res, err := find(ctx, claims, dbConn, query, []interface{}{}, includedArchived) + res, err := repo.find(ctx, claims, query, []interface{}{}, includedArchived) if err != nil { return nil, err } else if res == nil || len(res) == 0 { @@ -483,14 +482,14 @@ func ReadByEmail(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, email } // Update replaces a user in the database. -func Update(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserUpdateRequest, now time.Time) error { +func (repo *Repository) Update(ctx context.Context, claims auth.Claims, req UserUpdateRequest, now time.Time) error { span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.Update") defer span.Finish() // Validation email address is unique in the database. if req.Email != nil { // Validation email address is unique in the database. - uniq, err := UniqueEmail(ctx, dbConn, *req.Email, req.ID) + uniq, err := repo.UniqueEmail(ctx, *req.Email, req.ID) if err != nil { return err } @@ -507,7 +506,7 @@ func Update(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserUp } // Ensure the claims can modify the user specified in the request. - err = CanModifyUser(ctx, claims, dbConn, req.ID) + err = repo.CanModifyUser(ctx, claims, req.ID) if err != nil { return err } @@ -555,8 +554,8 @@ func Update(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserUp // Execute the query with the provided context. sql, args := query.Build() - sql = dbConn.Rebind(sql) - _, err = dbConn.ExecContext(ctx, sql, args...) + sql = repo.DbConn.Rebind(sql) + _, err = repo.DbConn.ExecContext(ctx, sql, args...) if err != nil { err = errors.Wrapf(err, "query - %s", query.String()) err = errors.WithMessagef(err, "update user %s failed", req.ID) @@ -567,7 +566,7 @@ func Update(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserUp } // Update changes the password for a user in the database. -func UpdatePassword(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserUpdatePasswordRequest, now time.Time) error { +func (repo *Repository) UpdatePassword(ctx context.Context, claims auth.Claims, req UserUpdatePasswordRequest, now time.Time) error { span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.UpdatePassword") defer span.Finish() @@ -579,7 +578,7 @@ func UpdatePassword(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, re } // Ensure the claims can modify the user specified in the request. - err = CanModifyUser(ctx, claims, dbConn, req.ID) + err = repo.CanModifyUser(ctx, claims, req.ID) if err != nil { return err } @@ -616,8 +615,8 @@ func UpdatePassword(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, re // Execute the query with the provided context. sql, args := query.Build() - sql = dbConn.Rebind(sql) - _, err = dbConn.ExecContext(ctx, sql, args...) + sql = repo.DbConn.Rebind(sql) + _, err = repo.DbConn.ExecContext(ctx, sql, args...) if err != nil { err = errors.Wrapf(err, "query - %s", query.String()) err = errors.WithMessagef(err, "update password for user %s failed", req.ID) @@ -628,7 +627,7 @@ func UpdatePassword(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, re } // Archive soft deleted the user from the database. -func Archive(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserArchiveRequest, now time.Time) error { +func (repo *Repository) Archive(ctx context.Context, claims auth.Claims, req UserArchiveRequest, now time.Time) error { span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.Archive") defer span.Finish() @@ -640,7 +639,7 @@ func Archive(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserA } // Ensure the claims can modify the user specified in the request. - err = CanModifyUser(ctx, claims, dbConn, req.ID) + err = repo.CanModifyUser(ctx, claims, req.ID) if err != nil { return err } else if claims.Subject != "" && claims.Subject == req.ID && !req.force { @@ -669,8 +668,8 @@ func Archive(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserA // Execute the query with the provided context. sql, args := query.Build() - sql = dbConn.Rebind(sql) - _, err = dbConn.ExecContext(ctx, sql, args...) + sql = repo.DbConn.Rebind(sql) + _, err = repo.DbConn.ExecContext(ctx, sql, args...) if err != nil { err = errors.Wrapf(err, "query - %s", query.String()) err = errors.WithMessagef(err, "archive user %s failed", req.ID) @@ -689,8 +688,8 @@ func Archive(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserA // Execute the query with the provided context. sql, args := query.Build() - sql = dbConn.Rebind(sql) - _, err = dbConn.ExecContext(ctx, sql, args...) + sql = repo.DbConn.Rebind(sql) + _, err = repo.DbConn.ExecContext(ctx, sql, args...) if err != nil { err = errors.Wrapf(err, "query - %s", query.String()) err = errors.WithMessagef(err, "archive accounts for user %s failed", req.ID) @@ -702,7 +701,7 @@ func Archive(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserA } // Restore undeletes the user from the database. -func Restore(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserRestoreRequest, now time.Time) error { +func (repo *Repository) Restore(ctx context.Context, claims auth.Claims, req UserRestoreRequest, now time.Time) error { span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.Restore") defer span.Finish() @@ -714,7 +713,7 @@ func Restore(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserR } // Ensure the claims can modify the user specified in the request. - err = CanModifyUser(ctx, claims, dbConn, req.ID) + err = repo.CanModifyUser(ctx, claims, req.ID) if err != nil { return err } @@ -741,8 +740,8 @@ func Restore(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserR // Execute the query with the provided context. sql, args := query.Build() - sql = dbConn.Rebind(sql) - _, err = dbConn.ExecContext(ctx, sql, args...) + sql = repo.DbConn.Rebind(sql) + _, err = repo.DbConn.ExecContext(ctx, sql, args...) if err != nil { err = errors.Wrapf(err, "query - %s", query.String()) err = errors.WithMessagef(err, "unarchive user %s failed", req.ID) @@ -753,7 +752,7 @@ func Restore(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserR } // Delete removes a user from the database. -func Delete(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserDeleteRequest) error { +func (repo *Repository) Delete(ctx context.Context, claims auth.Claims, req UserDeleteRequest) error { span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.Delete") defer span.Finish() @@ -765,7 +764,7 @@ func Delete(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserDe } // Ensure the claims can modify the user specified in the request. - err = CanModifyUser(ctx, claims, dbConn, req.ID) + err = repo.CanModifyUser(ctx, claims, req.ID) if err != nil { return err } else if claims.Subject != "" && claims.Subject == req.ID && !req.force { @@ -773,7 +772,7 @@ func Delete(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserDe } // Start a new transaction to handle rollbacks on error. - tx, err := dbConn.Begin() + tx, err := repo.DbConn.Begin() if err != nil { return errors.WithStack(err) } @@ -790,7 +789,7 @@ func Delete(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserDe // Execute the query with the provided context. sql, args := query.Build() - sql = dbConn.Rebind(sql) + sql = repo.DbConn.Rebind(sql) _, err = tx.ExecContext(ctx, sql, args...) if err != nil { tx.Rollback() @@ -808,7 +807,7 @@ func Delete(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserDe // Execute the query with the provided context. sql, args := query.Build() - sql = dbConn.Rebind(sql) + sql = repo.DbConn.Rebind(sql) _, err = tx.ExecContext(ctx, sql, args...) if err != nil { tx.Rollback() @@ -827,7 +826,7 @@ func Delete(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserDe } // ResetPassword sends en email to the user to allow them to reset their password. -func ResetPassword(ctx context.Context, dbConn *sqlx.DB, resetUrl func(string) string, notify notify.Email, req UserResetPasswordRequest, secretKey string, now time.Time) (string, error) { +func (repo *Repository) ResetPassword(ctx context.Context, req UserResetPasswordRequest, now time.Time) (string, error) { span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.ResetPassword") defer span.Finish() @@ -845,7 +844,7 @@ func ResetPassword(ctx context.Context, dbConn *sqlx.DB, resetUrl func(string) s query := selectQuery() query.Where(query.Equal("email", req.Email)) - res, err := find(ctx, auth.Claims{}, dbConn, query, []interface{}{}, false) + res, err := repo.find(ctx, auth.Claims{}, query, []interface{}{}, false) if err != nil { return "", err } else if res == nil || len(res) == 0 { @@ -876,8 +875,8 @@ func ResetPassword(ctx context.Context, dbConn *sqlx.DB, resetUrl func(string) s // Execute the query with the provided context. sql, args := query.Build() - sql = dbConn.Rebind(sql) - _, err = dbConn.ExecContext(ctx, sql, args...) + sql = repo.DbConn.Rebind(sql) + _, err = repo.DbConn.ExecContext(ctx, sql, args...) if err != nil { err = errors.Wrapf(err, "query - %s", query.String()) err = errors.WithMessagef(err, "Update user %s failed.", u.ID) @@ -895,18 +894,18 @@ func ResetPassword(ctx context.Context, dbConn *sqlx.DB, resetUrl func(string) s requestIp = vals.RequestIP } - encrypted, err := NewResetHash(ctx, secretKey, resetId, requestIp, req.TTL, now) + encrypted, err := NewResetHash(ctx, repo.SecretKey, resetId, requestIp, req.TTL, now) if err != nil { return "", err } data := map[string]interface{}{ "Name": u.FirstName, - "Url": resetUrl(encrypted), + "Url": repo.ResetUrl(encrypted), "Minutes": req.TTL.Minutes(), } - err = notify.Send(ctx, u.Email, "Reset your Password", "user_reset_password", data) + err = repo.Notify.Send(ctx, u.Email, "Reset your Password", "user_reset_password", data) if err != nil { err = errors.WithMessagef(err, "Send password reset email to %s failed.", u.Email) return "", err @@ -916,7 +915,7 @@ func ResetPassword(ctx context.Context, dbConn *sqlx.DB, resetUrl func(string) s } // ResetConfirm updates the password for a user using the provided reset password ID. -func ResetConfirm(ctx context.Context, dbConn *sqlx.DB, req UserResetConfirmRequest, secretKey string, now time.Time) (*User, error) { +func (repo *Repository) ResetConfirm(ctx context.Context, req UserResetConfirmRequest, now time.Time) (*User, error) { span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.ResetConfirm") defer span.Finish() @@ -928,7 +927,7 @@ func ResetConfirm(ctx context.Context, dbConn *sqlx.DB, req UserResetConfirmRequ return nil, err } - hash, err := ParseResetHash(ctx, secretKey, req.ResetHash, now) + hash, err := ParseResetHash(ctx, repo.SecretKey, req.ResetHash, now) if err != nil { return nil, err } @@ -939,7 +938,7 @@ func ResetConfirm(ctx context.Context, dbConn *sqlx.DB, req UserResetConfirmRequ query := selectQuery() query.Where(query.Equal("password_reset", hash.ResetID)) - res, err := find(ctx, auth.Claims{}, dbConn, query, []interface{}{}, false) + res, err := repo.find(ctx, auth.Claims{}, query, []interface{}{}, false) if err != nil { return nil, err } else if res == nil || len(res) == 0 { @@ -979,8 +978,8 @@ func ResetConfirm(ctx context.Context, dbConn *sqlx.DB, req UserResetConfirmRequ // Execute the query with the provided context. sql, args := query.Build() - sql = dbConn.Rebind(sql) - _, err = dbConn.ExecContext(ctx, sql, args...) + sql = repo.DbConn.Rebind(sql) + _, err = repo.DbConn.ExecContext(ctx, sql, args...) if err != nil { err = errors.Wrapf(err, "query - %s", query.String()) err = errors.WithMessagef(err, "update password for user %s failed", u.ID) @@ -1000,6 +999,10 @@ type MockUserResponse struct { func MockUser(ctx context.Context, dbConn *sqlx.DB, now time.Time) (*MockUserResponse, error) { pass := uuid.NewRandom().String() + repo := &Repository{ + DbConn: dbConn, + } + req := UserCreateRequest{ FirstName: "Lee", LastName: "Brown", @@ -1007,7 +1010,7 @@ func MockUser(ctx context.Context, dbConn *sqlx.DB, now time.Time) (*MockUserRes Password: pass, PasswordConfirm: pass, } - u, err := Create(ctx, auth.Claims{}, dbConn, req, now) + u, err := repo.Create(ctx, auth.Claims{}, req, now) if err != nil { return nil, err } diff --git a/internal/user/user_test.go b/internal/user/user_test.go index 2dc15fd..edfaf49 100644 --- a/internal/user/user_test.go +++ b/internal/user/user_test.go @@ -18,7 +18,10 @@ import ( "github.com/pkg/errors" ) -var test *tests.Test +var ( + test *tests.Test + repo *Repository +) // TestMain is the entry point for testing. func TestMain(m *testing.M) { @@ -28,6 +31,16 @@ func TestMain(m *testing.M) { func testMain(m *testing.M) int { test = tests.New() defer test.TearDown() + + // Mock the methods needed to make a password reset. + resetUrl := func(string) string { + return "" + } + notify := ¬ify.MockEmail{} + secretKey := "6368616e676520746869732070617373" + + repo = NewRepository(test.MasterDB, resetUrl, notify, secretKey) + return m.Run() } @@ -219,7 +232,7 @@ func TestCreateValidation(t *testing.T) { { ctx := tests.Context() - res, err := Create(ctx, auth.Claims{}, test.MasterDB, tt.req, now) + res, err := repo.Create(ctx, auth.Claims{}, tt.req, now) if err != tt.error { // TODO: need a better way to handle validation errors as they are // of type interface validator.ValidationErrorsTranslations @@ -272,7 +285,7 @@ func TestCreateValidationEmailUnique(t *testing.T) { Password: "akTechFr0n!ier", PasswordConfirm: "akTechFr0n!ier", } - user1, err := Create(ctx, auth.Claims{}, test.MasterDB, req1, now) + user1, err := repo.Create(ctx, auth.Claims{}, req1, now) if err != nil { t.Log("\t\tGot :", err) t.Fatalf("\t%s\tCreate failed.", tests.Failed) @@ -286,7 +299,7 @@ func TestCreateValidationEmailUnique(t *testing.T) { PasswordConfirm: "W0rkL1fe#", } expectedErr := errors.New("Key: 'UserCreateRequest.email' Error:Field validation for 'email' failed on the 'unique' tag") - _, err = Create(ctx, auth.Claims{}, test.MasterDB, req2, now) + _, err = repo.Create(ctx, auth.Claims{}, req2, now) if err == nil { t.Logf("\t\tWant: %+v", expectedErr) t.Fatalf("\t%s\tCreate failed.", tests.Failed) @@ -374,7 +387,7 @@ func TestCreateClaims(t *testing.T) { { ctx := tests.Context() - _, err := Create(ctx, tt.claims, test.MasterDB, tt.req, now) + _, err := repo.Create(ctx, tt.claims, tt.req, now) if errors.Cause(err) != tt.error { t.Logf("\t\tGot : %+v", err) t.Logf("\t\tWant: %+v", tt.error) @@ -421,7 +434,7 @@ func TestUpdateValidation(t *testing.T) { { ctx := tests.Context() - err := Update(ctx, auth.Claims{}, test.MasterDB, tt.req, now) + err := repo.Update(ctx, auth.Claims{}, tt.req, now) if err != tt.error { // TODO: need a better way to handle validation errors as they are // of type interface validator.ValidationErrorsTranslations @@ -463,7 +476,7 @@ func TestUpdateValidationEmailUnique(t *testing.T) { Password: "akTechFr0n!ier", PasswordConfirm: "akTechFr0n!ier", } - user1, err := Create(ctx, auth.Claims{}, test.MasterDB, req1, now) + user1, err := repo.Create(ctx, auth.Claims{}, req1, now) if err != nil { t.Log("\t\tGot :", err) t.Fatalf("\t%s\tCreate failed.", tests.Failed) @@ -476,7 +489,7 @@ func TestUpdateValidationEmailUnique(t *testing.T) { Password: "W0rkL1fe#", PasswordConfirm: "W0rkL1fe#", } - user2, err := Create(ctx, auth.Claims{}, test.MasterDB, req2, now) + user2, err := repo.Create(ctx, auth.Claims{}, req2, now) if err != nil { t.Log("\t\tGot :", err) t.Fatalf("\t%s\tCreate failed.", tests.Failed) @@ -488,7 +501,7 @@ func TestUpdateValidationEmailUnique(t *testing.T) { Email: &user1.Email, } expectedErr := errors.New("Key: 'UserUpdateRequest.email' Error:Field validation for 'email' failed on the 'unique' tag") - err = Update(ctx, auth.Claims{}, test.MasterDB, updateReq, now) + err = repo.Update(ctx, auth.Claims{}, updateReq, now) if err == nil { t.Logf("\t\tWant: %+v", expectedErr) t.Fatalf("\t%s\tUpdate failed.", tests.Failed) @@ -518,7 +531,7 @@ func TestUpdatePassword(t *testing.T) { // Create a new user for testing. initPass := uuid.NewRandom().String() - user, err := Create(ctx, auth.Claims{}, test.MasterDB, UserCreateRequest{ + user, err := repo.Create(ctx, auth.Claims{}, UserCreateRequest{ FirstName: "Lee", LastName: "Brown", Email: uuid.NewRandom().String() + "@geeksinthewoods.com", @@ -549,7 +562,7 @@ func TestUpdatePassword(t *testing.T) { expectedErr := errors.New("Key: 'UserUpdatePasswordRequest.id' Error:Field validation for 'id' failed on the 'required' tag\n" + "Key: 'UserUpdatePasswordRequest.password' Error:Field validation for 'password' failed on the 'required' tag\n" + "Key: 'UserUpdatePasswordRequest.password_confirm' Error:Field validation for 'password_confirm' failed on the 'required' tag") - err = UpdatePassword(ctx, auth.Claims{}, test.MasterDB, UserUpdatePasswordRequest{}, now) + err = repo.UpdatePassword(ctx, auth.Claims{}, UserUpdatePasswordRequest{}, now) if err == nil { t.Logf("\t\tWant: %+v", expectedErr) t.Fatalf("\t%s\tUpdate failed.", tests.Failed) @@ -567,7 +580,7 @@ func TestUpdatePassword(t *testing.T) { // Update the users password. newPass := uuid.NewRandom().String() - err = UpdatePassword(ctx, auth.Claims{}, test.MasterDB, UserUpdatePasswordRequest{ + err = repo.UpdatePassword(ctx, auth.Claims{}, UserUpdatePasswordRequest{ ID: user.ID, Password: newPass, PasswordConfirm: newPass, @@ -800,7 +813,7 @@ func TestCrud(t *testing.T) { // Always create the new user with empty claims, testing claims for create user // will be handled separately. - user, err := Create(tests.Context(), auth.Claims{}, test.MasterDB, tt.create, now) + user, err := repo.Create(tests.Context(), auth.Claims{}, tt.create, now) if err != nil { t.Log("\t\tGot :", err) t.Fatalf("\t%s\tCreate user failed.", tests.Failed) @@ -823,7 +836,7 @@ func TestCrud(t *testing.T) { // Update the user. updateReq := tt.update(user) - err = Update(ctx, tt.claims(user, accountId), test.MasterDB, updateReq, now) + err = repo.Update(ctx, tt.claims(user, accountId), updateReq, now) if err != nil && errors.Cause(err) != tt.updateErr { t.Logf("\t\tGot : %+v", err) t.Logf("\t\tWant: %+v", tt.updateErr) @@ -832,7 +845,7 @@ func TestCrud(t *testing.T) { t.Logf("\t%s\tUpdate ok.", tests.Success) // Find the user and make sure the updates where made. - findRes, err := ReadByID(ctx, tt.claims(user, accountId), test.MasterDB, user.ID) + findRes, err := repo.ReadByID(ctx, tt.claims(user, accountId), user.ID) if err != nil && errors.Cause(err) != tt.findErr { t.Logf("\t\tGot : %+v", err) t.Logf("\t\tWant: %+v", tt.findErr) @@ -846,14 +859,14 @@ func TestCrud(t *testing.T) { } // Archive (soft-delete) the user. - err = Archive(ctx, tt.claims(user, accountId), test.MasterDB, UserArchiveRequest{ID: user.ID, force: true}, now) + err = repo.Archive(ctx, tt.claims(user, accountId), UserArchiveRequest{ID: user.ID, force: true}, now) if err != nil && errors.Cause(err) != tt.updateErr { t.Logf("\t\tGot : %+v", err) t.Logf("\t\tWant: %+v", tt.updateErr) t.Fatalf("\t%s\tArchive failed.", tests.Failed) } else if tt.updateErr == nil { // Trying to find the archived user with the includeArchived false should result in not found. - _, err = ReadByID(ctx, tt.claims(user, accountId), test.MasterDB, user.ID) + _, err = repo.ReadByID(ctx, tt.claims(user, accountId), user.ID) if err != nil && errors.Cause(err) != ErrNotFound { t.Logf("\t\tGot : %+v", err) t.Logf("\t\tWant: %+v", ErrNotFound) @@ -861,7 +874,7 @@ func TestCrud(t *testing.T) { } // Trying to find the archived user with the includeArchived true should result no error. - _, err = Read(ctx, tt.claims(user, accountId), test.MasterDB, + _, err = repo.Read(ctx, tt.claims(user, accountId), UserReadRequest{ID: user.ID, IncludeArchived: true}) if err != nil { t.Log("\t\tGot :", err) @@ -871,14 +884,14 @@ func TestCrud(t *testing.T) { t.Logf("\t%s\tArchive ok.", tests.Success) // Restore (un-delete) the user. - err = Restore(ctx, tt.claims(user, accountId), test.MasterDB, UserRestoreRequest{ID: user.ID}, now) + err = repo.Restore(ctx, tt.claims(user, accountId), UserRestoreRequest{ID: user.ID}, now) if err != nil && errors.Cause(err) != tt.updateErr { t.Logf("\t\tGot : %+v", err) t.Logf("\t\tWant: %+v", tt.updateErr) t.Fatalf("\t%s\tUnarchive failed.", tests.Failed) } else if tt.updateErr == nil { // Trying to find the archived user with the includeArchived false should result no error. - _, err = ReadByID(ctx, tt.claims(user, accountId), test.MasterDB, user.ID) + _, err = repo.ReadByID(ctx, tt.claims(user, accountId), user.ID) if err != nil { t.Log("\t\tGot :", err) t.Fatalf("\t%s\tUnarchive Read failed.", tests.Failed) @@ -887,14 +900,14 @@ func TestCrud(t *testing.T) { t.Logf("\t%s\tUnarchive ok.", tests.Success) // Delete (hard-delete) the user. - err = Delete(ctx, tt.claims(user, accountId), test.MasterDB, UserDeleteRequest{ID: user.ID, force: true}) + err = repo.Delete(ctx, tt.claims(user, accountId), UserDeleteRequest{ID: user.ID, force: true}) if err != nil && errors.Cause(err) != tt.updateErr { t.Logf("\t\tGot : %+v", err) t.Logf("\t\tWant: %+v", tt.updateErr) t.Fatalf("\t%s\tUpdate failed.", tests.Failed) } else if tt.updateErr == nil { // Trying to find the deleted user with the includeArchived true should result in not found. - _, err = ReadByID(ctx, tt.claims(user, accountId), test.MasterDB, user.ID) + _, err = repo.ReadByID(ctx, tt.claims(user, accountId), user.ID) if errors.Cause(err) != ErrNotFound { t.Logf("\t\tGot : %+v", err) t.Logf("\t\tWant: %+v", ErrNotFound) @@ -917,7 +930,7 @@ func TestFind(t *testing.T) { var users []*User for i := 0; i <= 4; i++ { - user, err := Create(tests.Context(), auth.Claims{}, test.MasterDB, UserCreateRequest{ + user, err := repo.Create(tests.Context(), auth.Claims{}, UserCreateRequest{ FirstName: "Lee", LastName: "Brown", Email: uuid.NewRandom().String() + "@geeksinthewoods.com", @@ -1029,7 +1042,7 @@ func TestFind(t *testing.T) { { ctx := tests.Context() - res, err := Find(ctx, auth.Claims{}, test.MasterDB, tt.req) + res, err := repo.Find(ctx, auth.Claims{}, tt.req) if errors.Cause(err) != tt.error { t.Logf("\t\tGot : %+v", err) t.Logf("\t\tWant: %+v", tt.error) @@ -1064,7 +1077,7 @@ func TestResetPassword(t *testing.T) { // Create a new user for testing. initPass := uuid.NewRandom().String() - user, err := Create(ctx, auth.Claims{}, test.MasterDB, UserCreateRequest{ + user, err := repo.Create(ctx, auth.Claims{}, UserCreateRequest{ FirstName: "Lee", LastName: "Brown", Email: uuid.NewRandom().String() + "@geeksinthewoods.com", @@ -1091,18 +1104,10 @@ func TestResetPassword(t *testing.T) { t.Fatalf("\t%s\tCreate user account failed.", tests.Failed) } - // Mock the methods needed to make a password reset. - resetUrl := func(string) string { - return "" - } - notify := ¬ify.MockEmail{} - - secretKey := "6368616e676520746869732070617373" - // Ensure validation is working by trying ResetPassword with an empty request. { expectedErr := errors.New("Key: 'UserResetPasswordRequest.email' Error:Field validation for 'email' failed on the 'required' tag") - _, err = ResetPassword(ctx, test.MasterDB, resetUrl, notify, UserResetPasswordRequest{}, secretKey, now) + _, err = repo.ResetPassword(ctx, UserResetPasswordRequest{}, now) if err == nil { t.Logf("\t\tWant: %+v", expectedErr) t.Fatalf("\t%s\tResetPassword failed.", tests.Failed) @@ -1122,10 +1127,10 @@ func TestResetPassword(t *testing.T) { ttl := time.Hour // Make the reset password request. - resetHash, err := ResetPassword(ctx, test.MasterDB, resetUrl, notify, UserResetPasswordRequest{ + resetHash, err := repo.ResetPassword(ctx, UserResetPasswordRequest{ Email: user.Email, TTL: ttl, - }, secretKey, now) + }, now) if err != nil { t.Log("\t\tGot :", err) t.Fatalf("\t%s\tResetPassword failed.", tests.Failed) @@ -1133,7 +1138,7 @@ func TestResetPassword(t *testing.T) { t.Logf("\t%s\tResetPassword ok.", tests.Success) // Read the user to ensure the password_reset field was set. - user, err = ReadByID(ctx, auth.Claims{}, test.MasterDB, user.ID) + user, err = repo.ReadByID(ctx, auth.Claims{}, user.ID) if err != nil { t.Log("\t\tGot :", err) t.Fatalf("\t%s\tRead failed.", tests.Failed) @@ -1146,7 +1151,7 @@ func TestResetPassword(t *testing.T) { expectedErr := errors.New("Key: 'UserResetConfirmRequest.reset_hash' Error:Field validation for 'reset_hash' failed on the 'required' tag\n" + "Key: 'UserResetConfirmRequest.password' Error:Field validation for 'password' failed on the 'required' tag\n" + "Key: 'UserResetConfirmRequest.password_confirm' Error:Field validation for 'password_confirm' failed on the 'required' tag") - _, err = ResetConfirm(ctx, test.MasterDB, UserResetConfirmRequest{}, secretKey, now) + _, err = repo.ResetConfirm(ctx, UserResetConfirmRequest{}, now) if err == nil { t.Logf("\t\tWant: %+v", expectedErr) t.Fatalf("\t%s\tResetConfirm failed.", tests.Failed) @@ -1166,11 +1171,11 @@ func TestResetPassword(t *testing.T) { // Ensure the TTL is enforced. { newPass := uuid.NewRandom().String() - _, err = ResetConfirm(ctx, test.MasterDB, UserResetConfirmRequest{ + _, err = repo.ResetConfirm(ctx, UserResetConfirmRequest{ ResetHash: resetHash, Password: newPass, PasswordConfirm: newPass, - }, secretKey, now.UTC().Add(ttl*2)) + }, now.UTC().Add(ttl*2)) if errors.Cause(err) != ErrResetExpired { t.Logf("\t\tGot : %+v", errors.Cause(err)) t.Logf("\t\tWant: %+v", ErrResetExpired) @@ -1181,11 +1186,11 @@ func TestResetPassword(t *testing.T) { // Assuming we have received the email and clicked the link, we now can ensure confirm works. newPass := uuid.NewRandom().String() - reset, err := ResetConfirm(ctx, test.MasterDB, UserResetConfirmRequest{ + reset, err := repo.ResetConfirm(ctx, UserResetConfirmRequest{ ResetHash: resetHash, Password: newPass, PasswordConfirm: newPass, - }, secretKey, now) + }, now) if err != nil { t.Log("\t\tGot :", err) t.Fatalf("\t%s\tResetConfirm failed.", tests.Failed) @@ -1199,11 +1204,11 @@ func TestResetPassword(t *testing.T) { // Ensure the reset hash does not work after its used. { newPass := uuid.NewRandom().String() - _, err = ResetConfirm(ctx, test.MasterDB, UserResetConfirmRequest{ + _, err = repo.ResetConfirm(ctx, UserResetConfirmRequest{ ResetHash: resetHash, Password: newPass, PasswordConfirm: newPass, - }, secretKey, now) + }, now) if errors.Cause(err) != ErrNotFound { t.Logf("\t\tGot : %+v", errors.Cause(err)) t.Logf("\t\tWant: %+v", ErrNotFound) diff --git a/internal/user_account/models.go b/internal/user_account/models.go index 2b63ef0..e2131a4 100644 --- a/internal/user_account/models.go +++ b/internal/user_account/models.go @@ -3,6 +3,7 @@ package user_account import ( "context" "database/sql/driver" + "github.com/jmoiron/sqlx" "strings" "time" @@ -13,6 +14,18 @@ import ( "gopkg.in/go-playground/validator.v9" ) +// Repository defines the required dependencies for UserAccount. +type Repository struct { + DbConn *sqlx.DB +} + +// NewRepository creates a new Repository that defines dependencies for UserAccount. +func NewRepository(db *sqlx.DB) *Repository { + return &Repository{ + DbConn: db, + } +} + // UserAccount defines the one to many relationship of an user to an account. This // will enable a single user access to multiple accounts without having duplicate // users. Each association of a user to an account has a set of roles and a status diff --git a/internal/user_account/user.go b/internal/user_account/user.go index 9c63821..152bff2 100644 --- a/internal/user_account/user.go +++ b/internal/user_account/user.go @@ -6,13 +6,12 @@ import ( "geeks-accelerator/oss/saas-starter-kit/internal/platform/auth" "geeks-accelerator/oss/saas-starter-kit/internal/platform/web/webcontext" "github.com/huandu/go-sqlbuilder" - "github.com/jmoiron/sqlx" "github.com/pkg/errors" "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer" ) // UserFindByAccount lists all the users for a given account ID. -func UserFindByAccount(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserFindByAccountRequest) (Users, error) { +func (repo *Repository) UserFindByAccount(ctx context.Context, claims auth.Claims, req UserFindByAccountRequest) (Users, error) { span, ctx := tracer.StartSpanFromContext(ctx, "internal.user_account.UserFindByAccount") defer span.Finish() @@ -113,12 +112,12 @@ func UserFindByAccount(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, } queryStr, moreQueryArgs := query.Build() - queryStr = dbConn.Rebind(queryStr) + queryStr = repo.DbConn.Rebind(queryStr) queryArgs = append(queryArgs, moreQueryArgs...) // fetch all places from the db - rows, err := dbConn.QueryContext(ctx, queryStr, queryArgs...) + rows, err := repo.DbConn.QueryContext(ctx, queryStr, queryArgs...) if err != nil { err = errors.Wrapf(err, "query - %s", query.String()) err = errors.WithMessage(err, "find users failed") diff --git a/internal/user_account/user_account.go b/internal/user_account/user_account.go index 98de009..4fde70f 100644 --- a/internal/user_account/user_account.go +++ b/internal/user_account/user_account.go @@ -49,14 +49,14 @@ func mapRowsToUserAccount(rows *sql.Rows) (*UserAccount, error) { } // CanReadAccount determines if claims has the authority to access the specified user account by user ID. -func CanReadAccount(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, accountID string) error { - err := account.CanReadAccount(ctx, claims, dbConn, accountID) +func (repo *Repository) CanReadAccount(ctx context.Context, claims auth.Claims, accountID string) error { + err := account.CanReadAccount(ctx, claims, accountID) return mapAccountError(err) } // CanModifyAccount determines if claims has the authority to modify the specified user ID. -func CanModifyAccount(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, accountID string) error { - err := account.CanModifyAccount(ctx, claims, dbConn, accountID) +func (repo *Repository) CanModifyAccount(ctx context.Context, claims auth.Claims, accountID string) error { + err := account.CanModifyAccount(ctx, claims, accountID) return mapAccountError(err) } diff --git a/internal/user_account/user_account_test.go b/internal/user_account/user_account_test.go index 7ff38c7..2728273 100644 --- a/internal/user_account/user_account_test.go +++ b/internal/user_account/user_account_test.go @@ -17,7 +17,10 @@ import ( "github.com/pkg/errors" ) -var test *tests.Test +var ( + test *tests.Test + repo *Repository +) // TestMain is the entry point for testing. func TestMain(m *testing.M) { @@ -27,6 +30,9 @@ func TestMain(m *testing.M) { func testMain(m *testing.M) int { test = tests.New() defer test.TearDown() + + repo = NewRepository(test.MasterDB) + return m.Run() } From e45dd56149cc70f9e02c6da33c88ec8979c0f102 Mon Sep 17 00:00:00 2001 From: Lee Brown Date: Wed, 14 Aug 2019 11:40:26 -0800 Subject: [PATCH 03/13] Completed updating biz logic packages to use repository pattern --- internal/account/account.go | 70 ++++++----- .../account_preference/account_preference.go | 43 ++++--- .../account_preference_test.go | 34 ++--- internal/account/account_preference/models.go | 15 ++- internal/account/account_test.go | 44 ++++--- internal/account/models.go | 15 ++- internal/project/models.go | 16 ++- internal/project/project.go | 51 ++++---- internal/project/project_test.go | 13 +- .../project_routes.go | 20 +-- internal/signup/models.go | 21 ++++ internal/signup/signup.go | 13 +- internal/signup/signup_test.go | 27 +++- internal/user/models.go | 28 ++--- internal/user/user.go | 44 ++++--- internal/user/user_test.go | 14 +-- internal/user_account/invite/invite.go | 52 ++++---- internal/user_account/invite/invite_test.go | 55 ++++---- internal/user_account/invite/models.go | 29 +++++ internal/user_account/models.go | 4 +- internal/user_account/user_account.go | 60 ++++----- internal/user_account/user_account_test.go | 40 +++--- internal/user_auth/auth.go | 33 +++-- internal/user_auth/auth_test.go | 118 +++++++++--------- internal/user_auth/models.go | 24 ++++ 25 files changed, 530 insertions(+), 353 deletions(-) rename internal/{project-routes => project_route}/project_routes.go (64%) diff --git a/internal/account/account.go b/internal/account/account.go index 5a8dc7a..19575ee 100644 --- a/internal/account/account.go +++ b/internal/account/account.go @@ -64,6 +64,11 @@ func CanReadAccount(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, ac return nil } +// CanReadAccount determines if claims has the authority to access the specified account ID. +func (repo *Repository) CanReadAccount(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, accountID string) error { + return repo.CanReadAccount(ctx, claims, repo.DbConn, accountID) +} + // CanModifyAccount determines if claims has the authority to modify the specified account ID. func CanModifyAccount(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, accountID string) error { // If the request has claims from a specific account, ensure that the claims @@ -105,6 +110,11 @@ func CanModifyAccount(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, return nil } +// CanModifyAccount determines if claims has the authority to modify the specified account ID. +func (repo *Repository) CanModifyAccount(ctx context.Context, claims auth.Claims, accountID string) error { + return CanModifyAccount(ctx, claims, repo.DbConn, accountID) +} + // applyClaimsSelect applies a sub-query to the provided query to enforce ACL based on // the claims provided. // 1. All role types can access their user ID @@ -150,7 +160,7 @@ func selectQuery() *sqlbuilder.SelectBuilder { // Find gets all the accounts from the database based on the request params. // TODO: Need to figure out why can't parse the args when appending the where // to the query. -func Find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req AccountFindRequest) (Accounts, error) { +func (repo *Repository) Find(ctx context.Context, claims auth.Claims, req AccountFindRequest) (Accounts, error) { query := selectQuery() if req.Where != "" { @@ -166,7 +176,7 @@ func Find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req AccountF query.Offset(int(*req.Offset)) } - return find(ctx, claims, dbConn, query, req.Args, req.IncludeArchived) + return find(ctx, claims, repo.DbConn, query, req.Args, req.IncludeArchived) } // find internal method for getting all the accounts from the database using a select query. @@ -242,14 +252,14 @@ func UniqueName(ctx context.Context, dbConn *sqlx.DB, name, accountId string) (b } // Create inserts a new account into the database. -func Create(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req AccountCreateRequest, now time.Time) (*Account, error) { +func (repo *Repository) Create(ctx context.Context, claims auth.Claims, req AccountCreateRequest, now time.Time) (*Account, error) { span, ctx := tracer.StartSpanFromContext(ctx, "internal.account.Create") defer span.Finish() v := webcontext.Validator() // Validation account name is unique in the database. - uniq, err := UniqueName(ctx, dbConn, req.Name, "") + uniq, err := UniqueName(ctx, repo.DbConn, req.Name, "") if err != nil { return nil, err } @@ -310,8 +320,8 @@ func Create(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req Accoun // Execute the query with the provided context. sql, args := query.Build() - sql = dbConn.Rebind(sql) - _, err = dbConn.ExecContext(ctx, sql, args...) + sql = repo.DbConn.Rebind(sql) + _, err = repo.DbConn.ExecContext(ctx, sql, args...) if err != nil { err = errors.Wrapf(err, "query - %s", query.String()) err = errors.WithMessage(err, "create account failed") @@ -322,15 +332,15 @@ func Create(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req Accoun } // ReadByID gets the specified user by ID from the database. -func ReadByID(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, id string) (*Account, error) { - return Read(ctx, claims, dbConn, AccountReadRequest{ +func (repo *Repository) ReadByID(ctx context.Context, claims auth.Claims, id string) (*Account, error) { + return repo.Read(ctx, claims, AccountReadRequest{ ID: id, IncludeArchived: false, }) } // Read gets the specified account from the database. -func Read(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req AccountReadRequest) (*Account, error) { +func (repo *Repository) Read(ctx context.Context, claims auth.Claims, req AccountReadRequest) (*Account, error) { span, ctx := tracer.StartSpanFromContext(ctx, "internal.account.Read") defer span.Finish() @@ -345,7 +355,7 @@ func Read(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req AccountR query := sqlbuilder.NewSelectBuilder() query.Where(query.Equal("id", req.ID)) - res, err := find(ctx, claims, dbConn, query, []interface{}{}, req.IncludeArchived) + res, err := find(ctx, claims, repo.DbConn, query, []interface{}{}, req.IncludeArchived) if err != nil { return nil, err } else if res == nil || len(res) == 0 { @@ -358,7 +368,7 @@ func Read(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req AccountR } // Update replaces an account in the database. -func Update(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req AccountUpdateRequest, now time.Time) error { +func (repo *Repository) Update(ctx context.Context, claims auth.Claims, req AccountUpdateRequest, now time.Time) error { span, ctx := tracer.StartSpanFromContext(ctx, "internal.account.Update") defer span.Finish() @@ -366,7 +376,7 @@ func Update(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req Accoun if req.Name != nil { // Validation account name is unique in the database. - uniq, err := UniqueName(ctx, dbConn, *req.Name, req.ID) + uniq, err := UniqueName(ctx, repo.DbConn, *req.Name, req.ID) if err != nil { return err } @@ -382,7 +392,7 @@ func Update(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req Accoun } // Ensure the claims can modify the account specified in the request. - err = CanModifyAccount(ctx, claims, dbConn, req.ID) + err = CanModifyAccount(ctx, claims, repo.DbConn, req.ID) if err != nil { return err } @@ -460,8 +470,8 @@ func Update(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req Accoun // Execute the query with the provided context. sql, args := query.Build() - sql = dbConn.Rebind(sql) - _, err = dbConn.ExecContext(ctx, sql, args...) + sql = repo.DbConn.Rebind(sql) + _, err = repo.DbConn.ExecContext(ctx, sql, args...) if err != nil { err = errors.Wrapf(err, "query - %s", query.String()) err = errors.WithMessagef(err, "update account %s failed", req.ID) @@ -472,7 +482,7 @@ func Update(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req Accoun } // Archive soft deleted the account from the database. -func Archive(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req AccountArchiveRequest, now time.Time) error { +func (repo *Repository) Archive(ctx context.Context, claims auth.Claims, req AccountArchiveRequest, now time.Time) error { span, ctx := tracer.StartSpanFromContext(ctx, "internal.account.Archive") defer span.Finish() @@ -484,7 +494,7 @@ func Archive(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req Accou } // Ensure the claims can modify the account specified in the request. - err = CanModifyAccount(ctx, claims, dbConn, req.ID) + err = CanModifyAccount(ctx, claims, repo.DbConn, req.ID) if err != nil { return err } @@ -511,8 +521,8 @@ func Archive(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req Accou // Execute the query with the provided context. sql, args := query.Build() - sql = dbConn.Rebind(sql) - _, err = dbConn.ExecContext(ctx, sql, args...) + sql = repo.DbConn.Rebind(sql) + _, err = repo.DbConn.ExecContext(ctx, sql, args...) if err != nil { err = errors.Wrapf(err, "query - %s", query.String()) err = errors.WithMessagef(err, "archive account %s failed", req.ID) @@ -531,8 +541,8 @@ func Archive(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req Accou // Execute the query with the provided context. sql, args := query.Build() - sql = dbConn.Rebind(sql) - _, err = dbConn.ExecContext(ctx, sql, args...) + sql = repo.DbConn.Rebind(sql) + _, err = repo.DbConn.ExecContext(ctx, sql, args...) if err != nil { err = errors.Wrapf(err, "query - %s", query.String()) err = errors.WithMessagef(err, "archive users for account %s failed", req.ID) @@ -544,7 +554,7 @@ func Archive(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req Accou } // Delete removes an account from the database. -func Delete(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req AccountDeleteRequest) error { +func (repo *Repository) Delete(ctx context.Context, claims auth.Claims, req AccountDeleteRequest) error { span, ctx := tracer.StartSpanFromContext(ctx, "internal.account.Delete") defer span.Finish() @@ -556,13 +566,13 @@ func Delete(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req Accoun } // Ensure the claims can modify the account specified in the request. - err = CanModifyAccount(ctx, claims, dbConn, req.ID) + err = CanModifyAccount(ctx, claims, repo.DbConn, req.ID) if err != nil { return err } // Start a new transaction to handle rollbacks on error. - tx, err := dbConn.Begin() + tx, err := repo.DbConn.Begin() if err != nil { return errors.WithStack(err) } @@ -579,7 +589,7 @@ func Delete(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req Accoun // Execute the query with the provided context. sql, args := query.Build() - sql = dbConn.Rebind(sql) + sql = repo.DbConn.Rebind(sql) _, err = tx.ExecContext(ctx, sql, args...) if err != nil { tx.Rollback() @@ -602,7 +612,7 @@ func Delete(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req Accoun // Execute the query with the provided context. sql, args := query.Build() - sql = dbConn.Rebind(sql) + sql = repo.DbConn.Rebind(sql) _, err = tx.ExecContext(ctx, sql, args...) if err != nil { tx.Rollback() @@ -620,7 +630,7 @@ func Delete(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req Accoun // Execute the query with the provided context. sql, args := query.Build() - sql = dbConn.Rebind(sql) + sql = repo.DbConn.Rebind(sql) _, err = tx.ExecContext(ctx, sql, args...) if err != nil { tx.Rollback() @@ -642,6 +652,10 @@ func Delete(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req Accoun func MockAccount(ctx context.Context, dbConn *sqlx.DB, now time.Time) (*Account, error) { s := AccountStatus_Active + repo := &Repository{ + DbConn: dbConn, + } + req := AccountCreateRequest{ Name: uuid.NewRandom().String(), Address1: "103 East Main St", @@ -652,5 +666,5 @@ func MockAccount(ctx context.Context, dbConn *sqlx.DB, now time.Time) (*Account, Zipcode: "99686", Status: &s, } - return Create(ctx, auth.Claims{}, dbConn, req, now) + return repo.Create(ctx, auth.Claims{}, req, now) } diff --git a/internal/account/account_preference/account_preference.go b/internal/account/account_preference/account_preference.go index d995f75..3dd0e41 100644 --- a/internal/account/account_preference/account_preference.go +++ b/internal/account/account_preference/account_preference.go @@ -63,7 +63,7 @@ func applyClaimsSelect(ctx context.Context, claims auth.Claims, query *sqlbuilde // Find gets all the account preferences from the database based on the request params. // TODO: Need to figure out why can't parse the args when appending the where to the query. -func Find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req AccountPreferenceFindRequest) ([]*AccountPreference, error) { +func (repo *Repository) Find(ctx context.Context, claims auth.Claims, req AccountPreferenceFindRequest) ([]*AccountPreference, error) { query := sqlbuilder.NewSelectBuilder() if req.Where != "" { query.Where(query.And(req.Where)) @@ -78,11 +78,11 @@ func Find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req AccountP query.Offset(int(*req.Offset)) } - return find(ctx, claims, dbConn, query, req.Args, req.IncludeArchived) + return find(ctx, claims, repo.DbConn, query, req.Args, req.IncludeArchived) } // FindByAccountID gets the specified account preferences for an account from the database. -func FindByAccountID(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req AccountPreferenceFindByAccountIDRequest) ([]*AccountPreference, error) { +func (repo *Repository) FindByAccountID(ctx context.Context, claims auth.Claims, req AccountPreferenceFindByAccountIDRequest) ([]*AccountPreference, error) { span, ctx := tracer.StartSpanFromContext(ctx, "internal.account_preference.FindByAccountID") defer span.Finish() @@ -106,7 +106,7 @@ func FindByAccountID(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, r query.Offset(int(*req.Offset)) } - return find(ctx, claims, dbConn, query, []interface{}{}, req.IncludeArchived) + return find(ctx, claims, repo.DbConn, query, []interface{}{}, req.IncludeArchived) } // find internal method for getting all the account preferences from the database using a select query. @@ -157,7 +157,7 @@ func find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, query *sqlbu } // Read gets the specified account preference from the database. -func Read(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req AccountPreferenceReadRequest) (*AccountPreference, error) { +func (repo *Repository) Read(ctx context.Context, claims auth.Claims, req AccountPreferenceReadRequest) (*AccountPreference, error) { span, ctx := tracer.StartSpanFromContext(ctx, "internal.account_preference.Read") defer span.Finish() @@ -173,7 +173,7 @@ func Read(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req AccountP query.Equal("account_id", req.AccountID)), query.Equal("name", req.Name)) - res, err := find(ctx, claims, dbConn, query, []interface{}{}, req.IncludeArchived) + res, err := find(ctx, claims, repo.DbConn, query, []interface{}{}, req.IncludeArchived) if err != nil { return nil, err } else if res == nil || len(res) == 0 { @@ -263,7 +263,7 @@ func Validator() *validator.Validate { } // Set inserts a new account preference or updates an existing on. -func Set(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req AccountPreferenceSetRequest, now time.Time) error { +func (repo *Repository) Set(ctx context.Context, claims auth.Claims, req AccountPreferenceSetRequest, now time.Time) error { span, ctx := tracer.StartSpanFromContext(ctx, "internal.account_preference.Set") defer span.Finish() @@ -276,7 +276,7 @@ func Set(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req AccountPr } // Ensure the claims can modify the account specified in the request. - err = account.CanModifyAccount(ctx, claims, dbConn, req.AccountID) + err = account.CanModifyAccount(ctx, claims, repo.DbConn, req.AccountID) if err != nil { return err } @@ -301,11 +301,11 @@ func Set(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req AccountPr // Execute the query with the provided context. sql, args := query.Build() - sql = dbConn.Rebind(sql) + sql = repo.DbConn.Rebind(sql) sql = sql + " ON CONFLICT ON CONSTRAINT account_preferences_pkey DO UPDATE set value = EXCLUDED.value " - _, err = dbConn.ExecContext(ctx, sql, args...) + _, err = repo.DbConn.ExecContext(ctx, sql, args...) if err != nil { err = errors.Wrapf(err, "query - %s", query.String()) err = errors.WithMessage(err, "set account preference failed") @@ -316,7 +316,7 @@ func Set(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req AccountPr } // Archive soft deleted the account preference from the database. -func Archive(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req AccountPreferenceArchiveRequest, now time.Time) error { +func (repo *Repository) Archive(ctx context.Context, claims auth.Claims, req AccountPreferenceArchiveRequest, now time.Time) error { span, ctx := tracer.StartSpanFromContext(ctx, "internal.account_preference.Archive") defer span.Finish() @@ -328,7 +328,7 @@ func Archive(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req Accou } // Ensure the claims can modify the account specified in the request. - err = account.CanModifyAccount(ctx, claims, dbConn, req.AccountID) + err = account.CanModifyAccount(ctx, claims, repo.DbConn, req.AccountID) if err != nil { return err } @@ -355,8 +355,8 @@ func Archive(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req Accou // Execute the query with the provided context. sql, args := query.Build() - sql = dbConn.Rebind(sql) - _, err = dbConn.ExecContext(ctx, sql, args...) + sql = repo.DbConn.Rebind(sql) + _, err = repo.DbConn.ExecContext(ctx, sql, args...) if err != nil { err = errors.Wrapf(err, "query - %s", query.String()) err = errors.WithMessagef(err, "archive account preference %s for account %s failed", req.Name, req.AccountID) @@ -367,7 +367,7 @@ func Archive(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req Accou } // Delete removes an account preference from the database. -func Delete(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req AccountPreferenceDeleteRequest) error { +func (repo *Repository) Delete(ctx context.Context, claims auth.Claims, req AccountPreferenceDeleteRequest) error { span, ctx := tracer.StartSpanFromContext(ctx, "internal.account_preference.Delete") defer span.Finish() @@ -379,13 +379,13 @@ func Delete(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req Accoun } // Ensure the claims can modify the account specified in the request. - err = account.CanModifyAccount(ctx, claims, dbConn, req.AccountID) + err = account.CanModifyAccount(ctx, claims, repo.DbConn, req.AccountID) if err != nil { return err } // Start a new transaction to handle rollbacks on error. - tx, err := dbConn.Begin() + tx, err := repo.DbConn.Begin() if err != nil { return errors.WithStack(err) } @@ -397,7 +397,7 @@ func Delete(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req Accoun // Execute the query with the provided context. sql, args := query.Build() - sql = dbConn.Rebind(sql) + sql = repo.DbConn.Rebind(sql) _, err = tx.ExecContext(ctx, sql, args...) if err != nil { tx.Rollback() @@ -417,10 +417,15 @@ func Delete(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req Accoun // MockAccountPreference returns a fake AccountPreference for testing. func MockAccountPreference(ctx context.Context, dbConn *sqlx.DB, now time.Time) error { + + repo := &Repository{ + DbConn: dbConn, + } + req := AccountPreferenceSetRequest{ AccountID: uuid.NewRandom().String(), Name: AccountPreference_Datetime_Format, Value: AccountPreference_Datetime_Format_Default, } - return Set(ctx, auth.Claims{}, dbConn, req, now) + return repo.Set(ctx, auth.Claims{}, req, now) } diff --git a/internal/account/account_preference/account_preference_test.go b/internal/account/account_preference/account_preference_test.go index e8886f9..2a9ebdf 100644 --- a/internal/account/account_preference/account_preference_test.go +++ b/internal/account/account_preference/account_preference_test.go @@ -1,13 +1,13 @@ package account_preference import ( - "geeks-accelerator/oss/saas-starter-kit/internal/account" "math/rand" "os" "strings" "testing" "time" + "geeks-accelerator/oss/saas-starter-kit/internal/account" "geeks-accelerator/oss/saas-starter-kit/internal/platform/auth" "geeks-accelerator/oss/saas-starter-kit/internal/platform/tests" "geeks-accelerator/oss/saas-starter-kit/internal/user_account" @@ -17,7 +17,10 @@ import ( "github.com/pkg/errors" ) -var test *tests.Test +var ( + test *tests.Test + repo *Repository +) // TestMain is the entry point for testing. func TestMain(m *testing.M) { @@ -27,6 +30,9 @@ func TestMain(m *testing.M) { func testMain(m *testing.M) int { test = tests.New() defer test.TearDown() + + repo = NewRepository(test.MasterDB) + return m.Run() } @@ -66,7 +72,7 @@ func TestSetValidation(t *testing.T) { { ctx := tests.Context() - err := Set(ctx, auth.Claims{}, test.MasterDB, tt.req, now) + err := repo.Set(ctx, auth.Claims{}, tt.req, now) if err != tt.error { // TODO: need a better way to handle validation errors as they are // of type interface validator.ValidationErrorsTranslations @@ -225,7 +231,7 @@ func TestCrud(t *testing.T) { { ctx := tests.Context() - err := Set(ctx, tt.claims(usrAcc.AccountID, usrAcc.UserID), test.MasterDB, tt.set, now) + err := repo.Set(ctx, tt.claims(usrAcc.AccountID, usrAcc.UserID), tt.set, now) if err != nil && errors.Cause(err) != tt.writeErr { t.Logf("\t\tGot : %+v", err) t.Logf("\t\tWant: %+v", tt.writeErr) @@ -234,7 +240,7 @@ func TestCrud(t *testing.T) { // If user doesn't have access to set, create one anyways to test the other endpoints. if tt.writeErr != nil { - err := Set(ctx, auth.Claims{}, test.MasterDB, tt.set, now) + err := repo.Set(ctx, auth.Claims{}, tt.set, now) if err != nil { t.Log("\t\tGot :", err) t.Fatalf("\t%s\tCreate failed.", tests.Failed) @@ -242,7 +248,7 @@ func TestCrud(t *testing.T) { } // Find the account and make sure the set where made. - readRes, err := Read(ctx, tt.claims(usrAcc.AccountID, usrAcc.UserID), test.MasterDB, AccountPreferenceReadRequest{ + readRes, err := repo.Read(ctx, tt.claims(usrAcc.AccountID, usrAcc.UserID), AccountPreferenceReadRequest{ AccountID: tt.set.AccountID, Name: tt.set.Name, }) @@ -266,7 +272,7 @@ func TestCrud(t *testing.T) { } // Archive (soft-delete) the account. - err = Archive(ctx, tt.claims(usrAcc.AccountID, usrAcc.UserID), test.MasterDB, AccountPreferenceArchiveRequest{ + err = repo.Archive(ctx, tt.claims(usrAcc.AccountID, usrAcc.UserID), AccountPreferenceArchiveRequest{ AccountID: tt.set.AccountID, Name: tt.set.Name, }, now) @@ -276,7 +282,7 @@ func TestCrud(t *testing.T) { t.Fatalf("\t%s\tArchive failed.", tests.Failed) } else if tt.findErr == nil { // Trying to find the archived account with the includeArchived false should result in not found. - _, err = Read(ctx, tt.claims(usrAcc.AccountID, usrAcc.UserID), test.MasterDB, AccountPreferenceReadRequest{ + _, err = repo.Read(ctx, tt.claims(usrAcc.AccountID, usrAcc.UserID), AccountPreferenceReadRequest{ AccountID: tt.set.AccountID, Name: tt.set.Name, }) @@ -287,7 +293,7 @@ func TestCrud(t *testing.T) { } // Trying to find the archived account with the includeArchived true should result no error. - _, err = Read(ctx, tt.claims(usrAcc.AccountID, usrAcc.UserID), test.MasterDB, AccountPreferenceReadRequest{ + _, err = repo.Read(ctx, tt.claims(usrAcc.AccountID, usrAcc.UserID), AccountPreferenceReadRequest{ AccountID: tt.set.AccountID, Name: tt.set.Name, IncludeArchived: true, @@ -300,7 +306,7 @@ func TestCrud(t *testing.T) { t.Logf("\t%s\tArchive ok.", tests.Success) // Delete (hard-delete) the account. - err = Delete(ctx, tt.claims(usrAcc.AccountID, usrAcc.UserID), test.MasterDB, AccountPreferenceDeleteRequest{ + err = repo.Delete(ctx, tt.claims(usrAcc.AccountID, usrAcc.UserID), AccountPreferenceDeleteRequest{ AccountID: tt.set.AccountID, Name: tt.set.Name, }) @@ -310,7 +316,7 @@ func TestCrud(t *testing.T) { t.Fatalf("\t%s\tDelete failed.", tests.Failed) } else if tt.writeErr == nil { // Trying to find the deleted account with the includeArchived true should result in not found. - _, err = Read(ctx, tt.claims(usrAcc.AccountID, usrAcc.UserID), test.MasterDB, AccountPreferenceReadRequest{ + _, err = repo.Read(ctx, tt.claims(usrAcc.AccountID, usrAcc.UserID), AccountPreferenceReadRequest{ AccountID: tt.set.AccountID, Name: tt.set.Name, IncludeArchived: true, @@ -362,14 +368,14 @@ func TestFind(t *testing.T) { var prefs []*AccountPreference for idx, req := range reqs { - err = Set(tests.Context(), auth.Claims{}, test.MasterDB, req, now.Add(time.Second*time.Duration(idx))) + err = repo.Set(tests.Context(), auth.Claims{}, req, now.Add(time.Second*time.Duration(idx))) if err != nil { t.Logf("\t\tGot : %+v", err) t.Logf("\t\tRequest : %+v", req) t.Fatalf("\t%s\tSet failed.", tests.Failed) } - pref, err := Read(tests.Context(), auth.Claims{}, test.MasterDB, AccountPreferenceReadRequest{ + pref, err := repo.Read(tests.Context(), auth.Claims{}, AccountPreferenceReadRequest{ AccountID: req.AccountID, Name: req.Name, }) @@ -479,7 +485,7 @@ func TestFind(t *testing.T) { { ctx := tests.Context() - res, err := Find(ctx, auth.Claims{}, test.MasterDB, tt.req) + res, err := repo.Find(ctx, auth.Claims{}, tt.req) if errors.Cause(err) != tt.error { t.Logf("\t\tGot : %+v", err) t.Logf("\t\tWant: %+v", tt.error) diff --git a/internal/account/account_preference/models.go b/internal/account/account_preference/models.go index 869da27..3d383b3 100644 --- a/internal/account/account_preference/models.go +++ b/internal/account/account_preference/models.go @@ -2,15 +2,28 @@ package account_preference import ( "context" - "github.com/pkg/errors" "time" "database/sql/driver" "geeks-accelerator/oss/saas-starter-kit/internal/platform/web" + "github.com/jmoiron/sqlx" "github.com/lib/pq" + "github.com/pkg/errors" "gopkg.in/go-playground/validator.v9" ) +// Repository defines the required dependencies for AccountPreference. +type Repository struct { + DbConn *sqlx.DB +} + +// NewRepository creates a new Repository that defines dependencies for AccountPreference. +func NewRepository(db *sqlx.DB) *Repository { + return &Repository{ + DbConn: db, + } +} + // AccountPreference represents an account setting. type AccountPreference struct { AccountID string `json:"account_id" validate:"required,uuid" example:"c4653bf9-5978-48b7-89c5-95704aebb7e2"` diff --git a/internal/account/account_test.go b/internal/account/account_test.go index 628fc54..8f14b4f 100644 --- a/internal/account/account_test.go +++ b/internal/account/account_test.go @@ -17,7 +17,10 @@ import ( "github.com/pkg/errors" ) -var test *tests.Test +var ( + test *tests.Test + repo *Repository +) // TestMain is the entry point for testing. func TestMain(m *testing.M) { @@ -27,6 +30,9 @@ func TestMain(m *testing.M) { func testMain(m *testing.M) int { test = tests.New() defer test.TearDown() + + repo = NewRepository(test.MasterDB) + return m.Run() } @@ -184,7 +190,7 @@ func TestCreateValidation(t *testing.T) { { ctx := tests.Context() - res, err := Create(ctx, auth.Claims{}, test.MasterDB, tt.req, now) + res, err := repo.Create(ctx, auth.Claims{}, tt.req, now) if err != tt.error { // TODO: need a better way to handle validation errors as they are // of type interface validator.ValidationErrorsTranslations @@ -239,7 +245,7 @@ func TestCreateValidationNameUnique(t *testing.T) { Country: "USA", Zipcode: "99686", } - account1, err := Create(ctx, auth.Claims{}, test.MasterDB, req1, now) + account1, err := repo.Create(ctx, auth.Claims{}, req1, now) if err != nil { t.Log("\t\tGot :", err) t.Fatalf("\t%s\tCreate failed.", tests.Failed) @@ -255,7 +261,7 @@ func TestCreateValidationNameUnique(t *testing.T) { Zipcode: "99686", } expectedErr := errors.New("Key: 'AccountCreateRequest.name' Error:Field validation for 'name' failed on the 'unique' tag") - _, err = Create(ctx, auth.Claims{}, test.MasterDB, req2, now) + _, err = repo.Create(ctx, auth.Claims{}, req2, now) if err == nil { t.Logf("\t\tWant: %+v", expectedErr) t.Fatalf("\t%s\tCreate failed.", tests.Failed) @@ -349,7 +355,7 @@ func TestCreateClaims(t *testing.T) { { ctx := tests.Context() - _, err := Create(ctx, auth.Claims{}, test.MasterDB, tt.req, now) + _, err := repo.Create(ctx, auth.Claims{}, tt.req, now) if errors.Cause(err) != tt.error { t.Logf("\t\tGot : %+v", err) t.Logf("\t\tWant: %+v", tt.error) @@ -396,7 +402,7 @@ func TestUpdateValidation(t *testing.T) { { ctx := tests.Context() - err := Update(ctx, auth.Claims{}, test.MasterDB, tt.req, now) + err := repo.Update(ctx, auth.Claims{}, tt.req, now) if err != tt.error { // TODO: need a better way to handle validation errors as they are // of type interface validator.ValidationErrorsTranslations @@ -440,7 +446,7 @@ func TestUpdateValidationNameUnique(t *testing.T) { Country: "USA", Zipcode: "99686", } - account1, err := Create(ctx, auth.Claims{}, test.MasterDB, req1, now) + account1, err := repo.Create(ctx, auth.Claims{}, req1, now) if err != nil { t.Log("\t\tGot :", err) t.Fatalf("\t%s\tCreate failed.", tests.Failed) @@ -455,7 +461,7 @@ func TestUpdateValidationNameUnique(t *testing.T) { Country: "USA", Zipcode: "99686", } - account2, err := Create(ctx, auth.Claims{}, test.MasterDB, req2, now) + account2, err := repo.Create(ctx, auth.Claims{}, req2, now) if err != nil { t.Log("\t\tGot :", err) t.Fatalf("\t%s\tCreate failed.", tests.Failed) @@ -467,7 +473,7 @@ func TestUpdateValidationNameUnique(t *testing.T) { Name: &account1.Name, } expectedErr := errors.New("Key: 'AccountUpdateRequest.name' Error:Field validation for 'name' failed on the 'unique' tag") - err = Update(ctx, auth.Claims{}, test.MasterDB, updateReq, now) + err = repo.Update(ctx, auth.Claims{}, updateReq, now) if err == nil { t.Logf("\t\tWant: %+v", expectedErr) t.Fatalf("\t%s\tUpdate failed.", tests.Failed) @@ -728,7 +734,7 @@ func TestCrud(t *testing.T) { // Always create the new account with empty claims, testing claims for create account // will be handled separately. - account, err := Create(ctx, auth.Claims{}, test.MasterDB, tt.create, now) + account, err := repo.Create(ctx, auth.Claims{}, tt.create, now) if err != nil { t.Log("\t\tGot :", err) t.Fatalf("\t%s\tCreate failed.", tests.Failed) @@ -744,7 +750,7 @@ func TestCrud(t *testing.T) { // Update the account. updateReq := tt.update(account) - err = Update(ctx, tt.claims(account, userId), test.MasterDB, updateReq, now) + err = repo.Update(ctx, tt.claims(account, userId), updateReq, now) if err != nil && errors.Cause(err) != tt.updateErr { t.Logf("\t\tGot : %+v", err) t.Logf("\t\tWant: %+v", tt.updateErr) @@ -753,7 +759,7 @@ func TestCrud(t *testing.T) { t.Logf("\t%s\tUpdate ok.", tests.Success) // Find the account and make sure the updates where made. - findRes, err := ReadByID(ctx, tt.claims(account, userId), test.MasterDB, account.ID) + findRes, err := repo.ReadByID(ctx, tt.claims(account, userId), account.ID) if err != nil && errors.Cause(err) != tt.findErr { t.Logf("\t\tGot : %+v", err) t.Logf("\t\tWant: %+v", tt.findErr) @@ -767,14 +773,14 @@ func TestCrud(t *testing.T) { } // Archive (soft-delete) the account. - err = Archive(ctx, tt.claims(account, userId), test.MasterDB, AccountArchiveRequest{ID: account.ID}, now) + err = repo.Archive(ctx, tt.claims(account, userId), AccountArchiveRequest{ID: account.ID}, now) if err != nil && errors.Cause(err) != tt.updateErr { t.Logf("\t\tGot : %+v", err) t.Logf("\t\tWant: %+v", tt.updateErr) t.Fatalf("\t%s\tArchive failed.", tests.Failed) } else if tt.updateErr == nil { // Trying to find the archived account with the includeArchived false should result in not found. - _, err = ReadByID(ctx, tt.claims(account, userId), test.MasterDB, account.ID) + _, err = repo.ReadByID(ctx, tt.claims(account, userId), account.ID) if err != nil && errors.Cause(err) != ErrNotFound { t.Logf("\t\tGot : %+v", err) t.Logf("\t\tWant: %+v", ErrNotFound) @@ -782,7 +788,7 @@ func TestCrud(t *testing.T) { } // Trying to find the archived account with the includeArchived true should result no error. - _, err = Read(ctx, tt.claims(account, userId), test.MasterDB, + _, err = repo.Read(ctx, tt.claims(account, userId), AccountReadRequest{ID: account.ID, IncludeArchived: true}) if err != nil { t.Log("\t\tGot :", err) @@ -792,14 +798,14 @@ func TestCrud(t *testing.T) { t.Logf("\t%s\tArchive ok.", tests.Success) // Delete (hard-delete) the account. - err = Delete(ctx, tt.claims(account, userId), test.MasterDB, AccountDeleteRequest{ID: account.ID}) + err = repo.Delete(ctx, tt.claims(account, userId), AccountDeleteRequest{ID: account.ID}) if err != nil && errors.Cause(err) != tt.updateErr { t.Logf("\t\tGot : %+v", err) t.Logf("\t\tWant: %+v", tt.updateErr) t.Fatalf("\t%s\tUpdate failed.", tests.Failed) } else if tt.updateErr == nil { // Trying to find the deleted account with the includeArchived true should result in not found. - _, err = ReadByID(ctx, tt.claims(account, userId), test.MasterDB, account.ID) + _, err = repo.ReadByID(ctx, tt.claims(account, userId), account.ID) if errors.Cause(err) != ErrNotFound { t.Logf("\t\tGot : %+v", err) t.Logf("\t\tWant: %+v", ErrNotFound) @@ -822,7 +828,7 @@ func TestFind(t *testing.T) { var accounts []*Account for i := 0; i <= 4; i++ { - account, err := Create(tests.Context(), auth.Claims{}, test.MasterDB, AccountCreateRequest{ + account, err := repo.Create(tests.Context(), auth.Claims{}, AccountCreateRequest{ Name: uuid.NewRandom().String(), Address1: "103 East Main St", Address2: "Unit 546", @@ -935,7 +941,7 @@ func TestFind(t *testing.T) { { ctx := tests.Context() - res, err := Find(ctx, auth.Claims{}, test.MasterDB, tt.req) + res, err := repo.Find(ctx, auth.Claims{}, tt.req) if errors.Cause(err) != tt.error { t.Logf("\t\tGot : %+v", err) t.Logf("\t\tWant: %+v", tt.error) diff --git a/internal/account/models.go b/internal/account/models.go index 0907b36..843ce7c 100644 --- a/internal/account/models.go +++ b/internal/account/models.go @@ -5,14 +5,27 @@ import ( "database/sql" "database/sql/driver" "encoding/json" - "geeks-accelerator/oss/saas-starter-kit/internal/platform/web" "time" + "geeks-accelerator/oss/saas-starter-kit/internal/platform/web" + "github.com/jmoiron/sqlx" "github.com/lib/pq" "github.com/pkg/errors" "gopkg.in/go-playground/validator.v9" ) +// Repository defines the required dependencies for Account. +type Repository struct { + DbConn *sqlx.DB +} + +// NewRepository creates a new Repository that defines dependencies for Account. +func NewRepository(db *sqlx.DB) *Repository { + return &Repository{ + DbConn: db, + } +} + // Account represents someone with access to our system. type Account struct { ID string `json:"id" validate:"required,uuid" example:"c4653bf9-5978-48b7-89c5-95704aebb7e2"` diff --git a/internal/project/models.go b/internal/project/models.go index ba52589..eff7cfb 100644 --- a/internal/project/models.go +++ b/internal/project/models.go @@ -2,14 +2,28 @@ package project import ( "context" + "time" + "database/sql/driver" "geeks-accelerator/oss/saas-starter-kit/internal/platform/web" + "github.com/jmoiron/sqlx" "github.com/lib/pq" "github.com/pkg/errors" "gopkg.in/go-playground/validator.v9" - "time" ) +// Repository defines the required dependencies for Project. +type Repository struct { + DbConn *sqlx.DB +} + +// NewRepository creates a new Repository that defines dependencies for Project. +func NewRepository(db *sqlx.DB) *Repository { + return &Repository{ + DbConn: db, + } +} + // Project represents a workflow. type Project struct { ID string `json:"id" validate:"required,uuid" example:"985f1746-1d9f-459f-a2d9-fc53ece5ae86"` diff --git a/internal/project/project.go b/internal/project/project.go index 990fdd8..b4fe3e2 100644 --- a/internal/project/project.go +++ b/internal/project/project.go @@ -3,6 +3,8 @@ package project import ( "context" "database/sql" + "time" + "geeks-accelerator/oss/saas-starter-kit/internal/platform/auth" "geeks-accelerator/oss/saas-starter-kit/internal/platform/web/webcontext" "github.com/huandu/go-sqlbuilder" @@ -10,7 +12,6 @@ import ( "github.com/pborman/uuid" "github.com/pkg/errors" "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer" - "time" ) const ( @@ -27,7 +28,7 @@ var ( ) // CanReadProject determines if claims has the authority to access the specified project by id. -func CanReadProject(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, id string) error { +func (repo *Repository) CanReadProject(ctx context.Context, claims auth.Claims, id string) error { // If the request has claims from a specific project, ensure that the claims // has the correct access to the project. @@ -40,9 +41,9 @@ func CanReadProject(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, id )) queryStr, args := query.Build() - queryStr = dbConn.Rebind(queryStr) + queryStr = repo.DbConn.Rebind(queryStr) var id string - err := dbConn.QueryRowContext(ctx, queryStr, args...).Scan(&id) + err := repo.DbConn.QueryRowContext(ctx, queryStr, args...).Scan(&id) if err != nil && err != sql.ErrNoRows { err = errors.Wrapf(err, "query - %s", query.String()) return err @@ -60,8 +61,8 @@ func CanReadProject(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, id } // CanModifyProject determines if claims has the authority to modify the specified project by id. -func CanModifyProject(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, id string) error { - err := CanReadProject(ctx, claims, dbConn, id) +func (repo *Repository) CanModifyProject(ctx context.Context, claims auth.Claims, id string) error { + err := repo.CanReadProject(ctx, claims, id) if err != nil { return err } @@ -124,9 +125,9 @@ func findRequestQuery(req ProjectFindRequest) (*sqlbuilder.SelectBuilder, []inte } // Find gets all the projects from the database based on the request params. -func Find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req ProjectFindRequest) (Projects, error) { +func (repo *Repository) Find(ctx context.Context, claims auth.Claims, req ProjectFindRequest) (Projects, error) { query, args := findRequestQuery(req) - return find(ctx, claims, dbConn, query, args, req.IncludeArchived) + return find(ctx, claims, repo.DbConn, query, args, req.IncludeArchived) } // find internal method for getting all the projects from the database using a select query. @@ -177,15 +178,15 @@ func find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, query *sqlbu } // ReadByID gets the specified project by ID from the database. -func ReadByID(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, id string) (*Project, error) { - return Read(ctx, claims, dbConn, ProjectReadRequest{ +func (repo *Repository) ReadByID(ctx context.Context, claims auth.Claims, id string) (*Project, error) { + return repo.Read(ctx, claims, ProjectReadRequest{ ID: id, IncludeArchived: false, }) } // Read gets the specified project from the database. -func Read(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req ProjectReadRequest) (*Project, error) { +func (repo *Repository) Read(ctx context.Context, claims auth.Claims, req ProjectReadRequest) (*Project, error) { span, ctx := tracer.StartSpanFromContext(ctx, "internal.project.Read") defer span.Finish() @@ -200,7 +201,7 @@ func Read(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req ProjectR query := sqlbuilder.NewSelectBuilder() query.Where(query.Equal("id", req.ID)) - res, err := find(ctx, claims, dbConn, query, []interface{}{}, req.IncludeArchived) + res, err := find(ctx, claims, repo.DbConn, query, []interface{}{}, req.IncludeArchived) if err != nil { return nil, err } else if res == nil || len(res) == 0 { @@ -213,7 +214,7 @@ func Read(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req ProjectR } // Create inserts a new project into the database. -func Create(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req ProjectCreateRequest, now time.Time) (*Project, error) { +func (repo *Repository) Create(ctx context.Context, claims auth.Claims, req ProjectCreateRequest, now time.Time) (*Project, error) { span, ctx := tracer.StartSpanFromContext(ctx, "internal.project.Create") defer span.Finish() if claims.Audience != "" { @@ -290,8 +291,8 @@ func Create(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req Projec // Execute the query with the provided context. sql, args := query.Build() - sql = dbConn.Rebind(sql) - _, err = dbConn.ExecContext(ctx, sql, args...) + sql = repo.DbConn.Rebind(sql) + _, err = repo.DbConn.ExecContext(ctx, sql, args...) if err != nil { err = errors.Wrapf(err, "query - %s", query.String()) err = errors.WithMessage(err, "create project failed") @@ -302,7 +303,7 @@ func Create(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req Projec } // Update replaces an project in the database. -func Update(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req ProjectUpdateRequest, now time.Time) error { +func (repo *Repository) Update(ctx context.Context, claims auth.Claims, req ProjectUpdateRequest, now time.Time) error { span, ctx := tracer.StartSpanFromContext(ctx, "internal.project.Update") defer span.Finish() @@ -314,7 +315,7 @@ func Update(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req Projec } // Ensure the claims can modify the project specified in the request. - err = CanModifyProject(ctx, claims, dbConn, req.ID) + err = repo.CanModifyProject(ctx, claims, req.ID) if err != nil { return err } @@ -352,8 +353,8 @@ func Update(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req Projec query.Where(query.Equal("id", req.ID)) // Execute the query with the provided context. sql, args := query.Build() - sql = dbConn.Rebind(sql) - _, err = dbConn.ExecContext(ctx, sql, args...) + sql = repo.DbConn.Rebind(sql) + _, err = repo.DbConn.ExecContext(ctx, sql, args...) if err != nil { err = errors.Wrapf(err, "query - %s", query.String()) err = errors.WithMessagef(err, "update project %s failed", req.ID) @@ -364,7 +365,7 @@ func Update(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req Projec } // Archive soft deleted the project from the database. -func Archive(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req ProjectArchiveRequest, now time.Time) error { +func (repo *Repository) Archive(ctx context.Context, claims auth.Claims, req ProjectArchiveRequest, now time.Time) error { span, ctx := tracer.StartSpanFromContext(ctx, "internal.project.Archive") defer span.Finish() @@ -376,7 +377,7 @@ func Archive(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req Proje } // Ensure the claims can modify the project specified in the request. - err = CanModifyProject(ctx, claims, dbConn, req.ID) + err = repo.CanModifyProject(ctx, claims, req.ID) if err != nil { return err } @@ -401,8 +402,8 @@ func Archive(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req Proje query.Where(query.Equal("id", req.ID)) // Execute the query with the provided context. sql, args := query.Build() - sql = dbConn.Rebind(sql) - _, err = dbConn.ExecContext(ctx, sql, args...) + sql = repo.DbConn.Rebind(sql) + _, err = repo.DbConn.ExecContext(ctx, sql, args...) if err != nil { err = errors.Wrapf(err, "query - %s", query.String()) err = errors.WithMessagef(err, "archive project %s failed", req.ID) @@ -413,7 +414,7 @@ func Archive(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req Proje } // Delete removes an project from the database. -func Delete(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req ProjectDeleteRequest) error { +func (repo *Repository) Delete(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req ProjectDeleteRequest) error { span, ctx := tracer.StartSpanFromContext(ctx, "internal.project.Delete") defer span.Finish() @@ -425,7 +426,7 @@ func Delete(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req Projec } // Ensure the claims can modify the project specified in the request. - err = CanModifyProject(ctx, claims, dbConn, req.ID) + err = repo.CanModifyProject(ctx, claims, req.ID) if err != nil { return err } diff --git a/internal/project/project_test.go b/internal/project/project_test.go index f097c9e..8e702aa 100644 --- a/internal/project/project_test.go +++ b/internal/project/project_test.go @@ -1,15 +1,19 @@ package project import ( + "os" + "testing" + "geeks-accelerator/oss/saas-starter-kit/internal/platform/auth" "geeks-accelerator/oss/saas-starter-kit/internal/platform/tests" "github.com/google/go-cmp/cmp" "github.com/huandu/go-sqlbuilder" - "os" - "testing" ) -var test *tests.Test +var ( + test *tests.Test + repo *Repository +) // TestMain is the entry point for testing. func TestMain(m *testing.M) { @@ -19,6 +23,9 @@ func TestMain(m *testing.M) { func testMain(m *testing.M) int { test = tests.New() defer test.TearDown() + + repo = NewRepository(test.MasterDB) + return m.Run() } diff --git a/internal/project-routes/project_routes.go b/internal/project_route/project_routes.go similarity index 64% rename from internal/project-routes/project_routes.go rename to internal/project_route/project_routes.go index 14ab093..2c98003 100644 --- a/internal/project-routes/project_routes.go +++ b/internal/project_route/project_routes.go @@ -1,17 +1,17 @@ -package project_routes +package project_route import ( "github.com/pkg/errors" "net/url" ) -type ProjectRoutes struct { +type ProjectRoute struct { webAppUrl url.URL webApiUrl url.URL } -func New(apiBaseUrl, appBaseUrl string) (ProjectRoutes, error) { - var r ProjectRoutes +func New(apiBaseUrl, appBaseUrl string) (ProjectRoute, error) { + var r ProjectRoute apiUrl, err := url.Parse(apiBaseUrl) if err != nil { @@ -28,37 +28,37 @@ func New(apiBaseUrl, appBaseUrl string) (ProjectRoutes, error) { return r, nil } -func (r ProjectRoutes) WebAppUrl(urlPath string) string { +func (r ProjectRoute) WebAppUrl(urlPath string) string { u := r.webAppUrl u.Path = urlPath return u.String() } -func (r ProjectRoutes) WebApiUrl(urlPath string) string { +func (r ProjectRoute) WebApiUrl(urlPath string) string { u := r.webApiUrl u.Path = urlPath return u.String() } -func (r ProjectRoutes) UserResetPassword(resetHash string) string { +func (r ProjectRoute) UserResetPassword(resetHash string) string { u := r.webAppUrl u.Path = "/user/reset-password/" + resetHash return u.String() } -func (r ProjectRoutes) UserInviteAccept(inviteHash string) string { +func (r ProjectRoute) UserInviteAccept(inviteHash string) string { u := r.webAppUrl u.Path = "/users/invite/" + inviteHash return u.String() } -func (r ProjectRoutes) ApiDocs() string { +func (r ProjectRoute) ApiDocs() string { u := r.webApiUrl u.Path = "/docs" return u.String() } -func (r ProjectRoutes) ApiDocsJson() string { +func (r ProjectRoute) ApiDocsJson() string { u := r.webApiUrl u.Path = "/docs/doc.json" return u.String() diff --git a/internal/signup/models.go b/internal/signup/models.go index 8db28ca..a84b272 100644 --- a/internal/signup/models.go +++ b/internal/signup/models.go @@ -2,10 +2,31 @@ package signup import ( "context" + "geeks-accelerator/oss/saas-starter-kit/internal/account" "geeks-accelerator/oss/saas-starter-kit/internal/user" + "geeks-accelerator/oss/saas-starter-kit/internal/user_account" + "github.com/jmoiron/sqlx" ) +// Repository defines the required dependencies for Signup. +type Repository struct { + DbConn *sqlx.DB + User *user.Repository + UserAccount *user_account.Repository + Account *account.Repository +} + +// NewRepository creates a new Repository that defines dependencies for Signup. +func NewRepository(db *sqlx.DB, user *user.Repository, userAccount *user_account.Repository, account *account.Repository) *Repository { + return &Repository{ + DbConn: db, + User: user, + UserAccount: userAccount, + Account: account, + } +} + // SignupRequest contains information needed perform signup. type SignupRequest struct { Account SignupAccount `json:"account" validate:"required"` // Account details. diff --git a/internal/signup/signup.go b/internal/signup/signup.go index ce5e51c..4c049eb 100644 --- a/internal/signup/signup.go +++ b/internal/signup/signup.go @@ -9,25 +9,24 @@ import ( "geeks-accelerator/oss/saas-starter-kit/internal/platform/web/webcontext" "geeks-accelerator/oss/saas-starter-kit/internal/user" "geeks-accelerator/oss/saas-starter-kit/internal/user_account" - "github.com/jmoiron/sqlx" "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer" ) // Signup performs the steps needed to create a new account, new user and then associate // both records with a new user_account entry. -func Signup(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req SignupRequest, now time.Time) (*SignupResult, error) { +func (repo *Repository) Signup(ctx context.Context, claims auth.Claims, req SignupRequest, now time.Time) (*SignupResult, error) { span, ctx := tracer.StartSpanFromContext(ctx, "internal.signup.Signup") defer span.Finish() // Validate the user email address is unique in the database. - uniqEmail, err := user.UniqueEmail(ctx, dbConn, req.User.Email, "") + uniqEmail, err := user.UniqueEmail(ctx, repo.DbConn, req.User.Email, "") if err != nil { return nil, err } ctx = webcontext.ContextAddUniqueValue(ctx, req.User, "Email", uniqEmail) // Validate the account name is unique in the database. - uniqName, err := account.UniqueName(ctx, dbConn, req.Account.Name, "") + uniqName, err := account.UniqueName(ctx, repo.DbConn, req.Account.Name, "") if err != nil { return nil, err } @@ -52,7 +51,7 @@ func Signup(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req Signup } // Execute user creation. - resp.User, err = user.Create(ctx, claims, dbConn, userReq, now) + resp.User, err = repo.User.Create(ctx, claims, userReq, now) if err != nil { return nil, err } @@ -73,7 +72,7 @@ func Signup(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req Signup } // Execute account creation. - resp.Account, err = account.Create(ctx, claims, dbConn, accountReq, now) + resp.Account, err = repo.Account.Create(ctx, claims, accountReq, now) if err != nil { return nil, err } @@ -87,7 +86,7 @@ func Signup(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req Signup //Status: Use default value } - _, err = user_account.Create(ctx, claims, dbConn, ua, now) + _, err = repo.UserAccount.Create(ctx, claims, ua, now) if err != nil { return nil, err } diff --git a/internal/signup/signup_test.go b/internal/signup/signup_test.go index a8d7e95..369f114 100644 --- a/internal/signup/signup_test.go +++ b/internal/signup/signup_test.go @@ -1,19 +1,26 @@ package signup import ( - "geeks-accelerator/oss/saas-starter-kit/internal/user_auth" "os" "testing" "time" + "geeks-accelerator/oss/saas-starter-kit/internal/account" + "geeks-accelerator/oss/saas-starter-kit/internal/account/account_preference" "geeks-accelerator/oss/saas-starter-kit/internal/platform/auth" "geeks-accelerator/oss/saas-starter-kit/internal/platform/tests" + "geeks-accelerator/oss/saas-starter-kit/internal/user" + "geeks-accelerator/oss/saas-starter-kit/internal/user_account" + "geeks-accelerator/oss/saas-starter-kit/internal/user_auth" "github.com/google/go-cmp/cmp" "github.com/pborman/uuid" "github.com/pkg/errors" ) -var test *tests.Test +var ( + test *tests.Test + repo *Repository +) // TestMain is the entry point for testing. func TestMain(m *testing.M) { @@ -23,6 +30,13 @@ func TestMain(m *testing.M) { func testMain(m *testing.M) int { test = tests.New() defer test.TearDown() + + userRepo := user.MockRepository(test.MasterDB) + userAccRepo := user_account.NewRepository(test.MasterDB) + accRepo := account.NewRepository(test.MasterDB) + + repo = NewRepository(test.MasterDB, userRepo, userAccRepo, accRepo) + return m.Run() } @@ -63,7 +77,7 @@ func TestSignupValidation(t *testing.T) { { ctx := tests.Context() - res, err := Signup(ctx, auth.Claims{}, test.MasterDB, tt.req, now) + res, err := repo.Signup(ctx, auth.Claims{}, tt.req, now) if err != tt.error { // TODO: need a better way to handle validation errors as they are // of type interface validator.ValidationErrorsTranslations @@ -127,9 +141,12 @@ func TestSignupFull(t *testing.T) { tknGen := &auth.MockTokenGenerator{} + accPrefRepo := account_preference.NewRepository(test.MasterDB) + authRepo := user_auth.NewRepository(test.MasterDB, tknGen, repo.User, repo.UserAccount, accPrefRepo) + t.Log("Given the need to ensure signup works.") { - res, err := Signup(ctx, auth.Claims{}, test.MasterDB, req, now) + res, err := repo.Signup(ctx, auth.Claims{}, req, now) if err != nil { t.Logf("\t\tGot error : %+v", err) t.Fatalf("\t%s\tSignup failed.", tests.Failed) @@ -162,7 +179,7 @@ func TestSignupFull(t *testing.T) { t.Logf("\t%s\tSignup ok.", tests.Success) // Verify that the user can be authenticated with the updated password. - _, err = user_auth.Authenticate(ctx, test.MasterDB, tknGen, user_auth.AuthenticateRequest{ + _, err = authRepo.Authenticate(ctx, user_auth.AuthenticateRequest{ Email: res.User.Email, Password: req.User.Password, }, time.Hour, now) diff --git a/internal/user/models.go b/internal/user/models.go index b4c65c5..7860dd3 100644 --- a/internal/user/models.go +++ b/internal/user/models.go @@ -4,34 +4,34 @@ import ( "context" "database/sql" "encoding/json" - "geeks-accelerator/oss/saas-starter-kit/internal/platform/notify" - "geeks-accelerator/oss/saas-starter-kit/internal/platform/web" - "geeks-accelerator/oss/saas-starter-kit/internal/platform/web/webcontext" - "github.com/jmoiron/sqlx" - "github.com/pkg/errors" - "github.com/sudo-suhas/symcrypto" "strconv" "strings" "time" + "geeks-accelerator/oss/saas-starter-kit/internal/platform/notify" + "geeks-accelerator/oss/saas-starter-kit/internal/platform/web" + "geeks-accelerator/oss/saas-starter-kit/internal/platform/web/webcontext" + "github.com/jmoiron/sqlx" "github.com/lib/pq" + "github.com/pkg/errors" + "github.com/sudo-suhas/symcrypto" ) // Repository defines the required dependencies for User. type Repository struct { - DbConn *sqlx.DB - ResetUrl func(string) string - Notify notify.Email - SecretKey string + DbConn *sqlx.DB + ResetUrl func(string) string + Notify notify.Email + secretKey string } // NewRepository creates a new Repository that defines dependencies for User. func NewRepository(db *sqlx.DB, resetUrl func(string) string, notify notify.Email, secretKey string) *Repository { return &Repository{ - DbConn: db, - ResetUrl: resetUrl, - Notify: notify, - SecretKey: secretKey, + DbConn: db, + ResetUrl: resetUrl, + Notify: notify, + secretKey: secretKey, } } diff --git a/internal/user/user.go b/internal/user/user.go index 81d036e..58d005c 100644 --- a/internal/user/user.go +++ b/internal/user/user.go @@ -6,6 +6,7 @@ import ( "time" "geeks-accelerator/oss/saas-starter-kit/internal/platform/auth" + "geeks-accelerator/oss/saas-starter-kit/internal/platform/notify" "geeks-accelerator/oss/saas-starter-kit/internal/platform/web/webcontext" "github.com/huandu/go-sqlbuilder" "github.com/jmoiron/sqlx" @@ -200,11 +201,11 @@ func findRequestQuery(req UserFindRequest) (*sqlbuilder.SelectBuilder, []interfa // Find gets all the users from the database based on the request params. func (repo *Repository) Find(ctx context.Context, claims auth.Claims, req UserFindRequest) (Users, error) { query, args := findRequestQuery(req) - return repo.find(ctx, claims, query, args, req.IncludeArchived) + return find(ctx, claims, repo.DbConn, query, args, req.IncludeArchived) } // find internal method for getting all the users from the database using a select query. -func (repo *Repository) find(ctx context.Context, claims auth.Claims, query *sqlbuilder.SelectBuilder, args []interface{}, includedArchived bool) (Users, error) { +func find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, query *sqlbuilder.SelectBuilder, args []interface{}, includedArchived bool) (Users, error) { span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.Find") defer span.Finish() @@ -221,11 +222,11 @@ func (repo *Repository) find(ctx context.Context, claims auth.Claims, query *sql return nil, err } queryStr, queryArgs := query.Build() - queryStr = repo.DbConn.Rebind(queryStr) + queryStr = dbConn.Rebind(queryStr) args = append(args, queryArgs...) // fetch all places from the db - rows, err := repo.DbConn.QueryContext(ctx, queryStr, args...) + rows, err := dbConn.QueryContext(ctx, queryStr, args...) if err != nil { err = errors.Wrapf(err, "query - %s", query.String()) err = errors.WithMessage(err, "find users failed") @@ -247,17 +248,17 @@ func (repo *Repository) find(ctx context.Context, claims auth.Claims, query *sql } // Validation an email address is unique excluding the current user ID. -func (repo *Repository) UniqueEmail(ctx context.Context, email, userId string) (bool, error) { +func UniqueEmail(ctx context.Context, dbConn *sqlx.DB, email, userId string) (bool, error) { query := sqlbuilder.NewSelectBuilder().Select("id").From(userTableName) query.Where(query.And( query.Equal("email", email), query.NotEqual("id", userId), )) queryStr, args := query.Build() - queryStr = repo.DbConn.Rebind(queryStr) + queryStr = dbConn.Rebind(queryStr) var existingId string - err := repo.DbConn.QueryRowContext(ctx, queryStr, args...).Scan(&existingId) + err := dbConn.QueryRowContext(ctx, queryStr, args...).Scan(&existingId) if err != nil && err != sql.ErrNoRows { err = errors.Wrapf(err, "query - %s", query.String()) return false, err @@ -283,7 +284,7 @@ func (repo *Repository) Create(ctx context.Context, claims auth.Claims, req User v := webcontext.Validator() // Validation email address is unique in the database. - uniq, err := repo.UniqueEmail(ctx, req.Email, "") + uniq, err := UniqueEmail(ctx, repo.DbConn, req.Email, "") if err != nil { return nil, err } @@ -364,7 +365,7 @@ func (repo *Repository) CreateInvite(ctx context.Context, claims auth.Claims, re v := webcontext.Validator() // Validation email address is unique in the database. - uniq, err := repo.UniqueEmail(ctx, req.Email, "") + uniq, err := UniqueEmail(ctx, repo.DbConn, req.Email, "") if err != nil { return nil, err } @@ -448,7 +449,7 @@ func (repo *Repository) Read(ctx context.Context, claims auth.Claims, req UserRe query := selectQuery() query.Where(query.Equal("id", req.ID)) - res, err := repo.find(ctx, claims, query, []interface{}{}, req.IncludeArchived) + res, err := find(ctx, claims, repo.DbConn, query, []interface{}{}, req.IncludeArchived) if err != nil { return nil, err } else if res == nil || len(res) == 0 { @@ -469,7 +470,7 @@ func (repo *Repository) ReadByEmail(ctx context.Context, claims auth.Claims, ema query := selectQuery() query.Where(query.Equal("email", email)) - res, err := repo.find(ctx, claims, query, []interface{}{}, includedArchived) + res, err := find(ctx, claims, repo.DbConn, query, []interface{}{}, includedArchived) if err != nil { return nil, err } else if res == nil || len(res) == 0 { @@ -489,7 +490,7 @@ func (repo *Repository) Update(ctx context.Context, claims auth.Claims, req User // Validation email address is unique in the database. if req.Email != nil { // Validation email address is unique in the database. - uniq, err := repo.UniqueEmail(ctx, *req.Email, req.ID) + uniq, err := UniqueEmail(ctx, repo.DbConn, *req.Email, req.ID) if err != nil { return err } @@ -844,7 +845,7 @@ func (repo *Repository) ResetPassword(ctx context.Context, req UserResetPassword query := selectQuery() query.Where(query.Equal("email", req.Email)) - res, err := repo.find(ctx, auth.Claims{}, query, []interface{}{}, false) + res, err := find(ctx, auth.Claims{}, repo.DbConn, query, []interface{}{}, false) if err != nil { return "", err } else if res == nil || len(res) == 0 { @@ -894,7 +895,7 @@ func (repo *Repository) ResetPassword(ctx context.Context, req UserResetPassword requestIp = vals.RequestIP } - encrypted, err := NewResetHash(ctx, repo.SecretKey, resetId, requestIp, req.TTL, now) + encrypted, err := NewResetHash(ctx, repo.secretKey, resetId, requestIp, req.TTL, now) if err != nil { return "", err } @@ -927,7 +928,7 @@ func (repo *Repository) ResetConfirm(ctx context.Context, req UserResetConfirmRe return nil, err } - hash, err := ParseResetHash(ctx, repo.SecretKey, req.ResetHash, now) + hash, err := ParseResetHash(ctx, repo.secretKey, req.ResetHash, now) if err != nil { return nil, err } @@ -938,7 +939,7 @@ func (repo *Repository) ResetConfirm(ctx context.Context, req UserResetConfirmRe query := selectQuery() query.Where(query.Equal("password_reset", hash.ResetID)) - res, err := repo.find(ctx, auth.Claims{}, query, []interface{}{}, false) + res, err := find(ctx, auth.Claims{}, repo.DbConn, query, []interface{}{}, false) if err != nil { return nil, err } else if res == nil || len(res) == 0 { @@ -1020,3 +1021,14 @@ func MockUser(ctx context.Context, dbConn *sqlx.DB, now time.Time) (*MockUserRes Password: pass, }, nil } + +func MockRepository(dbConn *sqlx.DB) *Repository { + // Mock the methods needed to make a password reset. + resetUrl := func(string) string { + return "" + } + notify := ¬ify.MockEmail{} + secretKey := "6368616e676520746869732070617373" + + return NewRepository(dbConn, resetUrl, notify, secretKey) +} diff --git a/internal/user/user_test.go b/internal/user/user_test.go index edfaf49..fd40aa6 100644 --- a/internal/user/user_test.go +++ b/internal/user/user_test.go @@ -8,7 +8,6 @@ import ( "time" "geeks-accelerator/oss/saas-starter-kit/internal/platform/auth" - "geeks-accelerator/oss/saas-starter-kit/internal/platform/notify" "geeks-accelerator/oss/saas-starter-kit/internal/platform/tests" "github.com/dgrijalva/jwt-go" "github.com/google/go-cmp/cmp" @@ -32,14 +31,7 @@ func testMain(m *testing.M) int { test = tests.New() defer test.TearDown() - // Mock the methods needed to make a password reset. - resetUrl := func(string) string { - return "" - } - notify := ¬ify.MockEmail{} - secretKey := "6368616e676520746869732070617373" - - repo = NewRepository(test.MasterDB, resetUrl, notify, secretKey) + repo = MockRepository(test.MasterDB) return m.Run() } @@ -930,7 +922,7 @@ func TestFind(t *testing.T) { var users []*User for i := 0; i <= 4; i++ { - user, err := repo.Create(tests.Context(), auth.Claims{}, UserCreateRequest{ + user, err := repo.Create(tests.Context(), auth.Claims{}, UserCreateRequest{ FirstName: "Lee", LastName: "Brown", Email: uuid.NewRandom().String() + "@geeksinthewoods.com", @@ -1042,7 +1034,7 @@ func TestFind(t *testing.T) { { ctx := tests.Context() - res, err := repo.Find(ctx, auth.Claims{}, tt.req) + res, err := repo.Find(ctx, auth.Claims{}, tt.req) if errors.Cause(err) != tt.error { t.Logf("\t\tGot : %+v", err) t.Logf("\t\tWant: %+v", tt.error) diff --git a/internal/user_account/invite/invite.go b/internal/user_account/invite/invite.go index a1fe4b1..0ee8262 100644 --- a/internal/user_account/invite/invite.go +++ b/internal/user_account/invite/invite.go @@ -8,11 +8,9 @@ import ( "geeks-accelerator/oss/saas-starter-kit/internal/account" "geeks-accelerator/oss/saas-starter-kit/internal/platform/auth" - "geeks-accelerator/oss/saas-starter-kit/internal/platform/notify" "geeks-accelerator/oss/saas-starter-kit/internal/platform/web/webcontext" "geeks-accelerator/oss/saas-starter-kit/internal/user" "geeks-accelerator/oss/saas-starter-kit/internal/user_account" - "github.com/jmoiron/sqlx" "github.com/pkg/errors" "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer" ) @@ -29,7 +27,7 @@ var ( ) // SendUserInvites sends emails to the users inviting them to join an account. -func SendUserInvites(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, resetUrl func(string) string, notify notify.Email, req SendUserInvitesRequest, secretKey string, now time.Time) ([]string, error) { +func (repo *Repository) SendUserInvites(ctx context.Context, claims auth.Claims, req SendUserInvitesRequest, now time.Time) ([]string, error) { span, ctx := tracer.StartSpanFromContext(ctx, "internal.user_account.invite.SendUserInvites") defer span.Finish() @@ -42,7 +40,7 @@ func SendUserInvites(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, r } // Ensure the claims can modify the account specified in the request. - err = user_account.CanModifyAccount(ctx, claims, dbConn, req.AccountID) + err = account.CanModifyAccount(ctx, claims, repo.DbConn, req.AccountID) if err != nil { return nil, err } @@ -51,7 +49,7 @@ func SendUserInvites(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, r emailUserIDs := make(map[string]string) { // Find all users without passing in claims to search all users. - users, err := user.Find(ctx, auth.Claims{}, dbConn, user.UserFindRequest{ + users, err := repo.User.Find(ctx, auth.Claims{}, user.UserFindRequest{ Where: fmt.Sprintf("email in ('%s')", strings.Join(req.Emails, "','")), }) @@ -72,7 +70,7 @@ func SendUserInvites(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, r args = append(args, userID) } - userAccs, err := user_account.Find(ctx, claims, dbConn, user_account.UserAccountFindRequest{ + userAccs, err := repo.UserAccount.Find(ctx, claims, user_account.UserAccountFindRequest{ Where: fmt.Sprintf("user_id in ('%s') and status = '%s'", strings.Join(args, "','"), user_account.UserAccountStatus_Active.String()), @@ -99,7 +97,7 @@ func SendUserInvites(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, r continue } - u, err := user.CreateInvite(ctx, claims, dbConn, user.UserCreateInviteRequest{ + u, err := repo.User.CreateInvite(ctx, claims, user.UserCreateInviteRequest{ Email: email, }, now) if err != nil { @@ -118,7 +116,7 @@ func SendUserInvites(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, r } status := user_account.UserAccountStatus_Invited - _, err = user_account.Create(ctx, claims, dbConn, user_account.UserAccountCreateRequest{ + _, err = repo.UserAccount.Create(ctx, claims, user_account.UserAccountCreateRequest{ UserID: userID, AccountID: req.AccountID, Roles: req.Roles, @@ -133,12 +131,12 @@ func SendUserInvites(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, r req.TTL = time.Minute * 90 } - fromUser, err := user.ReadByID(ctx, claims, dbConn, req.UserID) + fromUser, err := repo.User.ReadByID(ctx, claims, req.UserID) if err != nil { return nil, err } - account, err := account.ReadByID(ctx, claims, dbConn, req.AccountID) + account, err := repo.Account.ReadByID(ctx, claims, req.AccountID) if err != nil { return nil, err } @@ -151,7 +149,7 @@ func SendUserInvites(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, r var inviteHashes []string for email, userID := range emailUserIDs { - hash, err := NewInviteHash(ctx, secretKey, userID, req.AccountID, requestIp, req.TTL, now) + hash, err := NewInviteHash(ctx, repo.secretKey, userID, req.AccountID, requestIp, req.TTL, now) if err != nil { return nil, err } @@ -159,13 +157,13 @@ func SendUserInvites(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, r data := map[string]interface{}{ "FromUser": fromUser.Response(ctx), "Account": account.Response(ctx), - "Url": resetUrl(hash), + "Url": repo.ResetUrl(hash), "Minutes": req.TTL.Minutes(), } subject := fmt.Sprintf("%s %s has invited you to %s", fromUser.FirstName, fromUser.LastName, account.Name) - err = notify.Send(ctx, email, subject, "user_invite", data) + err = repo.Notify.Send(ctx, email, subject, "user_invite", data) if err != nil { err = errors.WithMessagef(err, "Send invite to %s failed.", email) return nil, err @@ -178,7 +176,7 @@ func SendUserInvites(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, r } // AcceptInvite updates the user using the provided invite hash. -func AcceptInvite(ctx context.Context, dbConn *sqlx.DB, req AcceptInviteRequest, secretKey string, now time.Time) (*user_account.UserAccount, error) { +func (repo *Repository) AcceptInvite(ctx context.Context, req AcceptInviteRequest, now time.Time) (*user_account.UserAccount, error) { span, ctx := tracer.StartSpanFromContext(ctx, "internal.user_account.invite.AcceptInvite") defer span.Finish() @@ -190,25 +188,25 @@ func AcceptInvite(ctx context.Context, dbConn *sqlx.DB, req AcceptInviteRequest, return nil, err } - hash, err := ParseInviteHash(ctx, req.InviteHash, secretKey, now) + hash, err := ParseInviteHash(ctx, req.InviteHash, repo.secretKey, now) if err != nil { return nil, err } - u, err := user.Read(ctx, auth.Claims{}, dbConn, + u, err := repo.User.Read(ctx, auth.Claims{}, user.UserReadRequest{ID: hash.UserID, IncludeArchived: true}) if err != nil { return nil, err } if u.ArchivedAt != nil && !u.ArchivedAt.Time.IsZero() { - err = user.Restore(ctx, auth.Claims{}, dbConn, user.UserRestoreRequest{ID: hash.UserID}, now) + err = repo.User.Restore(ctx, auth.Claims{}, user.UserRestoreRequest{ID: hash.UserID}, now) if err != nil { return nil, err } } - usrAcc, err := user_account.Read(ctx, auth.Claims{}, dbConn, user_account.UserAccountReadRequest{ + usrAcc, err := repo.UserAccount.Read(ctx, auth.Claims{}, user_account.UserAccountReadRequest{ UserID: hash.UserID, AccountID: hash.AccountID, }) @@ -230,7 +228,7 @@ func AcceptInvite(ctx context.Context, dbConn *sqlx.DB, req AcceptInviteRequest, if len(u.PasswordHash) > 0 { usrAcc.Status = user_account.UserAccountStatus_Active - err = user_account.Update(ctx, auth.Claims{}, dbConn, user_account.UserAccountUpdateRequest{ + err = repo.UserAccount.Update(ctx, auth.Claims{}, user_account.UserAccountUpdateRequest{ UserID: usrAcc.UserID, AccountID: usrAcc.AccountID, Status: &usrAcc.Status, @@ -244,7 +242,7 @@ func AcceptInvite(ctx context.Context, dbConn *sqlx.DB, req AcceptInviteRequest, } // AcceptInviteUser updates the user using the provided invite hash. -func AcceptInviteUser(ctx context.Context, dbConn *sqlx.DB, req AcceptInviteUserRequest, secretKey string, now time.Time) (*user_account.UserAccount, error) { +func (repo *Repository) AcceptInviteUser(ctx context.Context, req AcceptInviteUserRequest, now time.Time) (*user_account.UserAccount, error) { span, ctx := tracer.StartSpanFromContext(ctx, "internal.user_account.invite.AcceptInviteUser") defer span.Finish() @@ -256,25 +254,25 @@ func AcceptInviteUser(ctx context.Context, dbConn *sqlx.DB, req AcceptInviteUser return nil, err } - hash, err := ParseInviteHash(ctx, req.InviteHash, secretKey, now) + hash, err := ParseInviteHash(ctx, req.InviteHash, repo.secretKey, now) if err != nil { return nil, err } - u, err := user.Read(ctx, auth.Claims{}, dbConn, + u, err := repo.User.Read(ctx, auth.Claims{}, user.UserReadRequest{ID: hash.UserID, IncludeArchived: true}) if err != nil { return nil, err } if u.ArchivedAt != nil && !u.ArchivedAt.Time.IsZero() { - err = user.Restore(ctx, auth.Claims{}, dbConn, user.UserRestoreRequest{ID: hash.UserID}, now) + err = repo.User.Restore(ctx, auth.Claims{}, user.UserRestoreRequest{ID: hash.UserID}, now) if err != nil { return nil, err } } - usrAcc, err := user_account.Read(ctx, auth.Claims{}, dbConn, user_account.UserAccountReadRequest{ + usrAcc, err := repo.UserAccount.Read(ctx, auth.Claims{}, user_account.UserAccountReadRequest{ UserID: hash.UserID, AccountID: hash.AccountID, }) @@ -293,7 +291,7 @@ func AcceptInviteUser(ctx context.Context, dbConn *sqlx.DB, req AcceptInviteUser // These three calls, user.Update, user.UpdatePassword, and user_account.Update // should probably be in a transaction! - err = user.Update(ctx, auth.Claims{}, dbConn, user.UserUpdateRequest{ + err = repo.User.Update(ctx, auth.Claims{}, user.UserUpdateRequest{ ID: hash.UserID, Email: &req.Email, FirstName: &req.FirstName, @@ -304,7 +302,7 @@ func AcceptInviteUser(ctx context.Context, dbConn *sqlx.DB, req AcceptInviteUser return nil, err } - err = user.UpdatePassword(ctx, auth.Claims{}, dbConn, user.UserUpdatePasswordRequest{ + err = repo.User.UpdatePassword(ctx, auth.Claims{}, user.UserUpdatePasswordRequest{ ID: hash.UserID, Password: req.Password, PasswordConfirm: req.PasswordConfirm, @@ -314,7 +312,7 @@ func AcceptInviteUser(ctx context.Context, dbConn *sqlx.DB, req AcceptInviteUser } usrAcc.Status = user_account.UserAccountStatus_Active - err = user_account.Update(ctx, auth.Claims{}, dbConn, user_account.UserAccountUpdateRequest{ + err = repo.UserAccount.Update(ctx, auth.Claims{}, user_account.UserAccountUpdateRequest{ UserID: usrAcc.UserID, AccountID: usrAcc.AccountID, Status: &usrAcc.Status, diff --git a/internal/user_account/invite/invite_test.go b/internal/user_account/invite/invite_test.go index 032c6e7..3018733 100644 --- a/internal/user_account/invite/invite_test.go +++ b/internal/user_account/invite/invite_test.go @@ -1,7 +1,6 @@ package invite import ( - "geeks-accelerator/oss/saas-starter-kit/internal/platform/web/webcontext" "os" "strings" "testing" @@ -11,6 +10,7 @@ import ( "geeks-accelerator/oss/saas-starter-kit/internal/platform/auth" "geeks-accelerator/oss/saas-starter-kit/internal/platform/notify" "geeks-accelerator/oss/saas-starter-kit/internal/platform/tests" + "geeks-accelerator/oss/saas-starter-kit/internal/platform/web/webcontext" "geeks-accelerator/oss/saas-starter-kit/internal/user" "geeks-accelerator/oss/saas-starter-kit/internal/user_account" "github.com/dgrijalva/jwt-go" @@ -18,7 +18,10 @@ import ( "github.com/pkg/errors" ) -var test *tests.Test +var ( + test *tests.Test + repo *Repository +) // TestMain is the entry point for testing. func TestMain(m *testing.M) { @@ -28,6 +31,20 @@ func TestMain(m *testing.M) { func testMain(m *testing.M) int { test = tests.New() defer test.TearDown() + + userRepo := user.MockRepository(test.MasterDB) + userAccRepo := user_account.NewRepository(test.MasterDB) + accRepo := account.NewRepository(test.MasterDB) + + // Mock the methods needed to make an invite. + resetUrl := func(string) string { + return "" + } + notify := ¬ify.MockEmail{} + secretKey := "6368616e676520746869732070613434" + + repo = NewRepository(test.MasterDB, userRepo, userAccRepo, accRepo, resetUrl, notify, secretKey) + return m.Run() } @@ -42,7 +59,7 @@ func TestSendUserInvites(t *testing.T) { // Create a new user for testing. initPass := uuid.NewRandom().String() - u, err := user.Create(ctx, auth.Claims{}, test.MasterDB, user.UserCreateRequest{ + u, err := repo.User.Create(ctx, auth.Claims{}, user.UserCreateRequest{ FirstName: "Lee", LastName: "Brown", Email: uuid.NewRandom().String() + "@geeksinthewoods.com", @@ -54,7 +71,7 @@ func TestSendUserInvites(t *testing.T) { t.Fatalf("\t%s\tCreate user failed.", tests.Failed) } - a, err := account.Create(ctx, auth.Claims{}, test.MasterDB, account.AccountCreateRequest{ + a, err := repo.Account.Create(ctx, auth.Claims{}, account.AccountCreateRequest{ Name: uuid.NewRandom().String(), Address1: "101 E Main", City: "Valdez", @@ -68,7 +85,7 @@ func TestSendUserInvites(t *testing.T) { } uRoles := []user_account.UserAccountRole{user_account.UserAccountRole_Admin} - _, err = user_account.Create(ctx, auth.Claims{}, test.MasterDB, user_account.UserAccountCreateRequest{ + _, err = repo.UserAccount.Create(ctx, auth.Claims{}, user_account.UserAccountCreateRequest{ UserID: u.ID, AccountID: a.ID, Roles: uRoles, @@ -91,21 +108,13 @@ func TestSendUserInvites(t *testing.T) { claims.Roles = append(claims.Roles, r.String()) } - // Mock the methods needed to make a password reset. - resetUrl := func(string) string { - return "" - } - notify := ¬ify.MockEmail{} - - secretKey := "6368616e676520746869732070617373" - // Ensure validation is working by trying ResetPassword with an empty request. { expectedErr := errors.New("Key: 'SendUserInvitesRequest.account_id' Error:Field validation for 'account_id' failed on the 'required' tag\n" + "Key: 'SendUserInvitesRequest.user_id' Error:Field validation for 'user_id' failed on the 'required' tag\n" + "Key: 'SendUserInvitesRequest.emails' Error:Field validation for 'emails' failed on the 'required' tag\n" + "Key: 'SendUserInvitesRequest.roles' Error:Field validation for 'roles' failed on the 'required' tag") - _, err = SendUserInvites(ctx, claims, test.MasterDB, resetUrl, notify, SendUserInvitesRequest{}, secretKey, now) + _, err = repo.SendUserInvites(ctx, claims, SendUserInvitesRequest{}, now) if err == nil { t.Logf("\t\tWant: %+v", expectedErr) t.Fatalf("\t%s\tInviteUsers failed.", tests.Failed) @@ -129,13 +138,13 @@ func TestSendUserInvites(t *testing.T) { } // Make the reset password request. - inviteHashes, err := SendUserInvites(ctx, claims, test.MasterDB, resetUrl, notify, SendUserInvitesRequest{ + inviteHashes, err := repo.SendUserInvites(ctx, claims, SendUserInvitesRequest{ UserID: u.ID, AccountID: a.ID, Emails: inviteEmails, Roles: []user_account.UserAccountRole{user_account.UserAccountRole_User}, TTL: ttl, - }, secretKey, now) + }, now) if err != nil { t.Log("\t\tGot :", err) t.Fatalf("\t%s\tInviteUsers failed.", tests.Failed) @@ -154,7 +163,7 @@ func TestSendUserInvites(t *testing.T) { "Key: 'AcceptInviteUserRequest.last_name' Error:Field validation for 'last_name' failed on the 'required' tag\n" + "Key: 'AcceptInviteUserRequest.password' Error:Field validation for 'password' failed on the 'required' tag\n" + "Key: 'AcceptInviteUserRequest.password_confirm' Error:Field validation for 'password_confirm' failed on the 'required' tag") - _, err = AcceptInviteUser(ctx, test.MasterDB, AcceptInviteUserRequest{}, secretKey, now) + _, err = repo.AcceptInviteUser(ctx, AcceptInviteUserRequest{}, now) if err == nil { t.Logf("\t\tWant: %+v", expectedErr) t.Fatalf("\t%s\tResetConfirm failed.", tests.Failed) @@ -174,14 +183,14 @@ func TestSendUserInvites(t *testing.T) { // Ensure the TTL is enforced. { newPass := uuid.NewRandom().String() - _, err = AcceptInviteUser(ctx, test.MasterDB, AcceptInviteUserRequest{ + _, err = repo.AcceptInviteUser(ctx, AcceptInviteUserRequest{ InviteHash: inviteHashes[0], Email: inviteEmails[0], FirstName: "Foo", LastName: "Bar", Password: newPass, PasswordConfirm: newPass, - }, secretKey, now.UTC().Add(ttl*2)) + }, now.UTC().Add(ttl*2)) if errors.Cause(err) != ErrInviteExpired { t.Logf("\t\tGot : %+v", errors.Cause(err)) t.Logf("\t\tWant: %+v", ErrInviteExpired) @@ -194,14 +203,14 @@ func TestSendUserInvites(t *testing.T) { for idx, inviteHash := range inviteHashes { newPass := uuid.NewRandom().String() - hash, err := AcceptInviteUser(ctx, test.MasterDB, AcceptInviteUserRequest{ + hash, err := repo.AcceptInviteUser(ctx, AcceptInviteUserRequest{ InviteHash: inviteHash, Email: inviteEmails[idx], FirstName: "Foo", LastName: "Bar", Password: newPass, PasswordConfirm: newPass, - }, secretKey, now) + }, now) if err != nil { t.Log("\t\tGot :", err) t.Fatalf("\t%s\tInviteAccept failed.", tests.Failed) @@ -227,14 +236,14 @@ func TestSendUserInvites(t *testing.T) { // Ensure the reset hash does not work after its used. { newPass := uuid.NewRandom().String() - _, err = AcceptInviteUser(ctx, test.MasterDB, AcceptInviteUserRequest{ + _, err = repo.AcceptInviteUser(ctx, AcceptInviteUserRequest{ InviteHash: inviteHashes[0], Email: inviteEmails[0], FirstName: "Foo", LastName: "Bar", Password: newPass, PasswordConfirm: newPass, - }, secretKey, now) + }, now) if errors.Cause(err) != ErrUserAccountActive { t.Logf("\t\tGot : %+v", errors.Cause(err)) t.Logf("\t\tWant: %+v", ErrUserAccountActive) diff --git a/internal/user_account/invite/models.go b/internal/user_account/invite/models.go index ca87007..5c7231c 100644 --- a/internal/user_account/invite/models.go +++ b/internal/user_account/invite/models.go @@ -6,12 +6,41 @@ import ( "strings" "time" + "geeks-accelerator/oss/saas-starter-kit/internal/account" + "geeks-accelerator/oss/saas-starter-kit/internal/platform/notify" "geeks-accelerator/oss/saas-starter-kit/internal/platform/web/webcontext" + "geeks-accelerator/oss/saas-starter-kit/internal/user" "geeks-accelerator/oss/saas-starter-kit/internal/user_account" + "github.com/jmoiron/sqlx" "github.com/pkg/errors" "github.com/sudo-suhas/symcrypto" ) +// Repository defines the required dependencies for User Invite. +type Repository struct { + DbConn *sqlx.DB + User *user.Repository + UserAccount *user_account.Repository + Account *account.Repository + ResetUrl func(string) string + Notify notify.Email + secretKey string +} + +// NewRepository creates a new Repository that defines dependencies for User Invite. +func NewRepository(db *sqlx.DB, user *user.Repository, userAccount *user_account.Repository, account *account.Repository, + resetUrl func(string) string, notify notify.Email, secretKey string) *Repository { + return &Repository{ + DbConn: db, + User: user, + UserAccount: userAccount, + Account: account, + ResetUrl: resetUrl, + Notify: notify, + secretKey: secretKey, + } +} + // SendUserInvitesRequest defines the data needed to make an invite request. type SendUserInvitesRequest struct { AccountID string `json:"account_id" validate:"required,uuid" example:"c4653bf9-5978-48b7-89c5-95704aebb7e2"` diff --git a/internal/user_account/models.go b/internal/user_account/models.go index e2131a4..df16106 100644 --- a/internal/user_account/models.go +++ b/internal/user_account/models.go @@ -2,13 +2,13 @@ package user_account import ( "context" - "database/sql/driver" - "github.com/jmoiron/sqlx" "strings" "time" + "database/sql/driver" "geeks-accelerator/oss/saas-starter-kit/internal/platform/auth" "geeks-accelerator/oss/saas-starter-kit/internal/platform/web" + "github.com/jmoiron/sqlx" "github.com/lib/pq" "github.com/pkg/errors" "gopkg.in/go-playground/validator.v9" diff --git a/internal/user_account/user_account.go b/internal/user_account/user_account.go index 4fde70f..e436f6e 100644 --- a/internal/user_account/user_account.go +++ b/internal/user_account/user_account.go @@ -3,12 +3,12 @@ package user_account import ( "context" "database/sql" - "geeks-accelerator/oss/saas-starter-kit/internal/user" "time" "geeks-accelerator/oss/saas-starter-kit/internal/account" "geeks-accelerator/oss/saas-starter-kit/internal/platform/auth" "geeks-accelerator/oss/saas-starter-kit/internal/platform/web/webcontext" + "geeks-accelerator/oss/saas-starter-kit/internal/user" "github.com/huandu/go-sqlbuilder" "github.com/jmoiron/sqlx" "github.com/pborman/uuid" @@ -50,13 +50,13 @@ func mapRowsToUserAccount(rows *sql.Rows) (*UserAccount, error) { // CanReadAccount determines if claims has the authority to access the specified user account by user ID. func (repo *Repository) CanReadAccount(ctx context.Context, claims auth.Claims, accountID string) error { - err := account.CanReadAccount(ctx, claims, accountID) + err := account.CanReadAccount(ctx, claims, repo.DbConn, accountID) return mapAccountError(err) } // CanModifyAccount determines if claims has the authority to modify the specified user ID. func (repo *Repository) CanModifyAccount(ctx context.Context, claims auth.Claims, accountID string) error { - err := account.CanModifyAccount(ctx, claims, accountID) + err := account.CanModifyAccount(ctx, claims, repo.DbConn, accountID) return mapAccountError(err) } @@ -131,9 +131,9 @@ func findRequestQuery(req UserAccountFindRequest) (*sqlbuilder.SelectBuilder, [] } // Find gets all the user accounts from the database based on the request params. -func Find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserAccountFindRequest) (UserAccounts, error) { +func (repo *Repository) Find(ctx context.Context, claims auth.Claims, req UserAccountFindRequest) (UserAccounts, error) { query, args := findRequestQuery(req) - return find(ctx, claims, dbConn, query, args, req.IncludeArchived) + return find(ctx, claims, repo.DbConn, query, args, req.IncludeArchived) } // Find gets all the user accounts from the database based on the select query @@ -180,7 +180,7 @@ func find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, query *sqlbu } // Retrieve gets the specified user from the database. -func FindByUserID(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, userID string, includedArchived bool) (UserAccounts, error) { +func (repo *Repository) FindByUserID(ctx context.Context, claims auth.Claims, userID string, includedArchived bool) (UserAccounts, error) { span, ctx := tracer.StartSpanFromContext(ctx, "internal.user_account.FindByUserID") defer span.Finish() @@ -190,7 +190,7 @@ func FindByUserID(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, user query.OrderBy("created_at") // Execute the find accounts method. - res, err := find(ctx, claims, dbConn, query, []interface{}{}, includedArchived) + res, err := find(ctx, claims, repo.DbConn, query, []interface{}{}, includedArchived) if err != nil { return nil, err } else if res == nil || len(res) == 0 { @@ -202,7 +202,7 @@ func FindByUserID(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, user } // Create a user account for a given user with specified roles. -func Create(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserAccountCreateRequest, now time.Time) (*UserAccount, error) { +func (repo *Repository) Create(ctx context.Context, claims auth.Claims, req UserAccountCreateRequest, now time.Time) (*UserAccount, error) { span, ctx := tracer.StartSpanFromContext(ctx, "internal.user_account.Create") defer span.Finish() @@ -214,7 +214,7 @@ func Create(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserAc } // Ensure the claims can modify the account specified in the request. - err = CanModifyAccount(ctx, claims, dbConn, req.AccountID) + err = repo.CanModifyAccount(ctx, claims, req.AccountID) if err != nil { return nil, err } @@ -237,7 +237,7 @@ func Create(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserAc existQuery.Equal("account_id", req.AccountID), existQuery.Equal("user_id", req.UserID), )) - existing, err := find(ctx, claims, dbConn, existQuery, []interface{}{}, true) + existing, err := find(ctx, claims, repo.DbConn, existQuery, []interface{}{}, true) if err != nil { return nil, err } @@ -251,7 +251,7 @@ func Create(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserAc Roles: &req.Roles, unArchive: true, } - err = Update(ctx, claims, dbConn, upReq, now) + err = repo.Update(ctx, claims, upReq, now) if err != nil { return nil, err } @@ -285,8 +285,8 @@ func Create(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserAc // Execute the query with the provided context. sql, args := query.Build() - sql = dbConn.Rebind(sql) - _, err = dbConn.ExecContext(ctx, sql, args...) + sql = repo.DbConn.Rebind(sql) + _, err = repo.DbConn.ExecContext(ctx, sql, args...) if err != nil { err = errors.Wrapf(err, "query - %s", query.String()) err = errors.WithMessagef(err, "add account %s to user %s failed", req.AccountID, req.UserID) @@ -298,7 +298,7 @@ func Create(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserAc } // Read gets the specified user account from the database. -func Read(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserAccountReadRequest) (*UserAccount, error) { +func (repo *Repository) Read(ctx context.Context, claims auth.Claims, req UserAccountReadRequest) (*UserAccount, error) { span, ctx := tracer.StartSpanFromContext(ctx, "internal.user_account.Read") defer span.Finish() @@ -315,7 +315,7 @@ func Read(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserAcco query.Equal("user_id", req.UserID), query.Equal("account_id", req.AccountID))) - res, err := find(ctx, claims, dbConn, query, []interface{}{}, req.IncludeArchived) + res, err := find(ctx, claims, repo.DbConn, query, []interface{}{}, req.IncludeArchived) if err != nil { return nil, err } else if res == nil || len(res) == 0 { @@ -328,7 +328,7 @@ func Read(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserAcco } // Update replaces a user account in the database. -func Update(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserAccountUpdateRequest, now time.Time) error { +func (repo *Repository) Update(ctx context.Context, claims auth.Claims, req UserAccountUpdateRequest, now time.Time) error { span, ctx := tracer.StartSpanFromContext(ctx, "internal.user_account.Update") defer span.Finish() @@ -340,7 +340,7 @@ func Update(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserAc } // Ensure the claims can modify the user specified in the request. - err = CanModifyAccount(ctx, claims, dbConn, req.AccountID) + err = repo.CanModifyAccount(ctx, claims, req.AccountID) if err != nil { return err } @@ -389,8 +389,8 @@ func Update(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserAc // Execute the query with the provided context. sql, args := query.Build() - sql = dbConn.Rebind(sql) - _, err = dbConn.ExecContext(ctx, sql, args...) + sql = repo.DbConn.Rebind(sql) + _, err = repo.DbConn.ExecContext(ctx, sql, args...) if err != nil { err = errors.Wrapf(err, "query - %s", query.String()) err = errors.WithMessagef(err, "update account %s for user %s failed", req.AccountID, req.UserID) @@ -401,7 +401,7 @@ func Update(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserAc } // Archive soft deleted the user account from the database. -func Archive(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserAccountArchiveRequest, now time.Time) error { +func (repo *Repository) Archive(ctx context.Context, claims auth.Claims, req UserAccountArchiveRequest, now time.Time) error { span, ctx := tracer.StartSpanFromContext(ctx, "internal.user_account.Archive") defer span.Finish() @@ -413,7 +413,7 @@ func Archive(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserA } // Ensure the claims can modify the user specified in the request. - err = CanModifyAccount(ctx, claims, dbConn, req.AccountID) + err = repo.CanModifyAccount(ctx, claims, req.AccountID) if err != nil { return err } @@ -441,8 +441,8 @@ func Archive(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserA // Execute the query with the provided context. sql, args := query.Build() - sql = dbConn.Rebind(sql) - _, err = dbConn.ExecContext(ctx, sql, args...) + sql = repo.DbConn.Rebind(sql) + _, err = repo.DbConn.ExecContext(ctx, sql, args...) if err != nil { err = errors.Wrapf(err, "query - %s", query.String()) err = errors.WithMessagef(err, "archive account %s from user %s failed", req.AccountID, req.UserID) @@ -453,7 +453,7 @@ func Archive(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserA } // Delete removes a user account from the database. -func Delete(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserAccountDeleteRequest) error { +func (repo *Repository) Delete(ctx context.Context, claims auth.Claims, req UserAccountDeleteRequest) error { span, ctx := tracer.StartSpanFromContext(ctx, "internal.user_account.Delete") defer span.Finish() @@ -465,7 +465,7 @@ func Delete(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserAc } // Ensure the claims can modify the user specified in the request. - err = CanModifyAccount(ctx, claims, dbConn, req.AccountID) + err = repo.CanModifyAccount(ctx, claims, req.AccountID) if err != nil { return err } @@ -480,8 +480,8 @@ func Delete(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserAc // Execute the query with the provided context. sql, args := query.Build() - sql = dbConn.Rebind(sql) - _, err = dbConn.ExecContext(ctx, sql, args...) + sql = repo.DbConn.Rebind(sql) + _, err = repo.DbConn.ExecContext(ctx, sql, args...) if err != nil { err = errors.Wrapf(err, "query - %s", query.String()) err = errors.WithMessagef(err, "delete account %s for user %s failed", req.AccountID, req.UserID) @@ -509,6 +509,10 @@ func MockUserAccount(ctx context.Context, dbConn *sqlx.DB, now time.Time, roles return nil, err } + repo := &Repository{ + DbConn: dbConn, + } + status := UserAccountStatus_Active req := UserAccountCreateRequest{ @@ -517,7 +521,7 @@ func MockUserAccount(ctx context.Context, dbConn *sqlx.DB, now time.Time, roles Status: &status, Roles: roles, } - ua, err := Create(ctx, auth.Claims{}, dbConn, req, now) + ua, err := repo.Create(ctx, auth.Claims{}, req, now) if err != nil { return nil, err } diff --git a/internal/user_account/user_account_test.go b/internal/user_account/user_account_test.go index 2728273..eb88466 100644 --- a/internal/user_account/user_account_test.go +++ b/internal/user_account/user_account_test.go @@ -1,7 +1,6 @@ package user_account import ( - "github.com/lib/pq" "math/rand" "os" "strings" @@ -13,6 +12,7 @@ import ( "github.com/dgrijalva/jwt-go" "github.com/google/go-cmp/cmp" "github.com/huandu/go-sqlbuilder" + "github.com/lib/pq" "github.com/pborman/uuid" "github.com/pkg/errors" ) @@ -232,7 +232,7 @@ func TestCreateValidation(t *testing.T) { t.Fatalf("\t%s\tMock account failed.", tests.Failed) } - res, err := Create(ctx, auth.Claims{}, test.MasterDB, tt.req, now) + res, err := repo.Create(ctx, auth.Claims{}, tt.req, now) if err != tt.error { // TODO: need a better way to handle validation errors as they are // of type interface validator.ValidationErrorsTranslations @@ -300,7 +300,7 @@ func TestCreateExistingEntry(t *testing.T) { AccountID: accountID, Roles: []UserAccountRole{UserAccountRole_User}, } - ua1, err := Create(ctx, auth.Claims{}, test.MasterDB, req1, now) + ua1, err := repo.Create(ctx, auth.Claims{}, req1, now) if err != nil { t.Log("\t\tGot :", err) t.Fatalf("\t%s\tCreate user account failed.", tests.Failed) @@ -313,7 +313,7 @@ func TestCreateExistingEntry(t *testing.T) { AccountID: req1.AccountID, Roles: []UserAccountRole{UserAccountRole_Admin}, } - ua2, err := Create(ctx, auth.Claims{}, test.MasterDB, req2, now) + ua2, err := repo.Create(ctx, auth.Claims{}, req2, now) if err != nil { t.Log("\t\tGot :", err) t.Fatalf("\t%s\tCreate user account failed.", tests.Failed) @@ -322,7 +322,7 @@ func TestCreateExistingEntry(t *testing.T) { } // Now archive the user account to test trying to create a new entry for an archived entry - err = Archive(tests.Context(), auth.Claims{}, test.MasterDB, UserAccountArchiveRequest{ + err = repo.Archive(tests.Context(), auth.Claims{}, UserAccountArchiveRequest{ UserID: req1.UserID, AccountID: req1.AccountID, }, now) @@ -332,7 +332,7 @@ func TestCreateExistingEntry(t *testing.T) { } // Find the archived user account - arcRes, err := Read(tests.Context(), auth.Claims{}, test.MasterDB, + arcRes, err := repo.Read(tests.Context(), auth.Claims{}, UserAccountReadRequest{UserID: req1.UserID, AccountID: req1.AccountID, IncludeArchived: true}) if err != nil || arcRes == nil { t.Log("\t\tGot :", err) @@ -347,7 +347,7 @@ func TestCreateExistingEntry(t *testing.T) { AccountID: req1.AccountID, Roles: []UserAccountRole{UserAccountRole_User}, } - ua3, err := Create(ctx, auth.Claims{}, test.MasterDB, req3, now) + ua3, err := repo.Create(ctx, auth.Claims{}, req3, now) if err != nil { t.Log("\t\tGot :", err) t.Fatalf("\t%s\tCreate user account failed.", tests.Failed) @@ -356,7 +356,7 @@ func TestCreateExistingEntry(t *testing.T) { } // Ensure the user account has archived_at empty - findRes, err := Read(tests.Context(), auth.Claims{}, test.MasterDB, + findRes, err := repo.Read(tests.Context(), auth.Claims{}, UserAccountReadRequest{UserID: req1.UserID, AccountID: req1.AccountID}) if err != nil || arcRes == nil { t.Log("\t\tGot :", err) @@ -414,7 +414,7 @@ func TestUpdateValidation(t *testing.T) { { ctx := tests.Context() - err := Update(ctx, auth.Claims{}, test.MasterDB, tt.req, now) + err := repo.Update(ctx, auth.Claims{}, tt.req, now) if err != tt.error { // TODO: need a better way to handle validation errors as they are // of type interface validator.ValidationErrorsTranslations @@ -564,7 +564,7 @@ func TestCrud(t *testing.T) { AccountID: accountID, Roles: []UserAccountRole{UserAccountRole_User}, } - ua, err := Create(tests.Context(), tt.claims(userID, accountID), test.MasterDB, createReq, now) + ua, err := repo.Create(tests.Context(), tt.claims(userID, accountID), createReq, now) if err != nil && errors.Cause(err) != tt.createErr { t.Logf("\t\tGot : %+v", err) t.Logf("\t\tWant: %+v", tt.createErr) @@ -577,7 +577,7 @@ func TestCrud(t *testing.T) { } if tt.createErr == ErrForbidden { - ua, err = Create(tests.Context(), auth.Claims{}, test.MasterDB, createReq, now) + ua, err = repo.Create(tests.Context(), auth.Claims{}, createReq, now) if err != nil && errors.Cause(err) != tt.createErr { t.Logf("\t\tGot : %+v", err) t.Fatalf("\t%s\tCreate user account failed.", tests.Failed) @@ -590,7 +590,7 @@ func TestCrud(t *testing.T) { AccountID: accountID, Roles: &UserAccountRoles{UserAccountRole_Admin}, } - err = Update(tests.Context(), tt.claims(userID, accountID), test.MasterDB, updateReq, now) + err = repo.Update(tests.Context(), tt.claims(userID, accountID), updateReq, now) if err != nil { if errors.Cause(err) != tt.updateErr { t.Logf("\t\tGot : %+v", err) @@ -604,7 +604,7 @@ func TestCrud(t *testing.T) { // Find the account for the user to verify the updates where made. There should only // be one account associated with the user for this test. - findRes, err := Find(tests.Context(), tt.claims(userID, accountID), test.MasterDB, UserAccountFindRequest{ + findRes, err := repo.Find(tests.Context(), tt.claims(userID, accountID), UserAccountFindRequest{ Where: "user_id = ? or account_id = ?", Args: []interface{}{userID, accountID}, Order: []string{"created_at"}, @@ -632,7 +632,7 @@ func TestCrud(t *testing.T) { } // Archive (soft-delete) the user account. - err = Archive(tests.Context(), tt.claims(userID, accountID), test.MasterDB, UserAccountArchiveRequest{ + err = repo.Archive(tests.Context(), tt.claims(userID, accountID), UserAccountArchiveRequest{ UserID: userID, AccountID: accountID, }, now) @@ -642,7 +642,7 @@ func TestCrud(t *testing.T) { t.Fatalf("\t%s\tArchive user account failed.", tests.Failed) } else if tt.updateErr == nil { // Trying to find the archived user with the includeArchived false should result in not found. - _, err = FindByUserID(tests.Context(), tt.claims(userID, accountID), test.MasterDB, userID, false) + _, err = repo.FindByUserID(tests.Context(), tt.claims(userID, accountID), userID, false) if errors.Cause(err) != ErrNotFound { t.Logf("\t\tGot : %+v", err) t.Logf("\t\tWant: %+v", ErrNotFound) @@ -650,7 +650,7 @@ func TestCrud(t *testing.T) { } // Trying to find the archived user with the includeArchived true should result no error. - findRes, err = FindByUserID(tests.Context(), tt.claims(userID, accountID), test.MasterDB, userID, true) + findRes, err = repo.FindByUserID(tests.Context(), tt.claims(userID, accountID), userID, true) if err != nil { t.Logf("\t\tGot : %+v", err) t.Fatalf("\t%s\tVerify archive user account failed when including archived.", tests.Failed) @@ -675,7 +675,7 @@ func TestCrud(t *testing.T) { t.Logf("\t%s\tArchive user account ok.", tests.Success) // Delete (hard-delete) the user account. - err = Delete(tests.Context(), tt.claims(userID, accountID), test.MasterDB, UserAccountDeleteRequest{ + err = repo.Delete(tests.Context(), tt.claims(userID, accountID), UserAccountDeleteRequest{ UserID: userID, AccountID: accountID, }) @@ -685,7 +685,7 @@ func TestCrud(t *testing.T) { t.Fatalf("\t%s\tDelete user account failed.", tests.Failed) } else if tt.updateErr == nil { // Trying to find the deleted user with the includeArchived true should result in not found. - _, err = FindByUserID(tests.Context(), tt.claims(userID, accountID), test.MasterDB, userID, true) + _, err = repo.FindByUserID(tests.Context(), tt.claims(userID, accountID), userID, true) if errors.Cause(err) != ErrNotFound { t.Logf("\t\tGot : %+v", err) t.Logf("\t\tWant: %+v", ErrNotFound) @@ -725,7 +725,7 @@ func TestFind(t *testing.T) { } // Execute Create that will associate the user with the account. - ua, err := Create(tests.Context(), auth.Claims{}, test.MasterDB, UserAccountCreateRequest{ + ua, err := repo.Create(tests.Context(), auth.Claims{}, UserAccountCreateRequest{ UserID: userID, AccountID: accountID, Roles: []UserAccountRole{UserAccountRole_User}, @@ -836,7 +836,7 @@ func TestFind(t *testing.T) { { ctx := tests.Context() - res, err := Find(ctx, auth.Claims{}, test.MasterDB, tt.req) + res, err := repo.Find(ctx, auth.Claims{}, tt.req) if errors.Cause(err) != tt.error { t.Logf("\t\tGot : %+v", err) t.Logf("\t\tWant: %+v", tt.error) diff --git a/internal/user_auth/auth.go b/internal/user_auth/auth.go index 34bf1ca..4e90b10 100644 --- a/internal/user_auth/auth.go +++ b/internal/user_auth/auth.go @@ -3,7 +3,6 @@ package user_auth import ( "context" "database/sql" - "geeks-accelerator/oss/saas-starter-kit/internal/user_account" "strings" "time" @@ -11,8 +10,8 @@ import ( "geeks-accelerator/oss/saas-starter-kit/internal/platform/auth" "geeks-accelerator/oss/saas-starter-kit/internal/platform/web/webcontext" "geeks-accelerator/oss/saas-starter-kit/internal/user" + "geeks-accelerator/oss/saas-starter-kit/internal/user_account" "github.com/huandu/go-sqlbuilder" - "github.com/jmoiron/sqlx" "github.com/lib/pq" "github.com/pkg/errors" "golang.org/x/crypto/bcrypt" @@ -40,7 +39,7 @@ const ( // Authenticate finds a user by their email and verifies their password. On success // it returns a Token that can be used to authenticate access to the application in // the future. -func Authenticate(ctx context.Context, dbConn *sqlx.DB, tknGen TokenGenerator, req AuthenticateRequest, expires time.Duration, now time.Time, scopes ...string) (Token, error) { +func (repo *Repository) Authenticate(ctx context.Context, req AuthenticateRequest, expires time.Duration, now time.Time, scopes ...string) (Token, error) { span, ctx := tracer.StartSpanFromContext(ctx, "internal.user_auth.Authenticate") defer span.Finish() @@ -51,7 +50,7 @@ func Authenticate(ctx context.Context, dbConn *sqlx.DB, tknGen TokenGenerator, r return Token{}, err } - u, err := user.ReadByEmail(ctx, auth.Claims{}, dbConn, req.Email, false) + u, err := repo.User.ReadByEmail(ctx, auth.Claims{}, req.Email, false) if err != nil { if errors.Cause(err) == user.ErrNotFound { err = errors.WithStack(ErrAuthenticationFailure) @@ -73,11 +72,11 @@ func Authenticate(ctx context.Context, dbConn *sqlx.DB, tknGen TokenGenerator, r } // The user is successfully authenticated with the supplied email and password. - return generateToken(ctx, dbConn, tknGen, auth.Claims{}, u.ID, req.AccountID, expires, now, scopes...) + return repo.generateToken(ctx, auth.Claims{}, u.ID, req.AccountID, expires, now, scopes...) } // SwitchAccount allows users to switch between multiple accounts, this changes the claim audience. -func SwitchAccount(ctx context.Context, dbConn *sqlx.DB, tknGen TokenGenerator, claims auth.Claims, req SwitchAccountRequest, expires time.Duration, now time.Time, scopes ...string) (Token, error) { +func (repo *Repository) SwitchAccount(ctx context.Context, claims auth.Claims, req SwitchAccountRequest, expires time.Duration, now time.Time, scopes ...string) (Token, error) { span, ctx := tracer.StartSpanFromContext(ctx, "internal.user_auth.SwitchAccount") defer span.Finish() @@ -97,11 +96,11 @@ func SwitchAccount(ctx context.Context, dbConn *sqlx.DB, tknGen TokenGenerator, // Generate a token for the user ID in supplied in claims as the Subject. Pass // in the supplied claims as well to enforce ACLs when finding the current // list of accounts for the user. - return generateToken(ctx, dbConn, tknGen, claims, claims.Subject, req.AccountID, expires, now, scopes...) + return repo.generateToken(ctx, claims, claims.Subject, req.AccountID, expires, now, scopes...) } // VirtualLogin allows users to mock being logged in as other users. -func VirtualLogin(ctx context.Context, dbConn *sqlx.DB, tknGen TokenGenerator, claims auth.Claims, req VirtualLoginRequest, expires time.Duration, now time.Time, scopes ...string) (Token, error) { +func (repo *Repository) VirtualLogin(ctx context.Context, claims auth.Claims, req VirtualLoginRequest, expires time.Duration, now time.Time, scopes ...string) (Token, error) { span, ctx := tracer.StartSpanFromContext(ctx, "internal.user_auth.VirtualLogin") defer span.Finish() @@ -113,7 +112,7 @@ func VirtualLogin(ctx context.Context, dbConn *sqlx.DB, tknGen TokenGenerator, c } // Find all the accounts that the current user has access to. - usrAccs, err := user_account.FindByUserID(ctx, claims, dbConn, claims.Subject, false) + usrAccs, err := repo.UserAccount.FindByUserID(ctx, claims, claims.Subject, false) if err != nil { return Token{}, err } @@ -142,23 +141,23 @@ func VirtualLogin(ctx context.Context, dbConn *sqlx.DB, tknGen TokenGenerator, c // Generate a token for the user ID in supplied in claims as the Subject. Pass // in the supplied claims as well to enforce ACLs when finding the current // list of accounts for the user. - return generateToken(ctx, dbConn, tknGen, claims, req.UserID, req.AccountID, expires, now, scopes...) + return repo.generateToken(ctx, claims, req.UserID, req.AccountID, expires, now, scopes...) } // VirtualLogout allows switch back to their root user/account. -func VirtualLogout(ctx context.Context, dbConn *sqlx.DB, tknGen TokenGenerator, claims auth.Claims, expires time.Duration, now time.Time, scopes ...string) (Token, error) { +func (repo *Repository) VirtualLogout(ctx context.Context, claims auth.Claims, expires time.Duration, now time.Time, scopes ...string) (Token, error) { span, ctx := tracer.StartSpanFromContext(ctx, "internal.user_auth.VirtualLogout") defer span.Finish() // Generate a token for the user ID in supplied in claims as the Subject. Pass // in the supplied claims as well to enforce ACLs when finding the current // list of accounts for the user. - return generateToken(ctx, dbConn, tknGen, claims, claims.RootUserID, claims.RootAccountID, expires, now, scopes...) + return repo.generateToken(ctx, claims, claims.RootUserID, claims.RootAccountID, expires, now, scopes...) } // generateToken generates claims for the supplied user ID and account ID and then // returns the token for the generated claims used for authentication. -func generateToken(ctx context.Context, dbConn *sqlx.DB, tknGen TokenGenerator, claims auth.Claims, userID, accountID string, expires time.Duration, now time.Time, scopes ...string) (Token, error) { +func (repo *Repository) generateToken(ctx context.Context, claims auth.Claims, userID, accountID string, expires time.Duration, now time.Time, scopes ...string) (Token, error) { type userAccount struct { AccountID string @@ -184,8 +183,8 @@ func generateToken(ctx context.Context, dbConn *sqlx.DB, tknGen TokenGenerator, // fetch all places from the db queryStr, queryArgs := query.Build() - queryStr = dbConn.Rebind(queryStr) - rows, err := dbConn.QueryContext(ctx, queryStr, queryArgs...) + queryStr = repo.DbConn.Rebind(queryStr) + rows, err := repo.DbConn.QueryContext(ctx, queryStr, queryArgs...) if err != nil { err = errors.Wrapf(err, "query - %s", query.String()) return nil, err @@ -339,7 +338,7 @@ func generateToken(ctx context.Context, dbConn *sqlx.DB, tknGen TokenGenerator, tz, _ = time.LoadLocation(account.AccountTimezone.String) } - prefs, err := account_preference.FindByAccountID(ctx, auth.Claims{}, dbConn, account_preference.AccountPreferenceFindByAccountIDRequest{ + prefs, err := repo.AccountPreference.FindByAccountID(ctx, auth.Claims{}, account_preference.AccountPreferenceFindByAccountIDRequest{ AccountID: accountID, }) if err != nil { @@ -393,7 +392,7 @@ func generateToken(ctx context.Context, dbConn *sqlx.DB, tknGen TokenGenerator, newClaims.RootUserID = claims.RootUserID // Generate a token for the user with the defined claims. - tknStr, err := tknGen.GenerateToken(newClaims) + tknStr, err := repo.TknGen.GenerateToken(newClaims) if err != nil { return Token{}, errors.Wrap(err, "generating token") } diff --git a/internal/user_auth/auth_test.go b/internal/user_auth/auth_test.go index db837b6..e7fb61a 100644 --- a/internal/user_auth/auth_test.go +++ b/internal/user_auth/auth_test.go @@ -8,8 +8,8 @@ import ( "time" "geeks-accelerator/oss/saas-starter-kit/internal/account" + "geeks-accelerator/oss/saas-starter-kit/internal/account/account_preference" "geeks-accelerator/oss/saas-starter-kit/internal/platform/auth" - "geeks-accelerator/oss/saas-starter-kit/internal/platform/notify" "geeks-accelerator/oss/saas-starter-kit/internal/platform/tests" "geeks-accelerator/oss/saas-starter-kit/internal/user" "geeks-accelerator/oss/saas-starter-kit/internal/user_account" @@ -18,7 +18,10 @@ import ( "github.com/pkg/errors" ) -var test *tests.Test +var ( + test *tests.Test + repo *Repository +) // TestMain is the entry point for testing. func TestMain(m *testing.M) { @@ -28,6 +31,15 @@ func TestMain(m *testing.M) { func testMain(m *testing.M) int { test = tests.New() defer test.TearDown() + + tknGen := &auth.MockTokenGenerator{} + + userRepo := user.MockRepository(test.MasterDB) + userAccRepo := user_account.NewRepository(test.MasterDB) + accPrefRepo := account_preference.NewRepository(test.MasterDB) + + repo = NewRepository(test.MasterDB, tknGen, userRepo, userAccRepo, accPrefRepo) + return m.Run() } @@ -41,14 +53,12 @@ func TestAuthenticate(t *testing.T) { { ctx := tests.Context() - tknGen := &auth.MockTokenGenerator{} - // Auth tokens are valid for an our and is verified against current time. // Issue the token one hour ago. now := time.Now().Add(time.Hour * -1) // Try to authenticate an invalid user. - _, err := Authenticate(ctx, test.MasterDB, tknGen, + _, err := repo.Authenticate(ctx, AuthenticateRequest{ Email: "doesnotexist@gmail.com", Password: "xy7", @@ -82,7 +92,7 @@ func TestAuthenticate(t *testing.T) { // is always greater than the first user_account entry created so it will // be returned consistently back in the same order, last. account2Role := auth.RoleUser - _, err = user_account.Create(ctx, auth.Claims{}, test.MasterDB, user_account.UserAccountCreateRequest{ + _, err = repo.UserAccount.Create(ctx, auth.Claims{}, user_account.UserAccountCreateRequest{ UserID: usrAcc.UserID, AccountID: acc2.ID, Roles: []user_account.UserAccountRole{user_account.UserAccountRole(account2Role)}, @@ -92,7 +102,7 @@ func TestAuthenticate(t *testing.T) { now = now.Add(time.Minute * 5) // Try to authenticate valid user with invalid password. - _, err = Authenticate(ctx, test.MasterDB, tknGen, + _, err = repo.Authenticate(ctx, AuthenticateRequest{ Email: usrAcc.User.Email, Password: "xy7", @@ -106,7 +116,7 @@ func TestAuthenticate(t *testing.T) { t.Logf("\t%s\tAuthenticate user w/invalid password ok.", tests.Success) // Verify that the user can be authenticated with the created user. - tkn1, err := Authenticate(ctx, test.MasterDB, tknGen, + tkn1, err := repo.Authenticate(ctx, AuthenticateRequest{ Email: usrAcc.User.Email, Password: usrAcc.User.Password, @@ -118,7 +128,7 @@ func TestAuthenticate(t *testing.T) { t.Logf("\t%s\tAuthenticate user ok.", tests.Success) // Ensure the token string was correctly generated. - claims1, err := tknGen.ParseClaims(tkn1.AccessToken) + claims1, err := repo.TknGen.ParseClaims(tkn1.AccessToken) if err != nil { t.Log("\t\tGot :", err) t.Fatalf("\t%s\tParse claims from token failed.", tests.Failed) @@ -135,7 +145,7 @@ func TestAuthenticate(t *testing.T) { t.Logf("\t%s\tAuthenticate parse claims from token ok.", tests.Success) // Try switching to a second account using the first set of claims. - tkn2, err := SwitchAccount(ctx, test.MasterDB, tknGen, claims1, + tkn2, err := repo.SwitchAccount(ctx, claims1, SwitchAccountRequest{AccountID: acc2.ID}, time.Hour, now) if err != nil { t.Log("\t\tGot :", err) @@ -144,7 +154,7 @@ func TestAuthenticate(t *testing.T) { t.Logf("\t%s\tSwitchAccount user ok.", tests.Success) // Ensure the token string was correctly generated. - claims2, err := tknGen.ParseClaims(tkn2.AccessToken) + claims2, err := repo.TknGen.ParseClaims(tkn2.AccessToken) if err != nil { t.Log("\t\tGot :", err) t.Fatalf("\t%s\tParse claims from token failed.", tests.Failed) @@ -172,8 +182,6 @@ func TestUserUpdatePassword(t *testing.T) { now := time.Date(2018, time.October, 1, 0, 0, 0, 0, time.UTC) - tknGen := &auth.MockTokenGenerator{} - // Create a new user for testing. usrAcc, err := user_account.MockUserAccount(ctx, test.MasterDB, now, user_account.UserAccountRole_User) if err != nil { @@ -183,7 +191,7 @@ func TestUserUpdatePassword(t *testing.T) { t.Logf("\t%s\tCreate user account ok.", tests.Success) // Verify that the user can be authenticated with the created user. - _, err = Authenticate(ctx, test.MasterDB, tknGen, + _, err = repo.Authenticate(ctx, AuthenticateRequest{ Email: usrAcc.User.Email, Password: usrAcc.User.Password, @@ -195,7 +203,7 @@ func TestUserUpdatePassword(t *testing.T) { // Update the users password. newPass := uuid.NewRandom().String() - err = user.UpdatePassword(ctx, auth.Claims{}, test.MasterDB, user.UserUpdatePasswordRequest{ + err = repo.User.UpdatePassword(ctx, auth.Claims{}, user.UserUpdatePasswordRequest{ ID: usrAcc.UserID, Password: newPass, PasswordConfirm: newPass, @@ -207,7 +215,7 @@ func TestUserUpdatePassword(t *testing.T) { t.Logf("\t%s\tUpdatePassword ok.", tests.Success) // Verify that the user can be authenticated with the updated password. - _, err = Authenticate(ctx, test.MasterDB, tknGen, + _, err = repo.Authenticate(ctx, AuthenticateRequest{ Email: usrAcc.User.Email, Password: newPass, @@ -229,8 +237,6 @@ func TestUserResetPassword(t *testing.T) { now := time.Date(2018, time.October, 1, 0, 0, 0, 0, time.UTC) - tknGen := &auth.MockTokenGenerator{} - // Create a new user for testing. usrAcc, err := user_account.MockUserAccount(ctx, test.MasterDB, now, user_account.UserAccountRole_User) if err != nil { @@ -239,21 +245,13 @@ func TestUserResetPassword(t *testing.T) { } t.Logf("\t%s\tCreate user account ok.", tests.Success) - // Mock the methods needed to make a password reset. - resetUrl := func(string) string { - return "" - } - notify := ¬ify.MockEmail{} - - secretKey := "6368616e676520746869732070617373" - ttl := time.Hour // Make the reset password request. - resetHash, err := user.ResetPassword(ctx, test.MasterDB, resetUrl, notify, user.UserResetPasswordRequest{ + resetHash, err := repo.User.ResetPassword(ctx, user.UserResetPasswordRequest{ Email: usrAcc.User.Email, TTL: ttl, - }, secretKey, now) + }, now) if err != nil { t.Log("\t\tGot :", err) t.Fatalf("\t%s\tResetPassword failed.", tests.Failed) @@ -262,11 +260,11 @@ func TestUserResetPassword(t *testing.T) { // Assuming we have received the email and clicked the link, we now can ensure confirm works. newPass := uuid.NewRandom().String() - reset, err := user.ResetConfirm(ctx, test.MasterDB, user.UserResetConfirmRequest{ + reset, err := repo.User.ResetConfirm(ctx, user.UserResetConfirmRequest{ ResetHash: resetHash, Password: newPass, PasswordConfirm: newPass, - }, secretKey, now) + }, now) if err != nil { t.Log("\t\tGot :", err) t.Fatalf("\t%s\tResetConfirm failed.", tests.Failed) @@ -278,7 +276,7 @@ func TestUserResetPassword(t *testing.T) { t.Logf("\t%s\tResetConfirm ok.", tests.Success) // Verify that the user can be authenticated with the updated password. - _, err = Authenticate(ctx, test.MasterDB, tknGen, + _, err = repo.Authenticate(ctx, AuthenticateRequest{ Email: usrAcc.User.Email, Password: newPass, @@ -340,7 +338,7 @@ func TestSwitchAccount(t *testing.T) { } // Associate the second account with root user. - usrAcc2, err := user_account.Create(ctx, auth.Claims{}, test.MasterDB, user_account.UserAccountCreateRequest{ + usrAcc2, err := repo.UserAccount.Create(ctx, auth.Claims{}, user_account.UserAccountCreateRequest{ UserID: usrAcc.UserID, AccountID: acc2.ID, Roles: []user_account.UserAccountRole{user_account.UserAccountRole(roles[1])}, @@ -359,7 +357,7 @@ func TestSwitchAccount(t *testing.T) { } // Associate the third account with root user. - usrAcc3, err := user_account.Create(ctx, auth.Claims{}, test.MasterDB, user_account.UserAccountCreateRequest{ + usrAcc3, err := repo.UserAccount.Create(ctx, auth.Claims{}, user_account.UserAccountCreateRequest{ UserID: usrAcc.UserID, AccountID: acc3.ID, Roles: []user_account.UserAccountRole{user_account.UserAccountRole(roles[2])}, @@ -426,7 +424,7 @@ func TestSwitchAccount(t *testing.T) { } // Associate the second account with root user. - usrAcc2, err := user_account.Create(ctx, auth.Claims{}, test.MasterDB, user_account.UserAccountCreateRequest{ + usrAcc2, err := repo.UserAccount.Create(ctx, auth.Claims{}, user_account.UserAccountCreateRequest{ UserID: usrAcc.UserID, AccountID: acc2.ID, Roles: []user_account.UserAccountRole{user_account.UserAccountRole_Admin}, @@ -445,7 +443,7 @@ func TestSwitchAccount(t *testing.T) { } // Associate the third account with root user. - usrAcc3, err := user_account.Create(ctx, auth.Claims{}, test.MasterDB, user_account.UserAccountCreateRequest{ + usrAcc3, err := repo.UserAccount.Create(ctx, auth.Claims{}, user_account.UserAccountCreateRequest{ UserID: usrAcc.UserID, AccountID: acc3.ID, Roles: []user_account.UserAccountRole{user_account.UserAccountRole_User}, @@ -472,8 +470,6 @@ func TestSwitchAccount(t *testing.T) { // Add 30 minutes to now to simulate time passing. now = now.Add(time.Minute * 5) - tknGen := &auth.MockTokenGenerator{} - t.Log("Given the need to switch accounts.") { for i, authTest := range authTests { @@ -481,7 +477,7 @@ func TestSwitchAccount(t *testing.T) { { // Verify that the user can be authenticated with the created user. var claims1 auth.Claims - tkn1, err := Authenticate(ctx, test.MasterDB, tknGen, + tkn1, err := repo.Authenticate(ctx, AuthenticateRequest{ Email: authTest.root.User.Email, Password: authTest.root.User.Password, @@ -491,7 +487,7 @@ func TestSwitchAccount(t *testing.T) { t.Fatalf("\t%s\tAuthenticate user failed.", tests.Failed) } else { // Ensure the token string was correctly generated. - claims1, err = tknGen.ParseClaims(tkn1.AccessToken) + claims1, err = repo.TknGen.ParseClaims(tkn1.AccessToken) if err != nil { t.Log("\t\tGot :", err) t.Fatalf("\t%s\tParse claims from token failed.", tests.Failed) @@ -511,7 +507,7 @@ func TestSwitchAccount(t *testing.T) { // Try to switch to account 2. var claims2 auth.Claims - tkn2, err := SwitchAccount(ctx, test.MasterDB, tknGen, claims1, authTest.switch1Req, time.Hour, now, authTest.switch1Scopes...) + tkn2, err := repo.SwitchAccount(ctx, claims1, authTest.switch1Req, time.Hour, now, authTest.switch1Scopes...) if err != authTest.switch1Err { if errors.Cause(err) != authTest.switch1Err { t.Log("\t\tExpected :", authTest.switch1Err) @@ -520,7 +516,7 @@ func TestSwitchAccount(t *testing.T) { } } else { // Ensure the token string was correctly generated. - claims2, err = tknGen.ParseClaims(tkn2.AccessToken) + claims2, err = repo.TknGen.ParseClaims(tkn2.AccessToken) if err != nil { t.Log("\t\tGot :", err) t.Fatalf("\t%s\tParse claims from token failed.", tests.Failed) @@ -549,7 +545,7 @@ func TestSwitchAccount(t *testing.T) { } // Try to switch to account 3. - tkn3, err := SwitchAccount(ctx, test.MasterDB, tknGen, claims2, authTest.switch2Req, time.Hour, now, authTest.switch2Scopes...) + tkn3, err := repo.SwitchAccount(ctx, claims2, authTest.switch2Req, time.Hour, now, authTest.switch2Scopes...) if err != authTest.switch2Err { if errors.Cause(err) != authTest.switch2Err { t.Log("\t\tExpected :", authTest.switch2Err) @@ -558,7 +554,7 @@ func TestSwitchAccount(t *testing.T) { } } else { // Ensure the token string was correctly generated. - claims3, err := tknGen.ParseClaims(tkn3.AccessToken) + claims3, err := repo.TknGen.ParseClaims(tkn3.AccessToken) if err != nil { t.Log("\t\tGot :", err) t.Fatalf("\t%s\tParse claims from token failed.", tests.Failed) @@ -610,7 +606,7 @@ func TestVirtualLogin(t *testing.T) { var authTests []authTest // Root admin -> role admin -> role admin - if true { + { // Create a new user for testing. usrAcc, err := user_account.MockUserAccount(ctx, test.MasterDB, now, user_account.UserAccountRole_Admin) if err != nil { @@ -625,7 +621,7 @@ func TestVirtualLogin(t *testing.T) { } // Associate second user with basic role associated with the same account. - usrAcc2, err := user_account.Create(ctx, auth.Claims{}, test.MasterDB, user_account.UserAccountCreateRequest{ + usrAcc2, err := repo.UserAccount.Create(ctx, auth.Claims{}, user_account.UserAccountCreateRequest{ UserID: usr2.ID, AccountID: usrAcc.AccountID, Roles: []user_account.UserAccountRole{user_account.UserAccountRole(user_account.UserAccountRole_Admin)}, @@ -642,7 +638,7 @@ func TestVirtualLogin(t *testing.T) { } // Associate second user with basic role associated with the same account. - usrAcc3, err := user_account.Create(ctx, auth.Claims{}, test.MasterDB, user_account.UserAccountCreateRequest{ + usrAcc3, err := repo.UserAccount.Create(ctx, auth.Claims{}, user_account.UserAccountCreateRequest{ UserID: usr3.ID, AccountID: usrAcc.AccountID, Roles: []user_account.UserAccountRole{user_account.UserAccountRole(user_account.UserAccountRole_Admin)}, @@ -687,7 +683,7 @@ func TestVirtualLogin(t *testing.T) { } // Associate second user with basic role associated with the same account. - usrAcc2, err := user_account.Create(ctx, auth.Claims{}, test.MasterDB, user_account.UserAccountCreateRequest{ + usrAcc2, err := repo.UserAccount.Create(ctx, auth.Claims{}, user_account.UserAccountCreateRequest{ UserID: usr2.ID, AccountID: usrAcc.AccountID, Roles: []user_account.UserAccountRole{user_account.UserAccountRole(user_account.UserAccountRole_Admin)}, @@ -704,7 +700,7 @@ func TestVirtualLogin(t *testing.T) { } // Associate second user with basic role associated with the same account. - usrAcc3, err := user_account.Create(ctx, auth.Claims{}, test.MasterDB, user_account.UserAccountCreateRequest{ + usrAcc3, err := repo.UserAccount.Create(ctx, auth.Claims{}, user_account.UserAccountCreateRequest{ UserID: usr3.ID, AccountID: usrAcc.AccountID, Roles: []user_account.UserAccountRole{user_account.UserAccountRole(user_account.UserAccountRole_User)}, @@ -749,7 +745,7 @@ func TestVirtualLogin(t *testing.T) { } // Associate second user with basic role associated with the same account. - usrAcc2, err := user_account.Create(ctx, auth.Claims{}, test.MasterDB, user_account.UserAccountCreateRequest{ + usrAcc2, err := repo.UserAccount.Create(ctx, auth.Claims{}, user_account.UserAccountCreateRequest{ UserID: usr2.ID, AccountID: usrAcc.AccountID, Roles: []user_account.UserAccountRole{user_account.UserAccountRole(user_account.UserAccountRole_User)}, @@ -766,7 +762,7 @@ func TestVirtualLogin(t *testing.T) { } // Associate second user with basic role associated with the same account. - usrAcc3, err := user_account.Create(ctx, auth.Claims{}, test.MasterDB, user_account.UserAccountCreateRequest{ + usrAcc3, err := repo.UserAccount.Create(ctx, auth.Claims{}, user_account.UserAccountCreateRequest{ UserID: usr3.ID, AccountID: usrAcc.AccountID, Roles: []user_account.UserAccountRole{user_account.UserAccountRole(user_account.UserAccountRole_Admin)}, @@ -811,7 +807,7 @@ func TestVirtualLogin(t *testing.T) { } // Associate second user with basic role associated with the same account. - usrAcc2, err := user_account.Create(ctx, auth.Claims{}, test.MasterDB, user_account.UserAccountCreateRequest{ + usrAcc2, err := repo.UserAccount.Create(ctx, auth.Claims{}, user_account.UserAccountCreateRequest{ UserID: usr2.ID, AccountID: usrAcc.AccountID, Roles: []user_account.UserAccountRole{user_account.UserAccountRole(user_account.UserAccountRole_Admin)}, @@ -850,7 +846,7 @@ func TestVirtualLogin(t *testing.T) { } // Associate second user with basic role associated with the same account. - usrAcc2, err := user_account.Create(ctx, auth.Claims{}, test.MasterDB, user_account.UserAccountCreateRequest{ + usrAcc2, err := repo.UserAccount.Create(ctx, auth.Claims{}, user_account.UserAccountCreateRequest{ UserID: usr2.ID, AccountID: usrAcc.AccountID, Roles: []user_account.UserAccountRole{user_account.UserAccountRole(user_account.UserAccountRole_User)}, @@ -876,8 +872,6 @@ func TestVirtualLogin(t *testing.T) { // Add 30 minutes to now to simulate time passing. now = now.Add(time.Minute * 5) - tknGen := &auth.MockTokenGenerator{} - t.Log("Given the need to virtual login.") { for i, authTest := range authTests { @@ -885,7 +879,7 @@ func TestVirtualLogin(t *testing.T) { { // Verify that the user can be authenticated with the created user. var claims1 auth.Claims - tkn1, err := Authenticate(ctx, test.MasterDB, tknGen, + tkn1, err := repo.Authenticate(ctx, AuthenticateRequest{ Email: authTest.root.User.Email, Password: authTest.root.User.Password, @@ -895,7 +889,7 @@ func TestVirtualLogin(t *testing.T) { t.Fatalf("\t%s\tAuthenticate user failed.", tests.Failed) } else { // Ensure the token string was correctly generated. - claims1, err = tknGen.ParseClaims(tkn1.AccessToken) + claims1, err = repo.TknGen.ParseClaims(tkn1.AccessToken) if err != nil { t.Log("\t\tGot :", err) t.Fatalf("\t%s\tParse claims from token failed.", tests.Failed) @@ -915,7 +909,7 @@ func TestVirtualLogin(t *testing.T) { // Try virtual login to user 2. var claims2 auth.Claims - tkn2, err := VirtualLogin(ctx, test.MasterDB, tknGen, claims1, authTest.login1Req, time.Hour, now) + tkn2, err := repo.VirtualLogin(ctx, claims1, authTest.login1Req, time.Hour, now) if err != authTest.login1Err { if errors.Cause(err) != authTest.login1Err { t.Log("\t\tExpected :", authTest.login1Err) @@ -924,7 +918,7 @@ func TestVirtualLogin(t *testing.T) { } } else { // Ensure the token string was correctly generated. - claims2, err = tknGen.ParseClaims(tkn2.AccessToken) + claims2, err = repo.TknGen.ParseClaims(tkn2.AccessToken) if err != nil { t.Log("\t\tGot :", err) t.Fatalf("\t%s\tParse claims from token failed.", tests.Failed) @@ -948,7 +942,7 @@ func TestVirtualLogin(t *testing.T) { } // Try virtual login to user 3. - tkn3, err := VirtualLogin(ctx, test.MasterDB, tknGen, claims2, authTest.login2Req, time.Hour, now) + tkn3, err := repo.VirtualLogin(ctx, claims2, authTest.login2Req, time.Hour, now) if err != authTest.login2Err { if errors.Cause(err) != authTest.login2Err { t.Log("\t\tExpected :", authTest.login2Err) @@ -957,7 +951,7 @@ func TestVirtualLogin(t *testing.T) { } } else { // Ensure the token string was correctly generated. - claims3, err := tknGen.ParseClaims(tkn3.AccessToken) + claims3, err := repo.TknGen.ParseClaims(tkn3.AccessToken) if err != nil { t.Log("\t\tGot :", err) t.Fatalf("\t%s\tParse claims from token failed.", tests.Failed) @@ -976,14 +970,14 @@ func TestVirtualLogin(t *testing.T) { t.Logf("\t%s\tVirtualLogin user 2 with role %s ok.", tests.Success, authTest.login2Role) if authTest.login2Logout { - tknOut, err := VirtualLogout(ctx, test.MasterDB, tknGen, claims2, time.Hour, now) + tknOut, err := repo.VirtualLogout(ctx, claims2, time.Hour, now) if err != nil { t.Log("\t\tGot :", err) t.Fatalf("\t%s\tVirtualLogout user 2 failed.", tests.Failed) } // Ensure the token string was correctly generated. - claimsOut, err := tknGen.ParseClaims(tknOut.AccessToken) + claimsOut, err := repo.TknGen.ParseClaims(tknOut.AccessToken) if err != nil { t.Log("\t\tGot :", err) t.Fatalf("\t%s\tParse claims from token failed.", tests.Failed) diff --git a/internal/user_auth/models.go b/internal/user_auth/models.go index 5990253..12e41d2 100644 --- a/internal/user_auth/models.go +++ b/internal/user_auth/models.go @@ -3,9 +3,33 @@ package user_auth import ( "time" + "geeks-accelerator/oss/saas-starter-kit/internal/account/account_preference" "geeks-accelerator/oss/saas-starter-kit/internal/platform/auth" + "geeks-accelerator/oss/saas-starter-kit/internal/user" + "geeks-accelerator/oss/saas-starter-kit/internal/user_account" + "github.com/jmoiron/sqlx" ) +// Repository defines the required dependencies for User Auth. +type Repository struct { + DbConn *sqlx.DB + TknGen TokenGenerator + User *user.Repository + UserAccount *user_account.Repository + AccountPreference *account_preference.Repository +} + +// NewRepository creates a new Repository that defines dependencies for User Auth. +func NewRepository(db *sqlx.DB, tknGen TokenGenerator, user *user.Repository, usrAcc *user_account.Repository, accPref *account_preference.Repository) *Repository { + return &Repository{ + DbConn: db, + TknGen: tknGen, + User: user, + UserAccount: usrAcc, + AccountPreference: accPref, + } +} + // AuthenticateRequest defines what information is required to authenticate a user. type AuthenticateRequest struct { Email string `json:"email" validate:"required,email" example:"gabi.may@geeksinthewoods.com"` From 04e73c8f4ea9235216d779be945d17586ebc6cfa Mon Sep 17 00:00:00 2001 From: Lee Brown Date: Wed, 14 Aug 2019 12:53:40 -0800 Subject: [PATCH 04/13] completed updating web-api --- cmd/web-api/handlers/account.go | 11 +- cmd/web-api/handlers/example.go | 48 +++++++ cmd/web-api/handlers/project.go | 27 ++-- cmd/web-api/handlers/routes.go | 146 +++++++++------------ cmd/web-api/handlers/signup.go | 7 +- cmd/web-api/handlers/user.go | 41 +++--- cmd/web-api/handlers/user_account.go | 27 ++-- cmd/web-api/main.go | 138 ++++++++++++++++--- cmd/web-api/tests/project_test.go | 2 +- cmd/web-api/tests/signup_test.go | 4 +- cmd/web-api/tests/tests_test.go | 61 +++++++-- cmd/web-api/tests/user_account_test.go | 7 +- cmd/web-api/tests/user_test.go | 10 +- cmd/web-app/main.go | 105 +++++++-------- internal/platform/notify/email_disabled.go | 21 +++ internal/platform/notify/email_smtp.go | 22 ++++ internal/project/project.go | 6 +- 17 files changed, 444 insertions(+), 239 deletions(-) create mode 100644 cmd/web-api/handlers/example.go create mode 100644 internal/platform/notify/email_disabled.go diff --git a/cmd/web-api/handlers/account.go b/cmd/web-api/handlers/account.go index b26c5a1..592d962 100644 --- a/cmd/web-api/handlers/account.go +++ b/cmd/web-api/handlers/account.go @@ -10,14 +10,13 @@ import ( "geeks-accelerator/oss/saas-starter-kit/internal/platform/web" "geeks-accelerator/oss/saas-starter-kit/internal/platform/web/webcontext" "geeks-accelerator/oss/saas-starter-kit/internal/platform/web/weberror" - "github.com/jmoiron/sqlx" "github.com/pkg/errors" "gopkg.in/go-playground/validator.v9" ) // Account represents the Account API method handler set. type Account struct { - MasterDB *sqlx.DB + *account.Repository // ADD OTHER STATE LIKE THE LOGGER AND CONFIG HERE. } @@ -35,7 +34,7 @@ type Account struct { // @Failure 404 {object} weberror.ErrorResponse // @Failure 500 {object} weberror.ErrorResponse // @Router /accounts/{id} [get] -func (a *Account) Read(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { +func (h *Account) Read(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { claims, ok := ctx.Value(auth.Key).(auth.Claims) if !ok { return errors.New("claims missing from context") @@ -52,7 +51,7 @@ func (a *Account) Read(ctx context.Context, w http.ResponseWriter, r *http.Reque includeArchived = b } - res, err := account.Read(ctx, claims, a.MasterDB, account.AccountReadRequest{ + res, err := h.Repository.Read(ctx, claims, account.AccountReadRequest{ ID: params["id"], IncludeArchived: includeArchived, }) @@ -82,7 +81,7 @@ func (a *Account) Read(ctx context.Context, w http.ResponseWriter, r *http.Reque // @Failure 403 {object} weberror.ErrorResponse // @Failure 500 {object} weberror.ErrorResponse // @Router /accounts [patch] -func (a *Account) Update(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { +func (h *Account) Update(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { v, err := webcontext.ContextValues(ctx) if err != nil { @@ -102,7 +101,7 @@ func (a *Account) Update(ctx context.Context, w http.ResponseWriter, r *http.Req return web.RespondJsonError(ctx, w, err) } - err = account.Update(ctx, claims, a.MasterDB, req, v.Now) + err = h.Repository.Update(ctx, claims, req, v.Now) if err != nil { cause := errors.Cause(err) switch cause { diff --git a/cmd/web-api/handlers/example.go b/cmd/web-api/handlers/example.go new file mode 100644 index 0000000..fe15ddd --- /dev/null +++ b/cmd/web-api/handlers/example.go @@ -0,0 +1,48 @@ +package handlers + +import ( + "context" + "geeks-accelerator/oss/saas-starter-kit/internal/platform/auth" + "geeks-accelerator/oss/saas-starter-kit/internal/platform/web" + "geeks-accelerator/oss/saas-starter-kit/internal/platform/web/webcontext" + "geeks-accelerator/oss/saas-starter-kit/internal/platform/web/weberror" + "geeks-accelerator/oss/saas-starter-kit/internal/project" + "github.com/pkg/errors" + "net/http" +) + +// Example represents the Example API method handler set. +type Example struct { + Project *project.Repository + + // ADD OTHER STATE LIKE THE LOGGER AND CONFIG HERE. +} + +// ErrorResponse returns example error messages. +func (h *Example) ErrorResponse(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { + v, err := webcontext.ContextValues(ctx) + if err != nil { + return err + } + + if qv := r.URL.Query().Get("test-validation-error"); qv != "" { + _, err := h.Project.Create(ctx, auth.Claims{}, project.ProjectCreateRequest{}, v.Now) + return web.RespondJsonError(ctx, w, err) + } + + if qv := r.URL.Query().Get("test-web-error"); qv != "" { + terr := errors.New("Some random error") + terr = errors.WithMessage(terr, "Actual error message") + rerr := weberror.NewError(ctx, terr, http.StatusBadRequest).(*weberror.Error) + rerr.Message = "Test Web Error Message" + return web.RespondJsonError(ctx, w, rerr) + } + + if qv := r.URL.Query().Get("test-error"); qv != "" { + terr := errors.New("Test error") + terr = errors.WithMessage(terr, "Error message") + return web.RespondJsonError(ctx, w, terr) + } + + return nil +} diff --git a/cmd/web-api/handlers/project.go b/cmd/web-api/handlers/project.go index f2dbf6c..82c835f 100644 --- a/cmd/web-api/handlers/project.go +++ b/cmd/web-api/handlers/project.go @@ -11,14 +11,13 @@ import ( "geeks-accelerator/oss/saas-starter-kit/internal/platform/web/webcontext" "geeks-accelerator/oss/saas-starter-kit/internal/platform/web/weberror" "geeks-accelerator/oss/saas-starter-kit/internal/project" - "github.com/jmoiron/sqlx" "github.com/pkg/errors" "gopkg.in/go-playground/validator.v9" ) // Project represents the Project API method handler set. type Project struct { - MasterDB *sqlx.DB + *project.Repository // ADD OTHER STATE LIKE THE LOGGER IF NEEDED. } @@ -41,7 +40,7 @@ type Project struct { // @Failure 403 {object} weberror.ErrorResponse // @Failure 500 {object} weberror.ErrorResponse // @Router /projects [get] -func (p *Project) Find(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { +func (h *Project) Find(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { claims, ok := ctx.Value(auth.Key).(auth.Claims) if !ok { return errors.New("claims missing from context") @@ -108,7 +107,7 @@ func (p *Project) Find(ctx context.Context, w http.ResponseWriter, r *http.Reque // return web.RespondJsonError(ctx, w, err) //} - res, err := project.Find(ctx, claims, p.MasterDB, req) + res, err := h.Repository.Find(ctx, claims, req) if err != nil { return err } @@ -134,7 +133,7 @@ func (p *Project) Find(ctx context.Context, w http.ResponseWriter, r *http.Reque // @Failure 404 {object} weberror.ErrorResponse // @Failure 500 {object} weberror.ErrorResponse // @Router /projects/{id} [get] -func (p *Project) Read(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { +func (h *Project) Read(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { claims, ok := ctx.Value(auth.Key).(auth.Claims) if !ok { return errors.New("claims missing from context") @@ -151,7 +150,7 @@ func (p *Project) Read(ctx context.Context, w http.ResponseWriter, r *http.Reque includeArchived = b } - res, err := project.Read(ctx, claims, p.MasterDB, project.ProjectReadRequest{ + res, err := h.Repository.Read(ctx, claims, project.ProjectReadRequest{ ID: params["id"], IncludeArchived: includeArchived, }) @@ -182,7 +181,7 @@ func (p *Project) Read(ctx context.Context, w http.ResponseWriter, r *http.Reque // @Failure 404 {object} weberror.ErrorResponse // @Failure 500 {object} weberror.ErrorResponse // @Router /projects [post] -func (p *Project) Create(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { +func (h *Project) Create(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { v, err := webcontext.ContextValues(ctx) if err != nil { return err @@ -201,7 +200,7 @@ func (p *Project) Create(ctx context.Context, w http.ResponseWriter, r *http.Req return web.RespondJsonError(ctx, w, err) } - res, err := project.Create(ctx, claims, p.MasterDB, req, v.Now) + res, err := h.Repository.Create(ctx, claims, req, v.Now) if err != nil { cause := errors.Cause(err) switch cause { @@ -232,7 +231,7 @@ func (p *Project) Create(ctx context.Context, w http.ResponseWriter, r *http.Req // @Failure 403 {object} weberror.ErrorResponse // @Failure 500 {object} weberror.ErrorResponse // @Router /projects [patch] -func (p *Project) Update(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { +func (h *Project) Update(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { v, err := webcontext.ContextValues(ctx) if err != nil { return err @@ -251,7 +250,7 @@ func (p *Project) Update(ctx context.Context, w http.ResponseWriter, r *http.Req return web.RespondJsonError(ctx, w, err) } - err = project.Update(ctx, claims, p.MasterDB, req, v.Now) + err = h.Repository.Update(ctx, claims, req, v.Now) if err != nil { cause := errors.Cause(err) switch cause { @@ -283,7 +282,7 @@ func (p *Project) Update(ctx context.Context, w http.ResponseWriter, r *http.Req // @Failure 403 {object} weberror.ErrorResponse // @Failure 500 {object} weberror.ErrorResponse // @Router /projects/archive [patch] -func (p *Project) Archive(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { +func (h *Project) Archive(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { v, err := webcontext.ContextValues(ctx) if err != nil { return err @@ -302,7 +301,7 @@ func (p *Project) Archive(ctx context.Context, w http.ResponseWriter, r *http.Re return web.RespondJsonError(ctx, w, err) } - err = project.Archive(ctx, claims, p.MasterDB, req, v.Now) + err = h.Repository.Archive(ctx, claims, req, v.Now) if err != nil { cause := errors.Cause(err) switch cause { @@ -334,13 +333,13 @@ func (p *Project) Archive(ctx context.Context, w http.ResponseWriter, r *http.Re // @Failure 403 {object} weberror.ErrorResponse // @Failure 500 {object} weberror.ErrorResponse // @Router /projects/{id} [delete] -func (p *Project) Delete(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { +func (h *Project) Delete(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { claims, err := auth.ClaimsFromContext(ctx) if err != nil { return err } - err = project.Delete(ctx, claims, p.MasterDB, + err = h.Repository.Delete(ctx, claims, project.ProjectDeleteRequest{ID: params["id"]}) if err != nil { cause := errors.Cause(err) diff --git a/cmd/web-api/handlers/routes.go b/cmd/web-api/handlers/routes.go index b68460f..00f4704 100644 --- a/cmd/web-api/handlers/routes.go +++ b/cmd/web-api/handlers/routes.go @@ -1,122 +1,134 @@ package handlers import ( - "context" - "geeks-accelerator/oss/saas-starter-kit/internal/user" "log" "net/http" "os" + "geeks-accelerator/oss/saas-starter-kit/internal/account" + "geeks-accelerator/oss/saas-starter-kit/internal/account/account_preference" "geeks-accelerator/oss/saas-starter-kit/internal/mid" saasSwagger "geeks-accelerator/oss/saas-starter-kit/internal/mid/saas-swagger" "geeks-accelerator/oss/saas-starter-kit/internal/platform/auth" "geeks-accelerator/oss/saas-starter-kit/internal/platform/web" "geeks-accelerator/oss/saas-starter-kit/internal/platform/web/webcontext" - "geeks-accelerator/oss/saas-starter-kit/internal/platform/web/weberror" _ "geeks-accelerator/oss/saas-starter-kit/internal/platform/web/weberror" "geeks-accelerator/oss/saas-starter-kit/internal/project" + "geeks-accelerator/oss/saas-starter-kit/internal/signup" _ "geeks-accelerator/oss/saas-starter-kit/internal/signup" + "geeks-accelerator/oss/saas-starter-kit/internal/user" + "geeks-accelerator/oss/saas-starter-kit/internal/user_account" + "geeks-accelerator/oss/saas-starter-kit/internal/user_account/invite" + "geeks-accelerator/oss/saas-starter-kit/internal/user_auth" "github.com/jmoiron/sqlx" - "github.com/pkg/errors" "gopkg.in/DataDog/dd-trace-go.v1/contrib/go-redis/redis" ) - type AppContext struct { - Log *log.Logger - Env webcontext.Env - Repo *user.Repository - MasterDB *sqlx.DB - Redis *redis.Client - Authenticator *auth.Authenticator - PreAppMiddleware []web.Middleware + Log *log.Logger + Env webcontext.Env + MasterDB *sqlx.DB + Redis *redis.Client + UserRepo *user.Repository + UserAccountRepo *user_account.Repository + AccountRepo *account.Repository + AccountPrefRepo *account_preference.Repository + AuthRepo *user_auth.Repository + SignupRepo *signup.Repository + InviteRepo *invite.Repository + ProjectRepo *project.Repository + Authenticator *auth.Authenticator + PreAppMiddleware []web.Middleware PostAppMiddleware []web.Middleware } - // API returns a handler for a set of routes. -func API(shutdown chan os.Signal, appContext *AppContext ) http.Handler { +func API(shutdown chan os.Signal, appCtx *AppContext) http.Handler { // Include the pre middlewares first. - middlewares := appContext.PreAppMiddleware + middlewares := appCtx.PreAppMiddleware // Define app middlewares applied to all requests. middlewares = append(middlewares, mid.Trace(), - mid.Logger(appContext.Log), - mid.Errors(appContext.Log, nil), + mid.Logger(appCtx.Log), + mid.Errors(appCtx.Log, nil), mid.Metrics(), mid.Panics()) // Append any global middlewares that should be included after the app middlewares. - if len(appContext.PostAppMiddleware) > 0 { - middlewares = append(middlewares, appContext.PostAppMiddleware...) + if len(appCtx.PostAppMiddleware) > 0 { + middlewares = append(middlewares, appCtx.PostAppMiddleware...) } // Construct the web.App which holds all routes as well as common Middleware. - app := web.NewApp(shutdown, appContext.Log, appContext.Env, middlewares...) + app := web.NewApp(shutdown, appCtx.Log, appCtx.Env, middlewares...) // Register health check endpoint. This route is not authenticated. check := Check{ - MasterDB: appContext.MasterDB, - Redis: appContext.Redis, + MasterDB: appCtx.MasterDB, + Redis: appCtx.Redis, } app.Handle("GET", "/v1/health", check.Health) app.Handle("GET", "/ping", check.Ping) + // Register example endpoints. + ex := Example{ + Project: appCtx.ProjectRepo, + } + app.Handle("GET", "/v1/examples/error-response", ex.ErrorResponse) + // Register user management and authentication endpoints. u := User{ - MasterDB: appContext.MasterDB, - TokenGenerator: authenticator, + Repository: appCtx.UserRepo, + Auth: appCtx.AuthRepo, } - app.Handle("GET", "/v1/users", u.Find, mid.AuthenticateHeader(authenticator)) - app.Handle("POST", "/v1/users", u.Create, mid.AuthenticateHeader(authenticator), mid.HasRole(auth.RoleAdmin)) - app.Handle("GET", "/v1/users/:id", u.Read, mid.AuthenticateHeader(authenticator)) - app.Handle("PATCH", "/v1/users", u.Update, mid.AuthenticateHeader(authenticator)) - app.Handle("PATCH", "/v1/users/password", u.UpdatePassword, mid.AuthenticateHeader(authenticator)) - app.Handle("PATCH", "/v1/users/archive", u.Archive, mid.AuthenticateHeader(authenticator), mid.HasRole(auth.RoleAdmin)) - app.Handle("DELETE", "/v1/users/:id", u.Delete, mid.AuthenticateHeader(authenticator), mid.HasRole(auth.RoleAdmin)) - app.Handle("PATCH", "/v1/users/switch-account/:account_id", u.SwitchAccount, mid.AuthenticateHeader(authenticator)) + app.Handle("GET", "/v1/users", u.Find, mid.AuthenticateHeader(appCtx.Authenticator)) + app.Handle("POST", "/v1/users", u.Create, mid.AuthenticateHeader(appCtx.Authenticator), mid.HasRole(auth.RoleAdmin)) + app.Handle("GET", "/v1/users/:id", u.Read, mid.AuthenticateHeader(appCtx.Authenticator)) + app.Handle("PATCH", "/v1/users", u.Update, mid.AuthenticateHeader(appCtx.Authenticator)) + app.Handle("PATCH", "/v1/users/password", u.UpdatePassword, mid.AuthenticateHeader(appCtx.Authenticator)) + app.Handle("PATCH", "/v1/users/archive", u.Archive, mid.AuthenticateHeader(appCtx.Authenticator), mid.HasRole(auth.RoleAdmin)) + app.Handle("DELETE", "/v1/users/:id", u.Delete, mid.AuthenticateHeader(appCtx.Authenticator), mid.HasRole(auth.RoleAdmin)) + app.Handle("PATCH", "/v1/users/switch-account/:account_id", u.SwitchAccount, mid.AuthenticateHeader(appCtx.Authenticator)) // This route is not authenticated app.Handle("POST", "/v1/oauth/token", u.Token) // Register user account management endpoints. ua := UserAccount{ - MasterDB: masterDB, + Repository: appCtx.UserAccountRepo, } - app.Handle("GET", "/v1/user_accounts", ua.Find, mid.AuthenticateHeader(authenticator)) - app.Handle("POST", "/v1/user_accounts", ua.Create, mid.AuthenticateHeader(authenticator), mid.HasRole(auth.RoleAdmin)) - app.Handle("GET", "/v1/user_accounts/:user_id/:account_id", ua.Read, mid.AuthenticateHeader(authenticator)) - app.Handle("PATCH", "/v1/user_accounts", ua.Update, mid.AuthenticateHeader(authenticator)) - app.Handle("PATCH", "/v1/user_accounts/archive", ua.Archive, mid.AuthenticateHeader(authenticator), mid.HasRole(auth.RoleAdmin)) - app.Handle("DELETE", "/v1/user_accounts", ua.Delete, mid.AuthenticateHeader(authenticator), mid.HasRole(auth.RoleAdmin)) + app.Handle("GET", "/v1/user_accounts", ua.Find, mid.AuthenticateHeader(appCtx.Authenticator)) + app.Handle("POST", "/v1/user_accounts", ua.Create, mid.AuthenticateHeader(appCtx.Authenticator), mid.HasRole(auth.RoleAdmin)) + app.Handle("GET", "/v1/user_accounts/:user_id/:account_id", ua.Read, mid.AuthenticateHeader(appCtx.Authenticator)) + app.Handle("PATCH", "/v1/user_accounts", ua.Update, mid.AuthenticateHeader(appCtx.Authenticator)) + app.Handle("PATCH", "/v1/user_accounts/archive", ua.Archive, mid.AuthenticateHeader(appCtx.Authenticator), mid.HasRole(auth.RoleAdmin)) + app.Handle("DELETE", "/v1/user_accounts", ua.Delete, mid.AuthenticateHeader(appCtx.Authenticator), mid.HasRole(auth.RoleAdmin)) // Register account endpoints. a := Account{ - MasterDB: masterDB, + Repository: appCtx.AccountRepo, } - app.Handle("GET", "/v1/accounts/:id", a.Read, mid.AuthenticateHeader(authenticator)) - app.Handle("PATCH", "/v1/accounts", a.Update, mid.AuthenticateHeader(authenticator), mid.HasRole(auth.RoleAdmin)) + app.Handle("GET", "/v1/accounts/:id", a.Read, mid.AuthenticateHeader(appCtx.Authenticator)) + app.Handle("PATCH", "/v1/accounts", a.Update, mid.AuthenticateHeader(appCtx.Authenticator), mid.HasRole(auth.RoleAdmin)) // Register signup endpoints. s := Signup{ - MasterDB: masterDB, + Repository: appCtx.SignupRepo, } app.Handle("POST", "/v1/signup", s.Signup) // Register project. p := Project{ - MasterDB: masterDB, + Repository: appCtx.ProjectRepo, } - app.Handle("GET", "/v1/projects", p.Find, mid.AuthenticateHeader(authenticator)) - app.Handle("POST", "/v1/projects", p.Create, mid.AuthenticateHeader(authenticator), mid.HasRole(auth.RoleAdmin)) - app.Handle("GET", "/v1/projects/:id", p.Read, mid.AuthenticateHeader(authenticator)) - app.Handle("PATCH", "/v1/projects", p.Update, mid.AuthenticateHeader(authenticator), mid.HasRole(auth.RoleAdmin)) - app.Handle("PATCH", "/v1/projects/archive", p.Archive, mid.AuthenticateHeader(authenticator), mid.HasRole(auth.RoleAdmin)) - app.Handle("DELETE", "/v1/projects/:id", p.Delete, mid.AuthenticateHeader(authenticator), mid.HasRole(auth.RoleAdmin)) - - app.Handle("GET", "/v1/examples/error-response", ExampleErrorResponse) + app.Handle("GET", "/v1/projects", p.Find, mid.AuthenticateHeader(appCtx.Authenticator)) + app.Handle("POST", "/v1/projects", p.Create, mid.AuthenticateHeader(appCtx.Authenticator), mid.HasRole(auth.RoleAdmin)) + app.Handle("GET", "/v1/projects/:id", p.Read, mid.AuthenticateHeader(appCtx.Authenticator)) + app.Handle("PATCH", "/v1/projects", p.Update, mid.AuthenticateHeader(appCtx.Authenticator), mid.HasRole(auth.RoleAdmin)) + app.Handle("PATCH", "/v1/projects/archive", p.Archive, mid.AuthenticateHeader(appCtx.Authenticator), mid.HasRole(auth.RoleAdmin)) + app.Handle("DELETE", "/v1/projects/:id", p.Delete, mid.AuthenticateHeader(appCtx.Authenticator), mid.HasRole(auth.RoleAdmin)) // Register swagger documentation. // TODO: Add authentication. Current authenticator requires an Authorization header @@ -127,36 +139,6 @@ func API(shutdown chan os.Signal, appContext *AppContext ) http.Handler { return app } -// ExampleErrorResponse returns example error messages. -func ExampleErrorResponse(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { - v, err := webcontext.ContextValues(ctx) - if err != nil { - return err - } - - if qv := r.URL.Query().Get("test-validation-error"); qv != "" { - _, err := project.Create(ctx, auth.Claims{}, nil, project.ProjectCreateRequest{}, v.Now) - return web.RespondJsonError(ctx, w, err) - - } - - if qv := r.URL.Query().Get("test-web-error"); qv != "" { - terr := errors.New("Some random error") - terr = errors.WithMessage(terr, "Actual error message") - rerr := weberror.NewError(ctx, terr, http.StatusBadRequest).(*weberror.Error) - rerr.Message = "Test Web Error Message" - return web.RespondJsonError(ctx, w, rerr) - } - - if qv := r.URL.Query().Get("test-error"); qv != "" { - terr := errors.New("Test error") - terr = errors.WithMessage(terr, "Error message") - return web.RespondJsonError(ctx, w, terr) - } - - return nil -} - // Types godoc // @Summary List of types. // @Param data body weberror.FieldError false "Field Error" diff --git a/cmd/web-api/handlers/signup.go b/cmd/web-api/handlers/signup.go index cd0ed21..e2472a3 100644 --- a/cmd/web-api/handlers/signup.go +++ b/cmd/web-api/handlers/signup.go @@ -10,14 +10,13 @@ import ( "geeks-accelerator/oss/saas-starter-kit/internal/platform/web/webcontext" "geeks-accelerator/oss/saas-starter-kit/internal/platform/web/weberror" "geeks-accelerator/oss/saas-starter-kit/internal/signup" - "github.com/jmoiron/sqlx" "github.com/pkg/errors" "gopkg.in/go-playground/validator.v9" ) // Signup represents the Signup API method handler set. type Signup struct { - MasterDB *sqlx.DB + *signup.Repository // ADD OTHER STATE LIKE THE LOGGER AND CONFIG HERE. } @@ -33,7 +32,7 @@ type Signup struct { // @Failure 400 {object} weberror.ErrorResponse // @Failure 500 {object} weberror.ErrorResponse // @Router /signup [post] -func (c *Signup) Signup(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { +func (h *Signup) Signup(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { v, err := webcontext.ContextValues(ctx) if err != nil { return err @@ -50,7 +49,7 @@ func (c *Signup) Signup(ctx context.Context, w http.ResponseWriter, r *http.Requ return web.RespondJsonError(ctx, w, err) } - res, err := signup.Signup(ctx, claims, c.MasterDB, req, v.Now) + res, err := h.Repository.Signup(ctx, claims, req, v.Now) if err != nil { switch errors.Cause(err) { case account.ErrForbidden: diff --git a/cmd/web-api/handlers/user.go b/cmd/web-api/handlers/user.go index e1b7e7f..ddbd934 100644 --- a/cmd/web-api/handlers/user.go +++ b/cmd/web-api/handlers/user.go @@ -14,7 +14,6 @@ import ( "geeks-accelerator/oss/saas-starter-kit/internal/user" "geeks-accelerator/oss/saas-starter-kit/internal/user_auth" "github.com/gorilla/schema" - "github.com/jmoiron/sqlx" "github.com/pkg/errors" "gopkg.in/go-playground/validator.v9" ) @@ -24,8 +23,8 @@ var sessionTtl = time.Hour * 24 // User represents the User API method handler set. type User struct { - MasterDB *sqlx.DB - TokenGenerator user_auth.TokenGenerator + *user.Repository + Auth *user_auth.Repository // ADD OTHER STATE LIKE THE LOGGER AND CONFIG HERE. } @@ -47,7 +46,7 @@ type User struct { // @Failure 400 {object} weberror.ErrorResponse // @Failure 500 {object} weberror.ErrorResponse // @Router /users [get] -func (u *User) Find(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { +func (h *User) Find(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { claims, ok := ctx.Value(auth.Key).(auth.Claims) if !ok { return errors.New("claims missing from context") @@ -114,7 +113,7 @@ func (u *User) Find(ctx context.Context, w http.ResponseWriter, r *http.Request, // return web.RespondJsonError(ctx, w, err) //} - res, err := user.Find(ctx, claims, u.MasterDB, req) + res, err := h.Repository.Find(ctx, claims, req) if err != nil { return err } @@ -140,7 +139,7 @@ func (u *User) Find(ctx context.Context, w http.ResponseWriter, r *http.Request, // @Failure 404 {object} weberror.ErrorResponse // @Failure 500 {object} weberror.ErrorResponse // @Router /users/{id} [get] -func (u *User) Read(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { +func (h *User) Read(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { claims, ok := ctx.Value(auth.Key).(auth.Claims) if !ok { return errors.New("claims missing from context") @@ -157,7 +156,7 @@ func (u *User) Read(ctx context.Context, w http.ResponseWriter, r *http.Request, includeArchived = b } - res, err := user.Read(ctx, claims, u.MasterDB, user.UserReadRequest{ + res, err := h.Repository.Read(ctx, claims, user.UserReadRequest{ ID: params["id"], IncludeArchived: includeArchived, }) @@ -187,7 +186,7 @@ func (u *User) Read(ctx context.Context, w http.ResponseWriter, r *http.Request, // @Failure 403 {object} weberror.ErrorResponse // @Failure 500 {object} weberror.ErrorResponse // @Router /users [post] -func (u *User) Create(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { +func (h *User) Create(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { v, err := webcontext.ContextValues(ctx) if err != nil { return err @@ -206,7 +205,7 @@ func (u *User) Create(ctx context.Context, w http.ResponseWriter, r *http.Reques return web.RespondJsonError(ctx, w, err) } - res, err := user.Create(ctx, claims, u.MasterDB, req, v.Now) + res, err := h.Repository.Create(ctx, claims, req, v.Now) if err != nil { cause := errors.Cause(err) switch cause { @@ -238,7 +237,7 @@ func (u *User) Create(ctx context.Context, w http.ResponseWriter, r *http.Reques // @Failure 403 {object} weberror.ErrorResponse // @Failure 500 {object} weberror.ErrorResponse // @Router /users [patch] -func (u *User) Update(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { +func (h *User) Update(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { v, err := webcontext.ContextValues(ctx) if err != nil { return err @@ -257,7 +256,7 @@ func (u *User) Update(ctx context.Context, w http.ResponseWriter, r *http.Reques return web.RespondJsonError(ctx, w, err) } - err = user.Update(ctx, claims, u.MasterDB, req, v.Now) + err = h.Repository.Update(ctx, claims, req, v.Now) if err != nil { cause := errors.Cause(err) switch cause { @@ -289,7 +288,7 @@ func (u *User) Update(ctx context.Context, w http.ResponseWriter, r *http.Reques // @Failure 403 {object} weberror.ErrorResponse // @Failure 500 {object} weberror.ErrorResponse // @Router /users/password [patch] -func (u *User) UpdatePassword(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { +func (h *User) UpdatePassword(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { v, err := webcontext.ContextValues(ctx) if err != nil { return err @@ -308,7 +307,7 @@ func (u *User) UpdatePassword(ctx context.Context, w http.ResponseWriter, r *htt return web.RespondJsonError(ctx, w, err) } - err = user.UpdatePassword(ctx, claims, u.MasterDB, req, v.Now) + err = h.Repository.UpdatePassword(ctx, claims, req, v.Now) if err != nil { cause := errors.Cause(err) switch cause { @@ -342,7 +341,7 @@ func (u *User) UpdatePassword(ctx context.Context, w http.ResponseWriter, r *htt // @Failure 403 {object} weberror.ErrorResponse // @Failure 500 {object} weberror.ErrorResponse // @Router /users/archive [patch] -func (u *User) Archive(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { +func (h *User) Archive(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { v, err := webcontext.ContextValues(ctx) if err != nil { return err @@ -361,7 +360,7 @@ func (u *User) Archive(ctx context.Context, w http.ResponseWriter, r *http.Reque return web.RespondJsonError(ctx, w, err) } - err = user.Archive(ctx, claims, u.MasterDB, req, v.Now) + err = h.Repository.Archive(ctx, claims, req, v.Now) if err != nil { cause := errors.Cause(err) switch cause { @@ -393,13 +392,13 @@ func (u *User) Archive(ctx context.Context, w http.ResponseWriter, r *http.Reque // @Failure 403 {object} weberror.ErrorResponse // @Failure 500 {object} weberror.ErrorResponse // @Router /users/{id} [delete] -func (u *User) Delete(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { +func (h *User) Delete(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { claims, err := auth.ClaimsFromContext(ctx) if err != nil { return err } - err = user.Delete(ctx, claims, u.MasterDB, + err = h.Repository.Delete(ctx, claims, user.UserDeleteRequest{ID: params["id"]}) if err != nil { cause := errors.Cause(err) @@ -432,7 +431,7 @@ func (u *User) Delete(ctx context.Context, w http.ResponseWriter, r *http.Reques // @Failure 401 {object} weberror.ErrorResponse // @Failure 500 {object} weberror.ErrorResponse // @Router /users/switch-account/{account_id} [patch] -func (u *User) SwitchAccount(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { +func (h *User) SwitchAccount(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { v, err := webcontext.ContextValues(ctx) if err != nil { return err @@ -443,7 +442,7 @@ func (u *User) SwitchAccount(ctx context.Context, w http.ResponseWriter, r *http return err } - tkn, err := user_auth.SwitchAccount(ctx, u.MasterDB, u.TokenGenerator, claims, user_auth.SwitchAccountRequest{ + tkn, err := h.Auth.SwitchAccount(ctx, claims, user_auth.SwitchAccountRequest{ AccountID: params["account_id"], }, sessionTtl, v.Now) if err != nil { @@ -479,7 +478,7 @@ func (u *User) SwitchAccount(ctx context.Context, w http.ResponseWriter, r *http // @Failure 401 {object} weberror.ErrorResponse // @Failure 500 {object} weberror.ErrorResponse // @Router /oauth/token [post] -func (u *User) Token(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { +func (h *User) Token(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { v, err := webcontext.ContextValues(ctx) if err != nil { return err @@ -534,7 +533,7 @@ func (u *User) Token(ctx context.Context, w http.ResponseWriter, r *http.Request scopes = strings.Split(qv, ",") } - tkn, err := user_auth.Authenticate(ctx, u.MasterDB, u.TokenGenerator, authReq, sessionTtl, v.Now, scopes...) + tkn, err := h.Auth.Authenticate(ctx, authReq, sessionTtl, v.Now, scopes...) if err != nil { cause := errors.Cause(err) switch cause { diff --git a/cmd/web-api/handlers/user_account.go b/cmd/web-api/handlers/user_account.go index 344ac7b..aec3075 100644 --- a/cmd/web-api/handlers/user_account.go +++ b/cmd/web-api/handlers/user_account.go @@ -11,14 +11,13 @@ import ( "geeks-accelerator/oss/saas-starter-kit/internal/platform/web/webcontext" "geeks-accelerator/oss/saas-starter-kit/internal/platform/web/weberror" "geeks-accelerator/oss/saas-starter-kit/internal/user_account" - "github.com/jmoiron/sqlx" "github.com/pkg/errors" "gopkg.in/go-playground/validator.v9" ) // UserAccount represents the UserAccount API method handler set. type UserAccount struct { - MasterDB *sqlx.DB + *user_account.Repository // ADD OTHER STATE LIKE THE LOGGER AND CONFIG HERE. } @@ -41,7 +40,7 @@ type UserAccount struct { // @Failure 403 {object} weberror.ErrorResponse // @Failure 500 {object} weberror.ErrorResponse // @Router /user_accounts [get] -func (u *UserAccount) Find(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { +func (h *UserAccount) Find(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { claims, ok := ctx.Value(auth.Key).(auth.Claims) if !ok { return errors.New("claims missing from context") @@ -108,7 +107,7 @@ func (u *UserAccount) Find(ctx context.Context, w http.ResponseWriter, r *http.R // return web.RespondJsonError(ctx, w, err) //} - res, err := user_account.Find(ctx, claims, u.MasterDB, req) + res, err := h.Repository.Find(ctx, claims, req) if err != nil { return err } @@ -134,7 +133,7 @@ func (u *UserAccount) Find(ctx context.Context, w http.ResponseWriter, r *http.R // @Failure 404 {object} weberror.ErrorResponse // @Failure 500 {object} weberror.ErrorResponse // @Router /user_accounts/{user_id}/{account_id} [get] -func (u *UserAccount) Read(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { +func (h *UserAccount) Read(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { claims, ok := ctx.Value(auth.Key).(auth.Claims) if !ok { return errors.New("claims missing from context") @@ -151,7 +150,7 @@ func (u *UserAccount) Read(ctx context.Context, w http.ResponseWriter, r *http.R includeArchived = b } - res, err := user_account.Read(ctx, claims, u.MasterDB, user_account.UserAccountReadRequest{ + res, err := h.Repository.Read(ctx, claims, user_account.UserAccountReadRequest{ UserID: params["user_id"], AccountID: params["account_id"], IncludeArchived: includeArchived, @@ -183,7 +182,7 @@ func (u *UserAccount) Read(ctx context.Context, w http.ResponseWriter, r *http.R // @Failure 404 {object} weberror.ErrorResponse // @Failure 500 {object} weberror.ErrorResponse // @Router /user_accounts [post] -func (u *UserAccount) Create(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { +func (h *UserAccount) Create(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { v, err := webcontext.ContextValues(ctx) if err != nil { return err @@ -202,7 +201,7 @@ func (u *UserAccount) Create(ctx context.Context, w http.ResponseWriter, r *http return web.RespondJsonError(ctx, w, err) } - res, err := user_account.Create(ctx, claims, u.MasterDB, req, v.Now) + res, err := h.Repository.Create(ctx, claims, req, v.Now) if err != nil { cause := errors.Cause(err) switch cause { @@ -234,7 +233,7 @@ func (u *UserAccount) Create(ctx context.Context, w http.ResponseWriter, r *http // @Failure 403 {object} weberror.ErrorResponse // @Failure 500 {object} weberror.ErrorResponse // @Router /user_accounts [patch] -func (u *UserAccount) Update(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { +func (h *UserAccount) Update(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { v, err := webcontext.ContextValues(ctx) if err != nil { return err @@ -253,7 +252,7 @@ func (u *UserAccount) Update(ctx context.Context, w http.ResponseWriter, r *http return web.RespondJsonError(ctx, w, err) } - err = user_account.Update(ctx, claims, u.MasterDB, req, v.Now) + err = h.Repository.Update(ctx, claims, req, v.Now) if err != nil { cause := errors.Cause(err) switch cause { @@ -285,7 +284,7 @@ func (u *UserAccount) Update(ctx context.Context, w http.ResponseWriter, r *http // @Failure 403 {object} weberror.ErrorResponse // @Failure 500 {object} weberror.ErrorResponse // @Router /user_accounts/archive [patch] -func (u *UserAccount) Archive(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { +func (h *UserAccount) Archive(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { v, err := webcontext.ContextValues(ctx) if err != nil { return err @@ -304,7 +303,7 @@ func (u *UserAccount) Archive(ctx context.Context, w http.ResponseWriter, r *htt return web.RespondJsonError(ctx, w, err) } - err = user_account.Archive(ctx, claims, u.MasterDB, req, v.Now) + err = h.Repository.Archive(ctx, claims, req, v.Now) if err != nil { cause := errors.Cause(err) switch cause { @@ -336,7 +335,7 @@ func (u *UserAccount) Archive(ctx context.Context, w http.ResponseWriter, r *htt // @Failure 403 {object} weberror.ErrorResponse // @Failure 500 {object} weberror.ErrorResponse // @Router /user_accounts [delete] -func (u *UserAccount) Delete(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { +func (h *UserAccount) Delete(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { claims, err := auth.ClaimsFromContext(ctx) if err != nil { return err @@ -350,7 +349,7 @@ func (u *UserAccount) Delete(ctx context.Context, w http.ResponseWriter, r *http return web.RespondJsonError(ctx, w, err) } - err = user_account.Delete(ctx, claims, u.MasterDB, req) + err = h.Repository.Delete(ctx, claims, req) if err != nil { cause := errors.Cause(err) switch cause { diff --git a/cmd/web-api/main.go b/cmd/web-api/main.go index e331ead..f64f953 100644 --- a/cmd/web-api/main.go +++ b/cmd/web-api/main.go @@ -6,7 +6,6 @@ import ( "encoding/json" "expvar" "fmt" - "geeks-accelerator/oss/saas-starter-kit/internal/platform/web/webcontext" "log" "net" "net/http" @@ -21,18 +20,30 @@ import ( "geeks-accelerator/oss/saas-starter-kit/cmd/web-api/docs" "geeks-accelerator/oss/saas-starter-kit/cmd/web-api/handlers" + "geeks-accelerator/oss/saas-starter-kit/internal/account" + "geeks-accelerator/oss/saas-starter-kit/internal/account/account_preference" "geeks-accelerator/oss/saas-starter-kit/internal/mid" "geeks-accelerator/oss/saas-starter-kit/internal/platform/auth" "geeks-accelerator/oss/saas-starter-kit/internal/platform/devops" "geeks-accelerator/oss/saas-starter-kit/internal/platform/flag" - "geeks-accelerator/oss/saas-starter-kit/internal/platform/web" + "geeks-accelerator/oss/saas-starter-kit/internal/platform/notify" + "geeks-accelerator/oss/saas-starter-kit/internal/platform/web/webcontext" + "geeks-accelerator/oss/saas-starter-kit/internal/project" + "geeks-accelerator/oss/saas-starter-kit/internal/project_route" + "geeks-accelerator/oss/saas-starter-kit/internal/signup" + "geeks-accelerator/oss/saas-starter-kit/internal/user" + "geeks-accelerator/oss/saas-starter-kit/internal/user_account" + "geeks-accelerator/oss/saas-starter-kit/internal/user_account/invite" + "geeks-accelerator/oss/saas-starter-kit/internal/user_auth" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/credentials" "github.com/aws/aws-sdk-go/aws/ec2metadata" "github.com/aws/aws-sdk-go/aws/session" "github.com/go-redis/redis" + "github.com/gorilla/securecookie" "github.com/kelseyhightower/envconfig" "github.com/lib/pq" + "github.com/pkg/errors" "golang.org/x/crypto/acme" "golang.org/x/crypto/acme/autocert" awstrace "gopkg.in/DataDog/dd-trace-go.v1/contrib/aws/aws-sdk-go/aws" @@ -66,10 +77,9 @@ func main() { // ========================================================================= // Logging - log.SetFlags(log.LstdFlags|log.Lmicroseconds|log.Lshortfile) - log.SetPrefix(service+" : ") - log := log.New(os.Stdout, log.Prefix() , log.Flags()) - + log.SetFlags(log.LstdFlags | log.Lmicroseconds | log.Lshortfile) + log.SetPrefix(service + " : ") + log := log.New(os.Stdout, log.Prefix(), log.Flags()) // ========================================================================= // Configuration @@ -87,16 +97,21 @@ func main() { DisableHTTP2 bool `default:"false" envconfig:"DISABLE_HTTP2"` } Service struct { - Name string `default:"web-api" envconfig:"NAME"` - Project string `default:"" envconfig:"PROJECT"` + Name string `default:"web-api" envconfig:"SERVICE"` BaseUrl string `default:"" envconfig:"BASE_URL" example:"http://api.example.saasstartupkit.com"` HostNames []string `envconfig:"HOST_NAMES" example:"alternative-subdomain.example.saasstartupkit.com"` EnableHTTPS bool `default:"false" envconfig:"ENABLE_HTTPS"` TemplateDir string `default:"./templates" envconfig:"TEMPLATE_DIR"` - WebAppBaseUrl string `default:"http://127.0.0.1:3000" envconfig:"WEB_APP_BASE_URL" example:"www.example.saasstartupkit.com"` DebugHost string `default:"0.0.0.0:4000" envconfig:"DEBUG_HOST"` ShutdownTimeout time.Duration `default:"5s" envconfig:"SHUTDOWN_TIMEOUT"` } + Project struct { + Name string `default:"" envconfig:"PROJECT"` + SharedTemplateDir string `default:"../../resources/templates/shared" envconfig:"SHARED_TEMPLATE_DIR"` + SharedSecretKey string `default:"" envconfig:"SHARED_SECRET_KEY"` + EmailSender string `default:"test@example.saasstartupkit.com" envconfig:"EMAIL_SENDER"` + WebAppBaseUrl string `default:"http://127.0.0.1:3000" envconfig:"WEB_APP_BASE_URL" example:"www.example.saasstartupkit.com"` + } Redis struct { Host string `default:":6379" envconfig:"HOST"` DB int `default:"1" envconfig:"DB"` @@ -185,8 +200,8 @@ func main() { // deployments and distributed to each instance of the service running. if cfg.Aws.SecretsManagerConfigPrefix == "" { var pts []string - if cfg.Service.Project != "" { - pts = append(pts, cfg.Service.Project) + if cfg.Project.Name != "" { + pts = append(pts, cfg.Project.Name) } pts = append(pts, cfg.Env, cfg.Service.Name) @@ -276,6 +291,37 @@ func main() { awsSession = awstrace.WrapSession(awsSession) } + // ========================================================================= + // Shared Secret Key used for encrypting sessions and links. + + // Set the secret key if not provided in the config. + if cfg.Project.SharedSecretKey == "" { + + // AWS secrets manager ID for storing the session key. This is optional and only will be used + // if a valid AWS session is provided. + secretID := filepath.Join(cfg.Aws.SecretsManagerConfigPrefix, "sharedSecretKey") + + // If AWS is enabled, check the Secrets Manager for the session key. + if awsSession != nil { + cfg.Project.SharedSecretKey, err = devops.SecretManagerGetString(awsSession, secretID) + if err != nil && errors.Cause(err) != devops.ErrSecreteNotFound { + log.Fatalf("main : Session : %+v", err) + } + } + + // If the session key is still empty, generate a new key. + if cfg.Project.SharedSecretKey == "" { + cfg.Project.SharedSecretKey = string(securecookie.GenerateRandomKey(32)) + + if awsSession != nil { + err = devops.SecretManagerPutString(awsSession, secretID, cfg.Project.SharedSecretKey) + if err != nil { + log.Fatalf("main : Session : %+v", err) + } + } + } + } + // ========================================================================= // Start Redis // Ensure the eviction policy on the redis cluster is set correctly. @@ -346,6 +392,31 @@ func main() { } defer masterDb.Close() + // ========================================================================= + // Notify Email + var notifyEmail notify.Email + if awsSession != nil { + // Send emails with AWS SES. Alternative to use SMTP with notify.NewEmailSmtp. + notifyEmail, err = notify.NewEmailAws(awsSession, cfg.Project.SharedTemplateDir, cfg.Project.EmailSender) + if err != nil { + log.Fatalf("main : Notify Email : %+v", err) + } + + err = notifyEmail.Verify() + if err != nil { + switch errors.Cause(err) { + case notify.ErrAwsSesIdentityNotVerified: + log.Printf("main : Notify Email : %s\n", err) + case notify.ErrAwsSesSendingDisabled: + log.Printf("main : Notify Email : %s\n", err) + default: + log.Fatalf("main : Notify Email Verify : %+v", err) + } + } + } else { + notifyEmail = notify.NewEmailDisabled() + } + // ========================================================================= // Init new Authenticator var authenticator *auth.Authenticator @@ -360,11 +431,41 @@ func main() { } // ========================================================================= - // Load middlewares that need to be configured specific for the service. - var serviceMiddlewares = []web.Middleware{ - mid.Translator(webcontext.UniversalTranslator()), + // Init repositories and AppContext + + projectRoute, err := project_route.New(cfg.Service.BaseUrl, cfg.Project.WebAppBaseUrl) + if err != nil { + log.Fatalf("main : project routes : %+v", cfg.Service.BaseUrl, err) } + usrRepo := user.NewRepository(masterDb, projectRoute.UserResetPassword, notifyEmail, cfg.Project.SharedSecretKey) + usrAccRepo := user_account.NewRepository(masterDb) + accRepo := account.NewRepository(masterDb) + accPrefRepo := account_preference.NewRepository(masterDb) + authRepo := user_auth.NewRepository(masterDb, authenticator, usrRepo, usrAccRepo, accPrefRepo) + signupRepo := signup.NewRepository(masterDb, usrRepo, usrAccRepo, accRepo) + inviteRepo := invite.NewRepository(masterDb, usrRepo, usrAccRepo, accRepo, projectRoute.UserInviteAccept, notifyEmail, cfg.Project.SharedSecretKey) + prjRepo := project.NewRepository(masterDb) + + appCtx := &handlers.AppContext{ + Log: log, + Env: cfg.Env, + MasterDB: masterDb, + Redis: redisClient, + UserRepo: usrRepo, + UserAccountRepo: usrAccRepo, + AccountRepo: accRepo, + AccountPrefRepo: accPrefRepo, + AuthRepo: authRepo, + SignupRepo: signupRepo, + InviteRepo: inviteRepo, + ProjectRepo: prjRepo, + Authenticator: authenticator, + } + + // ========================================================================= + // Load middlewares that need to be configured specific for the service. + // Init redirect middleware to ensure all requests go to the primary domain contained in the base URL. if primaryServiceHost != "127.0.0.1" && primaryServiceHost != "localhost" { redirect := mid.DomainNameRedirect(mid.DomainNameRedirectConfig{ @@ -380,9 +481,12 @@ func main() { DomainName: primaryServiceHost, HTTPSEnabled: cfg.Service.EnableHTTPS, }) - serviceMiddlewares = append(serviceMiddlewares, redirect) + appCtx.PostAppMiddleware = append(appCtx.PostAppMiddleware, redirect) } + // Add the translator middleware for localization. + appCtx.PostAppMiddleware = append(appCtx.PostAppMiddleware, mid.Translator(webcontext.UniversalTranslator())) + // ========================================================================= // Start Tracing Support th := fmt.Sprintf("%s:%d", cfg.Trace.Host, cfg.Trace.Port) @@ -443,7 +547,7 @@ func main() { if cfg.HTTP.Host != "" { api := http.Server{ Addr: cfg.HTTP.Host, - Handler: handlers.API(shutdown, log, cfg.Env, masterDb, redisClient, authenticator, serviceMiddlewares...), + Handler: handlers.API(shutdown, appCtx), ReadTimeout: cfg.HTTP.ReadTimeout, WriteTimeout: cfg.HTTP.WriteTimeout, MaxHeaderBytes: 1 << 20, @@ -460,7 +564,7 @@ func main() { if cfg.HTTPS.Host != "" { api := http.Server{ Addr: cfg.HTTPS.Host, - Handler: handlers.API(shutdown, log, cfg.Env, masterDb, redisClient, authenticator, serviceMiddlewares...), + Handler: handlers.API(shutdown, appCtx), ReadTimeout: cfg.HTTPS.ReadTimeout, WriteTimeout: cfg.HTTPS.WriteTimeout, MaxHeaderBytes: 1 << 20, diff --git a/cmd/web-api/tests/project_test.go b/cmd/web-api/tests/project_test.go index fbefebd..504cdd4 100644 --- a/cmd/web-api/tests/project_test.go +++ b/cmd/web-api/tests/project_test.go @@ -27,7 +27,7 @@ func mockProjectCreateRequest(accountID string) project.ProjectCreateRequest { // mockProject creates a new project for testing and associates it with the supplied account ID. func newMockProject(accountID string) *project.Project { req := mockProjectCreateRequest(accountID) - p, err := project.Create(tests.Context(), auth.Claims{}, test.MasterDB, req, time.Now().UTC().AddDate(-1, -1, -1)) + p, err := appCtx.ProjectRepo.Create(tests.Context(), auth.Claims{}, req, time.Now().UTC().AddDate(-1, -1, -1)) if err != nil { panic(err) } diff --git a/cmd/web-api/tests/signup_test.go b/cmd/web-api/tests/signup_test.go index 686f2a7..e1593f5 100644 --- a/cmd/web-api/tests/signup_test.go +++ b/cmd/web-api/tests/signup_test.go @@ -50,13 +50,13 @@ func mockSignupRequest() signup.SignupRequest { func newMockSignup() mockSignup { req := mockSignupRequest() now := time.Now().UTC().AddDate(-1, -1, -1) - s, err := signup.Signup(tests.Context(), auth.Claims{}, test.MasterDB, req, now) + s, err := appCtx.SignupRepo.Signup(tests.Context(), auth.Claims{}, req, now) if err != nil { panic(err) } expires := time.Now().UTC().Sub(s.User.CreatedAt) + time.Hour - tkn, err := user_auth.Authenticate(tests.Context(), test.MasterDB, authenticator, user_auth.AuthenticateRequest{ + tkn, err := appCtx.AuthRepo.Authenticate(tests.Context(), user_auth.AuthenticateRequest{ Email: req.User.Email, Password: req.User.Password, }, expires, now) diff --git a/cmd/web-api/tests/tests_test.go b/cmd/web-api/tests/tests_test.go index 83608f5..f968787 100644 --- a/cmd/web-api/tests/tests_test.go +++ b/cmd/web-api/tests/tests_test.go @@ -5,6 +5,11 @@ import ( "context" "encoding/json" "fmt" + "geeks-accelerator/oss/saas-starter-kit/internal/account/account_preference" + "geeks-accelerator/oss/saas-starter-kit/internal/platform/notify" + "geeks-accelerator/oss/saas-starter-kit/internal/project" + "geeks-accelerator/oss/saas-starter-kit/internal/project_route" + "geeks-accelerator/oss/saas-starter-kit/internal/user_account/invite" "io" "io/ioutil" "net/http" @@ -31,9 +36,12 @@ import ( "github.com/pkg/errors" ) -var a http.Handler -var test *tests.Test -var authenticator *auth.Authenticator +var ( + a http.Handler + test *tests.Test + authenticator *auth.Authenticator + appCtx *handlers.AppContext +) // Information about the users we have created for testing. type roleTest struct { @@ -84,18 +92,51 @@ func testMain(m *testing.M) int { log := test.Log log.SetOutput(ioutil.Discard) - a = handlers.API(shutdown, log, webcontext.Env_Dev, test.MasterDB, nil, authenticator) + + projectRoute, err := project_route.New("http://web-api.com", "http://web-app.com") + if err != nil { + panic(err) + } + + notifyEmail := notify.NewEmailDisabled() + + usrRepo := user.MockRepository(test.MasterDB) + usrAccRepo := user_account.NewRepository(test.MasterDB) + accRepo := account.NewRepository(test.MasterDB) + accPrefRepo := account_preference.NewRepository(test.MasterDB) + authRepo := user_auth.NewRepository(test.MasterDB, authenticator, usrRepo, usrAccRepo, accPrefRepo) + signupRepo := signup.NewRepository(test.MasterDB, usrRepo, usrAccRepo, accRepo) + inviteRepo := invite.NewRepository(test.MasterDB, usrRepo, usrAccRepo, accRepo, projectRoute.UserInviteAccept, notifyEmail, "6368616e676520746869732070613434") + prjRepo := project.NewRepository(test.MasterDB) + + appCtx = &handlers.AppContext{ + Log: log, + Env: webcontext.Env_Dev, + MasterDB: test.MasterDB, + Redis: nil, + UserRepo: usrRepo, + UserAccountRepo: usrAccRepo, + AccountRepo: accRepo, + AccountPrefRepo: accPrefRepo, + AuthRepo: authRepo, + SignupRepo: signupRepo, + InviteRepo: inviteRepo, + ProjectRepo: prjRepo, + Authenticator: authenticator, + } + + a = handlers.API(shutdown, appCtx) // Create a new account directly business logic. This creates an // initial account and user that we will use for admin validated endpoints. signupReq1 := mockSignupRequest() - signup1, err := signup.Signup(tests.Context(), auth.Claims{}, test.MasterDB, signupReq1, now) + signup1, err := signupRepo.Signup(tests.Context(), auth.Claims{}, signupReq1, now) if err != nil { panic(err) } expires := time.Now().UTC().Sub(signup1.User.CreatedAt) + time.Hour - adminTkn, err := user_auth.Authenticate(tests.Context(), test.MasterDB, authenticator, user_auth.AuthenticateRequest{ + adminTkn, err := authRepo.Authenticate(tests.Context(), user_auth.AuthenticateRequest{ Email: signupReq1.User.Email, Password: signupReq1.User.Password, }, expires, now) @@ -110,7 +151,7 @@ func testMain(m *testing.M) int { // Create a second account that the first account user should not have access to. signupReq2 := mockSignupRequest() - signup2, err := signup.Signup(tests.Context(), auth.Claims{}, test.MasterDB, signupReq2, now) + signup2, err := signupRepo.Signup(tests.Context(), auth.Claims{}, signupReq2, now) if err != nil { panic(err) } @@ -134,12 +175,12 @@ func testMain(m *testing.M) int { Password: "akTechFr0n!ier", PasswordConfirm: "akTechFr0n!ier", } - usr, err := user.Create(tests.Context(), adminClaims, test.MasterDB, userReq, now) + usr, err := usrRepo.Create(tests.Context(), adminClaims, userReq, now) if err != nil { panic(err) } - _, err = user_account.Create(tests.Context(), adminClaims, test.MasterDB, user_account.UserAccountCreateRequest{ + _, err = usrAccRepo.Create(tests.Context(), adminClaims, user_account.UserAccountCreateRequest{ UserID: usr.ID, AccountID: signup1.Account.ID, Roles: []user_account.UserAccountRole{user_account.UserAccountRole_User}, @@ -149,7 +190,7 @@ func testMain(m *testing.M) int { panic(err) } - userTkn, err := user_auth.Authenticate(tests.Context(), test.MasterDB, authenticator, user_auth.AuthenticateRequest{ + userTkn, err := authRepo.Authenticate(tests.Context(), user_auth.AuthenticateRequest{ Email: usr.Email, Password: userReq.Password, }, expires, now) diff --git a/cmd/web-api/tests/user_account_test.go b/cmd/web-api/tests/user_account_test.go index e1c650c..ff499be 100644 --- a/cmd/web-api/tests/user_account_test.go +++ b/cmd/web-api/tests/user_account_test.go @@ -14,7 +14,6 @@ import ( "geeks-accelerator/oss/saas-starter-kit/internal/platform/tests" "geeks-accelerator/oss/saas-starter-kit/internal/platform/web" "geeks-accelerator/oss/saas-starter-kit/internal/platform/web/weberror" - "geeks-accelerator/oss/saas-starter-kit/internal/user" "geeks-accelerator/oss/saas-starter-kit/internal/user_account" "github.com/pborman/uuid" ) @@ -22,12 +21,12 @@ import ( // newMockUserAccount creates a new user user for testing and associates it with the supplied account ID. func newMockUserAccount(accountID string, role user_account.UserAccountRole) *user_account.UserAccount { req := mockUserCreateRequest() - u, err := user.Create(tests.Context(), auth.Claims{}, test.MasterDB, req, time.Now().UTC().AddDate(-1, -1, -1)) + u, err := appCtx.UserRepo.Create(tests.Context(), auth.Claims{}, req, time.Now().UTC().AddDate(-1, -1, -1)) if err != nil { panic(err) } - ua, err := user_account.Create(tests.Context(), auth.Claims{}, test.MasterDB, user_account.UserAccountCreateRequest{ + ua, err := appCtx.UserAccountRepo.Create(tests.Context(), auth.Claims{}, user_account.UserAccountCreateRequest{ UserID: u.ID, AccountID: accountID, Roles: []user_account.UserAccountRole{role}, @@ -65,7 +64,7 @@ func TestUserAccountCRUDAdmin(t *testing.T) { } t.Logf("\tTest: %s - %s %s", rt.name, rt.method, rt.url) - newUser, err := user.Create(tests.Context(), auth.Claims{}, test.MasterDB, mockUserCreateRequest(), time.Now().UTC().AddDate(-1, -1, -1)) + newUser, err := appCtx.UserRepo.Create(tests.Context(), auth.Claims{}, mockUserCreateRequest(), time.Now().UTC().AddDate(-1, -1, -1)) if err != nil { t.Fatalf("\t%s\tCreate new user failed.", tests.Failed) } diff --git a/cmd/web-api/tests/user_test.go b/cmd/web-api/tests/user_test.go index 4b55dd4..f823c6b 100644 --- a/cmd/web-api/tests/user_test.go +++ b/cmd/web-api/tests/user_test.go @@ -38,12 +38,12 @@ func mockUserCreateRequest() user.UserCreateRequest { // mockUser creates a new user for testing and associates it with the supplied account ID. func newMockUser(accountID string, role user_account.UserAccountRole) mockUser { req := mockUserCreateRequest() - u, err := user.Create(tests.Context(), auth.Claims{}, test.MasterDB, req, time.Now().UTC().AddDate(-1, -1, -1)) + u, err := appCtx.UserRepo.Create(tests.Context(), auth.Claims{}, req, time.Now().UTC().AddDate(-1, -1, -1)) if err != nil { panic(err) } - _, err = user_account.Create(tests.Context(), auth.Claims{}, test.MasterDB, user_account.UserAccountCreateRequest{ + _, err = appCtx.UserAccountRepo.Create(tests.Context(), auth.Claims{}, user_account.UserAccountCreateRequest{ UserID: u.ID, AccountID: accountID, Roles: []user_account.UserAccountRole{role}, @@ -126,7 +126,7 @@ func TestUserCRUDAdmin(t *testing.T) { t.Logf("\t%s\tReceived expected result.", tests.Success) // Only for user creation do we need to do this. - _, err := user_account.Create(tests.Context(), auth.Claims{}, test.MasterDB, user_account.UserAccountCreateRequest{ + _, err := appCtx.UserAccountRepo.Create(tests.Context(), auth.Claims{}, user_account.UserAccountCreateRequest{ UserID: actual.ID, AccountID: tr.Account.ID, Roles: []user_account.UserAccountRole{user_account.UserAccountRole_User}, @@ -401,7 +401,7 @@ func TestUserCRUDAdmin(t *testing.T) { } t.Logf("\tTest: %s - %s %s", rt.name, rt.method, rt.url) - _, err := user_account.Create(tests.Context(), auth.Claims{}, test.MasterDB, user_account.UserAccountCreateRequest{ + _, err := appCtx.UserAccountRepo.Create(tests.Context(), auth.Claims{}, user_account.UserAccountCreateRequest{ UserID: tr.User.ID, AccountID: newAccount.ID, Roles: []user_account.UserAccountRole{user_account.UserAccountRole_User}, @@ -805,7 +805,7 @@ func TestUserCRUDUser(t *testing.T) { } t.Logf("\tTest: %s - %s %s", rt.name, rt.method, rt.url) - _, err := user_account.Create(tests.Context(), auth.Claims{}, test.MasterDB, user_account.UserAccountCreateRequest{ + _, err := appCtx.UserAccountRepo.Create(tests.Context(), auth.Claims{}, user_account.UserAccountCreateRequest{ UserID: tr.User.ID, AccountID: newAccount.ID, Roles: []user_account.UserAccountRole{user_account.UserAccountRole_User}, diff --git a/cmd/web-app/main.go b/cmd/web-app/main.go index 046b695..a89843f 100644 --- a/cmd/web-app/main.go +++ b/cmd/web-app/main.go @@ -6,6 +6,7 @@ import ( "encoding/json" "expvar" "fmt" + "geeks-accelerator/oss/saas-starter-kit/internal/project_route" "html/template" "log" "net" @@ -88,12 +89,10 @@ func main() { } Service struct { Name string `default:"web-app" envconfig:"NAME"` - Project string `default:"" envconfig:"PROJECT"` BaseUrl string `default:"" envconfig:"BASE_URL" example:"http://example.saasstartupkit.com"` HostNames []string `envconfig:"HOST_NAMES" example:"www.example.saasstartupkit.com"` EnableHTTPS bool `default:"false" envconfig:"ENABLE_HTTPS"` TemplateDir string `default:"./templates" envconfig:"TEMPLATE_DIR"` - SharedTemplateDir string `default:"../../resources/templates/shared" envconfig:"SHARED_TEMPLATE_DIR"` StaticFiles struct { Dir string `default:"./static" envconfig:"STATIC_DIR"` S3Enabled bool `envconfig:"S3_ENABLED"` @@ -101,13 +100,17 @@ func main() { CloudFrontEnabled bool `envconfig:"CLOUDFRONT_ENABLED"` ImgResizeEnabled bool `envconfig:"IMG_RESIZE_ENABLED"` } - WebApiBaseUrl string `default:"http://127.0.0.1:3001" envconfig:"WEB_API_BASE_URL" example:"http://api.example.saasstartupkit.com"` - SessionKey string `default:"" envconfig:"SESSION_KEY"` SessionName string `default:"" envconfig:"SESSION_NAME"` - EmailSender string `default:"test@example.saasstartupkit.com" envconfig:"EMAIL_SENDER"` DebugHost string `default:"0.0.0.0:4000" envconfig:"DEBUG_HOST"` ShutdownTimeout time.Duration `default:"5s" envconfig:"SHUTDOWN_TIMEOUT"` } + Project struct { + Name string `default:"" envconfig:"PROJECT"` + SharedTemplateDir string `default:"../../resources/templates/shared" envconfig:"SHARED_TEMPLATE_DIR"` + SharedSecretKey string `default:"" envconfig:"SHARED_SECRET_KEY"` + EmailSender string `default:"test@example.saasstartupkit.com" envconfig:"EMAIL_SENDER"` + WebApiBaseUrl string `default:"http://127.0.0.1:3001" envconfig:"WEB_API_BASE_URL" example:"http://api.example.saasstartupkit.com"` + } Redis struct { Host string `default:":6379" envconfig:"HOST"` DB int `default:"1" envconfig:"DB"` @@ -145,12 +148,6 @@ func main() { UseAwsSecretManager bool `default:"false" envconfig:"USE_AWS_SECRET_MANAGER"` KeyExpiration time.Duration `default:"3600s" envconfig:"KEY_EXPIRATION"` } - STMP struct { - Host string `default:"localhost" envconfig:"HOST"` - Port int `default:"25" envconfig:"PORT"` - User string `default:"" envconfig:"USER"` - Pass string `default:"" envconfig:"PASS" json:"-"` // don't print - } BuildInfo struct { CiCommitRefName string `envconfig:"CI_COMMIT_REF_NAME"` CiCommitShortSha string `envconfig:"CI_COMMIT_SHORT_SHA"` @@ -202,8 +199,8 @@ func main() { // deployments and distributed to each instance of the service running. if cfg.Aws.SecretsManagerConfigPrefix == "" { var pts []string - if cfg.Service.Project != "" { - pts = append(pts, cfg.Service.Project) + if cfg.Project.Name != "" { + pts = append(pts, cfg.Project.Name) } pts = append(pts, cfg.Env, cfg.Service.Name) @@ -293,6 +290,37 @@ func main() { awsSession = awstrace.WrapSession(awsSession) } + // ========================================================================= + // Shared Secret Key used for encrypting sessions and links. + + // Set the secret key if not provided in the config. + if cfg.Project.SharedSecretKey == "" { + + // AWS secrets manager ID for storing the session key. This is optional and only will be used + // if a valid AWS session is provided. + secretID := filepath.Join(cfg.Aws.SecretsManagerConfigPrefix, "sharedSecretKey") + + // If AWS is enabled, check the Secrets Manager for the session key. + if awsSession != nil { + cfg.Project.SharedSecretKey, err = devops.SecretManagerGetString(awsSession, secretID) + if err != nil && errors.Cause(err) != devops.ErrSecreteNotFound { + log.Fatalf("main : Session : %+v", err) + } + } + + // If the session key is still empty, generate a new key. + if cfg.Project.SharedSecretKey == "" { + cfg.Project.SharedSecretKey = string(securecookie.GenerateRandomKey(32)) + + if awsSession != nil { + err = devops.SecretManagerPutString(awsSession, secretID, cfg.Service.SecretKey) + if err != nil { + log.Fatalf("main : Session : %+v", err) + } + } + } + } + // ========================================================================= // Start Redis // Ensure the eviction policy on the redis cluster is set correctly. @@ -367,6 +395,7 @@ func main() { // Notify Email var notifyEmail notify.Email if awsSession != nil { + // Send emails with AWS SES. Alternative to use SMTP with notify.NewEmailSmtp. notifyEmail, err = notify.NewEmailAws(awsSession, cfg.Service.SharedTemplateDir, cfg.Service.EmailSender) if err != nil { log.Fatalf("main : Notify Email : %+v", err) @@ -384,15 +413,7 @@ func main() { } } } else { - d := gomail.Dialer{ - Host: cfg.STMP.Host, - Port: cfg.STMP.Port, - Username: cfg.STMP.User, - Password: cfg.STMP.Pass} - notifyEmail, err = notify.NewEmailSmtp(d, cfg.Service.SharedTemplateDir, cfg.Service.EmailSender) - if err != nil { - log.Fatalf("main : Notify Email : %+v", err) - } + notifyEmail = notify.NewEmailDisabled() } // ========================================================================= @@ -433,46 +454,18 @@ func main() { serviceMiddlewares = append(serviceMiddlewares, redirect) } + // Generate the new session store and append it to the global list of middlewares. + // Init session store if cfg.Service.SessionName == "" { cfg.Service.SessionName = fmt.Sprintf("%s-session", cfg.Service.Name) } - - // Set the session key if not provided in the config. - if cfg.Service.SessionKey == "" { - - // AWS secrets manager ID for storing the session key. This is optional and only will be used - // if a valid AWS session is provided. - secretID := filepath.Join(cfg.Aws.SecretsManagerConfigPrefix, "session") - - // If AWS is enabled, check the Secrets Manager for the session key. - if awsSession != nil { - cfg.Service.SessionKey, err = devops.SecretManagerGetString(awsSession, secretID) - if err != nil && errors.Cause(err) != devops.ErrSecreteNotFound { - log.Fatalf("main : Session : %+v", err) - } - } - - // If the session key is still empty, generate a new key. - if cfg.Service.SessionKey == "" { - cfg.Service.SessionKey = string(securecookie.GenerateRandomKey(32)) - - if awsSession != nil { - err = devops.SecretManagerPutString(awsSession, secretID, cfg.Service.SessionKey) - if err != nil { - log.Fatalf("main : Session : %+v", err) - } - } - } - } - - // Generate the new session store and append it to the global list of middlewares. - sessionStore := sessions.NewCookieStore([]byte(cfg.Service.SessionKey)) + sessionStore := sessions.NewCookieStore([]byte(cfg.Service.SecretKey)) serviceMiddlewares = append(serviceMiddlewares, mid.Session(sessionStore, cfg.Service.SessionName)) // ========================================================================= // URL Formatter - projectRoutes, err := project_routes.New(cfg.Service.WebApiBaseUrl, cfg.Service.BaseUrl) + projectRoutes, err := project_route.New(cfg.Service.WebApiBaseUrl, cfg.Service.BaseUrl) if err != nil { log.Fatalf("main : project routes : %+v", cfg.Service.BaseUrl, err) } @@ -926,7 +919,7 @@ func main() { if cfg.HTTP.Host != "" { api := http.Server{ Addr: cfg.HTTP.Host, - Handler: handlers.APP(shutdown, log, cfg.Env, cfg.Service.StaticFiles.Dir, cfg.Service.TemplateDir, masterDb, redisClient, authenticator, projectRoutes, cfg.Service.SessionKey, notifyEmail, renderer, serviceMiddlewares...), + Handler: handlers.APP(shutdown, log, cfg.Env, cfg.Service.StaticFiles.Dir, cfg.Service.TemplateDir, masterDb, redisClient, authenticator, projectRoutes, cfg.Service.SecretKey, notifyEmail, renderer, serviceMiddlewares...), ReadTimeout: cfg.HTTP.ReadTimeout, WriteTimeout: cfg.HTTP.WriteTimeout, MaxHeaderBytes: 1 << 20, @@ -943,7 +936,7 @@ func main() { if cfg.HTTPS.Host != "" { api := http.Server{ Addr: cfg.HTTPS.Host, - Handler: handlers.APP(shutdown, log, cfg.Env, cfg.Service.StaticFiles.Dir, cfg.Service.TemplateDir, masterDb, redisClient, authenticator, projectRoutes, cfg.Service.SessionKey, notifyEmail, renderer, serviceMiddlewares...), + Handler: handlers.APP(shutdown, log, cfg.Env, cfg.Service.StaticFiles.Dir, cfg.Service.TemplateDir, masterDb, redisClient, authenticator, projectRoutes, cfg.Service.SecretKey, notifyEmail, renderer, serviceMiddlewares...), ReadTimeout: cfg.HTTPS.ReadTimeout, WriteTimeout: cfg.HTTPS.WriteTimeout, MaxHeaderBytes: 1 << 20, diff --git a/internal/platform/notify/email_disabled.go b/internal/platform/notify/email_disabled.go new file mode 100644 index 0000000..85345a1 --- /dev/null +++ b/internal/platform/notify/email_disabled.go @@ -0,0 +1,21 @@ +package notify + +import "context" + +// DisableEmail defines an implementation of the email interface that doesn't send any email. +type DisableEmail struct{} + +// NewEmailDisabled disables sending any emails with an empty implementation of the email interface. +func NewEmailDisabled() *DisableEmail { + return &DisableEmail{} +} + +// Send does nothing. +func (n *DisableEmail) Send(ctx context.Context, toEmail, subject, templateName string, data map[string]interface{}) error { + return nil +} + +// Verify does nothing. +func (n *DisableEmail) Verify() error { + return nil +} diff --git a/internal/platform/notify/email_smtp.go b/internal/platform/notify/email_smtp.go index 748a7d3..ea56c4a 100644 --- a/internal/platform/notify/email_smtp.go +++ b/internal/platform/notify/email_smtp.go @@ -1,5 +1,27 @@ package notify +/* + // Alternative to use AWS SES with SMTP + import "gopkg.in/gomail.v2" + + var cfg struct { + ... + SMTP struct { + Host string `default:"localhost" envconfig:"HOST"` + Port int `default:"25" envconfig:"PORT"` + User string `default:"" envconfig:"USER"` + Pass string `default:"" envconfig:"PASS" json:"-"` // don't print + }, + } + + d := gomail.Dialer{ + Host: cfg.SMTP.Host, + Port: cfg.SMTP.Port, + Username: cfg.SMTP.User, + Password: cfg.SMTP.Pass} + notifyEmail, err = notify.NewEmailSmtp(d, cfg.Service.SharedTemplateDir, cfg.Service.EmailSender) + */ + import ( "context" "github.com/pkg/errors" diff --git a/internal/project/project.go b/internal/project/project.go index b4fe3e2..b130702 100644 --- a/internal/project/project.go +++ b/internal/project/project.go @@ -414,7 +414,7 @@ func (repo *Repository) Archive(ctx context.Context, claims auth.Claims, req Pro } // Delete removes an project from the database. -func (repo *Repository) Delete(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req ProjectDeleteRequest) error { +func (repo *Repository) Delete(ctx context.Context, claims auth.Claims, req ProjectDeleteRequest) error { span, ctx := tracer.StartSpanFromContext(ctx, "internal.project.Delete") defer span.Finish() @@ -437,8 +437,8 @@ func (repo *Repository) Delete(ctx context.Context, claims auth.Claims, dbConn * query.Where(query.Equal("id", req.ID)) // Execute the query with the provided context. sql, args := query.Build() - sql = dbConn.Rebind(sql) - _, err = dbConn.ExecContext(ctx, sql, args...) + sql = repo.DbConn.Rebind(sql) + _, err = repo.DbConn.ExecContext(ctx, sql, args...) if err != nil { err = errors.Wrapf(err, "query - %s", query.String()) err = errors.WithMessagef(err, "delete project %s failed", req.ID) From 102ca821255134c8cf066479317dbb4be40e178e Mon Sep 17 00:00:00 2001 From: Lee Brown Date: Wed, 14 Aug 2019 17:59:47 -0800 Subject: [PATCH 05/13] Completed web-api and web-app updates --- cmd/web-api/ecs-task-definition.json | 7 +- cmd/web-api/main.go | 8 +- cmd/web-app/ecs-task-definition.json | 7 +- cmd/web-app/handlers/account.go | 24 +-- cmd/web-app/handlers/projects.go | 19 ++- cmd/web-app/handlers/root.go | 18 +-- cmd/web-app/handlers/routes.go | 204 +++++++++++++++---------- cmd/web-app/handlers/signup.go | 13 +- cmd/web-app/handlers/user.go | 66 ++++---- cmd/web-app/handlers/users.go | 61 ++++---- cmd/web-app/main.go | 112 +++++++++----- internal/platform/notify/email_smtp.go | 2 +- internal/user/models.go | 5 + 13 files changed, 312 insertions(+), 234 deletions(-) diff --git a/cmd/web-api/ecs-task-definition.json b/cmd/web-api/ecs-task-definition.json index dc78985..59ffe40 100644 --- a/cmd/web-api/ecs-task-definition.json +++ b/cmd/web-api/ecs-task-definition.json @@ -34,12 +34,13 @@ {"name": "ECS_SERVICE", "value": "{ECS_SERVICE}"}, {"name": "WEB_API_HTTP_HOST", "value": "{HTTP_HOST}"}, {"name": "WEB_API_HTTPS_HOST", "value": "{HTTPS_HOST}"}, - {"name": "WEB_API_SERVICE_PROJECT", "value": "{APP_PROJECT}"}, + {"name": "WEB_API_SERVICE_SERVICE_NAME", "value": "{SERVICE}"}, {"name": "WEB_API_SERVICE_BASE_URL", "value": "{APP_BASE_URL}"}, {"name": "WEB_API_SERVICE_HOST_NAMES", "value": "{HOST_NAMES}"}, {"name": "WEB_API_SERVICE_ENABLE_HTTPS", "value": "{HTTPS_ENABLED}"}, - {"name": "WEB_API_SERVICE_EMAIL_SENDER", "value": "{EMAIL_SENDER}"}, - {"name": "WEB_API_SERVICE_WEB_APP_BASE_URL", "value": "{WEB_APP_BASE_URL}"}, + {"name": "WEB_API_PROJECT_PROJECT_NAME", "value": "{APP_PROJECT}"}, + {"name": "WEB_API_PROJECT_EMAIL_SENDER", "value": "{EMAIL_SENDER}"}, + {"name": "WEB_API_PROJECT_WEB_APP_BASE_URL", "value": "{WEB_APP_BASE_URL}"}, {"name": "WEB_API_REDIS_HOST", "value": "{CACHE_HOST}"}, {"name": "WEB_API_DB_HOST", "value": "{DB_HOST}"}, {"name": "WEB_API_DB_USER", "value": "{DB_USER}"}, diff --git a/cmd/web-api/main.go b/cmd/web-api/main.go index f64f953..9a48ec4 100644 --- a/cmd/web-api/main.go +++ b/cmd/web-api/main.go @@ -97,7 +97,7 @@ func main() { DisableHTTP2 bool `default:"false" envconfig:"DISABLE_HTTP2"` } Service struct { - Name string `default:"web-api" envconfig:"SERVICE"` + Name string `default:"web-api" envconfig:"SERVICE_NAME"` BaseUrl string `default:"" envconfig:"BASE_URL" example:"http://api.example.saasstartupkit.com"` HostNames []string `envconfig:"HOST_NAMES" example:"alternative-subdomain.example.saasstartupkit.com"` EnableHTTPS bool `default:"false" envconfig:"ENABLE_HTTPS"` @@ -106,7 +106,7 @@ func main() { ShutdownTimeout time.Duration `default:"5s" envconfig:"SHUTDOWN_TIMEOUT"` } Project struct { - Name string `default:"" envconfig:"PROJECT"` + Name string `default:"" envconfig:"PROJECT_NAME"` SharedTemplateDir string `default:"../../resources/templates/shared" envconfig:"SHARED_TEMPLATE_DIR"` SharedSecretKey string `default:"" envconfig:"SHARED_SECRET_KEY"` EmailSender string `default:"test@example.saasstartupkit.com" envconfig:"EMAIL_SENDER"` @@ -203,7 +203,7 @@ func main() { if cfg.Project.Name != "" { pts = append(pts, cfg.Project.Name) } - pts = append(pts, cfg.Env, cfg.Service.Name) + pts = append(pts, cfg.Env) cfg.Aws.SecretsManagerConfigPrefix = filepath.Join(pts...) } @@ -299,7 +299,7 @@ func main() { // AWS secrets manager ID for storing the session key. This is optional and only will be used // if a valid AWS session is provided. - secretID := filepath.Join(cfg.Aws.SecretsManagerConfigPrefix, "sharedSecretKey") + secretID := filepath.Join(cfg.Aws.SecretsManagerConfigPrefix, "SharedSecretKey") // If AWS is enabled, check the Secrets Manager for the session key. if awsSession != nil { diff --git a/cmd/web-app/ecs-task-definition.json b/cmd/web-app/ecs-task-definition.json index 1132f7a..88b63ef 100644 --- a/cmd/web-app/ecs-task-definition.json +++ b/cmd/web-app/ecs-task-definition.json @@ -34,7 +34,7 @@ {"name": "ECS_SERVICE", "value": "{ECS_SERVICE}"}, {"name": "WEB_APP_HTTP_HOST", "value": "{HTTP_HOST}"}, {"name": "WEB_APP_HTTPS_HOST", "value": "{HTTPS_HOST}"}, - {"name": "WEB_APP_SERVICE_PROJECT", "value": "{APP_PROJECT}"}, + {"name": "WEB_APP_SERVICE_SERVICE_NAME", "value": "{SERVICE}"}, {"name": "WEB_APP_SERVICE_BASE_URL", "value": "{APP_BASE_URL}"}, {"name": "WEB_APP_SERVICE_HOST_NAMES", "value": "{HOST_NAMES}"}, {"name": "WEB_APP_SERVICE_ENABLE_HTTPS", "value": "{HTTPS_ENABLED}"}, @@ -42,8 +42,9 @@ {"name": "WEB_APP_SERVICE_STATICFILES_S3_PREFIX", "value": "{STATIC_FILES_S3_PREFIX}"}, {"name": "WEB_APP_SERVICE_STATICFILES_CLOUDFRONT_ENABLED", "value": "{STATIC_FILES_CLOUDFRONT_ENABLED}"}, {"name": "WEB_APP_SERVICE_STATICFILES_IMG_RESIZE_ENABLED", "value": "{STATIC_FILES_IMG_RESIZE_ENABLED}"}, - {"name": "WEB_APP_SERVICE_EMAIL_SENDER", "value": "{EMAIL_SENDER}"}, - {"name": "WEB_APP_SERVICE_WEB_API_BASE_URL", "value": "{WEB_API_BASE_URL}"}, + {"name": "WEB_APP_PROJECT_PROJECT_NAME", "value": "{APP_PROJECT}"}, + {"name": "WEB_APP_PROJECT_EMAIL_SENDER", "value": "{EMAIL_SENDER}"}, + {"name": "WEB_APP_PROJECT_WEB_API_BASE_URL", "value": "{WEB_API_BASE_URL}"}, {"name": "WEB_APP_REDIS_HOST", "value": "{CACHE_HOST}"}, {"name": "WEB_APP_DB_HOST", "value": "{DB_HOST}"}, {"name": "WEB_APP_DB_USER", "value": "{DB_USER}"}, diff --git a/cmd/web-app/handlers/account.go b/cmd/web-app/handlers/account.go index c2e3a0c..3ba3e0f 100644 --- a/cmd/web-app/handlers/account.go +++ b/cmd/web-app/handlers/account.go @@ -12,6 +12,7 @@ import ( "geeks-accelerator/oss/saas-starter-kit/internal/platform/web" "geeks-accelerator/oss/saas-starter-kit/internal/platform/web/webcontext" "geeks-accelerator/oss/saas-starter-kit/internal/platform/web/weberror" + "geeks-accelerator/oss/saas-starter-kit/internal/user_auth" "github.com/gorilla/schema" "github.com/jmoiron/sqlx" "github.com/pkg/errors" @@ -19,9 +20,12 @@ import ( // Account represents the Account API method handler set. type Account struct { - MasterDB *sqlx.DB - Renderer web.Renderer - Authenticator *auth.Authenticator + AccountRepo *account.Repository + AccountPrefRepo *account_preference.Repository + AuthRepo *user_auth.Repository + Authenticator *auth.Authenticator + MasterDB *sqlx.DB + Renderer web.Renderer } // View handles displaying the current account profile. @@ -35,7 +39,7 @@ func (h *Account) View(ctx context.Context, w http.ResponseWriter, r *http.Reque return err } - acc, err := account.ReadByID(ctx, claims, h.MasterDB, claims.Audience) + acc, err := h.AccountRepo.ReadByID(ctx, claims, claims.Audience) if err != nil { return err } @@ -77,7 +81,7 @@ func (h *Account) Update(ctx context.Context, w http.ResponseWriter, r *http.Req return false, err } - prefs, err := account_preference.FindByAccountID(ctx, claims, h.MasterDB, account_preference.AccountPreferenceFindByAccountIDRequest{ + prefs, err := h.AccountPrefRepo.FindByAccountID(ctx, claims, account_preference.AccountPreferenceFindByAccountIDRequest{ AccountID: claims.Audience, }) if err != nil { @@ -115,7 +119,7 @@ func (h *Account) Update(ctx context.Context, w http.ResponseWriter, r *http.Req } req.ID = claims.Audience - err = account.Update(ctx, claims, h.MasterDB, req.AccountUpdateRequest, ctxValues.Now) + err = h.AccountRepo.Update(ctx, claims, req.AccountUpdateRequest, ctxValues.Now) if err != nil { switch errors.Cause(err) { default: @@ -135,7 +139,7 @@ func (h *Account) Update(ctx context.Context, w http.ResponseWriter, r *http.Req } if preferenceDatetimeFormat != req.PreferenceDatetimeFormat { - err = account_preference.Set(ctx, claims, h.MasterDB, account_preference.AccountPreferenceSetRequest{ + err = h.AccountPrefRepo.Set(ctx, claims, account_preference.AccountPreferenceSetRequest{ AccountID: claims.Audience, Name: account_preference.AccountPreference_Datetime_Format, Value: req.PreferenceDatetimeFormat, @@ -156,7 +160,7 @@ func (h *Account) Update(ctx context.Context, w http.ResponseWriter, r *http.Req } if preferenceDateFormat != req.PreferenceDateFormat { - err = account_preference.Set(ctx, claims, h.MasterDB, account_preference.AccountPreferenceSetRequest{ + err = h.AccountPrefRepo.Set(ctx, claims, account_preference.AccountPreferenceSetRequest{ AccountID: claims.Audience, Name: account_preference.AccountPreference_Date_Format, Value: req.PreferenceDateFormat, @@ -177,7 +181,7 @@ func (h *Account) Update(ctx context.Context, w http.ResponseWriter, r *http.Req } if preferenceTimeFormat != req.PreferenceTimeFormat { - err = account_preference.Set(ctx, claims, h.MasterDB, account_preference.AccountPreferenceSetRequest{ + err = h.AccountPrefRepo.Set(ctx, claims, account_preference.AccountPreferenceSetRequest{ AccountID: claims.Audience, Name: account_preference.AccountPreference_Time_Format, Value: req.PreferenceTimeFormat, @@ -213,7 +217,7 @@ func (h *Account) Update(ctx context.Context, w http.ResponseWriter, r *http.Req return true, web.Redirect(ctx, w, r, "/account", http.StatusFound) } - acc, err := account.ReadByID(ctx, claims, h.MasterDB, claims.Audience) + acc, err := h.AccountRepo.ReadByID(ctx, claims, claims.Audience) if err != nil { return false, err } diff --git a/cmd/web-app/handlers/projects.go b/cmd/web-app/handlers/projects.go index d3bda68..8a375fe 100644 --- a/cmd/web-app/handlers/projects.go +++ b/cmd/web-app/handlers/projects.go @@ -13,16 +13,15 @@ import ( "geeks-accelerator/oss/saas-starter-kit/internal/platform/web/weberror" "geeks-accelerator/oss/saas-starter-kit/internal/project" "github.com/gorilla/schema" - "github.com/jmoiron/sqlx" "github.com/pkg/errors" "gopkg.in/DataDog/dd-trace-go.v1/contrib/go-redis/redis" ) // Projects represents the Projects API method handler set. type Projects struct { - MasterDB *sqlx.DB - Redis *redis.Client - Renderer web.Renderer + ProjectRepo *project.Repository + Redis *redis.Client + Renderer web.Renderer } func urlProjectsIndex() string { @@ -110,7 +109,7 @@ func (h *Projects) Index(ctx context.Context, w http.ResponseWriter, r *http.Req } loadFunc := func(ctx context.Context, sorting string, fields []datatable.DisplayField) (resp [][]datatable.ColumnValue, err error) { - res, err := project.Find(ctx, claims, h.MasterDB, project.ProjectFindRequest{ + res, err := h.ProjectRepo.Find(ctx, claims, project.ProjectFindRequest{ Where: "account_id = ?", Args: []interface{}{claims.Audience}, Order: strings.Split(sorting, ","), @@ -186,7 +185,7 @@ func (h *Projects) Create(ctx context.Context, w http.ResponseWriter, r *http.Re } req.AccountID = claims.Audience - usr, err := project.Create(ctx, claims, h.MasterDB, *req, ctxValues.Now) + usr, err := h.ProjectRepo.Create(ctx, claims, *req, ctxValues.Now) if err != nil { switch errors.Cause(err) { default: @@ -251,7 +250,7 @@ func (h *Projects) View(ctx context.Context, w http.ResponseWriter, r *http.Requ switch r.PostForm.Get("action") { case "archive": - err = project.Archive(ctx, claims, h.MasterDB, project.ProjectArchiveRequest{ + err = h.ProjectRepo.Archive(ctx, claims, project.ProjectArchiveRequest{ ID: projectID, }, ctxValues.Now) if err != nil { @@ -276,7 +275,7 @@ func (h *Projects) View(ctx context.Context, w http.ResponseWriter, r *http.Requ return nil } - prj, err := project.ReadByID(ctx, claims, h.MasterDB, projectID) + prj, err := h.ProjectRepo.ReadByID(ctx, claims, projectID) if err != nil { return err } @@ -320,7 +319,7 @@ func (h *Projects) Update(ctx context.Context, w http.ResponseWriter, r *http.Re } req.ID = projectID - err = project.Update(ctx, claims, h.MasterDB, *req, ctxValues.Now) + err = h.ProjectRepo.Update(ctx, claims, *req, ctxValues.Now) if err != nil { switch errors.Cause(err) { default: @@ -351,7 +350,7 @@ func (h *Projects) Update(ctx context.Context, w http.ResponseWriter, r *http.Re return nil } - prj, err := project.ReadByID(ctx, claims, h.MasterDB, projectID) + prj, err := h.ProjectRepo.ReadByID(ctx, claims, projectID) if err != nil { return err } diff --git a/cmd/web-app/handlers/root.go b/cmd/web-app/handlers/root.go index e4ea3e6..dad0360 100644 --- a/cmd/web-app/handlers/root.go +++ b/cmd/web-app/handlers/root.go @@ -8,9 +8,8 @@ import ( "geeks-accelerator/oss/saas-starter-kit/internal/platform/auth" "geeks-accelerator/oss/saas-starter-kit/internal/platform/web" "geeks-accelerator/oss/saas-starter-kit/internal/platform/web/webcontext" - project_routes "geeks-accelerator/oss/saas-starter-kit/internal/project-routes" + "geeks-accelerator/oss/saas-starter-kit/internal/project_route" "github.com/ikeikeikeike/go-sitemap-generator/v2/stm" - "github.com/jmoiron/sqlx" "github.com/pkg/errors" "github.com/sethgrid/pester" "io/ioutil" @@ -19,10 +18,9 @@ import ( // Root represents the Root API method handler set. type Root struct { - MasterDB *sqlx.DB - Renderer web.Renderer - Sitemap *stm.Sitemap - ProjectRoutes project_routes.ProjectRoutes + Renderer web.Renderer + Sitemap *stm.Sitemap + ProjectRoute project_route.ProjectRoute } // Index determines if the user has authentication and loads the associated page. @@ -57,7 +55,7 @@ func (h *Root) SitePage(ctx context.Context, w http.ResponseWriter, r *http.Requ tmpName = "site-api.gohtml" // http://127.0.0.1:3001/docs/doc.json - swaggerJsonUrl := h.ProjectRoutes.ApiDocsJson() + swaggerJsonUrl := h.ProjectRoute.ApiDocsJson() // Load the json file from the API service. res, err := pester.Get(swaggerJsonUrl) @@ -93,8 +91,8 @@ func (h *Root) SitePage(ctx context.Context, w http.ResponseWriter, r *http.Requ return errors.WithStack(err) } - data["urlApiBaseUri"] = h.ProjectRoutes.WebApiUrl(doc.BasePath) - data["urlApiDocs"] = h.ProjectRoutes.ApiDocs() + data["urlApiBaseUri"] = h.ProjectRoute.WebApiUrl(doc.BasePath) + data["urlApiDocs"] = h.ProjectRoute.ApiDocs() case "/pricing": tmpName = "site-pricing.gohtml" @@ -123,7 +121,7 @@ func (h *Root) RobotTxt(ctx context.Context, w http.ResponseWriter, r *http.Requ return web.RespondText(ctx, w, txt, http.StatusOK) } - sitemapUrl := h.ProjectRoutes.WebAppUrl("/sitemap.xml") + sitemapUrl := h.ProjectRoute.WebAppUrl("/sitemap.xml") txt := fmt.Sprintf("User-agent: *\nDisallow: /ping\nDisallow: /status\nDisallow: /debug/\nSitemap: %s", sitemapUrl) return web.RespondText(ctx, w, txt, http.StatusOK) diff --git a/cmd/web-app/handlers/routes.go b/cmd/web-app/handlers/routes.go index e30a2c8..01bd4b2 100644 --- a/cmd/web-app/handlers/routes.go +++ b/cmd/web-app/handlers/routes.go @@ -9,13 +9,20 @@ import ( "path/filepath" "time" + "geeks-accelerator/oss/saas-starter-kit/internal/account" + "geeks-accelerator/oss/saas-starter-kit/internal/account/account_preference" "geeks-accelerator/oss/saas-starter-kit/internal/mid" "geeks-accelerator/oss/saas-starter-kit/internal/platform/auth" - "geeks-accelerator/oss/saas-starter-kit/internal/platform/notify" "geeks-accelerator/oss/saas-starter-kit/internal/platform/web" "geeks-accelerator/oss/saas-starter-kit/internal/platform/web/webcontext" "geeks-accelerator/oss/saas-starter-kit/internal/platform/web/weberror" - project_routes "geeks-accelerator/oss/saas-starter-kit/internal/project-routes" + "geeks-accelerator/oss/saas-starter-kit/internal/project" + "geeks-accelerator/oss/saas-starter-kit/internal/project_route" + "geeks-accelerator/oss/saas-starter-kit/internal/signup" + "geeks-accelerator/oss/saas-starter-kit/internal/user" + "geeks-accelerator/oss/saas-starter-kit/internal/user_account" + "geeks-accelerator/oss/saas-starter-kit/internal/user_account/invite" + "geeks-accelerator/oss/saas-starter-kit/internal/user_auth" "github.com/ikeikeikeike/go-sitemap-generator/v2/stm" "github.com/jmoiron/sqlx" "gopkg.in/DataDog/dd-trace-go.v1/contrib/go-redis/redis" @@ -27,30 +34,58 @@ const ( TmplContentErrorGeneric = "error-generic.gohtml" ) +type AppContext struct { + Log *log.Logger + Env webcontext.Env + MasterDB *sqlx.DB + Redis *redis.Client + UserRepo *user.Repository + UserAccountRepo *user_account.Repository + AccountRepo *account.Repository + AccountPrefRepo *account_preference.Repository + AuthRepo *user_auth.Repository + SignupRepo *signup.Repository + InviteRepo *invite.Repository + ProjectRepo *project.Repository + Authenticator *auth.Authenticator + StaticDir string + TemplateDir string + Renderer web.Renderer + ProjectRoute project_route.ProjectRoute + PreAppMiddleware []web.Middleware + PostAppMiddleware []web.Middleware +} + // API returns a handler for a set of routes. -func APP(shutdown chan os.Signal, log *log.Logger, env webcontext.Env, staticDir, templateDir string, masterDB *sqlx.DB, redis *redis.Client, authenticator *auth.Authenticator, projectRoutes project_routes.ProjectRoutes, secretKey string, notifyEmail notify.Email, renderer web.Renderer, globalMids ...web.Middleware) http.Handler { +func APP(shutdown chan os.Signal, appCtx *AppContext) http.Handler { - // Define base middlewares applied to all requests. - middlewares := []web.Middleware{ - mid.Trace(), mid.Logger(log), mid.Errors(log, renderer), mid.Metrics(), mid.Panics(), - } + // Include the pre middlewares first. + middlewares := appCtx.PreAppMiddleware - // Append any global middlewares if they were included. - if len(globalMids) > 0 { - middlewares = append(middlewares, globalMids...) + // Define app middlewares applied to all requests. + middlewares = append(middlewares, + mid.Trace(), + mid.Logger(appCtx.Log), + mid.Errors(appCtx.Log, appCtx.Renderer), + mid.Metrics(), + mid.Panics()) + + // Append any global middlewares that should be included after the app middlewares. + if len(appCtx.PostAppMiddleware) > 0 { + middlewares = append(middlewares, appCtx.PostAppMiddleware...) } // Construct the web.App which holds all routes as well as common Middleware. - app := web.NewApp(shutdown, log, env, middlewares...) + app := web.NewApp(shutdown, appCtx.Log, appCtx.Env, middlewares...) // Build a sitemap. sm := stm.NewSitemap(1) sm.SetVerbose(false) - sm.SetDefaultHost(projectRoutes.WebAppUrl("")) + sm.SetDefaultHost(appCtx.ProjectRoute.WebAppUrl("")) sm.Create() smLocAddModified := func(loc stm.URL, filename string) { - contentPath := filepath.Join(templateDir, "content", filename) + contentPath := filepath.Join(appCtx.TemplateDir, "content", filename) file, err := os.Stat(contentPath) if err != nil { @@ -64,48 +99,48 @@ func APP(shutdown chan os.Signal, log *log.Logger, env webcontext.Env, staticDir // Register project management pages. p := Projects{ - MasterDB: masterDB, - Redis: redis, - Renderer: renderer, + ProjectRepo: appCtx.ProjectRepo, + Redis: appCtx.Redis, + Renderer: appCtx.Renderer, } - app.Handle("POST", "/projects/:project_id/update", p.Update, mid.AuthenticateSessionRequired(authenticator), mid.HasRole(auth.RoleAdmin)) - app.Handle("GET", "/projects/:project_id/update", p.Update, mid.AuthenticateSessionRequired(authenticator), mid.HasRole(auth.RoleAdmin)) - app.Handle("POST", "/projects/:project_id", p.View, mid.AuthenticateSessionRequired(authenticator), mid.HasRole(auth.RoleAdmin)) - app.Handle("GET", "/projects/:project_id", p.View, mid.AuthenticateSessionRequired(authenticator), mid.HasAuth()) - app.Handle("POST", "/projects/create", p.Create, mid.AuthenticateSessionRequired(authenticator), mid.HasRole(auth.RoleAdmin)) - app.Handle("GET", "/projects/create", p.Create, mid.AuthenticateSessionRequired(authenticator), mid.HasRole(auth.RoleAdmin)) - app.Handle("GET", "/projects", p.Index, mid.AuthenticateSessionRequired(authenticator), mid.HasAuth()) + app.Handle("POST", "/projects/:project_id/update", p.Update, mid.AuthenticateSessionRequired(appCtx.Authenticator), mid.HasRole(auth.RoleAdmin)) + app.Handle("GET", "/projects/:project_id/update", p.Update, mid.AuthenticateSessionRequired(appCtx.Authenticator), mid.HasRole(auth.RoleAdmin)) + app.Handle("POST", "/projects/:project_id", p.View, mid.AuthenticateSessionRequired(appCtx.Authenticator), mid.HasRole(auth.RoleAdmin)) + app.Handle("GET", "/projects/:project_id", p.View, mid.AuthenticateSessionRequired(appCtx.Authenticator), mid.HasAuth()) + app.Handle("POST", "/projects/create", p.Create, mid.AuthenticateSessionRequired(appCtx.Authenticator), mid.HasRole(auth.RoleAdmin)) + app.Handle("GET", "/projects/create", p.Create, mid.AuthenticateSessionRequired(appCtx.Authenticator), mid.HasRole(auth.RoleAdmin)) + app.Handle("GET", "/projects", p.Index, mid.AuthenticateSessionRequired(appCtx.Authenticator), mid.HasAuth()) // Register user management pages. us := Users{ - MasterDB: masterDB, - Redis: redis, - Renderer: renderer, - Authenticator: authenticator, - ProjectRoutes: projectRoutes, - NotifyEmail: notifyEmail, - SecretKey: secretKey, + UserRepo: appCtx.UserRepo, + UserAccountRepo: appCtx.UserAccountRepo, + AuthRepo: appCtx.AuthRepo, + InviteRepo: appCtx.InviteRepo, + MasterDB: appCtx.MasterDB, + Redis: appCtx.Redis, + Renderer: appCtx.Renderer, } - app.Handle("POST", "/users/:user_id/update", us.Update, mid.AuthenticateSessionRequired(authenticator), mid.HasRole(auth.RoleAdmin)) - app.Handle("GET", "/users/:user_id/update", us.Update, mid.AuthenticateSessionRequired(authenticator), mid.HasRole(auth.RoleAdmin)) - app.Handle("POST", "/users/:user_id", us.View, mid.AuthenticateSessionRequired(authenticator), mid.HasRole(auth.RoleAdmin)) - app.Handle("GET", "/users/:user_id", us.View, mid.AuthenticateSessionRequired(authenticator), mid.HasAuth()) + app.Handle("POST", "/users/:user_id/update", us.Update, mid.AuthenticateSessionRequired(appCtx.Authenticator), mid.HasRole(auth.RoleAdmin)) + app.Handle("GET", "/users/:user_id/update", us.Update, mid.AuthenticateSessionRequired(appCtx.Authenticator), mid.HasRole(auth.RoleAdmin)) + app.Handle("POST", "/users/:user_id", us.View, mid.AuthenticateSessionRequired(appCtx.Authenticator), mid.HasRole(auth.RoleAdmin)) + app.Handle("GET", "/users/:user_id", us.View, mid.AuthenticateSessionRequired(appCtx.Authenticator), mid.HasAuth()) app.Handle("POST", "/users/invite/:hash", us.InviteAccept) app.Handle("GET", "/users/invite/:hash", us.InviteAccept) - app.Handle("POST", "/users/invite", us.Invite, mid.AuthenticateSessionRequired(authenticator), mid.HasRole(auth.RoleAdmin)) - app.Handle("GET", "/users/invite", us.Invite, mid.AuthenticateSessionRequired(authenticator), mid.HasRole(auth.RoleAdmin)) - app.Handle("POST", "/users/create", us.Create, mid.AuthenticateSessionRequired(authenticator), mid.HasRole(auth.RoleAdmin)) - app.Handle("GET", "/users/create", us.Create, mid.AuthenticateSessionRequired(authenticator), mid.HasRole(auth.RoleAdmin)) - app.Handle("GET", "/users", us.Index, mid.AuthenticateSessionRequired(authenticator), mid.HasAuth()) + app.Handle("POST", "/users/invite", us.Invite, mid.AuthenticateSessionRequired(appCtx.Authenticator), mid.HasRole(auth.RoleAdmin)) + app.Handle("GET", "/users/invite", us.Invite, mid.AuthenticateSessionRequired(appCtx.Authenticator), mid.HasRole(auth.RoleAdmin)) + app.Handle("POST", "/users/create", us.Create, mid.AuthenticateSessionRequired(appCtx.Authenticator), mid.HasRole(auth.RoleAdmin)) + app.Handle("GET", "/users/create", us.Create, mid.AuthenticateSessionRequired(appCtx.Authenticator), mid.HasRole(auth.RoleAdmin)) + app.Handle("GET", "/users", us.Index, mid.AuthenticateSessionRequired(appCtx.Authenticator), mid.HasAuth()) // Register user management and authentication endpoints. u := User{ - MasterDB: masterDB, - Renderer: renderer, - Authenticator: authenticator, - ProjectRoutes: projectRoutes, - NotifyEmail: notifyEmail, - SecretKey: secretKey, + UserRepo: appCtx.UserRepo, + UserAccountRepo: appCtx.UserAccountRepo, + AccountRepo: appCtx.AccountRepo, + AuthRepo: appCtx.AuthRepo, + MasterDB: appCtx.MasterDB, + Renderer: appCtx.Renderer, } app.Handle("POST", "/user/login", u.Login) app.Handle("GET", "/user/login", u.Login) @@ -114,35 +149,39 @@ func APP(shutdown chan os.Signal, log *log.Logger, env webcontext.Env, staticDir app.Handle("GET", "/user/reset-password/:hash", u.ResetConfirm) app.Handle("POST", "/user/reset-password", u.ResetPassword) app.Handle("GET", "/user/reset-password", u.ResetPassword) - app.Handle("POST", "/user/update", u.Update, mid.AuthenticateSessionRequired(authenticator), mid.HasAuth()) - app.Handle("GET", "/user/update", u.Update, mid.AuthenticateSessionRequired(authenticator), mid.HasAuth()) - app.Handle("GET", "/user/account", u.Account, mid.AuthenticateSessionRequired(authenticator), mid.HasAuth()) - app.Handle("GET", "/user/virtual-login/:user_id", u.VirtualLogin, mid.AuthenticateSessionRequired(authenticator), mid.HasRole(auth.RoleAdmin)) - app.Handle("POST", "/user/virtual-login", u.VirtualLogin, mid.AuthenticateSessionRequired(authenticator), mid.HasRole(auth.RoleAdmin)) - app.Handle("GET", "/user/virtual-login", u.VirtualLogin, mid.AuthenticateSessionRequired(authenticator), mid.HasRole(auth.RoleAdmin)) - app.Handle("GET", "/user/virtual-logout", u.VirtualLogout, mid.AuthenticateSessionRequired(authenticator), mid.HasAuth()) - app.Handle("GET", "/user/switch-account/:account_id", u.SwitchAccount, mid.AuthenticateSessionRequired(authenticator), mid.HasAuth()) - app.Handle("POST", "/user/switch-account", u.SwitchAccount, mid.AuthenticateSessionRequired(authenticator), mid.HasAuth()) - app.Handle("GET", "/user/switch-account", u.SwitchAccount, mid.AuthenticateSessionRequired(authenticator), mid.HasAuth()) - app.Handle("POST", "/user", u.View, mid.AuthenticateSessionRequired(authenticator), mid.HasAuth()) - app.Handle("GET", "/user", u.View, mid.AuthenticateSessionRequired(authenticator), mid.HasAuth()) + app.Handle("POST", "/user/update", u.Update, mid.AuthenticateSessionRequired(appCtx.Authenticator), mid.HasAuth()) + app.Handle("GET", "/user/update", u.Update, mid.AuthenticateSessionRequired(appCtx.Authenticator), mid.HasAuth()) + app.Handle("GET", "/user/account", u.Account, mid.AuthenticateSessionRequired(appCtx.Authenticator), mid.HasAuth()) + app.Handle("GET", "/user/virtual-login/:user_id", u.VirtualLogin, mid.AuthenticateSessionRequired(appCtx.Authenticator), mid.HasRole(auth.RoleAdmin)) + app.Handle("POST", "/user/virtual-login", u.VirtualLogin, mid.AuthenticateSessionRequired(appCtx.Authenticator), mid.HasRole(auth.RoleAdmin)) + app.Handle("GET", "/user/virtual-login", u.VirtualLogin, mid.AuthenticateSessionRequired(appCtx.Authenticator), mid.HasRole(auth.RoleAdmin)) + app.Handle("GET", "/user/virtual-logout", u.VirtualLogout, mid.AuthenticateSessionRequired(appCtx.Authenticator), mid.HasAuth()) + app.Handle("GET", "/user/switch-account/:account_id", u.SwitchAccount, mid.AuthenticateSessionRequired(appCtx.Authenticator), mid.HasAuth()) + app.Handle("POST", "/user/switch-account", u.SwitchAccount, mid.AuthenticateSessionRequired(appCtx.Authenticator), mid.HasAuth()) + app.Handle("GET", "/user/switch-account", u.SwitchAccount, mid.AuthenticateSessionRequired(appCtx.Authenticator), mid.HasAuth()) + app.Handle("POST", "/user", u.View, mid.AuthenticateSessionRequired(appCtx.Authenticator), mid.HasAuth()) + app.Handle("GET", "/user", u.View, mid.AuthenticateSessionRequired(appCtx.Authenticator), mid.HasAuth()) // Register account management endpoints. acc := Account{ - MasterDB: masterDB, - Renderer: renderer, - Authenticator: authenticator, + AccountRepo: appCtx.AccountRepo, + AccountPrefRepo: appCtx.AccountPrefRepo, + AuthRepo: appCtx.AuthRepo, + Authenticator: appCtx.Authenticator, + MasterDB: appCtx.MasterDB, + Renderer: appCtx.Renderer, } - app.Handle("POST", "/account/update", acc.Update, mid.AuthenticateSessionRequired(authenticator), mid.HasRole(auth.RoleAdmin)) - app.Handle("GET", "/account/update", acc.Update, mid.AuthenticateSessionRequired(authenticator), mid.HasRole(auth.RoleAdmin)) - app.Handle("POST", "/account", acc.View, mid.AuthenticateSessionRequired(authenticator), mid.HasRole(auth.RoleAdmin)) - app.Handle("GET", "/account", acc.View, mid.AuthenticateSessionRequired(authenticator), mid.HasRole(auth.RoleAdmin)) + app.Handle("POST", "/account/update", acc.Update, mid.AuthenticateSessionRequired(appCtx.Authenticator), mid.HasRole(auth.RoleAdmin)) + app.Handle("GET", "/account/update", acc.Update, mid.AuthenticateSessionRequired(appCtx.Authenticator), mid.HasRole(auth.RoleAdmin)) + app.Handle("POST", "/account", acc.View, mid.AuthenticateSessionRequired(appCtx.Authenticator), mid.HasRole(auth.RoleAdmin)) + app.Handle("GET", "/account", acc.View, mid.AuthenticateSessionRequired(appCtx.Authenticator), mid.HasRole(auth.RoleAdmin)) - // Register user management and authentication endpoints. + // Register signup endpoints. s := Signup{ - MasterDB: masterDB, - Renderer: renderer, - Authenticator: authenticator, + SignupRepo: appCtx.SignupRepo, + AuthRepo: appCtx.AuthRepo, + MasterDB: appCtx.MasterDB, + Renderer: appCtx.Renderer, } // This route is not authenticated app.Handle("POST", "/signup", s.Step1) @@ -150,16 +189,16 @@ func APP(shutdown chan os.Signal, log *log.Logger, env webcontext.Env, staticDir // Register example endpoints. ex := Examples{ - Renderer: renderer, + Renderer: appCtx.Renderer, } - app.Handle("POST", "/examples/flash-messages", ex.FlashMessages, mid.AuthenticateSessionOptional(authenticator)) - app.Handle("GET", "/examples/flash-messages", ex.FlashMessages, mid.AuthenticateSessionOptional(authenticator)) - app.Handle("GET", "/examples/images", ex.Images, mid.AuthenticateSessionOptional(authenticator)) + app.Handle("POST", "/examples/flash-messages", ex.FlashMessages, mid.AuthenticateSessionOptional(appCtx.Authenticator)) + app.Handle("GET", "/examples/flash-messages", ex.FlashMessages, mid.AuthenticateSessionOptional(appCtx.Authenticator)) + app.Handle("GET", "/examples/images", ex.Images, mid.AuthenticateSessionOptional(appCtx.Authenticator)) // Register geo g := Geo{ - MasterDB: masterDB, - Redis: redis, + MasterDB: appCtx.MasterDB, + Redis: appCtx.Redis, } app.Handle("GET", "/geo/regions/autocomplete", g.RegionsAutocomplete) app.Handle("GET", "/geo/postal_codes/autocomplete", g.PostalCodesAutocomplete) @@ -168,17 +207,16 @@ func APP(shutdown chan os.Signal, log *log.Logger, env webcontext.Env, staticDir // Register root r := Root{ - MasterDB: masterDB, - Renderer: renderer, - ProjectRoutes: projectRoutes, - Sitemap: sm, + Renderer: appCtx.Renderer, + ProjectRoute: appCtx.ProjectRoute, + Sitemap: sm, } app.Handle("GET", "/api", r.SitePage) app.Handle("GET", "/pricing", r.SitePage) app.Handle("GET", "/support", r.SitePage) app.Handle("GET", "/legal/privacy", r.SitePage) app.Handle("GET", "/legal/terms", r.SitePage) - app.Handle("GET", "/", r.Index, mid.AuthenticateSessionOptional(authenticator)) + app.Handle("GET", "/", r.Index, mid.AuthenticateSessionOptional(appCtx.Authenticator)) app.Handle("GET", "/index.html", r.IndexHtml) app.Handle("GET", "/robots.txt", r.RobotTxt) app.Handle("GET", "/sitemap.xml", r.SitemapXml) @@ -193,14 +231,14 @@ func APP(shutdown chan os.Signal, log *log.Logger, env webcontext.Env, staticDir // Register health check endpoint. This route is not authenticated. check := Check{ - MasterDB: masterDB, - Redis: redis, + MasterDB: appCtx.MasterDB, + Redis: appCtx.Redis, } app.Handle("GET", "/v1/health", check.Health) // Handle static files/pages. Render a custom 404 page when file not found. static := func(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { - err := web.StaticHandler(ctx, w, r, params, staticDir, "") + err := web.StaticHandler(ctx, w, r, params, appCtx.StaticDir, "") if err != nil { if os.IsNotExist(err) { rmsg := fmt.Sprintf("%s %s not found", r.Method, r.RequestURI) @@ -209,7 +247,7 @@ func APP(shutdown chan os.Signal, log *log.Logger, env webcontext.Env, staticDir err = weberror.NewError(ctx, err, http.StatusInternalServerError) } - return web.RenderError(ctx, w, r, err, renderer, TmplLayoutBase, TmplContentErrorGeneric, web.MIMETextHTMLCharsetUTF8) + return web.RenderError(ctx, w, r, err, appCtx.Renderer, TmplLayoutBase, TmplContentErrorGeneric, web.MIMETextHTMLCharsetUTF8) } return nil diff --git a/cmd/web-app/handlers/signup.go b/cmd/web-app/handlers/signup.go index f24a175..cb23c48 100644 --- a/cmd/web-app/handlers/signup.go +++ b/cmd/web-app/handlers/signup.go @@ -20,9 +20,10 @@ import ( // Signup represents the Signup API method handler set. type Signup struct { - MasterDB *sqlx.DB - Renderer web.Renderer - Authenticator *auth.Authenticator + SignupRepo *signup.Repository + AuthRepo *user_auth.Repository + MasterDB *sqlx.DB + Renderer web.Renderer } // Step1 handles collecting the first detailed needed to create a new account. @@ -52,7 +53,7 @@ func (h *Signup) Step1(ctx context.Context, w http.ResponseWriter, r *http.Reque } // Execute the account / user signup. - _, err = signup.Signup(ctx, claims, h.MasterDB, *req, ctxValues.Now) + _, err = h.SignupRepo.Signup(ctx, claims, *req, ctxValues.Now) if err != nil { switch errors.Cause(err) { case account.ErrForbidden: @@ -68,7 +69,7 @@ func (h *Signup) Step1(ctx context.Context, w http.ResponseWriter, r *http.Reque } // Authenticated the new user. - token, err := user_auth.Authenticate(ctx, h.MasterDB, h.Authenticator, user_auth.AuthenticateRequest{ + token, err := h.AuthRepo.Authenticate(ctx, user_auth.AuthenticateRequest{ Email: req.User.Email, Password: req.User.Password, }, time.Hour, ctxValues.Now) @@ -77,7 +78,7 @@ func (h *Signup) Step1(ctx context.Context, w http.ResponseWriter, r *http.Reque } // Add the token to the users session. - err = handleSessionToken(ctx, h.MasterDB, w, r, token) + err = handleSessionToken(ctx, w, r, token) if err != nil { return false, err } diff --git a/cmd/web-app/handlers/user.go b/cmd/web-app/handlers/user.go index c91a3d3..b47b4c9 100644 --- a/cmd/web-app/handlers/user.go +++ b/cmd/web-app/handlers/user.go @@ -11,11 +11,9 @@ import ( "geeks-accelerator/oss/saas-starter-kit/internal/account" "geeks-accelerator/oss/saas-starter-kit/internal/geonames" "geeks-accelerator/oss/saas-starter-kit/internal/platform/auth" - "geeks-accelerator/oss/saas-starter-kit/internal/platform/notify" "geeks-accelerator/oss/saas-starter-kit/internal/platform/web" "geeks-accelerator/oss/saas-starter-kit/internal/platform/web/webcontext" "geeks-accelerator/oss/saas-starter-kit/internal/platform/web/weberror" - project_routes "geeks-accelerator/oss/saas-starter-kit/internal/project-routes" "geeks-accelerator/oss/saas-starter-kit/internal/user" "geeks-accelerator/oss/saas-starter-kit/internal/user_account" "geeks-accelerator/oss/saas-starter-kit/internal/user_auth" @@ -27,12 +25,12 @@ import ( // User represents the User API method handler set. type User struct { - MasterDB *sqlx.DB - Renderer web.Renderer - Authenticator *auth.Authenticator - ProjectRoutes project_routes.ProjectRoutes - NotifyEmail notify.Email - SecretKey string + UserRepo *user.Repository + AuthRepo *user_auth.Repository + UserAccountRepo *user_account.Repository + AccountRepo *account.Repository + MasterDB *sqlx.DB + Renderer web.Renderer } func urlUserVirtualLogin(userID string) string { @@ -75,7 +73,7 @@ func (h *User) Login(ctx context.Context, w http.ResponseWriter, r *http.Request } // Authenticated the user. - token, err := user_auth.Authenticate(ctx, h.MasterDB, h.Authenticator, user_auth.AuthenticateRequest{ + token, err := h.AuthRepo.Authenticate(ctx, user_auth.AuthenticateRequest{ Email: req.Email, Password: req.Password, }, sessionTTL, ctxValues.Now) @@ -97,7 +95,7 @@ func (h *User) Login(ctx context.Context, w http.ResponseWriter, r *http.Request } // Add the token to the users session. - err = handleSessionToken(ctx, h.MasterDB, w, r, token) + err = handleSessionToken(ctx, w, r, token) if err != nil { return false, err } @@ -173,7 +171,7 @@ func (h *User) ResetPassword(ctx context.Context, w http.ResponseWriter, r *http return err } - _, err = user.ResetPassword(ctx, h.MasterDB, h.ProjectRoutes.UserResetPassword, h.NotifyEmail, *req, h.SecretKey, ctxValues.Now) + _, err = h.UserRepo.ResetPassword(ctx, *req, ctxValues.Now) if err != nil { switch errors.Cause(err) { default: @@ -238,7 +236,7 @@ func (h *User) ResetConfirm(ctx context.Context, w http.ResponseWriter, r *http. // Append the query param value to the request. req.ResetHash = resetHash - u, err := user.ResetConfirm(ctx, h.MasterDB, *req, h.SecretKey, ctxValues.Now) + u, err := h.UserRepo.ResetConfirm(ctx, *req, ctxValues.Now) if err != nil { switch errors.Cause(err) { case user.ErrResetExpired: @@ -257,7 +255,7 @@ func (h *User) ResetConfirm(ctx context.Context, w http.ResponseWriter, r *http. } // Authenticated the user. Probably should use the default session TTL from UserLogin. - token, err := user_auth.Authenticate(ctx, h.MasterDB, h.Authenticator, user_auth.AuthenticateRequest{ + token, err := h.AuthRepo.Authenticate(ctx, user_auth.AuthenticateRequest{ Email: u.Email, Password: req.Password, }, time.Hour, ctxValues.Now) @@ -271,7 +269,7 @@ func (h *User) ResetConfirm(ctx context.Context, w http.ResponseWriter, r *http. } // Add the token to the users session. - err = handleSessionToken(ctx, h.MasterDB, w, r, token) + err = handleSessionToken(ctx, w, r, token) if err != nil { return false, err } @@ -280,7 +278,7 @@ func (h *User) ResetConfirm(ctx context.Context, w http.ResponseWriter, r *http. return true, web.Redirect(ctx, w, r, "/", http.StatusFound) } - _, err = user.ParseResetHash(ctx, h.SecretKey, resetHash, ctxValues.Now) + _, err = h.UserRepo.ParseResetHash(ctx, resetHash, ctxValues.Now) if err != nil { switch errors.Cause(err) { case user.ErrResetExpired: @@ -328,14 +326,14 @@ func (h *User) View(ctx context.Context, w http.ResponseWriter, r *http.Request, return err } - usr, err := user.ReadByID(ctx, claims, h.MasterDB, claims.Subject) + usr, err := h.UserRepo.ReadByID(ctx, claims, claims.Subject) if err != nil { return err } data["user"] = usr.Response(ctx) - usrAccs, err := user_account.FindByUserID(ctx, claims, h.MasterDB, claims.Subject, false) + usrAccs, err := h.UserAccountRepo.FindByUserID(ctx, claims, claims.Subject, false) if err != nil { return err } @@ -388,7 +386,7 @@ func (h *User) Update(ctx context.Context, w http.ResponseWriter, r *http.Reques } req.ID = claims.Subject - err = user.Update(ctx, claims, h.MasterDB, *req, ctxValues.Now) + err = h.UserRepo.Update(ctx, claims, *req, ctxValues.Now) if err != nil { switch errors.Cause(err) { default: @@ -409,7 +407,7 @@ func (h *User) Update(ctx context.Context, w http.ResponseWriter, r *http.Reques } pwdReq.ID = claims.Subject - err = user.UpdatePassword(ctx, claims, h.MasterDB, *pwdReq, ctxValues.Now) + err = h.UserRepo.UpdatePassword(ctx, claims, *pwdReq, ctxValues.Now) if err != nil { switch errors.Cause(err) { default: @@ -441,7 +439,7 @@ func (h *User) Update(ctx context.Context, w http.ResponseWriter, r *http.Reques return nil } - usr, err := user.ReadByID(ctx, claims, h.MasterDB, claims.Subject) + usr, err := h.UserRepo.ReadByID(ctx, claims, claims.Subject) if err != nil { return err } @@ -484,7 +482,7 @@ func (h *User) Account(ctx context.Context, w http.ResponseWriter, r *http.Reque return err } - acc, err := account.ReadByID(ctx, claims, h.MasterDB, claims.Audience) + acc, err := h.AccountRepo.ReadByID(ctx, claims, claims.Audience) if err != nil { return err } @@ -551,7 +549,7 @@ func (h *User) VirtualLogin(ctx context.Context, w http.ResponseWriter, r *http. } // Perform the account switch. - tkn, err := user_auth.VirtualLogin(ctx, h.MasterDB, h.Authenticator, claims, *req, expires, ctxValues.Now) + tkn, err := h.AuthRepo.VirtualLogin(ctx, claims, *req, expires, ctxValues.Now) if err != nil { if verr, ok := weberror.NewValidationError(ctx, err); ok { data["validationErrors"] = verr.(*weberror.Error) @@ -565,7 +563,7 @@ func (h *User) VirtualLogin(ctx context.Context, w http.ResponseWriter, r *http. sess = webcontext.SessionUpdateAccessToken(sess, tkn.AccessToken) // Read the account for a flash message. - usr, err := user.ReadByID(ctx, claims, h.MasterDB, tkn.UserID) + usr, err := h.UserRepo.ReadByID(ctx, claims, tkn.UserID) if err != nil { return false, err } @@ -588,7 +586,7 @@ func (h *User) VirtualLogin(ctx context.Context, w http.ResponseWriter, r *http. return nil } - usrAccs, err := user_account.Find(ctx, claims, h.MasterDB, user_account.UserAccountFindRequest{ + usrAccs, err := h.UserAccountRepo.Find(ctx, claims, user_account.UserAccountFindRequest{ Where: "account_id = ?", Args: []interface{}{claims.Audience}, }) @@ -612,7 +610,7 @@ func (h *User) VirtualLogin(ctx context.Context, w http.ResponseWriter, r *http. userPhs = append(userPhs, "?") } - users, err := user.Find(ctx, claims, h.MasterDB, user.UserFindRequest{ + users, err := h.UserRepo.Find(ctx, claims, user.UserFindRequest{ Where: fmt.Sprintf("id IN (%s)", strings.Join(userPhs, ", ")), Args: userIDs, @@ -657,7 +655,7 @@ func (h *User) VirtualLogout(ctx context.Context, w http.ResponseWriter, r *http expires = time.Hour } - tkn, err := user_auth.VirtualLogout(ctx, h.MasterDB, h.Authenticator, claims, expires, ctxValues.Now) + tkn, err := h.AuthRepo.VirtualLogout(ctx, claims, expires, ctxValues.Now) if err != nil { return err } @@ -667,11 +665,11 @@ func (h *User) VirtualLogout(ctx context.Context, w http.ResponseWriter, r *http // Display a success message to verify the user has switched contexts. if claims.Subject != tkn.UserID && claims.Audience != tkn.AccountID { - usr, err := user.ReadByID(ctx, claims, h.MasterDB, tkn.UserID) + usr, err := h.UserRepo.ReadByID(ctx, claims, tkn.UserID) if err != nil { return err } - acc, err := account.ReadByID(ctx, claims, h.MasterDB, tkn.AccountID) + acc, err := h.AccountRepo.ReadByID(ctx, claims, tkn.AccountID) if err != nil { return err } @@ -680,7 +678,7 @@ func (h *User) VirtualLogout(ctx context.Context, w http.ResponseWriter, r *http fmt.Sprintf("You are now virtually logged back into account %s user %s.", acc.Response(ctx).Name, usr.Response(ctx).Name)) } else if claims.Audience != tkn.AccountID { - acc, err := account.ReadByID(ctx, claims, h.MasterDB, tkn.AccountID) + acc, err := h.AccountRepo.ReadByID(ctx, claims, tkn.AccountID) if err != nil { return err } @@ -689,7 +687,7 @@ func (h *User) VirtualLogout(ctx context.Context, w http.ResponseWriter, r *http fmt.Sprintf("You are now virtually logged back into account %s.", acc.Response(ctx).Name)) } else { - usr, err := user.ReadByID(ctx, claims, h.MasterDB, tkn.UserID) + usr, err := h.UserRepo.ReadByID(ctx, claims, tkn.UserID) if err != nil { return err } @@ -757,7 +755,7 @@ func (h *User) SwitchAccount(ctx context.Context, w http.ResponseWriter, r *http } // Perform the account switch. - tkn, err := user_auth.SwitchAccount(ctx, h.MasterDB, h.Authenticator, claims, *req, expires, ctxValues.Now) + tkn, err := h.AuthRepo.SwitchAccount(ctx, claims, *req, expires, ctxValues.Now) if err != nil { if verr, ok := weberror.NewValidationError(ctx, err); ok { data["validationErrors"] = verr.(*weberror.Error) @@ -771,7 +769,7 @@ func (h *User) SwitchAccount(ctx context.Context, w http.ResponseWriter, r *http sess = webcontext.SessionUpdateAccessToken(sess, tkn.AccessToken) // Read the account for a flash message. - acc, err := account.ReadByID(ctx, claims, h.MasterDB, tkn.AccountID) + acc, err := h.AccountRepo.ReadByID(ctx, claims, tkn.AccountID) if err != nil { return false, err } @@ -794,7 +792,7 @@ func (h *User) SwitchAccount(ctx context.Context, w http.ResponseWriter, r *http return nil } - accounts, err := account.Find(ctx, claims, h.MasterDB, account.AccountFindRequest{ + accounts, err := h.AccountRepo.Find(ctx, claims, account.AccountFindRequest{ Order: []string{"name"}, }) if err != nil { @@ -816,7 +814,7 @@ func (h *User) SwitchAccount(ctx context.Context, w http.ResponseWriter, r *http } // handleSessionToken persists the access token to the session for request authentication. -func handleSessionToken(ctx context.Context, db *sqlx.DB, w http.ResponseWriter, r *http.Request, token user_auth.Token) error { +func handleSessionToken(ctx context.Context, w http.ResponseWriter, r *http.Request, token user_auth.Token) error { if token.AccessToken == "" { return errors.New("accessToken is required.") } diff --git a/cmd/web-app/handlers/users.go b/cmd/web-app/handlers/users.go index e0b9526..86e5f8b 100644 --- a/cmd/web-app/handlers/users.go +++ b/cmd/web-app/handlers/users.go @@ -3,14 +3,16 @@ package handlers import ( "context" "fmt" + "net/http" + "strings" + "time" + "geeks-accelerator/oss/saas-starter-kit/internal/geonames" "geeks-accelerator/oss/saas-starter-kit/internal/platform/auth" "geeks-accelerator/oss/saas-starter-kit/internal/platform/datatable" - "geeks-accelerator/oss/saas-starter-kit/internal/platform/notify" "geeks-accelerator/oss/saas-starter-kit/internal/platform/web" "geeks-accelerator/oss/saas-starter-kit/internal/platform/web/webcontext" "geeks-accelerator/oss/saas-starter-kit/internal/platform/web/weberror" - project_routes "geeks-accelerator/oss/saas-starter-kit/internal/project-routes" "geeks-accelerator/oss/saas-starter-kit/internal/user" "geeks-accelerator/oss/saas-starter-kit/internal/user_account" "geeks-accelerator/oss/saas-starter-kit/internal/user_account/invite" @@ -20,20 +22,17 @@ import ( "github.com/jmoiron/sqlx" "github.com/pkg/errors" "gopkg.in/DataDog/dd-trace-go.v1/contrib/go-redis/redis" - "net/http" - "strings" - "time" ) // Users represents the Users API method handler set. type Users struct { - MasterDB *sqlx.DB - Redis *redis.Client - Renderer web.Renderer - Authenticator *auth.Authenticator - ProjectRoutes project_routes.ProjectRoutes - NotifyEmail notify.Email - SecretKey string + UserRepo *user.Repository + UserAccountRepo *user_account.Repository + AuthRepo *user_auth.Repository + InviteRepo *invite.Repository + MasterDB *sqlx.DB + Redis *redis.Client + Renderer web.Renderer } func urlUsersIndex() string { @@ -144,7 +143,7 @@ func (h *Users) Index(ctx context.Context, w http.ResponseWriter, r *http.Reques } loadFunc := func(ctx context.Context, sorting string, fields []datatable.DisplayField) (resp [][]datatable.ColumnValue, err error) { - res, err := user_account.UserFindByAccount(ctx, claims, h.MasterDB, user_account.UserFindByAccountRequest{ + res, err := h.UserAccountRepo.UserFindByAccount(ctx, claims, user_account.UserFindByAccountRequest{ AccountID: claims.Audience, Order: strings.Split(sorting, ","), }) @@ -232,7 +231,7 @@ func (h *Users) Create(ctx context.Context, w http.ResponseWriter, r *http.Reque } } - usr, err := user.Create(ctx, claims, h.MasterDB, req.UserCreateRequest, ctxValues.Now) + usr, err := h.UserRepo.Create(ctx, claims, req.UserCreateRequest, ctxValues.Now) if err != nil { switch errors.Cause(err) { default: @@ -246,7 +245,7 @@ func (h *Users) Create(ctx context.Context, w http.ResponseWriter, r *http.Reque } uaStatus := user_account.UserAccountStatus_Active - _, err = user_account.Create(ctx, claims, h.MasterDB, user_account.UserAccountCreateRequest{ + _, err = h.UserAccountRepo.Create(ctx, claims, user_account.UserAccountCreateRequest{ UserID: usr.ID, AccountID: claims.Audience, Roles: req.Roles, @@ -327,7 +326,7 @@ func (h *Users) View(ctx context.Context, w http.ResponseWriter, r *http.Request switch r.PostForm.Get("action") { case "archive": - err = user.Archive(ctx, claims, h.MasterDB, user.UserArchiveRequest{ + err = h.UserRepo.Archive(ctx, claims, user.UserArchiveRequest{ ID: userID, }, ctxValues.Now) if err != nil { @@ -352,14 +351,14 @@ func (h *Users) View(ctx context.Context, w http.ResponseWriter, r *http.Request return nil } - usr, err := user.ReadByID(ctx, claims, h.MasterDB, userID) + usr, err := h.UserRepo.ReadByID(ctx, claims, userID) if err != nil { return err } data["user"] = usr.Response(ctx) - usrAccs, err := user_account.FindByUserID(ctx, claims, h.MasterDB, userID, false) + usrAccs, err := h.UserAccountRepo.FindByUserID(ctx, claims, userID, false) if err != nil { return err } @@ -425,7 +424,7 @@ func (h *Users) Update(ctx context.Context, w http.ResponseWriter, r *http.Reque } } - err = user.Update(ctx, claims, h.MasterDB, req.UserUpdateRequest, ctxValues.Now) + err = h.UserRepo.Update(ctx, claims, req.UserUpdateRequest, ctxValues.Now) if err != nil { switch errors.Cause(err) { default: @@ -439,7 +438,7 @@ func (h *Users) Update(ctx context.Context, w http.ResponseWriter, r *http.Reque } if req.Roles != nil { - err = user_account.Update(ctx, claims, h.MasterDB, user_account.UserAccountUpdateRequest{ + err = h.UserAccountRepo.Update(ctx, claims, user_account.UserAccountUpdateRequest{ UserID: userID, AccountID: claims.Audience, Roles: &req.Roles, @@ -465,7 +464,7 @@ func (h *Users) Update(ctx context.Context, w http.ResponseWriter, r *http.Reque } pwdReq.ID = userID - err = user.UpdatePassword(ctx, claims, h.MasterDB, *pwdReq, ctxValues.Now) + err = h.UserRepo.UpdatePassword(ctx, claims, *pwdReq, ctxValues.Now) if err != nil { switch errors.Cause(err) { default: @@ -497,12 +496,12 @@ func (h *Users) Update(ctx context.Context, w http.ResponseWriter, r *http.Reque return nil } - usr, err := user.ReadByID(ctx, claims, h.MasterDB, userID) + usr, err := h.UserRepo.ReadByID(ctx, claims, userID) if err != nil { return err } - usrAcc, err := user_account.Read(ctx, claims, h.MasterDB, user_account.UserAccountReadRequest{ + usrAcc, err := h.UserAccountRepo.Read(ctx, claims, user_account.UserAccountReadRequest{ UserID: userID, AccountID: claims.Audience, }) @@ -577,7 +576,7 @@ func (h *Users) Invite(ctx context.Context, w http.ResponseWriter, r *http.Reque req.UserID = claims.Subject req.AccountID = claims.Audience - res, err := invite.SendUserInvites(ctx, claims, h.MasterDB, h.ProjectRoutes.UserInviteAccept, h.NotifyEmail, *req, h.SecretKey, ctxValues.Now) + res, err := h.InviteRepo.SendUserInvites(ctx, claims, *req, ctxValues.Now) if err != nil { switch errors.Cause(err) { default: @@ -661,7 +660,7 @@ func (h *Users) InviteAccept(ctx context.Context, w http.ResponseWriter, r *http // Append the query param value to the request. req.InviteHash = inviteHash - hash, err := invite.AcceptInviteUser(ctx, h.MasterDB, *req, h.SecretKey, ctxValues.Now) + hash, err := h.InviteRepo.AcceptInviteUser(ctx, *req, ctxValues.Now) if err != nil { switch errors.Cause(err) { case invite.ErrInviteExpired: @@ -699,13 +698,13 @@ func (h *Users) InviteAccept(ctx context.Context, w http.ResponseWriter, r *http } // Load the user without any claims applied. - usr, err := user.ReadByID(ctx, auth.Claims{}, h.MasterDB, hash.UserID) + usr, err := h.UserRepo.ReadByID(ctx, auth.Claims{}, hash.UserID) if err != nil { return false, err } // Authenticated the user. Probably should use the default session TTL from UserLogin. - token, err := user_auth.Authenticate(ctx, h.MasterDB, h.Authenticator, user_auth.AuthenticateRequest{ + token, err := h.AuthRepo.Authenticate(ctx, user_auth.AuthenticateRequest{ Email: usr.Email, Password: req.Password, AccountID: hash.AccountID, @@ -720,7 +719,7 @@ func (h *Users) InviteAccept(ctx context.Context, w http.ResponseWriter, r *http } // Add the token to the users session. - err = handleSessionToken(ctx, h.MasterDB, w, r, token) + err = handleSessionToken(ctx, w, r, token) if err != nil { return false, err } @@ -729,9 +728,9 @@ func (h *Users) InviteAccept(ctx context.Context, w http.ResponseWriter, r *http return true, web.Redirect(ctx, w, r, "/", http.StatusFound) } - usrAcc, err := invite.AcceptInvite(ctx, h.MasterDB, invite.AcceptInviteRequest{ + usrAcc, err := h.InviteRepo.AcceptInvite(ctx, invite.AcceptInviteRequest{ InviteHash: inviteHash, - }, h.SecretKey, ctxValues.Now) + }, ctxValues.Now) if err != nil { switch errors.Cause(err) { @@ -776,7 +775,7 @@ func (h *Users) InviteAccept(ctx context.Context, w http.ResponseWriter, r *http } // Read user by ID with no claims. - usr, err := user.ReadByID(ctx, auth.Claims{}, h.MasterDB, usrAcc.UserID) + usr, err := h.UserRepo.ReadByID(ctx, auth.Claims{}, usrAcc.UserID) if err != nil { return false, err } diff --git a/cmd/web-app/main.go b/cmd/web-app/main.go index a89843f..659d576 100644 --- a/cmd/web-app/main.go +++ b/cmd/web-app/main.go @@ -6,7 +6,12 @@ import ( "encoding/json" "expvar" "fmt" - "geeks-accelerator/oss/saas-starter-kit/internal/project_route" + "geeks-accelerator/oss/saas-starter-kit/internal/account/account_preference" + "geeks-accelerator/oss/saas-starter-kit/internal/project" + "geeks-accelerator/oss/saas-starter-kit/internal/signup" + "geeks-accelerator/oss/saas-starter-kit/internal/user_account" + "geeks-accelerator/oss/saas-starter-kit/internal/user_account/invite" + "geeks-accelerator/oss/saas-starter-kit/internal/user_auth" "html/template" "log" "net" @@ -33,7 +38,7 @@ import ( template_renderer "geeks-accelerator/oss/saas-starter-kit/internal/platform/web/template-renderer" "geeks-accelerator/oss/saas-starter-kit/internal/platform/web/webcontext" "geeks-accelerator/oss/saas-starter-kit/internal/platform/web/weberror" - project_routes "geeks-accelerator/oss/saas-starter-kit/internal/project-routes" + "geeks-accelerator/oss/saas-starter-kit/internal/project_route" "geeks-accelerator/oss/saas-starter-kit/internal/user" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/credentials" @@ -52,7 +57,6 @@ import ( redistrace "gopkg.in/DataDog/dd-trace-go.v1/contrib/go-redis/redis" sqlxtrace "gopkg.in/DataDog/dd-trace-go.v1/contrib/jmoiron/sqlx" "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer" - "gopkg.in/gomail.v2" ) // build is the git version of this program. It is set using build flags in the makefile. @@ -67,10 +71,9 @@ func main() { // ========================================================================= // Logging - log.SetFlags(log.LstdFlags|log.Lmicroseconds|log.Lshortfile) - log.SetPrefix(service+" : ") - log := log.New(os.Stdout, log.Prefix() , log.Flags()) - + log.SetFlags(log.LstdFlags | log.Lmicroseconds | log.Lshortfile) + log.SetPrefix(service + " : ") + log := log.New(os.Stdout, log.Prefix(), log.Flags()) // ========================================================================= // Configuration @@ -88,12 +91,12 @@ func main() { DisableHTTP2 bool `default:"false" envconfig:"DISABLE_HTTP2"` } Service struct { - Name string `default:"web-app" envconfig:"NAME"` - BaseUrl string `default:"" envconfig:"BASE_URL" example:"http://example.saasstartupkit.com"` - HostNames []string `envconfig:"HOST_NAMES" example:"www.example.saasstartupkit.com"` - EnableHTTPS bool `default:"false" envconfig:"ENABLE_HTTPS"` - TemplateDir string `default:"./templates" envconfig:"TEMPLATE_DIR"` - StaticFiles struct { + Name string `default:"web-app" envconfig:"SERVICE_NAME"` + BaseUrl string `default:"" envconfig:"BASE_URL" example:"http://example.saasstartupkit.com"` + HostNames []string `envconfig:"HOST_NAMES" example:"www.example.saasstartupkit.com"` + EnableHTTPS bool `default:"false" envconfig:"ENABLE_HTTPS"` + TemplateDir string `default:"./templates" envconfig:"TEMPLATE_DIR"` + StaticFiles struct { Dir string `default:"./static" envconfig:"STATIC_DIR"` S3Enabled bool `envconfig:"S3_ENABLED"` S3Prefix string `default:"public/web_app/static" envconfig:"S3_PREFIX"` @@ -105,11 +108,11 @@ func main() { ShutdownTimeout time.Duration `default:"5s" envconfig:"SHUTDOWN_TIMEOUT"` } Project struct { - Name string `default:"" envconfig:"PROJECT"` - SharedTemplateDir string `default:"../../resources/templates/shared" envconfig:"SHARED_TEMPLATE_DIR"` - SharedSecretKey string `default:"" envconfig:"SHARED_SECRET_KEY"` - EmailSender string `default:"test@example.saasstartupkit.com" envconfig:"EMAIL_SENDER"` - WebApiBaseUrl string `default:"http://127.0.0.1:3001" envconfig:"WEB_API_BASE_URL" example:"http://api.example.saasstartupkit.com"` + Name string `default:"" envconfig:"PROJECT_NAME"` + SharedTemplateDir string `default:"../../resources/templates/shared" envconfig:"SHARED_TEMPLATE_DIR"` + SharedSecretKey string `default:"" envconfig:"SHARED_SECRET_KEY"` + EmailSender string `default:"test@example.saasstartupkit.com" envconfig:"EMAIL_SENDER"` + WebApiBaseUrl string `default:"http://127.0.0.1:3001" envconfig:"WEB_API_BASE_URL" example:"http://api.example.saasstartupkit.com"` } Redis struct { Host string `default:":6379" envconfig:"HOST"` @@ -202,7 +205,7 @@ func main() { if cfg.Project.Name != "" { pts = append(pts, cfg.Project.Name) } - pts = append(pts, cfg.Env, cfg.Service.Name) + pts = append(pts, cfg.Env) cfg.Aws.SecretsManagerConfigPrefix = filepath.Join(pts...) } @@ -298,7 +301,7 @@ func main() { // AWS secrets manager ID for storing the session key. This is optional and only will be used // if a valid AWS session is provided. - secretID := filepath.Join(cfg.Aws.SecretsManagerConfigPrefix, "sharedSecretKey") + secretID := filepath.Join(cfg.Aws.SecretsManagerConfigPrefix, "SharedSecretKey") // If AWS is enabled, check the Secrets Manager for the session key. if awsSession != nil { @@ -310,10 +313,10 @@ func main() { // If the session key is still empty, generate a new key. if cfg.Project.SharedSecretKey == "" { - cfg.Project.SharedSecretKey = string(securecookie.GenerateRandomKey(32)) + cfg.Project.SharedSecretKey = string(securecookie.GenerateRandomKey(32)) if awsSession != nil { - err = devops.SecretManagerPutString(awsSession, secretID, cfg.Service.SecretKey) + err = devops.SecretManagerPutString(awsSession, secretID, cfg.Project.SharedSecretKey) if err != nil { log.Fatalf("main : Session : %+v", err) } @@ -396,7 +399,7 @@ func main() { var notifyEmail notify.Email if awsSession != nil { // Send emails with AWS SES. Alternative to use SMTP with notify.NewEmailSmtp. - notifyEmail, err = notify.NewEmailAws(awsSession, cfg.Service.SharedTemplateDir, cfg.Service.EmailSender) + notifyEmail, err = notify.NewEmailAws(awsSession, cfg.Project.SharedTemplateDir, cfg.Project.EmailSender) if err != nil { log.Fatalf("main : Notify Email : %+v", err) } @@ -430,12 +433,44 @@ func main() { } // ========================================================================= - // Load middlewares that need to be configured specific for the service. + // Init repositories and AppContext - var serviceMiddlewares = []web.Middleware{ - mid.Translator(webcontext.UniversalTranslator()), + projectRoute, err := project_route.New(cfg.Project.WebApiBaseUrl, cfg.Service.BaseUrl) + if err != nil { + log.Fatalf("main : project routes : %+v", cfg.Service.BaseUrl, err) } + usrRepo := user.NewRepository(masterDb, projectRoute.UserResetPassword, notifyEmail, cfg.Project.SharedSecretKey) + usrAccRepo := user_account.NewRepository(masterDb) + accRepo := account.NewRepository(masterDb) + accPrefRepo := account_preference.NewRepository(masterDb) + authRepo := user_auth.NewRepository(masterDb, authenticator, usrRepo, usrAccRepo, accPrefRepo) + signupRepo := signup.NewRepository(masterDb, usrRepo, usrAccRepo, accRepo) + inviteRepo := invite.NewRepository(masterDb, usrRepo, usrAccRepo, accRepo, projectRoute.UserInviteAccept, notifyEmail, cfg.Project.SharedSecretKey) + prjRepo := project.NewRepository(masterDb) + + appCtx := &handlers.AppContext{ + Log: log, + Env: cfg.Env, + MasterDB: masterDb, + Redis: redisClient, + TemplateDir: cfg.Service.TemplateDir, + StaticDir: cfg.Service.StaticFiles.Dir, + ProjectRoute: projectRoute, + UserRepo: usrRepo, + UserAccountRepo: usrAccRepo, + AccountRepo: accRepo, + AccountPrefRepo: accPrefRepo, + AuthRepo: authRepo, + SignupRepo: signupRepo, + InviteRepo: inviteRepo, + ProjectRepo: prjRepo, + Authenticator: authenticator, + } + + // ========================================================================= + // Load middlewares that need to be configured specific for the service. + // Init redirect middleware to ensure all requests go to the primary domain contained in the base URL. if primaryServiceHost != "127.0.0.1" && primaryServiceHost != "localhost" { redirect := mid.DomainNameRedirect(mid.DomainNameRedirectConfig{ @@ -451,24 +486,23 @@ func main() { DomainName: primaryServiceHost, HTTPSEnabled: cfg.Service.EnableHTTPS, }) - serviceMiddlewares = append(serviceMiddlewares, redirect) + appCtx.PostAppMiddleware = append(appCtx.PostAppMiddleware, redirect) } + // Add the translator middleware for localization. + appCtx.PostAppMiddleware = append(appCtx.PostAppMiddleware, mid.Translator(webcontext.UniversalTranslator())) + // Generate the new session store and append it to the global list of middlewares. // Init session store if cfg.Service.SessionName == "" { cfg.Service.SessionName = fmt.Sprintf("%s-session", cfg.Service.Name) } - sessionStore := sessions.NewCookieStore([]byte(cfg.Service.SecretKey)) - serviceMiddlewares = append(serviceMiddlewares, mid.Session(sessionStore, cfg.Service.SessionName)) + sessionStore := sessions.NewCookieStore([]byte(cfg.Project.SharedSecretKey)) + appCtx.PostAppMiddleware = append(appCtx.PostAppMiddleware, mid.Session(sessionStore, cfg.Service.SessionName)) // ========================================================================= // URL Formatter - projectRoutes, err := project_route.New(cfg.Service.WebApiBaseUrl, cfg.Service.BaseUrl) - if err != nil { - log.Fatalf("main : project routes : %+v", cfg.Service.BaseUrl, err) - } // s3UrlFormatter is a help function used by to convert an s3 key to // a publicly available image URL. @@ -488,7 +522,7 @@ func main() { return s3UrlFormatter(p) } } else { - staticS3UrlFormatter = projectRoutes.WebAppUrl + staticS3UrlFormatter = projectRoute.WebAppUrl } // staticUrlFormatter is a help function used by template functions defined below. @@ -691,7 +725,7 @@ func main() { return nil } - usr, err := user.ReadByID(ctx, auth.Claims{}, masterDb, claims.Subject) + usr, err := usrRepo.ReadByID(ctx, auth.Claims{}, claims.Subject) if err != nil { return nil } @@ -726,7 +760,7 @@ func main() { return nil } - acc, err := account.ReadByID(ctx, auth.Claims{}, masterDb, claims.Audience) + acc, err := accRepo.ReadByID(ctx, auth.Claims{}, claims.Audience) if err != nil { return nil } @@ -867,7 +901,7 @@ func main() { enableHotReload := cfg.Env == "dev" // Template Renderer used to generate HTML response for web experience. - renderer, err := template_renderer.NewTemplateRenderer(cfg.Service.TemplateDir, enableHotReload, gvd, t, eh) + appCtx.Renderer, err = template_renderer.NewTemplateRenderer(cfg.Service.TemplateDir, enableHotReload, gvd, t, eh) if err != nil { log.Fatalf("main : Marshalling Config to JSON : %+v", err) } @@ -919,7 +953,7 @@ func main() { if cfg.HTTP.Host != "" { api := http.Server{ Addr: cfg.HTTP.Host, - Handler: handlers.APP(shutdown, log, cfg.Env, cfg.Service.StaticFiles.Dir, cfg.Service.TemplateDir, masterDb, redisClient, authenticator, projectRoutes, cfg.Service.SecretKey, notifyEmail, renderer, serviceMiddlewares...), + Handler: handlers.APP(shutdown, appCtx), ReadTimeout: cfg.HTTP.ReadTimeout, WriteTimeout: cfg.HTTP.WriteTimeout, MaxHeaderBytes: 1 << 20, @@ -936,7 +970,7 @@ func main() { if cfg.HTTPS.Host != "" { api := http.Server{ Addr: cfg.HTTPS.Host, - Handler: handlers.APP(shutdown, log, cfg.Env, cfg.Service.StaticFiles.Dir, cfg.Service.TemplateDir, masterDb, redisClient, authenticator, projectRoutes, cfg.Service.SecretKey, notifyEmail, renderer, serviceMiddlewares...), + Handler: handlers.APP(shutdown, appCtx), ReadTimeout: cfg.HTTPS.ReadTimeout, WriteTimeout: cfg.HTTPS.WriteTimeout, MaxHeaderBytes: 1 << 20, diff --git a/internal/platform/notify/email_smtp.go b/internal/platform/notify/email_smtp.go index ea56c4a..9be78c8 100644 --- a/internal/platform/notify/email_smtp.go +++ b/internal/platform/notify/email_smtp.go @@ -20,7 +20,7 @@ package notify Username: cfg.SMTP.User, Password: cfg.SMTP.Pass} notifyEmail, err = notify.NewEmailSmtp(d, cfg.Service.SharedTemplateDir, cfg.Service.EmailSender) - */ +*/ import ( "context" diff --git a/internal/user/models.go b/internal/user/models.go index 7860dd3..b0c3cb7 100644 --- a/internal/user/models.go +++ b/internal/user/models.go @@ -272,3 +272,8 @@ func ParseResetHash(ctx context.Context, secretKey string, str string, now time. return &hash, nil } + +// ParseResetHash extracts the details encrypted in the hash string. +func (repo *Repository) ParseResetHash(ctx context.Context, str string, now time.Time) (*ResetHash, error) { + return ParseResetHash(ctx, repo.secretKey, str, now) +} From 8c28261fee80292826f0b87d7b9a74692c76b5b7 Mon Sep 17 00:00:00 2001 From: huyng Date: Thu, 15 Aug 2019 14:27:05 +0700 Subject: [PATCH 06/13] Update GetGeoNames and Migration functions. --- internal/geonames/geonames.go | 134 ++++++++++++++++++++++++++++++++++ internal/schema/migrations.go | 98 ++++++++++++++++++++----- 2 files changed, 214 insertions(+), 18 deletions(-) diff --git a/internal/geonames/geonames.go b/internal/geonames/geonames.go index 47a4e48..f184425 100644 --- a/internal/geonames/geonames.go +++ b/internal/geonames/geonames.go @@ -8,10 +8,13 @@ import ( "encoding/csv" "fmt" "io" + "net/http" "strconv" "strings" + "time" "geeks-accelerator/oss/saas-starter-kit/internal/platform/web/webcontext" + "github.com/huandu/go-sqlbuilder" "github.com/jmoiron/sqlx" "github.com/pkg/errors" @@ -325,3 +328,134 @@ func loadGeonameCountry(ctx context.Context, rr chan<- interface{}, country stri } } } + +// GetGeonameCountry downloads geoname data for the country. +// Parses data and returns slice of Geoname +func GetGeonameCountry(ctx context.Context, country string) ([]Geoname, error) { + res := make([]Geoname, 0) + var err error + var resp *http.Response + + u := fmt.Sprintf("http://download.geonames.org/export/zip/%s.zip", country) + resp, err = pester.Get(u) + if err != nil { + for i := 0; i < 3; i++ { + resp, err = pester.Get(u) + if err == nil { + break + } + time.Sleep(time.Second * 1) + } + if err != nil { + err = errors.WithMessagef(err, "Failed to read countries from '%s'", u) + return res, err + } + } + defer resp.Body.Close() + + br := bufio.NewReader(resp.Body) + + buff := bytes.NewBuffer([]byte{}) + size, err := io.Copy(buff, br) + if err != nil { + err = errors.WithStack(err) + return res, err + } + + b := bytes.NewReader(buff.Bytes()) + zr, err := zip.NewReader(b, size) + if err != nil { + err = errors.WithStack(err) + return res, err + } + + for _, f := range zr.File { + if f.Name == "readme.txt" { + continue + } + + fh, err := f.Open() + if err != nil { + err = errors.WithStack(err) + return res, err + } + + scanner := bufio.NewScanner(fh) + for scanner.Scan() { + line := scanner.Text() + + if strings.Contains(line, "\"") { + line = strings.Replace(line, "\"", "\\\"", -1) + } + + r := csv.NewReader(strings.NewReader(line)) + r.Comma = '\t' // Use tab-delimited instead of comma <---- here! + r.LazyQuotes = true + r.FieldsPerRecord = -1 + + lines, err := r.ReadAll() + if err != nil { + err = errors.WithStack(err) + continue + } + + for _, row := range lines { + + /* + fmt.Println("CountryCode: row[0]", row[0]) + fmt.Println("PostalCode: row[1]", row[1]) + fmt.Println("PlaceName: row[2]", row[2]) + fmt.Println("StateName: row[3]", row[3]) + fmt.Println("StateCode : row[4]", row[4]) + fmt.Println("CountyName: row[5]", row[5]) + fmt.Println("CountyCode : row[6]", row[6]) + fmt.Println("CommunityName: row[7]", row[7]) + fmt.Println("CommunityCode: row[8]", row[8]) + fmt.Println("Latitude: row[9]", row[9]) + fmt.Println("Longitude: row[10]", row[10]) + fmt.Println("Accuracy: row[11]", row[11]) + */ + + gn := Geoname{ + CountryCode: row[0], + PostalCode: row[1], + PlaceName: row[2], + StateName: row[3], + StateCode: row[4], + CountyName: row[5], + CountyCode: row[6], + CommunityName: row[7], + CommunityCode: row[8], + } + if row[9] != "" { + gn.Latitude, err = decimal.NewFromString(row[9]) + if err != nil { + err = errors.WithStack(err) + } + } + + if row[10] != "" { + gn.Longitude, err = decimal.NewFromString(row[10]) + if err != nil { + err = errors.WithStack(err) + } + } + + if row[11] != "" { + gn.Accuracy, err = strconv.Atoi(row[11]) + if err != nil { + err = errors.WithStack(err) + } + } + + res = append(res, gn) + } + } + + if err := scanner.Err(); err != nil { + err = errors.WithStack(err) + } + } + + return res, err +} diff --git a/internal/schema/migrations.go b/internal/schema/migrations.go index fe6a3c9..4523de1 100644 --- a/internal/schema/migrations.go +++ b/internal/schema/migrations.go @@ -9,6 +9,10 @@ import ( "strings" "geeks-accelerator/oss/saas-starter-kit/internal/geonames" + + "fmt" + "time" + "github.com/geeks-accelerator/sqlxmigrate" "github.com/jmoiron/sqlx" _ "github.com/lib/pq" @@ -240,33 +244,91 @@ func migrationList(ctx context.Context, db *sqlx.DB, log *log.Logger, isUnittest } } - q := "insert into geonames " + - "(country_code,postal_code,place_name,state_name,state_code,county_name,county_code,community_name,community_code,latitude,longitude,accuracy) " + - "values(?,?,?,?,?,?,?,?,?,?,?,?)" - q = db.Rebind(q) - stmt, err := db.Prepare(q) - if err != nil { - return errors.WithMessagef(err, "Failed to prepare sql query '%s'", q) - } - + countries := geonames.ValidGeonameCountries(context.Background()) if isUnittest { - } else { - resChan := make(chan interface{}) - go geonames.LoadGeonames(ctx, resChan) + } - for r := range resChan { - switch v := r.(type) { - case geonames.Geoname: - _, err = stmt.Exec(v.CountryCode, v.PostalCode, v.PlaceName, v.StateName, v.StateCode, v.CountyName, v.CountyCode, v.CommunityName, v.CommunityCode, v.Latitude, v.Longitude, v.Accuracy) + ncol := 12 + fn := func(geoNames []geonames.Geoname) error { + valueStrings := make([]string, 0, len(geoNames)) + valueArgs := make([]interface{}, 0, len(geoNames)*ncol) + for _, geoname := range geoNames { + valueStrings = append(valueStrings, "(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)") + + valueArgs = append(valueArgs, geoname.CountryCode) + valueArgs = append(valueArgs, geoname.PostalCode) + valueArgs = append(valueArgs, geoname.PlaceName) + + valueArgs = append(valueArgs, geoname.StateName) + valueArgs = append(valueArgs, geoname.StateCode) + valueArgs = append(valueArgs, geoname.CountyName) + + valueArgs = append(valueArgs, geoname.CountyCode) + valueArgs = append(valueArgs, geoname.CommunityName) + valueArgs = append(valueArgs, geoname.CommunityCode) + + valueArgs = append(valueArgs, geoname.Latitude) + valueArgs = append(valueArgs, geoname.Longitude) + valueArgs = append(valueArgs, geoname.Accuracy) + } + insertStmt := fmt.Sprintf("insert into geonames "+ + "(country_code,postal_code,place_name,state_name,state_code,county_name,county_code,community_name,community_code,latitude,longitude,accuracy) "+ + "VALUES %s", strings.Join(valueStrings, ",")) + insertStmt = db.Rebind(insertStmt) + + stmt, err := db.Prepare(insertStmt) + if err != nil { + return errors.WithMessagef(err, "Failed to prepare sql query '%s'", insertStmt) + } + + _, err = stmt.Exec(valueArgs...) + return err + } + start := time.Now() + for _, country := range countries { + //fmt.Println("LoadGeonames: start country: ", country) + v, err := geonames.GetGeonameCountry(context.Background(), country) + if err != nil { + return errors.WithStack(err) + } + //fmt.Println("Geoname records: ", len(v)) + + batch := 4500 + n := len(v) / batch + + //fmt.Println("Number of batch: ", n) + + if n == 0 { + err := fn(v) + if err != nil { + return errors.WithStack(err) + } + } else { + for i := 0; i < n; i++ { + vn := v[i*batch : (i+1)*batch] + err := fn(vn) + if err != nil { + return errors.WithStack(err) + } + if n > 0 && n%25 == 0 { + time.Sleep(200) + } + } + if len(v)%batch > 0 { + fmt.Println("Remain part: ", len(v)-n*batch) + vn := v[n*batch:] + err := fn(vn) if err != nil { return errors.WithStack(err) } - case error: - return v } } + + //fmt.Println("Insert Geoname took: ", time.Since(start)) + //fmt.Println("LoadGeonames: end country: ", country) } + fmt.Println("Total Geonames population took: ", time.Since(start)) queries := []string{ `create index idx_geonames_country_code on geonames (country_code)`, From 71713280729d79ee4207a3771cc96d574daa30de Mon Sep 17 00:00:00 2001 From: huyng Date: Thu, 15 Aug 2019 14:40:22 +0700 Subject: [PATCH 07/13] Use ctx param from outer function --- internal/schema/migrations.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/schema/migrations.go b/internal/schema/migrations.go index 4523de1..bab4c17 100644 --- a/internal/schema/migrations.go +++ b/internal/schema/migrations.go @@ -244,7 +244,7 @@ func migrationList(ctx context.Context, db *sqlx.DB, log *log.Logger, isUnittest } } - countries := geonames.ValidGeonameCountries(context.Background()) + countries := geonames.ValidGeonameCountries(ctx) if isUnittest { } From c61a934279b03fb77c87a2afdbff9fe7f7e46d79 Mon Sep 17 00:00:00 2001 From: huyng Date: Thu, 15 Aug 2019 14:46:44 +0700 Subject: [PATCH 08/13] Add more comment --- internal/geonames/geonames.go | 2 ++ internal/schema/migrations.go | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/internal/geonames/geonames.go b/internal/geonames/geonames.go index f184425..ae1c9c8 100644 --- a/internal/geonames/geonames.go +++ b/internal/geonames/geonames.go @@ -339,6 +339,8 @@ func GetGeonameCountry(ctx context.Context, country string) ([]Geoname, error) { u := fmt.Sprintf("http://download.geonames.org/export/zip/%s.zip", country) resp, err = pester.Get(u) if err != nil { + // Add re-try three times after failing first time + // This reduces the risk when network is lagy, we still have chance to re-try. for i := 0; i < 3; i++ { resp, err = pester.Get(u) if err == nil { diff --git a/internal/schema/migrations.go b/internal/schema/migrations.go index bab4c17..2ec3bec 100644 --- a/internal/schema/migrations.go +++ b/internal/schema/migrations.go @@ -293,7 +293,7 @@ func migrationList(ctx context.Context, db *sqlx.DB, log *log.Logger, isUnittest return errors.WithStack(err) } //fmt.Println("Geoname records: ", len(v)) - + // Max argument values of Postgres is about 54460. So the batch size for bulk insert is selected 4500*12 (ncol) batch := 4500 n := len(v) / batch From 83118e85ca67aac4d5858f2f276437a616c310e3 Mon Sep 17 00:00:00 2001 From: Lee Brown Date: Fri, 16 Aug 2019 14:52:57 -0800 Subject: [PATCH 09/13] Remove POD architecture from readme --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index c5fb683..fc58cd3 100644 --- a/README.md +++ b/README.md @@ -489,7 +489,7 @@ For more details on this service, read [web-app readme](https://gitlab.com/geeks Schema is a minimalistic database migration helper that can manually be invoked via CLI. It provides schema versioning and migration rollback. -To support POD architecture, the schema for the entire project is defined globally and is located inside internal: +The schema for the entire project is defined globally and is located inside internal: [internal/schema](https://gitlab.com/geeks-accelerator/oss/saas-starter-kit/tree/master/internal/schema) Keeping a global schema helps ensure business logic can be decoupled across multiple packages. It is a firm belief that From d277b0ec254797b66d23e0a291457ac60fa24367 Mon Sep 17 00:00:00 2001 From: huyng Date: Sat, 17 Aug 2019 11:03:48 +0700 Subject: [PATCH 10/13] Use interface in the handlers of web-api/web-app --- cmd/web-api/handlers/account.go | 30 ++++++-- cmd/web-api/handlers/example.go | 5 +- cmd/web-api/handlers/project.go | 28 +++++--- cmd/web-api/handlers/routes.go | 35 ++++------ cmd/web-api/handlers/signup.go | 8 ++- cmd/web-api/handlers/user.go | 70 +++++++++++++------ cmd/web-api/handlers/user_account.go | 25 ++++++- cmd/web-api/main.go | 3 +- cmd/web-api/tests/account_test.go | 1 + cmd/web-app/handlers/account.go | 22 +++--- cmd/web-app/handlers/api_geo.go | 25 +++++-- cmd/web-app/handlers/projects.go | 4 +- cmd/web-app/handlers/routes.go | 50 +++++++------ cmd/web-app/handlers/signup.go | 9 ++- cmd/web-app/handlers/user.go | 40 ++++++----- cmd/web-app/handlers/users.go | 19 ++--- cmd/web-app/main.go | 10 ++- docker-compose.yaml | 4 +- .../account_preference/account_preference.go | 1 + internal/geonames/countries.go | 8 +-- internal/geonames/country_timezones.go | 12 ++-- internal/geonames/geonames.go | 23 +++--- internal/geonames/models.go | 12 ++++ internal/schema/migrations.go | 4 +- internal/user_account/invite/invite.go | 5 +- internal/user_auth/auth.go | 4 +- 26 files changed, 294 insertions(+), 163 deletions(-) diff --git a/cmd/web-api/handlers/account.go b/cmd/web-api/handlers/account.go index 592d962..f81bba0 100644 --- a/cmd/web-api/handlers/account.go +++ b/cmd/web-api/handlers/account.go @@ -4,23 +4,45 @@ import ( "context" "net/http" "strconv" + "time" "geeks-accelerator/oss/saas-starter-kit/internal/account" + accountref "geeks-accelerator/oss/saas-starter-kit/internal/account/account_preference" "geeks-accelerator/oss/saas-starter-kit/internal/platform/auth" "geeks-accelerator/oss/saas-starter-kit/internal/platform/web" "geeks-accelerator/oss/saas-starter-kit/internal/platform/web/webcontext" "geeks-accelerator/oss/saas-starter-kit/internal/platform/web/weberror" + "github.com/pkg/errors" "gopkg.in/go-playground/validator.v9" ) // Account represents the Account API method handler set. -type Account struct { - *account.Repository +type Accounts struct { + Repository AccountRepository // ADD OTHER STATE LIKE THE LOGGER AND CONFIG HERE. } +type AccountRepository interface { + //CanReadAccount(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, accountID string) error + Find(ctx context.Context, claims auth.Claims, req account.AccountFindRequest) (account.Accounts, error) + Create(ctx context.Context, claims auth.Claims, req account.AccountCreateRequest, now time.Time) (*account.Account, error) + ReadByID(ctx context.Context, claims auth.Claims, id string) (*account.Account, error) + Read(ctx context.Context, claims auth.Claims, req account.AccountReadRequest) (*account.Account, error) + Update(ctx context.Context, claims auth.Claims, req account.AccountUpdateRequest, now time.Time) error + Archive(ctx context.Context, claims auth.Claims, req account.AccountArchiveRequest, now time.Time) error + Delete(ctx context.Context, claims auth.Claims, req account.AccountDeleteRequest) error +} +type AccountPrefRepository interface { + Find(ctx context.Context, claims auth.Claims, req accountref.AccountPreferenceFindRequest) ([]*accountref.AccountPreference, error) + FindByAccountID(ctx context.Context, claims auth.Claims, req accountref.AccountPreferenceFindByAccountIDRequest) ([]*accountref.AccountPreference, error) + Read(ctx context.Context, claims auth.Claims, req accountref.AccountPreferenceReadRequest) (*accountref.AccountPreference, error) + Set(ctx context.Context, claims auth.Claims, req accountref.AccountPreferenceSetRequest, now time.Time) error + Archive(ctx context.Context, claims auth.Claims, req accountref.AccountPreferenceArchiveRequest, now time.Time) error + Delete(ctx context.Context, claims auth.Claims, req accountref.AccountPreferenceDeleteRequest) error +} + // Read godoc // @Summary Get account by ID // @Description Read returns the specified account from the system. @@ -34,7 +56,7 @@ type Account struct { // @Failure 404 {object} weberror.ErrorResponse // @Failure 500 {object} weberror.ErrorResponse // @Router /accounts/{id} [get] -func (h *Account) Read(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { +func (h *Accounts) Read(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { claims, ok := ctx.Value(auth.Key).(auth.Claims) if !ok { return errors.New("claims missing from context") @@ -81,7 +103,7 @@ func (h *Account) Read(ctx context.Context, w http.ResponseWriter, r *http.Reque // @Failure 403 {object} weberror.ErrorResponse // @Failure 500 {object} weberror.ErrorResponse // @Router /accounts [patch] -func (h *Account) Update(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { +func (h *Accounts) Update(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { v, err := webcontext.ContextValues(ctx) if err != nil { diff --git a/cmd/web-api/handlers/example.go b/cmd/web-api/handlers/example.go index fe15ddd..b866eff 100644 --- a/cmd/web-api/handlers/example.go +++ b/cmd/web-api/handlers/example.go @@ -7,13 +7,14 @@ import ( "geeks-accelerator/oss/saas-starter-kit/internal/platform/web/webcontext" "geeks-accelerator/oss/saas-starter-kit/internal/platform/web/weberror" "geeks-accelerator/oss/saas-starter-kit/internal/project" - "github.com/pkg/errors" "net/http" + + "github.com/pkg/errors" ) // Example represents the Example API method handler set. type Example struct { - Project *project.Repository + Project ProjectRepository // ADD OTHER STATE LIKE THE LOGGER AND CONFIG HERE. } diff --git a/cmd/web-api/handlers/project.go b/cmd/web-api/handlers/project.go index 82c835f..b0cd946 100644 --- a/cmd/web-api/handlers/project.go +++ b/cmd/web-api/handlers/project.go @@ -5,23 +5,35 @@ import ( "net/http" "strconv" "strings" + "time" "geeks-accelerator/oss/saas-starter-kit/internal/platform/auth" "geeks-accelerator/oss/saas-starter-kit/internal/platform/web" "geeks-accelerator/oss/saas-starter-kit/internal/platform/web/webcontext" "geeks-accelerator/oss/saas-starter-kit/internal/platform/web/weberror" "geeks-accelerator/oss/saas-starter-kit/internal/project" + "github.com/pkg/errors" "gopkg.in/go-playground/validator.v9" ) // Project represents the Project API method handler set. -type Project struct { - *project.Repository +type Projects struct { + Repository ProjectRepository // ADD OTHER STATE LIKE THE LOGGER IF NEEDED. } +type ProjectRepository interface { + ReadByID(ctx context.Context, claims auth.Claims, id string) (*project.Project, error) + Find(ctx context.Context, claims auth.Claims, req project.ProjectFindRequest) (project.Projects, error) + Read(ctx context.Context, claims auth.Claims, req project.ProjectReadRequest) (*project.Project, error) + Create(ctx context.Context, claims auth.Claims, req project.ProjectCreateRequest, now time.Time) (*project.Project, error) + Update(ctx context.Context, claims auth.Claims, req project.ProjectUpdateRequest, now time.Time) error + Archive(ctx context.Context, claims auth.Claims, req project.ProjectArchiveRequest, now time.Time) error + Delete(ctx context.Context, claims auth.Claims, req project.ProjectDeleteRequest) error +} + // Find godoc // TODO: Need to implement unittests on projects/find endpoint. There are none. // @Summary List projects @@ -40,7 +52,7 @@ type Project struct { // @Failure 403 {object} weberror.ErrorResponse // @Failure 500 {object} weberror.ErrorResponse // @Router /projects [get] -func (h *Project) Find(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { +func (h *Projects) Find(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { claims, ok := ctx.Value(auth.Key).(auth.Claims) if !ok { return errors.New("claims missing from context") @@ -133,7 +145,7 @@ func (h *Project) Find(ctx context.Context, w http.ResponseWriter, r *http.Reque // @Failure 404 {object} weberror.ErrorResponse // @Failure 500 {object} weberror.ErrorResponse // @Router /projects/{id} [get] -func (h *Project) Read(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { +func (h *Projects) Read(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { claims, ok := ctx.Value(auth.Key).(auth.Claims) if !ok { return errors.New("claims missing from context") @@ -181,7 +193,7 @@ func (h *Project) Read(ctx context.Context, w http.ResponseWriter, r *http.Reque // @Failure 404 {object} weberror.ErrorResponse // @Failure 500 {object} weberror.ErrorResponse // @Router /projects [post] -func (h *Project) Create(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { +func (h *Projects) Create(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { v, err := webcontext.ContextValues(ctx) if err != nil { return err @@ -231,7 +243,7 @@ func (h *Project) Create(ctx context.Context, w http.ResponseWriter, r *http.Req // @Failure 403 {object} weberror.ErrorResponse // @Failure 500 {object} weberror.ErrorResponse // @Router /projects [patch] -func (h *Project) Update(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { +func (h *Projects) Update(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { v, err := webcontext.ContextValues(ctx) if err != nil { return err @@ -282,7 +294,7 @@ func (h *Project) Update(ctx context.Context, w http.ResponseWriter, r *http.Req // @Failure 403 {object} weberror.ErrorResponse // @Failure 500 {object} weberror.ErrorResponse // @Router /projects/archive [patch] -func (h *Project) Archive(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { +func (h *Projects) Archive(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { v, err := webcontext.ContextValues(ctx) if err != nil { return err @@ -333,7 +345,7 @@ func (h *Project) Archive(ctx context.Context, w http.ResponseWriter, r *http.Re // @Failure 403 {object} weberror.ErrorResponse // @Failure 500 {object} weberror.ErrorResponse // @Router /projects/{id} [delete] -func (h *Project) Delete(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { +func (h *Projects) Delete(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { claims, err := auth.ClaimsFromContext(ctx) if err != nil { return err diff --git a/cmd/web-api/handlers/routes.go b/cmd/web-api/handlers/routes.go index 00f4704..cfec550 100644 --- a/cmd/web-api/handlers/routes.go +++ b/cmd/web-api/handlers/routes.go @@ -5,21 +5,14 @@ import ( "net/http" "os" - "geeks-accelerator/oss/saas-starter-kit/internal/account" - "geeks-accelerator/oss/saas-starter-kit/internal/account/account_preference" "geeks-accelerator/oss/saas-starter-kit/internal/mid" saasSwagger "geeks-accelerator/oss/saas-starter-kit/internal/mid/saas-swagger" "geeks-accelerator/oss/saas-starter-kit/internal/platform/auth" "geeks-accelerator/oss/saas-starter-kit/internal/platform/web" "geeks-accelerator/oss/saas-starter-kit/internal/platform/web/webcontext" _ "geeks-accelerator/oss/saas-starter-kit/internal/platform/web/weberror" - "geeks-accelerator/oss/saas-starter-kit/internal/project" - "geeks-accelerator/oss/saas-starter-kit/internal/signup" _ "geeks-accelerator/oss/saas-starter-kit/internal/signup" - "geeks-accelerator/oss/saas-starter-kit/internal/user" - "geeks-accelerator/oss/saas-starter-kit/internal/user_account" - "geeks-accelerator/oss/saas-starter-kit/internal/user_account/invite" - "geeks-accelerator/oss/saas-starter-kit/internal/user_auth" + "github.com/jmoiron/sqlx" "gopkg.in/DataDog/dd-trace-go.v1/contrib/go-redis/redis" ) @@ -29,14 +22,14 @@ type AppContext struct { Env webcontext.Env MasterDB *sqlx.DB Redis *redis.Client - UserRepo *user.Repository - UserAccountRepo *user_account.Repository - AccountRepo *account.Repository - AccountPrefRepo *account_preference.Repository - AuthRepo *user_auth.Repository - SignupRepo *signup.Repository - InviteRepo *invite.Repository - ProjectRepo *project.Repository + UserRepo UserRepository + UserAccountRepo UserAccountRepository + AccountRepo AccountRepository + AccountPrefRepo AccountPrefRepository + AuthRepo UserAuthRepository + SignupRepo SignupRepository + InviteRepo UserInviteRepository + ProjectRepo ProjectRepository Authenticator *auth.Authenticator PreAppMiddleware []web.Middleware PostAppMiddleware []web.Middleware @@ -79,9 +72,9 @@ func API(shutdown chan os.Signal, appCtx *AppContext) http.Handler { app.Handle("GET", "/v1/examples/error-response", ex.ErrorResponse) // Register user management and authentication endpoints. - u := User{ - Repository: appCtx.UserRepo, - Auth: appCtx.AuthRepo, + u := Users{ + UserRepo: appCtx.UserRepo, + AuthRepo: appCtx.AuthRepo, } app.Handle("GET", "/v1/users", u.Find, mid.AuthenticateHeader(appCtx.Authenticator)) app.Handle("POST", "/v1/users", u.Create, mid.AuthenticateHeader(appCtx.Authenticator), mid.HasRole(auth.RoleAdmin)) @@ -107,7 +100,7 @@ func API(shutdown chan os.Signal, appCtx *AppContext) http.Handler { app.Handle("DELETE", "/v1/user_accounts", ua.Delete, mid.AuthenticateHeader(appCtx.Authenticator), mid.HasRole(auth.RoleAdmin)) // Register account endpoints. - a := Account{ + a := Accounts{ Repository: appCtx.AccountRepo, } app.Handle("GET", "/v1/accounts/:id", a.Read, mid.AuthenticateHeader(appCtx.Authenticator)) @@ -120,7 +113,7 @@ func API(shutdown chan os.Signal, appCtx *AppContext) http.Handler { app.Handle("POST", "/v1/signup", s.Signup) // Register project. - p := Project{ + p := Projects{ Repository: appCtx.ProjectRepo, } app.Handle("GET", "/v1/projects", p.Find, mid.AuthenticateHeader(appCtx.Authenticator)) diff --git a/cmd/web-api/handlers/signup.go b/cmd/web-api/handlers/signup.go index e2472a3..a29afe2 100644 --- a/cmd/web-api/handlers/signup.go +++ b/cmd/web-api/handlers/signup.go @@ -3,6 +3,7 @@ package handlers import ( "context" "net/http" + "time" "geeks-accelerator/oss/saas-starter-kit/internal/account" "geeks-accelerator/oss/saas-starter-kit/internal/platform/auth" @@ -10,17 +11,22 @@ import ( "geeks-accelerator/oss/saas-starter-kit/internal/platform/web/webcontext" "geeks-accelerator/oss/saas-starter-kit/internal/platform/web/weberror" "geeks-accelerator/oss/saas-starter-kit/internal/signup" + "github.com/pkg/errors" "gopkg.in/go-playground/validator.v9" ) // Signup represents the Signup API method handler set. type Signup struct { - *signup.Repository + Repository SignupRepository // ADD OTHER STATE LIKE THE LOGGER AND CONFIG HERE. } +type SignupRepository interface { + Signup(ctx context.Context, claims auth.Claims, req signup.SignupRequest, now time.Time) (*signup.SignupResult, error) +} + // Signup godoc // @Summary Signup handles new account creation. // @Description Signup creates a new account and user in the system. diff --git a/cmd/web-api/handlers/user.go b/cmd/web-api/handlers/user.go index ddbd934..93551fe 100644 --- a/cmd/web-api/handlers/user.go +++ b/cmd/web-api/handlers/user.go @@ -13,6 +13,7 @@ import ( "geeks-accelerator/oss/saas-starter-kit/internal/platform/web/weberror" "geeks-accelerator/oss/saas-starter-kit/internal/user" "geeks-accelerator/oss/saas-starter-kit/internal/user_auth" + "github.com/gorilla/schema" "github.com/pkg/errors" "gopkg.in/go-playground/validator.v9" @@ -22,13 +23,36 @@ import ( var sessionTtl = time.Hour * 24 // User represents the User API method handler set. -type User struct { - *user.Repository - Auth *user_auth.Repository - +type Users struct { + AuthRepo UserAuthRepository + UserRepo UserRepository // ADD OTHER STATE LIKE THE LOGGER AND CONFIG HERE. } +type UserAuthRepository interface { + SwitchAccount(ctx context.Context, claims auth.Claims, req user_auth.SwitchAccountRequest, expires time.Duration, + now time.Time, scopes ...string) (user_auth.Token, error) + Authenticate(ctx context.Context, req user_auth.AuthenticateRequest, expires time.Duration, now time.Time, scopes ...string) (user_auth.Token, error) + VirtualLogin(ctx context.Context, claims auth.Claims, req user_auth.VirtualLoginRequest, + expires time.Duration, now time.Time, scopes ...string) (user_auth.Token, error) + VirtualLogout(ctx context.Context, claims auth.Claims, expires time.Duration, now time.Time, scopes ...string) (user_auth.Token, error) +} + +type UserRepository interface { + Find(ctx context.Context, claims auth.Claims, req user.UserFindRequest) (user.Users, error) + //FindByAccount(ctx context.Context, claims auth.Claims, req user.UserFindByAccountRequest) (user.Users, error) + Read(ctx context.Context, claims auth.Claims, req user.UserReadRequest) (*user.User, error) + ReadByID(ctx context.Context, claims auth.Claims, id string) (*user.User, error) + Create(ctx context.Context, claims auth.Claims, req user.UserCreateRequest, now time.Time) (*user.User, error) + Update(ctx context.Context, claims auth.Claims, req user.UserUpdateRequest, now time.Time) error + UpdatePassword(ctx context.Context, claims auth.Claims, req user.UserUpdatePasswordRequest, now time.Time) error + Archive(ctx context.Context, claims auth.Claims, req user.UserArchiveRequest, now time.Time) error + Restore(ctx context.Context, claims auth.Claims, req user.UserRestoreRequest, now time.Time) error + Delete(ctx context.Context, claims auth.Claims, req user.UserDeleteRequest) error + ResetPassword(ctx context.Context, req user.UserResetPasswordRequest, now time.Time) (string, error) + ResetConfirm(ctx context.Context, req user.UserResetConfirmRequest, now time.Time) (*user.User, error) +} + // Find godoc // TODO: Need to implement unittests on users/find endpoint. There are none. // @Summary List users @@ -46,7 +70,7 @@ type User struct { // @Failure 400 {object} weberror.ErrorResponse // @Failure 500 {object} weberror.ErrorResponse // @Router /users [get] -func (h *User) Find(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { +func (h *Users) Find(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { claims, ok := ctx.Value(auth.Key).(auth.Claims) if !ok { return errors.New("claims missing from context") @@ -113,7 +137,7 @@ func (h *User) Find(ctx context.Context, w http.ResponseWriter, r *http.Request, // return web.RespondJsonError(ctx, w, err) //} - res, err := h.Repository.Find(ctx, claims, req) + res, err := h.UserRepo.Find(ctx, claims, req) if err != nil { return err } @@ -139,7 +163,7 @@ func (h *User) Find(ctx context.Context, w http.ResponseWriter, r *http.Request, // @Failure 404 {object} weberror.ErrorResponse // @Failure 500 {object} weberror.ErrorResponse // @Router /users/{id} [get] -func (h *User) Read(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { +func (h *Users) Read(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { claims, ok := ctx.Value(auth.Key).(auth.Claims) if !ok { return errors.New("claims missing from context") @@ -156,7 +180,7 @@ func (h *User) Read(ctx context.Context, w http.ResponseWriter, r *http.Request, includeArchived = b } - res, err := h.Repository.Read(ctx, claims, user.UserReadRequest{ + res, err := h.UserRepo.Read(ctx, claims, user.UserReadRequest{ ID: params["id"], IncludeArchived: includeArchived, }) @@ -186,7 +210,7 @@ func (h *User) Read(ctx context.Context, w http.ResponseWriter, r *http.Request, // @Failure 403 {object} weberror.ErrorResponse // @Failure 500 {object} weberror.ErrorResponse // @Router /users [post] -func (h *User) Create(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { +func (h *Users) Create(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { v, err := webcontext.ContextValues(ctx) if err != nil { return err @@ -205,7 +229,7 @@ func (h *User) Create(ctx context.Context, w http.ResponseWriter, r *http.Reques return web.RespondJsonError(ctx, w, err) } - res, err := h.Repository.Create(ctx, claims, req, v.Now) + usr, err := h.UserRepo.Create(ctx, claims, req, v.Now) if err != nil { cause := errors.Cause(err) switch cause { @@ -221,7 +245,7 @@ func (h *User) Create(ctx context.Context, w http.ResponseWriter, r *http.Reques } } - return web.RespondJson(ctx, w, res.Response(ctx), http.StatusCreated) + return web.RespondJson(ctx, w, usr.Response(ctx), http.StatusCreated) } // Read godoc @@ -237,7 +261,7 @@ func (h *User) Create(ctx context.Context, w http.ResponseWriter, r *http.Reques // @Failure 403 {object} weberror.ErrorResponse // @Failure 500 {object} weberror.ErrorResponse // @Router /users [patch] -func (h *User) Update(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { +func (h *Users) Update(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { v, err := webcontext.ContextValues(ctx) if err != nil { return err @@ -256,7 +280,7 @@ func (h *User) Update(ctx context.Context, w http.ResponseWriter, r *http.Reques return web.RespondJsonError(ctx, w, err) } - err = h.Repository.Update(ctx, claims, req, v.Now) + err = h.UserRepo.Update(ctx, claims, req, v.Now) if err != nil { cause := errors.Cause(err) switch cause { @@ -288,7 +312,7 @@ func (h *User) Update(ctx context.Context, w http.ResponseWriter, r *http.Reques // @Failure 403 {object} weberror.ErrorResponse // @Failure 500 {object} weberror.ErrorResponse // @Router /users/password [patch] -func (h *User) UpdatePassword(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { +func (h *Users) UpdatePassword(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { v, err := webcontext.ContextValues(ctx) if err != nil { return err @@ -307,7 +331,7 @@ func (h *User) UpdatePassword(ctx context.Context, w http.ResponseWriter, r *htt return web.RespondJsonError(ctx, w, err) } - err = h.Repository.UpdatePassword(ctx, claims, req, v.Now) + err = h.UserRepo.UpdatePassword(ctx, claims, req, v.Now) if err != nil { cause := errors.Cause(err) switch cause { @@ -341,7 +365,7 @@ func (h *User) UpdatePassword(ctx context.Context, w http.ResponseWriter, r *htt // @Failure 403 {object} weberror.ErrorResponse // @Failure 500 {object} weberror.ErrorResponse // @Router /users/archive [patch] -func (h *User) Archive(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { +func (h *Users) Archive(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { v, err := webcontext.ContextValues(ctx) if err != nil { return err @@ -360,7 +384,7 @@ func (h *User) Archive(ctx context.Context, w http.ResponseWriter, r *http.Reque return web.RespondJsonError(ctx, w, err) } - err = h.Repository.Archive(ctx, claims, req, v.Now) + err = h.UserRepo.Archive(ctx, claims, req, v.Now) if err != nil { cause := errors.Cause(err) switch cause { @@ -392,13 +416,13 @@ func (h *User) Archive(ctx context.Context, w http.ResponseWriter, r *http.Reque // @Failure 403 {object} weberror.ErrorResponse // @Failure 500 {object} weberror.ErrorResponse // @Router /users/{id} [delete] -func (h *User) Delete(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { +func (h *Users) Delete(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { claims, err := auth.ClaimsFromContext(ctx) if err != nil { return err } - err = h.Repository.Delete(ctx, claims, + err = h.UserRepo.Delete(ctx, claims, user.UserDeleteRequest{ID: params["id"]}) if err != nil { cause := errors.Cause(err) @@ -431,7 +455,7 @@ func (h *User) Delete(ctx context.Context, w http.ResponseWriter, r *http.Reques // @Failure 401 {object} weberror.ErrorResponse // @Failure 500 {object} weberror.ErrorResponse // @Router /users/switch-account/{account_id} [patch] -func (h *User) SwitchAccount(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { +func (h *Users) SwitchAccount(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { v, err := webcontext.ContextValues(ctx) if err != nil { return err @@ -442,7 +466,7 @@ func (h *User) SwitchAccount(ctx context.Context, w http.ResponseWriter, r *http return err } - tkn, err := h.Auth.SwitchAccount(ctx, claims, user_auth.SwitchAccountRequest{ + tkn, err := h.AuthRepo.SwitchAccount(ctx, claims, user_auth.SwitchAccountRequest{ AccountID: params["account_id"], }, sessionTtl, v.Now) if err != nil { @@ -478,7 +502,7 @@ func (h *User) SwitchAccount(ctx context.Context, w http.ResponseWriter, r *http // @Failure 401 {object} weberror.ErrorResponse // @Failure 500 {object} weberror.ErrorResponse // @Router /oauth/token [post] -func (h *User) Token(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { +func (h *Users) Token(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { v, err := webcontext.ContextValues(ctx) if err != nil { return err @@ -533,7 +557,7 @@ func (h *User) Token(ctx context.Context, w http.ResponseWriter, r *http.Request scopes = strings.Split(qv, ",") } - tkn, err := h.Auth.Authenticate(ctx, authReq, sessionTtl, v.Now, scopes...) + tkn, err := h.AuthRepo.Authenticate(ctx, authReq, sessionTtl, v.Now, scopes...) if err != nil { cause := errors.Cause(err) switch cause { diff --git a/cmd/web-api/handlers/user_account.go b/cmd/web-api/handlers/user_account.go index aec3075..57c0890 100644 --- a/cmd/web-api/handlers/user_account.go +++ b/cmd/web-api/handlers/user_account.go @@ -5,23 +5,44 @@ import ( "net/http" "strconv" "strings" + "time" "geeks-accelerator/oss/saas-starter-kit/internal/platform/auth" "geeks-accelerator/oss/saas-starter-kit/internal/platform/web" "geeks-accelerator/oss/saas-starter-kit/internal/platform/web/webcontext" "geeks-accelerator/oss/saas-starter-kit/internal/platform/web/weberror" + "geeks-accelerator/oss/saas-starter-kit/internal/user_account" + "geeks-accelerator/oss/saas-starter-kit/internal/user_account/invite" + "github.com/pkg/errors" "gopkg.in/go-playground/validator.v9" ) // UserAccount represents the UserAccount API method handler set. type UserAccount struct { - *user_account.Repository - + UserInvite UserInviteRepository + Repository UserAccountRepository // ADD OTHER STATE LIKE THE LOGGER AND CONFIG HERE. } +type UserAccountRepository interface { + Find(ctx context.Context, claims auth.Claims, req user_account.UserAccountFindRequest) (user_account.UserAccounts, error) + FindByUserID(ctx context.Context, claims auth.Claims, userID string, includedArchived bool) (user_account.UserAccounts, error) + UserFindByAccount(ctx context.Context, claims auth.Claims, req user_account.UserFindByAccountRequest) (user_account.Users, error) + Create(ctx context.Context, claims auth.Claims, req user_account.UserAccountCreateRequest, now time.Time) (*user_account.UserAccount, error) + Read(ctx context.Context, claims auth.Claims, req user_account.UserAccountReadRequest) (*user_account.UserAccount, error) + Update(ctx context.Context, claims auth.Claims, req user_account.UserAccountUpdateRequest, now time.Time) error + Archive(ctx context.Context, claims auth.Claims, req user_account.UserAccountArchiveRequest, now time.Time) error + Delete(ctx context.Context, claims auth.Claims, req user_account.UserAccountDeleteRequest) error +} + +type UserInviteRepository interface { + SendUserInvites(ctx context.Context, claims auth.Claims, req invite.SendUserInvitesRequest, now time.Time) ([]string, error) + AcceptInvite(ctx context.Context, req invite.AcceptInviteRequest, now time.Time) (*user_account.UserAccount, error) + AcceptInviteUser(ctx context.Context, req invite.AcceptInviteUserRequest, now time.Time) (*user_account.UserAccount, error) +} + // Find godoc // TODO: Need to implement unittests on user_accounts/find endpoint. There are none. // @Summary List user accounts diff --git a/cmd/web-api/main.go b/cmd/web-api/main.go index 9a48ec4..ede15da 100644 --- a/cmd/web-api/main.go +++ b/cmd/web-api/main.go @@ -35,6 +35,7 @@ import ( "geeks-accelerator/oss/saas-starter-kit/internal/user_account" "geeks-accelerator/oss/saas-starter-kit/internal/user_account/invite" "geeks-accelerator/oss/saas-starter-kit/internal/user_auth" + "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/credentials" "github.com/aws/aws-sdk-go/aws/ec2metadata" @@ -435,7 +436,7 @@ func main() { projectRoute, err := project_route.New(cfg.Service.BaseUrl, cfg.Project.WebAppBaseUrl) if err != nil { - log.Fatalf("main : project routes : %+v", cfg.Service.BaseUrl, err) + log.Fatalf("main : project routes : %s: %+v", cfg.Service.BaseUrl, err) } usrRepo := user.NewRepository(masterDb, projectRoute.UserResetPassword, notifyEmail, cfg.Project.SharedSecretKey) diff --git a/cmd/web-api/tests/account_test.go b/cmd/web-api/tests/account_test.go index 918aa6c..f7fecfd 100644 --- a/cmd/web-api/tests/account_test.go +++ b/cmd/web-api/tests/account_test.go @@ -13,6 +13,7 @@ import ( "geeks-accelerator/oss/saas-starter-kit/internal/platform/tests" "geeks-accelerator/oss/saas-starter-kit/internal/platform/web" "geeks-accelerator/oss/saas-starter-kit/internal/platform/web/weberror" + "github.com/pborman/uuid" ) diff --git a/cmd/web-app/handlers/account.go b/cmd/web-app/handlers/account.go index 3ba3e0f..0656fb0 100644 --- a/cmd/web-app/handlers/account.go +++ b/cmd/web-app/handlers/account.go @@ -2,9 +2,7 @@ package handlers import ( "context" - "net/http" - "time" - + "geeks-accelerator/oss/saas-starter-kit/cmd/web-api/handlers" "geeks-accelerator/oss/saas-starter-kit/internal/account" "geeks-accelerator/oss/saas-starter-kit/internal/account/account_preference" "geeks-accelerator/oss/saas-starter-kit/internal/geonames" @@ -12,19 +10,21 @@ import ( "geeks-accelerator/oss/saas-starter-kit/internal/platform/web" "geeks-accelerator/oss/saas-starter-kit/internal/platform/web/webcontext" "geeks-accelerator/oss/saas-starter-kit/internal/platform/web/weberror" - "geeks-accelerator/oss/saas-starter-kit/internal/user_auth" + + "net/http" + "time" + "github.com/gorilla/schema" - "github.com/jmoiron/sqlx" "github.com/pkg/errors" ) // Account represents the Account API method handler set. type Account struct { - AccountRepo *account.Repository - AccountPrefRepo *account_preference.Repository - AuthRepo *user_auth.Repository + AccountRepo handlers.AccountRepository + AccountPrefRepo handlers.AccountPrefRepository + AuthRepo handlers.UserAuthRepository + GeoRepo GeoRepository Authenticator *auth.Authenticator - MasterDB *sqlx.DB Renderer web.Renderer } @@ -248,14 +248,14 @@ func (h *Account) Update(ctx context.Context, w http.ResponseWriter, r *http.Req data["account"] = acc.Response(ctx) - data["timezones"], err = geonames.ListTimezones(ctx, h.MasterDB) + data["timezones"], err = h.GeoRepo.ListTimezones(ctx) if err != nil { return false, err } data["geonameCountries"] = geonames.ValidGeonameCountries(ctx) - data["countries"], err = geonames.FindCountries(ctx, h.MasterDB, "name", "") + data["countries"], err = h.GeoRepo.FindCountries(ctx, "name", "") if err != nil { return false, err } diff --git a/cmd/web-app/handlers/api_geo.go b/cmd/web-app/handlers/api_geo.go index 3e4c9af..dfc3ad6 100644 --- a/cmd/web-app/handlers/api_geo.go +++ b/cmd/web-app/handlers/api_geo.go @@ -8,14 +8,25 @@ import ( "geeks-accelerator/oss/saas-starter-kit/internal/geonames" "geeks-accelerator/oss/saas-starter-kit/internal/platform/web" - "github.com/jmoiron/sqlx" + + //"github.com/jmoiron/sqlx" "gopkg.in/DataDog/dd-trace-go.v1/contrib/go-redis/redis" ) // Check provides support for orchestration geo endpoints. type Geo struct { - MasterDB *sqlx.DB - Redis *redis.Client + Redis *redis.Client + GeoRepo GeoRepository +} + +type GeoRepository interface { + FindGeonames(ctx context.Context, orderBy, where string, args ...interface{}) ([]*geonames.Geoname, error) + FindGeonamePostalCodes(ctx context.Context, where string, args ...interface{}) ([]string, error) + FindGeonameRegions(ctx context.Context, orderBy, where string, args ...interface{}) ([]*geonames.Region, error) + FindCountries(ctx context.Context, orderBy, where string, args ...interface{}) ([]*geonames.Country, error) + FindCountryTimezones(ctx context.Context, orderBy, where string, args ...interface{}) ([]*geonames.CountryTimezone, error) + ListTimezones(ctx context.Context) ([]string, error) + LoadGeonames(ctx context.Context, rr chan<- interface{}, countries ...string) } // GeonameByPostalCode... @@ -39,7 +50,7 @@ func (h *Geo) GeonameByPostalCode(ctx context.Context, w http.ResponseWriter, r where := strings.Join(filters, " AND ") - res, err := geonames.FindGeonames(ctx, h.MasterDB, "postal_code", where, args...) + res, err := h.GeoRepo.FindGeonames(ctx, "postal_code", where, args...) if err != nil { fmt.Printf("%+v", err) return web.RespondJsonError(ctx, w, err) @@ -74,7 +85,7 @@ func (h *Geo) PostalCodesAutocomplete(ctx context.Context, w http.ResponseWriter where := strings.Join(filters, " AND ") - res, err := geonames.FindGeonamePostalCodes(ctx, h.MasterDB, where, args...) + res, err := h.GeoRepo.FindGeonamePostalCodes(ctx, where, args...) if err != nil { return web.RespondJsonError(ctx, w, err) } @@ -101,7 +112,7 @@ func (h *Geo) RegionsAutocomplete(ctx context.Context, w http.ResponseWriter, r where := strings.Join(filters, " AND ") - res, err := geonames.FindGeonameRegions(ctx, h.MasterDB, "state_name", where, args...) + res, err := h.GeoRepo.FindGeonameRegions(ctx, "state_name", where, args...) if err != nil { fmt.Printf("%+v", err) return web.RespondJsonError(ctx, w, err) @@ -144,7 +155,7 @@ func (h *Geo) CountryTimezones(ctx context.Context, w http.ResponseWriter, r *ht where := strings.Join(filters, " AND ") - res, err := geonames.FindCountryTimezones(ctx, h.MasterDB, "timezone_id", where, args...) + res, err := h.GeoRepo.FindCountryTimezones(ctx, "timezone_id", where, args...) if err != nil { return web.RespondJsonError(ctx, w, err) } diff --git a/cmd/web-app/handlers/projects.go b/cmd/web-app/handlers/projects.go index 8a375fe..9c54fa1 100644 --- a/cmd/web-app/handlers/projects.go +++ b/cmd/web-app/handlers/projects.go @@ -3,6 +3,7 @@ package handlers import ( "context" "fmt" + "geeks-accelerator/oss/saas-starter-kit/cmd/web-api/handlers" "net/http" "strings" @@ -12,6 +13,7 @@ import ( "geeks-accelerator/oss/saas-starter-kit/internal/platform/web/webcontext" "geeks-accelerator/oss/saas-starter-kit/internal/platform/web/weberror" "geeks-accelerator/oss/saas-starter-kit/internal/project" + "github.com/gorilla/schema" "github.com/pkg/errors" "gopkg.in/DataDog/dd-trace-go.v1/contrib/go-redis/redis" @@ -19,7 +21,7 @@ import ( // Projects represents the Projects API method handler set. type Projects struct { - ProjectRepo *project.Repository + ProjectRepo handlers.ProjectRepository Redis *redis.Client Renderer web.Renderer } diff --git a/cmd/web-app/handlers/routes.go b/cmd/web-app/handlers/routes.go index 01bd4b2..0aee9d7 100644 --- a/cmd/web-app/handlers/routes.go +++ b/cmd/web-app/handlers/routes.go @@ -9,20 +9,23 @@ import ( "path/filepath" "time" - "geeks-accelerator/oss/saas-starter-kit/internal/account" - "geeks-accelerator/oss/saas-starter-kit/internal/account/account_preference" + "geeks-accelerator/oss/saas-starter-kit/cmd/web-api/handlers" + //"geeks-accelerator/oss/saas-starter-kit/internal/account" + //"geeks-accelerator/oss/saas-starter-kit/internal/account/account_preference" "geeks-accelerator/oss/saas-starter-kit/internal/mid" "geeks-accelerator/oss/saas-starter-kit/internal/platform/auth" "geeks-accelerator/oss/saas-starter-kit/internal/platform/web" "geeks-accelerator/oss/saas-starter-kit/internal/platform/web/webcontext" "geeks-accelerator/oss/saas-starter-kit/internal/platform/web/weberror" - "geeks-accelerator/oss/saas-starter-kit/internal/project" + + //"geeks-accelerator/oss/saas-starter-kit/internal/project" "geeks-accelerator/oss/saas-starter-kit/internal/project_route" - "geeks-accelerator/oss/saas-starter-kit/internal/signup" - "geeks-accelerator/oss/saas-starter-kit/internal/user" - "geeks-accelerator/oss/saas-starter-kit/internal/user_account" - "geeks-accelerator/oss/saas-starter-kit/internal/user_account/invite" - "geeks-accelerator/oss/saas-starter-kit/internal/user_auth" + // "geeks-accelerator/oss/saas-starter-kit/internal/signup" + // "geeks-accelerator/oss/saas-starter-kit/internal/user" + // "geeks-accelerator/oss/saas-starter-kit/internal/user_account" + // "geeks-accelerator/oss/saas-starter-kit/internal/user_account/invite" + // "geeks-accelerator/oss/saas-starter-kit/internal/user_auth" + "github.com/ikeikeikeike/go-sitemap-generator/v2/stm" "github.com/jmoiron/sqlx" "gopkg.in/DataDog/dd-trace-go.v1/contrib/go-redis/redis" @@ -39,14 +42,15 @@ type AppContext struct { Env webcontext.Env MasterDB *sqlx.DB Redis *redis.Client - UserRepo *user.Repository - UserAccountRepo *user_account.Repository - AccountRepo *account.Repository - AccountPrefRepo *account_preference.Repository - AuthRepo *user_auth.Repository - SignupRepo *signup.Repository - InviteRepo *invite.Repository - ProjectRepo *project.Repository + UserRepo handlers.UserRepository + UserAccountRepo handlers.UserAccountRepository + AccountRepo handlers.AccountRepository + AccountPrefRepo handlers.AccountPrefRepository + AuthRepo handlers.UserAuthRepository + SignupRepo handlers.SignupRepository + InviteRepo handlers.UserInviteRepository + ProjectRepo handlers.ProjectRepository + GeoRepo GeoRepository Authenticator *auth.Authenticator StaticDir string TemplateDir string @@ -117,7 +121,7 @@ func APP(shutdown chan os.Signal, appCtx *AppContext) http.Handler { UserAccountRepo: appCtx.UserAccountRepo, AuthRepo: appCtx.AuthRepo, InviteRepo: appCtx.InviteRepo, - MasterDB: appCtx.MasterDB, + GeoRepo: appCtx.GeoRepo, Redis: appCtx.Redis, Renderer: appCtx.Renderer, } @@ -134,12 +138,12 @@ func APP(shutdown chan os.Signal, appCtx *AppContext) http.Handler { app.Handle("GET", "/users", us.Index, mid.AuthenticateSessionRequired(appCtx.Authenticator), mid.HasAuth()) // Register user management and authentication endpoints. - u := User{ + u := UserRepos{ UserRepo: appCtx.UserRepo, UserAccountRepo: appCtx.UserAccountRepo, AccountRepo: appCtx.AccountRepo, AuthRepo: appCtx.AuthRepo, - MasterDB: appCtx.MasterDB, + GeoRepo: appCtx.GeoRepo, Renderer: appCtx.Renderer, } app.Handle("POST", "/user/login", u.Login) @@ -168,7 +172,7 @@ func APP(shutdown chan os.Signal, appCtx *AppContext) http.Handler { AccountPrefRepo: appCtx.AccountPrefRepo, AuthRepo: appCtx.AuthRepo, Authenticator: appCtx.Authenticator, - MasterDB: appCtx.MasterDB, + GeoRepo: appCtx.GeoRepo, Renderer: appCtx.Renderer, } app.Handle("POST", "/account/update", acc.Update, mid.AuthenticateSessionRequired(appCtx.Authenticator), mid.HasRole(auth.RoleAdmin)) @@ -180,7 +184,7 @@ func APP(shutdown chan os.Signal, appCtx *AppContext) http.Handler { s := Signup{ SignupRepo: appCtx.SignupRepo, AuthRepo: appCtx.AuthRepo, - MasterDB: appCtx.MasterDB, + GeoRepo: appCtx.GeoRepo, Renderer: appCtx.Renderer, } // This route is not authenticated @@ -197,8 +201,8 @@ func APP(shutdown chan os.Signal, appCtx *AppContext) http.Handler { // Register geo g := Geo{ - MasterDB: appCtx.MasterDB, - Redis: appCtx.Redis, + GeoRepo: appCtx.GeoRepo, + Redis: appCtx.Redis, } app.Handle("GET", "/geo/regions/autocomplete", g.RegionsAutocomplete) app.Handle("GET", "/geo/postal_codes/autocomplete", g.PostalCodesAutocomplete) diff --git a/cmd/web-app/handlers/signup.go b/cmd/web-app/handlers/signup.go index cb23c48..964d90c 100644 --- a/cmd/web-app/handlers/signup.go +++ b/cmd/web-app/handlers/signup.go @@ -2,6 +2,7 @@ package handlers import ( "context" + "geeks-accelerator/oss/saas-starter-kit/cmd/web-api/handlers" "net/http" "time" @@ -13,6 +14,7 @@ import ( "geeks-accelerator/oss/saas-starter-kit/internal/platform/web/weberror" "geeks-accelerator/oss/saas-starter-kit/internal/signup" "geeks-accelerator/oss/saas-starter-kit/internal/user_auth" + "github.com/gorilla/schema" "github.com/jmoiron/sqlx" "github.com/pkg/errors" @@ -20,8 +22,9 @@ import ( // Signup represents the Signup API method handler set. type Signup struct { - SignupRepo *signup.Repository - AuthRepo *user_auth.Repository + SignupRepo handlers.SignupRepository + AuthRepo handlers.UserAuthRepository + GeoRepo GeoRepository MasterDB *sqlx.DB Renderer web.Renderer } @@ -108,7 +111,7 @@ func (h *Signup) Step1(ctx context.Context, w http.ResponseWriter, r *http.Reque data["geonameCountries"] = geonames.ValidGeonameCountries(ctx) - data["countries"], err = geonames.FindCountries(ctx, h.MasterDB, "name", "") + data["countries"], err = h.GeoRepo.FindCountries(ctx, "name", "") if err != nil { return err } diff --git a/cmd/web-app/handlers/user.go b/cmd/web-app/handlers/user.go index b47b4c9..57b92f4 100644 --- a/cmd/web-app/handlers/user.go +++ b/cmd/web-app/handlers/user.go @@ -8,8 +8,9 @@ import ( "strings" "time" + "geeks-accelerator/oss/saas-starter-kit/cmd/web-api/handlers" "geeks-accelerator/oss/saas-starter-kit/internal/account" - "geeks-accelerator/oss/saas-starter-kit/internal/geonames" + "geeks-accelerator/oss/saas-starter-kit/internal/platform/auth" "geeks-accelerator/oss/saas-starter-kit/internal/platform/web" "geeks-accelerator/oss/saas-starter-kit/internal/platform/web/webcontext" @@ -17,6 +18,7 @@ import ( "geeks-accelerator/oss/saas-starter-kit/internal/user" "geeks-accelerator/oss/saas-starter-kit/internal/user_account" "geeks-accelerator/oss/saas-starter-kit/internal/user_auth" + "github.com/gorilla/schema" "github.com/gorilla/sessions" "github.com/jmoiron/sqlx" @@ -24,13 +26,15 @@ import ( ) // User represents the User API method handler set. -type User struct { - UserRepo *user.Repository - AuthRepo *user_auth.Repository - UserAccountRepo *user_account.Repository - AccountRepo *account.Repository +type UserRepos struct { + UserRepo handlers.UserRepository + AuthRepo handlers.UserAuthRepository + UserAccountRepo handlers.UserAccountRepository + AccountRepo handlers.AccountRepository + GeoRepo GeoRepository MasterDB *sqlx.DB Renderer web.Renderer + SecretKey string } func urlUserVirtualLogin(userID string) string { @@ -44,7 +48,7 @@ type UserLoginRequest struct { } // Login handles authenticating a user into the system. -func (h *User) Login(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { +func (h UserRepos) Login(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { ctxValues, err := webcontext.ContextValues(ctx) if err != nil { @@ -132,7 +136,7 @@ func (h *User) Login(ctx context.Context, w http.ResponseWriter, r *http.Request } // Logout handles removing authentication for the user. -func (h *User) Logout(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { +func (h *UserRepos) Logout(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { sess := webcontext.ContextSession(ctx) @@ -148,7 +152,7 @@ func (h *User) Logout(ctx context.Context, w http.ResponseWriter, r *http.Reques } // ResetPassword allows a user to perform forgot password. -func (h *User) ResetPassword(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { +func (h *UserRepos) ResetPassword(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { ctxValues, err := webcontext.ContextValues(ctx) if err != nil { @@ -208,7 +212,7 @@ func (h *User) ResetPassword(ctx context.Context, w http.ResponseWriter, r *http } // ResetConfirm handles changing a users password after they have clicked on the link emailed. -func (h *User) ResetConfirm(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { +func (h *UserRepos) ResetConfirm(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { resetHash := params["hash"] @@ -278,7 +282,7 @@ func (h *User) ResetConfirm(ctx context.Context, w http.ResponseWriter, r *http. return true, web.Redirect(ctx, w, r, "/", http.StatusFound) } - _, err = h.UserRepo.ParseResetHash(ctx, resetHash, ctxValues.Now) + _, err = user.ParseResetHash(ctx, h.SecretKey, resetHash, ctxValues.Now) if err != nil { switch errors.Cause(err) { case user.ErrResetExpired: @@ -316,7 +320,7 @@ func (h *User) ResetConfirm(ctx context.Context, w http.ResponseWriter, r *http. } // View handles displaying the current user profile. -func (h *User) View(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { +func (h *UserRepos) View(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { data := make(map[string]interface{}) f := func() error { @@ -356,7 +360,7 @@ func (h *User) View(ctx context.Context, w http.ResponseWriter, r *http.Request, } // Update handles allowing the current user to update their profile. -func (h *User) Update(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { +func (h *UserRepos) Update(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { ctxValues, err := webcontext.ContextValues(ctx) if err != nil { @@ -453,7 +457,7 @@ func (h *User) Update(ctx context.Context, w http.ResponseWriter, r *http.Reques data["user"] = usr.Response(ctx) - data["timezones"], err = geonames.ListTimezones(ctx, h.MasterDB) + data["timezones"], err = h.GeoRepo.ListTimezones(ctx) if err != nil { return err } @@ -472,7 +476,7 @@ func (h *User) Update(ctx context.Context, w http.ResponseWriter, r *http.Reques } // Account handles displaying the Account for the current user. -func (h *User) Account(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { +func (h *UserRepos) Account(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { data := make(map[string]interface{}) f := func() error { @@ -499,7 +503,7 @@ func (h *User) Account(ctx context.Context, w http.ResponseWriter, r *http.Reque } // VirtualLogin handles switching the scope of the context to another user. -func (h *User) VirtualLogin(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { +func (h *UserRepos) VirtualLogin(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { ctxValues, err := webcontext.ContextValues(ctx) if err != nil { @@ -634,7 +638,7 @@ func (h *User) VirtualLogin(ctx context.Context, w http.ResponseWriter, r *http. } // VirtualLogout handles switching the scope back to the user who initiated the virtual login. -func (h *User) VirtualLogout(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { +func (h *UserRepos) VirtualLogout(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { ctxValues, err := webcontext.ContextValues(ctx) if err != nil { @@ -708,7 +712,7 @@ func (h *User) VirtualLogout(ctx context.Context, w http.ResponseWriter, r *http } // VirtualLogin handles switching the scope of the context to another user. -func (h *User) SwitchAccount(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { +func (h *UserRepos) SwitchAccount(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { ctxValues, err := webcontext.ContextValues(ctx) if err != nil { diff --git a/cmd/web-app/handlers/users.go b/cmd/web-app/handlers/users.go index 86e5f8b..8b322e7 100644 --- a/cmd/web-app/handlers/users.go +++ b/cmd/web-app/handlers/users.go @@ -3,11 +3,11 @@ package handlers import ( "context" "fmt" + "geeks-accelerator/oss/saas-starter-kit/cmd/web-api/handlers" "net/http" "strings" "time" - "geeks-accelerator/oss/saas-starter-kit/internal/geonames" "geeks-accelerator/oss/saas-starter-kit/internal/platform/auth" "geeks-accelerator/oss/saas-starter-kit/internal/platform/datatable" "geeks-accelerator/oss/saas-starter-kit/internal/platform/web" @@ -17,6 +17,7 @@ import ( "geeks-accelerator/oss/saas-starter-kit/internal/user_account" "geeks-accelerator/oss/saas-starter-kit/internal/user_account/invite" "geeks-accelerator/oss/saas-starter-kit/internal/user_auth" + "github.com/dustin/go-humanize/english" "github.com/gorilla/schema" "github.com/jmoiron/sqlx" @@ -26,10 +27,12 @@ import ( // Users represents the Users API method handler set. type Users struct { - UserRepo *user.Repository - UserAccountRepo *user_account.Repository - AuthRepo *user_auth.Repository - InviteRepo *invite.Repository + UserRepo handlers.UserRepository + AccountRepo handlers.AccountRepository + UserAccountRepo handlers.UserAccountRepository + AuthRepo handlers.UserAuthRepository + InviteRepo handlers.UserInviteRepository + GeoRepo GeoRepository MasterDB *sqlx.DB Redis *redis.Client Renderer web.Renderer @@ -281,7 +284,7 @@ func (h *Users) Create(ctx context.Context, w http.ResponseWriter, r *http.Reque return nil } - data["timezones"], err = geonames.ListTimezones(ctx, h.MasterDB) + data["timezones"], err = h.GeoRepo.ListTimezones(ctx) if err != nil { return err } @@ -519,7 +522,7 @@ func (h *Users) Update(ctx context.Context, w http.ResponseWriter, r *http.Reque data["user"] = usr.Response(ctx) - data["timezones"], err = geonames.ListTimezones(ctx, h.MasterDB) + data["timezones"], err = h.GeoRepo.ListTimezones(ctx) if err != nil { return err } @@ -798,7 +801,7 @@ func (h *Users) InviteAccept(ctx context.Context, w http.ResponseWriter, r *http return nil } - data["timezones"], err = geonames.ListTimezones(ctx, h.MasterDB) + data["timezones"], err = h.GeoRepo.ListTimezones(ctx) if err != nil { return err } diff --git a/cmd/web-app/main.go b/cmd/web-app/main.go index 659d576..8fe14f4 100644 --- a/cmd/web-app/main.go +++ b/cmd/web-app/main.go @@ -7,6 +7,7 @@ import ( "expvar" "fmt" "geeks-accelerator/oss/saas-starter-kit/internal/account/account_preference" + "geeks-accelerator/oss/saas-starter-kit/internal/geonames" "geeks-accelerator/oss/saas-starter-kit/internal/project" "geeks-accelerator/oss/saas-starter-kit/internal/signup" "geeks-accelerator/oss/saas-starter-kit/internal/user_account" @@ -40,6 +41,7 @@ import ( "geeks-accelerator/oss/saas-starter-kit/internal/platform/web/weberror" "geeks-accelerator/oss/saas-starter-kit/internal/project_route" "geeks-accelerator/oss/saas-starter-kit/internal/user" + "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/credentials" "github.com/aws/aws-sdk-go/aws/ec2metadata" @@ -443,6 +445,7 @@ func main() { usrRepo := user.NewRepository(masterDb, projectRoute.UserResetPassword, notifyEmail, cfg.Project.SharedSecretKey) usrAccRepo := user_account.NewRepository(masterDb) accRepo := account.NewRepository(masterDb) + geoRepo := geonames.NewRepository(masterDb) accPrefRepo := account_preference.NewRepository(masterDb) authRepo := user_auth.NewRepository(masterDb, authenticator, usrRepo, usrAccRepo, accPrefRepo) signupRepo := signup.NewRepository(masterDb, usrRepo, usrAccRepo, accRepo) @@ -450,9 +453,9 @@ func main() { prjRepo := project.NewRepository(masterDb) appCtx := &handlers.AppContext{ - Log: log, - Env: cfg.Env, - MasterDB: masterDb, + Log: log, + Env: cfg.Env, + //MasterDB: masterDb, Redis: redisClient, TemplateDir: cfg.Service.TemplateDir, StaticDir: cfg.Service.StaticFiles.Dir, @@ -462,6 +465,7 @@ func main() { AccountRepo: accRepo, AccountPrefRepo: accPrefRepo, AuthRepo: authRepo, + GeoRepo: geoRepo, SignupRepo: signupRepo, InviteRepo: inviteRepo, ProjectRepo: prjRepo, diff --git a/docker-compose.yaml b/docker-compose.yaml index 7c538bb..128c059 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -26,9 +26,9 @@ services: redis: image: redis:latest expose: - - "6379" + - "6378" ports: - - "6379:6379" + - "6378:6379" networks: main: aliases: diff --git a/internal/account/account_preference/account_preference.go b/internal/account/account_preference/account_preference.go index 3dd0e41..6fb3bad 100644 --- a/internal/account/account_preference/account_preference.go +++ b/internal/account/account_preference/account_preference.go @@ -7,6 +7,7 @@ import ( "geeks-accelerator/oss/saas-starter-kit/internal/account" "geeks-accelerator/oss/saas-starter-kit/internal/platform/auth" "geeks-accelerator/oss/saas-starter-kit/internal/platform/web/webcontext" + "github.com/huandu/go-sqlbuilder" "github.com/jmoiron/sqlx" "github.com/pborman/uuid" diff --git a/internal/geonames/countries.go b/internal/geonames/countries.go index 4eaaf65..342ef5e 100644 --- a/internal/geonames/countries.go +++ b/internal/geonames/countries.go @@ -2,8 +2,8 @@ package geonames import ( "context" + "github.com/huandu/go-sqlbuilder" - "github.com/jmoiron/sqlx" "github.com/pkg/errors" "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer" ) @@ -14,7 +14,7 @@ const ( ) // FindCountries .... -func FindCountries(ctx context.Context, dbConn *sqlx.DB, orderBy, where string, args ...interface{}) ([]*Country, error) { +func (repo *Repository) FindCountries(ctx context.Context, orderBy, where string, args ...interface{}) ([]*Country, error) { span, ctx := tracer.StartSpanFromContext(ctx, "internal.geonames.FindCountries") defer span.Finish() @@ -32,11 +32,11 @@ func FindCountries(ctx context.Context, dbConn *sqlx.DB, orderBy, where string, } queryStr, queryArgs := query.Build() - queryStr = dbConn.Rebind(queryStr) + queryStr = repo.DbConn.Rebind(queryStr) args = append(args, queryArgs...) // fetch all places from the db - rows, err := dbConn.QueryContext(ctx, queryStr, args...) + rows, err := repo.DbConn.QueryContext(ctx, queryStr, args...) if err != nil { err = errors.Wrapf(err, "query - %s", query.String()) err = errors.WithMessage(err, "find countries failed") diff --git a/internal/geonames/country_timezones.go b/internal/geonames/country_timezones.go index d16d92d..1e4546a 100644 --- a/internal/geonames/country_timezones.go +++ b/internal/geonames/country_timezones.go @@ -2,8 +2,8 @@ package geonames import ( "context" + "github.com/huandu/go-sqlbuilder" - "github.com/jmoiron/sqlx" "github.com/pkg/errors" "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer" ) @@ -14,7 +14,7 @@ const ( ) // FindCountryTimezones .... -func FindCountryTimezones(ctx context.Context, dbConn *sqlx.DB, orderBy, where string, args ...interface{}) ([]*CountryTimezone, error) { +func (repo *Repository) FindCountryTimezones(ctx context.Context, orderBy, where string, args ...interface{}) ([]*CountryTimezone, error) { span, ctx := tracer.StartSpanFromContext(ctx, "internal.geonames.FindCountryTimezones") defer span.Finish() @@ -32,11 +32,11 @@ func FindCountryTimezones(ctx context.Context, dbConn *sqlx.DB, orderBy, where s } queryStr, queryArgs := query.Build() - queryStr = dbConn.Rebind(queryStr) + queryStr = repo.DbConn.Rebind(queryStr) args = append(args, queryArgs...) // Fetch all country timezones from the db. - rows, err := dbConn.QueryContext(ctx, queryStr, args...) + rows, err := repo.DbConn.QueryContext(ctx, queryStr, args...) if err != nil { err = errors.Wrapf(err, "query - %s", query.String()) err = errors.WithMessage(err, "find country timezones failed") @@ -64,8 +64,8 @@ func FindCountryTimezones(ctx context.Context, dbConn *sqlx.DB, orderBy, where s return resp, nil } -func ListTimezones(ctx context.Context, dbConn *sqlx.DB) ([]string, error) { - res, err := FindCountryTimezones(ctx, dbConn, "timezone_id", "") +func (repo *Repository) ListTimezones(ctx context.Context) ([]string, error) { + res, err := repo.FindCountryTimezones(ctx, "timezone_id", "") if err != nil { return nil, err } diff --git a/internal/geonames/geonames.go b/internal/geonames/geonames.go index 47a4e48..9fba47c 100644 --- a/internal/geonames/geonames.go +++ b/internal/geonames/geonames.go @@ -12,8 +12,9 @@ import ( "strings" "geeks-accelerator/oss/saas-starter-kit/internal/platform/web/webcontext" + "github.com/huandu/go-sqlbuilder" - "github.com/jmoiron/sqlx" + // "github.com/jmoiron/sqlx" "github.com/pkg/errors" "github.com/sethgrid/pester" "github.com/shopspring/decimal" @@ -43,7 +44,7 @@ func ValidGeonameCountries(ctx context.Context) []string { } // FindGeonames .... -func FindGeonames(ctx context.Context, dbConn *sqlx.DB, orderBy, where string, args ...interface{}) ([]*Geoname, error) { +func (repo *Repository) FindGeonames(ctx context.Context, orderBy, where string, args ...interface{}) ([]*Geoname, error) { span, ctx := tracer.StartSpanFromContext(ctx, "internal.geonames.FindGeonames") defer span.Finish() @@ -61,11 +62,11 @@ func FindGeonames(ctx context.Context, dbConn *sqlx.DB, orderBy, where string, a } queryStr, queryArgs := query.Build() - queryStr = dbConn.Rebind(queryStr) + queryStr = repo.DbConn.Rebind(queryStr) args = append(args, queryArgs...) // fetch all places from the db - rows, err := dbConn.QueryContext(ctx, queryStr, args...) + rows, err := repo.DbConn.QueryContext(ctx, queryStr, args...) if err != nil { err = errors.Wrapf(err, "query - %s", query.String()) err = errors.WithMessage(err, "find regions failed") @@ -93,7 +94,7 @@ func FindGeonames(ctx context.Context, dbConn *sqlx.DB, orderBy, where string, a } // FindGeonamePostalCodes .... -func FindGeonamePostalCodes(ctx context.Context, dbConn *sqlx.DB, where string, args ...interface{}) ([]string, error) { +func (repo *Repository) FindGeonamePostalCodes(ctx context.Context, where string, args ...interface{}) ([]string, error) { span, ctx := tracer.StartSpanFromContext(ctx, "internal.geonames.FindGeonamePostalCodes") defer span.Finish() @@ -106,11 +107,11 @@ func FindGeonamePostalCodes(ctx context.Context, dbConn *sqlx.DB, where string, } queryStr, queryArgs := query.Build() - queryStr = dbConn.Rebind(queryStr) + queryStr = repo.DbConn.Rebind(queryStr) args = append(args, queryArgs...) // fetch all places from the db - rows, err := dbConn.QueryContext(ctx, queryStr, args...) + rows, err := repo.DbConn.QueryContext(ctx, queryStr, args...) if err != nil { err = errors.Wrapf(err, "query - %s", query.String()) err = errors.WithMessage(err, "find regions failed") @@ -138,7 +139,7 @@ func FindGeonamePostalCodes(ctx context.Context, dbConn *sqlx.DB, where string, } // FindGeonameRegions .... -func FindGeonameRegions(ctx context.Context, dbConn *sqlx.DB, orderBy, where string, args ...interface{}) ([]*Region, error) { +func (repo *Repository) FindGeonameRegions(ctx context.Context, orderBy, where string, args ...interface{}) ([]*Region, error) { span, ctx := tracer.StartSpanFromContext(ctx, "internal.geonames.FindGeonameRegions") defer span.Finish() @@ -156,11 +157,11 @@ func FindGeonameRegions(ctx context.Context, dbConn *sqlx.DB, orderBy, where str } queryStr, queryArgs := query.Build() - queryStr = dbConn.Rebind(queryStr) + queryStr = repo.DbConn.Rebind(queryStr) args = append(args, queryArgs...) // fetch all places from the db - rows, err := dbConn.QueryContext(ctx, queryStr, args...) + rows, err := repo.DbConn.QueryContext(ctx, queryStr, args...) if err != nil { err = errors.Wrapf(err, "query - %s", query.String()) err = errors.WithMessage(err, "find regions failed") @@ -194,7 +195,7 @@ func FindGeonameRegions(ctx context.Context, dbConn *sqlx.DB, orderBy, where str // Possible types sent to the channel are limited to: // - error // - GeoName -func LoadGeonames(ctx context.Context, rr chan<- interface{}, countries ...string) { +func (repo *Repository) LoadGeonames(ctx context.Context, rr chan<- interface{}, countries ...string) { defer close(rr) if len(countries) == 0 { diff --git a/internal/geonames/models.go b/internal/geonames/models.go index a70274d..68204e8 100644 --- a/internal/geonames/models.go +++ b/internal/geonames/models.go @@ -1,6 +1,18 @@ package geonames import "github.com/shopspring/decimal" +import "github.com/jmoiron/sqlx" + +type Repository struct { + DbConn *sqlx.DB +} + +// NewRepository creates a new Repository that defines dependencies for Project. +func NewRepository(db *sqlx.DB) *Repository { + return &Repository{ + DbConn: db, + } +} type Geoname struct { CountryCode string // US diff --git a/internal/schema/migrations.go b/internal/schema/migrations.go index fe6a3c9..14485c6 100644 --- a/internal/schema/migrations.go +++ b/internal/schema/migrations.go @@ -9,6 +9,7 @@ import ( "strings" "geeks-accelerator/oss/saas-starter-kit/internal/geonames" + "github.com/geeks-accelerator/sqlxmigrate" "github.com/jmoiron/sqlx" _ "github.com/lib/pq" @@ -19,6 +20,7 @@ import ( // migrationList returns a list of migrations to be executed. If the id of the // migration already exists in the migrations table it will be skipped. func migrationList(ctx context.Context, db *sqlx.DB, log *log.Logger, isUnittest bool) []*sqlxmigrate.Migration { + geoRepo := geonames.NewRepository(db) return []*sqlxmigrate.Migration{ // Create table users. { @@ -253,7 +255,7 @@ func migrationList(ctx context.Context, db *sqlx.DB, log *log.Logger, isUnittest } else { resChan := make(chan interface{}) - go geonames.LoadGeonames(ctx, resChan) + go geoRepo.LoadGeonames(ctx, resChan) for r := range resChan { switch v := r.(type) { diff --git a/internal/user_account/invite/invite.go b/internal/user_account/invite/invite.go index 0ee8262..274b441 100644 --- a/internal/user_account/invite/invite.go +++ b/internal/user_account/invite/invite.go @@ -6,11 +6,12 @@ import ( "strings" "time" - "geeks-accelerator/oss/saas-starter-kit/internal/account" + //"geeks-accelerator/oss/saas-starter-kit/internal/account" "geeks-accelerator/oss/saas-starter-kit/internal/platform/auth" "geeks-accelerator/oss/saas-starter-kit/internal/platform/web/webcontext" "geeks-accelerator/oss/saas-starter-kit/internal/user" "geeks-accelerator/oss/saas-starter-kit/internal/user_account" + "github.com/pkg/errors" "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer" ) @@ -40,7 +41,7 @@ func (repo *Repository) SendUserInvites(ctx context.Context, claims auth.Claims, } // Ensure the claims can modify the account specified in the request. - err = account.CanModifyAccount(ctx, claims, repo.DbConn, req.AccountID) + err = repo.Account.CanModifyAccount(ctx, claims, req.AccountID) if err != nil { return nil, err } diff --git a/internal/user_auth/auth.go b/internal/user_auth/auth.go index 4e90b10..dcb9dce 100644 --- a/internal/user_auth/auth.go +++ b/internal/user_auth/auth.go @@ -11,6 +11,7 @@ import ( "geeks-accelerator/oss/saas-starter-kit/internal/platform/web/webcontext" "geeks-accelerator/oss/saas-starter-kit/internal/user" "geeks-accelerator/oss/saas-starter-kit/internal/user_account" + "github.com/huandu/go-sqlbuilder" "github.com/lib/pq" "github.com/pkg/errors" @@ -100,7 +101,8 @@ func (repo *Repository) SwitchAccount(ctx context.Context, claims auth.Claims, r } // VirtualLogin allows users to mock being logged in as other users. -func (repo *Repository) VirtualLogin(ctx context.Context, claims auth.Claims, req VirtualLoginRequest, expires time.Duration, now time.Time, scopes ...string) (Token, error) { +func (repo *Repository) VirtualLogin(ctx context.Context, claims auth.Claims, req VirtualLoginRequest, + expires time.Duration, now time.Time, scopes ...string) (Token, error) { span, ctx := tracer.StartSpanFromContext(ctx, "internal.user_auth.VirtualLogin") defer span.Finish() From c7106f089fae9638d91815244e2ac6299c91e757 Mon Sep 17 00:00:00 2001 From: Lee Brown Date: Fri, 16 Aug 2019 20:40:48 -0800 Subject: [PATCH 11/13] Copper Valley was here --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index fc58cd3..bfd2bdb 100644 --- a/README.md +++ b/README.md @@ -3,6 +3,7 @@ Copyright 2019, Geeks Accelerator twins@geeksaccelerator.com +Sponsored by Copper Valley Telecom The SaaS Starter Kit is a set of libraries for building scalable software-as-a-service (SaaS) applications that helps preventing both misuse and fraud. The goal of this project is to provide a proven starting point for new From 666eafceec47f4bd92897ec3f3bf9c0afbfb8da6 Mon Sep 17 00:00:00 2001 From: Lee Brown Date: Sat, 17 Aug 2019 10:58:45 -0800 Subject: [PATCH 12/13] Fix random errors from tests --- cmd/web-api/main.go | 7 +++---- cmd/web-app/handlers/projects.go | 2 +- cmd/web-app/handlers/users.go | 2 +- cmd/web-app/main.go | 9 ++++----- internal/mid/saas-swagger/example/main.go | 3 ++- internal/mid/saas-swagger/swagger_test.go | 3 ++- internal/platform/logger/log.go | 5 +++-- 7 files changed, 16 insertions(+), 15 deletions(-) diff --git a/cmd/web-api/main.go b/cmd/web-api/main.go index e331ead..08b3f51 100644 --- a/cmd/web-api/main.go +++ b/cmd/web-api/main.go @@ -66,10 +66,9 @@ func main() { // ========================================================================= // Logging - log.SetFlags(log.LstdFlags|log.Lmicroseconds|log.Lshortfile) - log.SetPrefix(service+" : ") - log := log.New(os.Stdout, log.Prefix() , log.Flags()) - + log.SetFlags(log.LstdFlags | log.Lmicroseconds | log.Lshortfile) + log.SetPrefix(service + " : ") + log := log.New(os.Stdout, log.Prefix(), log.Flags()) // ========================================================================= // Configuration diff --git a/cmd/web-app/handlers/projects.go b/cmd/web-app/handlers/projects.go index d3bda68..b10ed69 100644 --- a/cmd/web-app/handlers/projects.go +++ b/cmd/web-app/handlers/projects.go @@ -73,7 +73,7 @@ func (h *Projects) Index(ctx context.Context, w http.ResponseWriter, r *http.Req var v datatable.ColumnValue switch col.Field { case "id": - v.Value = fmt.Sprintf("%d", q.ID) + v.Value = fmt.Sprintf("%s", q.ID) case "name": v.Value = q.Name v.Formatted = fmt.Sprintf("%s", urlProjectsView(q.ID), v.Value) diff --git a/cmd/web-app/handlers/users.go b/cmd/web-app/handlers/users.go index e0b9526..637e16a 100644 --- a/cmd/web-app/handlers/users.go +++ b/cmd/web-app/handlers/users.go @@ -100,7 +100,7 @@ func (h *Users) Index(ctx context.Context, w http.ResponseWriter, r *http.Reques var v datatable.ColumnValue switch col.Field { case "id": - v.Value = fmt.Sprintf("%d", q.ID) + v.Value = fmt.Sprintf("%s", q.ID) case "name": if strings.TrimSpace(q.Name) == "" { v.Value = q.Email diff --git a/cmd/web-app/main.go b/cmd/web-app/main.go index 046b695..9b4e215 100644 --- a/cmd/web-app/main.go +++ b/cmd/web-app/main.go @@ -66,10 +66,9 @@ func main() { // ========================================================================= // Logging - log.SetFlags(log.LstdFlags|log.Lmicroseconds|log.Lshortfile) - log.SetPrefix(service+" : ") - log := log.New(os.Stdout, log.Prefix() , log.Flags()) - + log.SetFlags(log.LstdFlags | log.Lmicroseconds | log.Lshortfile) + log.SetPrefix(service + " : ") + log := log.New(os.Stdout, log.Prefix(), log.Flags()) // ========================================================================= // Configuration @@ -474,7 +473,7 @@ func main() { // URL Formatter projectRoutes, err := project_routes.New(cfg.Service.WebApiBaseUrl, cfg.Service.BaseUrl) if err != nil { - log.Fatalf("main : project routes : %+v", cfg.Service.BaseUrl, err) + log.Fatalf("main : project routes : %s : %+v", cfg.Service.BaseUrl, err) } // s3UrlFormatter is a help function used by to convert an s3 key to diff --git a/internal/mid/saas-swagger/example/main.go b/internal/mid/saas-swagger/example/main.go index 5f9f2ae..70e9625 100644 --- a/internal/mid/saas-swagger/example/main.go +++ b/internal/mid/saas-swagger/example/main.go @@ -2,6 +2,7 @@ package main import ( "context" + "geeks-accelerator/oss/saas-starter-kit/internal/platform/web/webcontext" "log" "net/http" "os" @@ -135,7 +136,7 @@ func main() { func API(shutdown chan os.Signal, log *log.Logger) http.Handler { // Construct the web.App which holds all routes as well as common Middleware. - app := web.NewApp(shutdown, log, mid.Trace(), mid.Logger(log), mid.Errors(log), mid.Metrics(), mid.Panics()) + app := web.NewApp(shutdown, log, webcontext.Env_Dev, mid.Logger(log)) app.Handle("GET", "/swagger/", saasSwagger.WrapHandler) app.Handle("GET", "/swagger/*", saasSwagger.WrapHandler) diff --git a/internal/mid/saas-swagger/swagger_test.go b/internal/mid/saas-swagger/swagger_test.go index e533de0..ea2e037 100644 --- a/internal/mid/saas-swagger/swagger_test.go +++ b/internal/mid/saas-swagger/swagger_test.go @@ -9,6 +9,7 @@ import ( _ "geeks-accelerator/oss/saas-starter-kit/internal/mid/saas-swagger/example/docs" "geeks-accelerator/oss/saas-starter-kit/internal/platform/web" + "geeks-accelerator/oss/saas-starter-kit/internal/platform/web/webcontext" "github.com/stretchr/testify/assert" ) @@ -17,7 +18,7 @@ func TestWrapHandler(t *testing.T) { log := log.New(os.Stdout, "", log.LstdFlags|log.Lmicroseconds|log.Lshortfile) log.SetOutput(ioutil.Discard) - app := web.NewApp(nil, log) + app := web.NewApp(nil, log, webcontext.Env_Dev) app.Handle("GET", "/swagger/*", WrapHandler) w1 := performRequest("GET", "/swagger/index.html", app) diff --git a/internal/platform/logger/log.go b/internal/platform/logger/log.go index 58a7fb5..ba5b8e5 100644 --- a/internal/platform/logger/log.go +++ b/internal/platform/logger/log.go @@ -3,12 +3,13 @@ package logger import ( "context" "fmt" - "geeks-accelerator/oss/saas-starter-kit/internal/platform/web" + + "geeks-accelerator/oss/saas-starter-kit/internal/platform/web/webcontext" ) // WithContext manual injects context values to log message including Trace ID func WithContext(ctx context.Context, msg string) string { - v, ok := ctx.Value(web.KeyValues).(*web.Values) + v, ok := ctx.Value(webcontext.KeyValues).(*webcontext.Values) if !ok { return msg } From 295e46a885eed2328b5a0b8f00d83cdb7d281d5a Mon Sep 17 00:00:00 2001 From: Lee Brown Date: Sat, 17 Aug 2019 11:15:45 -0800 Subject: [PATCH 13/13] Cache geonames download --- internal/geonames/geonames.go | 218 ++++++++-------------------------- internal/schema/migrations.go | 12 +- 2 files changed, 51 insertions(+), 179 deletions(-) diff --git a/internal/geonames/geonames.go b/internal/geonames/geonames.go index ae1c9c8..c7a7bc9 100644 --- a/internal/geonames/geonames.go +++ b/internal/geonames/geonames.go @@ -5,10 +5,13 @@ import ( "bufio" "bytes" "context" + "crypto/md5" "encoding/csv" "fmt" "io" "net/http" + "os" + "path/filepath" "strconv" "strings" "time" @@ -191,144 +194,6 @@ func FindGeonameRegions(ctx context.Context, dbConn *sqlx.DB, orderBy, where str return resp, nil } -// LoadGeonames enables streaming retrieval of GeoNames. The downloaded results -// will be written to the interface{} resultReceiver channel enabling processing the results while -// they're still being fetched. After all pages have been processed the channel is closed. -// Possible types sent to the channel are limited to: -// - error -// - GeoName -func LoadGeonames(ctx context.Context, rr chan<- interface{}, countries ...string) { - defer close(rr) - - if len(countries) == 0 { - countries = ValidGeonameCountries(ctx) - } - - for _, country := range countries { - loadGeonameCountry(ctx, rr, country) - } -} - -// loadGeonameCountry enables streaming retrieval of GeoNames. The downloaded results -// will be written to the interface{} resultReceiver channel enabling processing the results while -// they're still being fetched. -// Possible types sent to the channel are limited to: -// - error -// - GeoName -func loadGeonameCountry(ctx context.Context, rr chan<- interface{}, country string) { - u := fmt.Sprintf("http://download.geonames.org/export/zip/%s.zip", country) - resp, err := pester.Get(u) - if err != nil { - rr <- errors.WithMessagef(err, "Failed to read countries from '%s'", u) - return - } - defer resp.Body.Close() - - br := bufio.NewReader(resp.Body) - - buff := bytes.NewBuffer([]byte{}) - size, err := io.Copy(buff, br) - if err != nil { - rr <- errors.WithStack(err) - return - } - - b := bytes.NewReader(buff.Bytes()) - zr, err := zip.NewReader(b, size) - if err != nil { - rr <- errors.WithStack(err) - return - } - - for _, f := range zr.File { - if f.Name == "readme.txt" { - continue - } - - fh, err := f.Open() - if err != nil { - rr <- errors.WithStack(err) - return - } - - scanner := bufio.NewScanner(fh) - for scanner.Scan() { - line := scanner.Text() - - if strings.Contains(line, "\"") { - line = strings.Replace(line, "\"", "\\\"", -1) - } - - r := csv.NewReader(strings.NewReader(line)) - r.Comma = '\t' // Use tab-delimited instead of comma <---- here! - r.LazyQuotes = true - r.FieldsPerRecord = -1 - - lines, err := r.ReadAll() - if err != nil { - rr <- errors.WithStack(err) - continue - } - - for _, row := range lines { - - /* - fmt.Println("CountryCode: row[0]", row[0]) - fmt.Println("PostalCode: row[1]", row[1]) - fmt.Println("PlaceName: row[2]", row[2]) - fmt.Println("StateName: row[3]", row[3]) - fmt.Println("StateCode : row[4]", row[4]) - fmt.Println("CountyName: row[5]", row[5]) - fmt.Println("CountyCode : row[6]", row[6]) - fmt.Println("CommunityName: row[7]", row[7]) - fmt.Println("CommunityCode: row[8]", row[8]) - fmt.Println("Latitude: row[9]", row[9]) - fmt.Println("Longitude: row[10]", row[10]) - fmt.Println("Accuracy: row[11]", row[11]) - */ - - gn := Geoname{ - CountryCode: row[0], - PostalCode: row[1], - PlaceName: row[2], - StateName: row[3], - StateCode: row[4], - CountyName: row[5], - CountyCode: row[6], - CommunityName: row[7], - CommunityCode: row[8], - } - if row[9] != "" { - gn.Latitude, err = decimal.NewFromString(row[9]) - if err != nil { - rr <- errors.WithStack(err) - } - } - - if row[10] != "" { - gn.Longitude, err = decimal.NewFromString(row[10]) - if err != nil { - rr <- errors.WithStack(err) - } - } - - if row[11] != "" { - gn.Accuracy, err = strconv.Atoi(row[11]) - if err != nil { - rr <- errors.WithStack(err) - } - } - - rr <- gn - } - } - - if err := scanner.Err(); err != nil { - rr <- errors.WithStack(err) - } - } -} - // GetGeonameCountry downloads geoname data for the country. // Parses data and returns slice of Geoname func GetGeonameCountry(ctx context.Context, country string) ([]Geoname, error) { @@ -337,25 +202,51 @@ func GetGeonameCountry(ctx context.Context, country string) ([]Geoname, error) { var resp *http.Response u := fmt.Sprintf("http://download.geonames.org/export/zip/%s.zip", country) - resp, err = pester.Get(u) - if err != nil { - // Add re-try three times after failing first time - // This reduces the risk when network is lagy, we still have chance to re-try. - for i := 0; i < 3; i++ { - resp, err = pester.Get(u) - if err == nil { - break - } - time.Sleep(time.Second * 1) - } - if err != nil { - err = errors.WithMessagef(err, "Failed to read countries from '%s'", u) - return res, err - } - } - defer resp.Body.Close() - br := bufio.NewReader(resp.Body) + h := fmt.Sprintf("%x", md5.Sum([]byte(u))) + cp := filepath.Join(os.TempDir(), h+".zip") + + if _, err := os.Stat(cp); err != nil { + resp, err = pester.Get(u) + if err != nil { + // Add re-try three times after failing first time + // This reduces the risk when network is lagy, we still have chance to re-try. + for i := 0; i < 3; i++ { + resp, err = pester.Get(u) + if err == nil { + break + } + time.Sleep(time.Second * 1) + } + if err != nil { + err = errors.WithMessagef(err, "Failed to read countries from '%s'", u) + return res, err + } + } + defer resp.Body.Close() + + // Create the file + out, err := os.Create(cp) + if err != nil { + return nil, err + } + defer out.Close() + + // Write the body to file + _, err = io.Copy(out, resp.Body) + if err != nil { + return nil, err + } + + out.Close() + } + + f, err := os.Open(cp) + if err != nil { + return nil, err + } + defer f.Close() + br := bufio.NewReader(f) buff := bytes.NewBuffer([]byte{}) size, err := io.Copy(buff, br) @@ -403,21 +294,6 @@ func GetGeonameCountry(ctx context.Context, country string) ([]Geoname, error) { for _, row := range lines { - /* - fmt.Println("CountryCode: row[0]", row[0]) - fmt.Println("PostalCode: row[1]", row[1]) - fmt.Println("PlaceName: row[2]", row[2]) - fmt.Println("StateName: row[3]", row[3]) - fmt.Println("StateCode : row[4]", row[4]) - fmt.Println("CountyName: row[5]", row[5]) - fmt.Println("CountyCode : row[6]", row[6]) - fmt.Println("CommunityName: row[7]", row[7]) - fmt.Println("CommunityCode: row[8]", row[8]) - fmt.Println("Latitude: row[9]", row[9]) - fmt.Println("Longitude: row[10]", row[10]) - fmt.Println("Accuracy: row[11]", row[11]) - */ - gn := Geoname{ CountryCode: row[0], PostalCode: row[1], diff --git a/internal/schema/migrations.go b/internal/schema/migrations.go index 2ec3bec..3fa9e87 100644 --- a/internal/schema/migrations.go +++ b/internal/schema/migrations.go @@ -217,7 +217,7 @@ func migrationList(ctx context.Context, db *sqlx.DB, log *log.Logger, isUnittest }, // Load new geonames table. { - ID: "20190731-02h", + ID: "20190731-02l", Migrate: func(tx *sql.Tx) error { schemas := []string{ @@ -246,7 +246,7 @@ func migrationList(ctx context.Context, db *sqlx.DB, log *log.Logger, isUnittest countries := geonames.ValidGeonameCountries(ctx) if isUnittest { - + countries = []string{"US"} } ncol := 12 @@ -287,7 +287,6 @@ func migrationList(ctx context.Context, db *sqlx.DB, log *log.Logger, isUnittest } start := time.Now() for _, country := range countries { - //fmt.Println("LoadGeonames: start country: ", country) v, err := geonames.GetGeonameCountry(context.Background(), country) if err != nil { return errors.WithStack(err) @@ -316,7 +315,7 @@ func migrationList(ctx context.Context, db *sqlx.DB, log *log.Logger, isUnittest } } if len(v)%batch > 0 { - fmt.Println("Remain part: ", len(v)-n*batch) + log.Println("Remain part: ", len(v)-n*batch) vn := v[n*batch:] err := fn(vn) if err != nil { @@ -324,11 +323,8 @@ func migrationList(ctx context.Context, db *sqlx.DB, log *log.Logger, isUnittest } } } - - //fmt.Println("Insert Geoname took: ", time.Since(start)) - //fmt.Println("LoadGeonames: end country: ", country) } - fmt.Println("Total Geonames population took: ", time.Since(start)) + log.Println("Total Geonames population took: ", time.Since(start)) queries := []string{ `create index idx_geonames_country_code on geonames (country_code)`,