1
0
mirror of https://github.com/raseels-repos/golang-saas-starter-kit.git synced 2025-08-08 22:36:41 +02:00

WIP: not sure how to solve user_account calling account.CanModifyAccount

This commit is contained in:
Lee Brown
2019-08-13 23:41:06 -08:00
parent 4be0454421
commit 3bc814a01e
8 changed files with 198 additions and 132 deletions

View File

@ -2,6 +2,7 @@ package handlers
import ( import (
"context" "context"
"geeks-accelerator/oss/saas-starter-kit/internal/user"
"log" "log"
"net/http" "net/http"
"os" "os"
@ -20,33 +21,52 @@ import (
"gopkg.in/DataDog/dd-trace-go.v1/contrib/go-redis/redis" "gopkg.in/DataDog/dd-trace-go.v1/contrib/go-redis/redis"
) )
// API returns a handler for a set of routes.
func API(shutdown chan os.Signal, log *log.Logger, env webcontext.Env, masterDB *sqlx.DB, redis *redis.Client, authenticator *auth.Authenticator, globalMids ...web.Middleware) http.Handler {
// Define base middlewares applied to all requests. type AppContext struct {
middlewares := []web.Middleware{ Log *log.Logger
mid.Trace(), mid.Logger(log), mid.Errors(log, nil), mid.Metrics(), mid.Panics(), Env webcontext.Env
Repo *user.Repository
MasterDB *sqlx.DB
Redis *redis.Client
Authenticator *auth.Authenticator
PreAppMiddleware []web.Middleware
PostAppMiddleware []web.Middleware
} }
// Append any global middlewares if they were included.
if len(globalMids) > 0 { // API returns a handler for a set of routes.
middlewares = append(middlewares, globalMids...) func API(shutdown chan os.Signal, appContext *AppContext ) http.Handler {
// Include the pre middlewares first.
middlewares := appContext.PreAppMiddleware
// Define app middlewares applied to all requests.
middlewares = append(middlewares,
mid.Trace(),
mid.Logger(appContext.Log),
mid.Errors(appContext.Log, nil),
mid.Metrics(),
mid.Panics())
// Append any global middlewares that should be included after the app middlewares.
if len(appContext.PostAppMiddleware) > 0 {
middlewares = append(middlewares, appContext.PostAppMiddleware...)
} }
// Construct the web.App which holds all routes as well as common Middleware. // Construct the web.App which holds all routes as well as common Middleware.
app := web.NewApp(shutdown, log, env, middlewares...) app := web.NewApp(shutdown, appContext.Log, appContext.Env, middlewares...)
// Register health check endpoint. This route is not authenticated. // Register health check endpoint. This route is not authenticated.
check := Check{ check := Check{
MasterDB: masterDB, MasterDB: appContext.MasterDB,
Redis: redis, Redis: appContext.Redis,
} }
app.Handle("GET", "/v1/health", check.Health) app.Handle("GET", "/v1/health", check.Health)
app.Handle("GET", "/ping", check.Ping) app.Handle("GET", "/ping", check.Ping)
// Register user management and authentication endpoints. // Register user management and authentication endpoints.
u := User{ u := User{
MasterDB: masterDB, MasterDB: appContext.MasterDB,
TokenGenerator: authenticator, TokenGenerator: authenticator,
} }
app.Handle("GET", "/v1/users", u.Find, mid.AuthenticateHeader(authenticator)) app.Handle("GET", "/v1/users", u.Find, mid.AuthenticateHeader(authenticator))

View File

@ -4,8 +4,10 @@ import (
"context" "context"
"database/sql" "database/sql"
"encoding/json" "encoding/json"
"geeks-accelerator/oss/saas-starter-kit/internal/platform/notify"
"geeks-accelerator/oss/saas-starter-kit/internal/platform/web" "geeks-accelerator/oss/saas-starter-kit/internal/platform/web"
"geeks-accelerator/oss/saas-starter-kit/internal/platform/web/webcontext" "geeks-accelerator/oss/saas-starter-kit/internal/platform/web/webcontext"
"github.com/jmoiron/sqlx"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/sudo-suhas/symcrypto" "github.com/sudo-suhas/symcrypto"
"strconv" "strconv"
@ -15,6 +17,24 @@ import (
"github.com/lib/pq" "github.com/lib/pq"
) )
// Repository defines the required dependencies for User.
type Repository struct {
DbConn *sqlx.DB
ResetUrl func(string) string
Notify notify.Email
SecretKey string
}
// NewRepository creates a new Repository that defines dependencies for User.
func NewRepository(db *sqlx.DB, resetUrl func(string) string, notify notify.Email, secretKey string) *Repository {
return &Repository{
DbConn: db,
ResetUrl: resetUrl,
Notify: notify,
SecretKey: secretKey,
}
}
// User represents someone with access to our system. // User represents someone with access to our system.
type User struct { type User struct {
ID string `json:"id" validate:"required,uuid" example:"d69bdef7-173f-4d29-b52c-3edc60baf6a2"` ID string `json:"id" validate:"required,uuid" example:"d69bdef7-173f-4d29-b52c-3edc60baf6a2"`

View File

@ -6,7 +6,6 @@ import (
"time" "time"
"geeks-accelerator/oss/saas-starter-kit/internal/platform/auth" "geeks-accelerator/oss/saas-starter-kit/internal/platform/auth"
"geeks-accelerator/oss/saas-starter-kit/internal/platform/notify"
"geeks-accelerator/oss/saas-starter-kit/internal/platform/web/webcontext" "geeks-accelerator/oss/saas-starter-kit/internal/platform/web/webcontext"
"github.com/huandu/go-sqlbuilder" "github.com/huandu/go-sqlbuilder"
"github.com/jmoiron/sqlx" "github.com/jmoiron/sqlx"
@ -55,7 +54,7 @@ func mapRowsToUser(rows *sql.Rows) (*User, error) {
} }
// CanReadUser determines if claims has the authority to access the specified user ID. // CanReadUser determines if claims has the authority to access the specified user ID.
func CanReadUser(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, userID string) error { func (repo *Repository) CanReadUser(ctx context.Context, claims auth.Claims, userID string) error {
// If the request has claims from a specific user, ensure that the user // If the request has claims from a specific user, ensure that the user
// has the correct access to the user. // has the correct access to the user.
if claims.Subject != "" && claims.Subject != userID { if claims.Subject != "" && claims.Subject != userID {
@ -68,10 +67,10 @@ func CanReadUser(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, userI
query.Equal("user_id", userID), query.Equal("user_id", userID),
)) ))
queryStr, args := query.Build() queryStr, args := query.Build()
queryStr = dbConn.Rebind(queryStr) queryStr = repo.DbConn.Rebind(queryStr)
var userAccountId string var userAccountId string
err := dbConn.QueryRowContext(ctx, queryStr, args...).Scan(&userAccountId) err := repo.DbConn.QueryRowContext(ctx, queryStr, args...).Scan(&userAccountId)
if err != nil && err != sql.ErrNoRows { if err != nil && err != sql.ErrNoRows {
err = errors.Wrapf(err, "query - %s", query.String()) err = errors.Wrapf(err, "query - %s", query.String())
return err return err
@ -88,7 +87,7 @@ func CanReadUser(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, userI
} }
// CanModifyUser determines if claims has the authority to modify the specified user ID. // CanModifyUser determines if claims has the authority to modify the specified user ID.
func CanModifyUser(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, userID string) error { func (repo *Repository) CanModifyUser(ctx context.Context, claims auth.Claims, userID string) error {
// If the request has claims from a specific user, ensure that the user // If the request has claims from a specific user, ensure that the user
// has the correct role for creating a new user. // has the correct role for creating a new user.
if claims.Subject != "" && claims.Subject != userID { if claims.Subject != "" && claims.Subject != userID {
@ -99,7 +98,7 @@ func CanModifyUser(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, use
} }
} }
if err := CanReadUser(ctx, claims, dbConn, userID); err != nil { if err := repo.CanReadUser(ctx, claims, userID); err != nil {
return err return err
} }
@ -118,10 +117,10 @@ func CanModifyUser(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, use
"'"+auth.RoleAdmin+"' = ANY (roles)", "'"+auth.RoleAdmin+"' = ANY (roles)",
)) ))
queryStr, args := query.Build() queryStr, args := query.Build()
queryStr = dbConn.Rebind(queryStr) queryStr = repo.DbConn.Rebind(queryStr)
var userAccountId string var userAccountId string
err := dbConn.QueryRowContext(ctx, queryStr, args...).Scan(&userAccountId) err := repo.DbConn.QueryRowContext(ctx, queryStr, args...).Scan(&userAccountId)
if err != nil && err != sql.ErrNoRows { if err != nil && err != sql.ErrNoRows {
err = errors.Wrapf(err, "query - %s", query.String()) err = errors.Wrapf(err, "query - %s", query.String())
return err return err
@ -199,13 +198,13 @@ func findRequestQuery(req UserFindRequest) (*sqlbuilder.SelectBuilder, []interfa
} }
// Find gets all the users from the database based on the request params. // Find gets all the users from the database based on the request params.
func Find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserFindRequest) (Users, error) { func (repo *Repository) Find(ctx context.Context, claims auth.Claims, req UserFindRequest) (Users, error) {
query, args := findRequestQuery(req) query, args := findRequestQuery(req)
return find(ctx, claims, dbConn, query, args, req.IncludeArchived) return repo.find(ctx, claims, query, args, req.IncludeArchived)
} }
// find internal method for getting all the users from the database using a select query. // find internal method for getting all the users from the database using a select query.
func find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, query *sqlbuilder.SelectBuilder, args []interface{}, includedArchived bool) (Users, error) { func (repo *Repository) find(ctx context.Context, claims auth.Claims, query *sqlbuilder.SelectBuilder, args []interface{}, includedArchived bool) (Users, error) {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.Find") span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.Find")
defer span.Finish() defer span.Finish()
@ -222,11 +221,11 @@ func find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, query *sqlbu
return nil, err return nil, err
} }
queryStr, queryArgs := query.Build() queryStr, queryArgs := query.Build()
queryStr = dbConn.Rebind(queryStr) queryStr = repo.DbConn.Rebind(queryStr)
args = append(args, queryArgs...) args = append(args, queryArgs...)
// fetch all places from the db // fetch all places from the db
rows, err := dbConn.QueryContext(ctx, queryStr, args...) rows, err := repo.DbConn.QueryContext(ctx, queryStr, args...)
if err != nil { if err != nil {
err = errors.Wrapf(err, "query - %s", query.String()) err = errors.Wrapf(err, "query - %s", query.String())
err = errors.WithMessage(err, "find users failed") err = errors.WithMessage(err, "find users failed")
@ -248,17 +247,17 @@ func find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, query *sqlbu
} }
// Validation an email address is unique excluding the current user ID. // Validation an email address is unique excluding the current user ID.
func UniqueEmail(ctx context.Context, dbConn *sqlx.DB, email, userId string) (bool, error) { func (repo *Repository) UniqueEmail(ctx context.Context, email, userId string) (bool, error) {
query := sqlbuilder.NewSelectBuilder().Select("id").From(userTableName) query := sqlbuilder.NewSelectBuilder().Select("id").From(userTableName)
query.Where(query.And( query.Where(query.And(
query.Equal("email", email), query.Equal("email", email),
query.NotEqual("id", userId), query.NotEqual("id", userId),
)) ))
queryStr, args := query.Build() queryStr, args := query.Build()
queryStr = dbConn.Rebind(queryStr) queryStr = repo.DbConn.Rebind(queryStr)
var existingId string var existingId string
err := dbConn.QueryRowContext(ctx, queryStr, args...).Scan(&existingId) err := repo.DbConn.QueryRowContext(ctx, queryStr, args...).Scan(&existingId)
if err != nil && err != sql.ErrNoRows { if err != nil && err != sql.ErrNoRows {
err = errors.Wrapf(err, "query - %s", query.String()) err = errors.Wrapf(err, "query - %s", query.String())
return false, err return false, err
@ -273,7 +272,7 @@ func UniqueEmail(ctx context.Context, dbConn *sqlx.DB, email, userId string) (bo
} }
// Create inserts a new user into the database. // Create inserts a new user into the database.
func Create(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserCreateRequest, now time.Time) (*User, error) { func (repo *Repository) Create(ctx context.Context, claims auth.Claims, req UserCreateRequest, now time.Time) (*User, error) {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.Create") span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.Create")
defer span.Finish() defer span.Finish()
@ -284,7 +283,7 @@ func Create(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserCr
v := webcontext.Validator() v := webcontext.Validator()
// Validation email address is unique in the database. // Validation email address is unique in the database.
uniq, err := UniqueEmail(ctx, dbConn, req.Email, "") uniq, err := repo.UniqueEmail(ctx, req.Email, "")
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -346,8 +345,8 @@ func Create(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserCr
// Execute the query with the provided context. // Execute the query with the provided context.
sql, args := query.Build() sql, args := query.Build()
sql = dbConn.Rebind(sql) sql = repo.DbConn.Rebind(sql)
_, err = dbConn.ExecContext(ctx, sql, args...) _, err = repo.DbConn.ExecContext(ctx, sql, args...)
if err != nil { if err != nil {
err = errors.Wrapf(err, "query - %s", query.String()) err = errors.Wrapf(err, "query - %s", query.String())
err = errors.WithMessage(err, "create user failed") err = errors.WithMessage(err, "create user failed")
@ -358,14 +357,14 @@ func Create(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserCr
} }
// Create invite inserts a new user into the database. // Create invite inserts a new user into the database.
func CreateInvite(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserCreateInviteRequest, now time.Time) (*User, error) { func (repo *Repository) CreateInvite(ctx context.Context, claims auth.Claims, req UserCreateInviteRequest, now time.Time) (*User, error) {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.CreateInvite") span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.CreateInvite")
defer span.Finish() defer span.Finish()
v := webcontext.Validator() v := webcontext.Validator()
// Validation email address is unique in the database. // Validation email address is unique in the database.
uniq, err := UniqueEmail(ctx, dbConn, req.Email, "") uniq, err := repo.UniqueEmail(ctx, req.Email, "")
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -414,8 +413,8 @@ func CreateInvite(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req
// Execute the query with the provided context. // Execute the query with the provided context.
sql, args := query.Build() sql, args := query.Build()
sql = dbConn.Rebind(sql) sql = repo.DbConn.Rebind(sql)
_, err = dbConn.ExecContext(ctx, sql, args...) _, err = repo.DbConn.ExecContext(ctx, sql, args...)
if err != nil { if err != nil {
err = errors.Wrapf(err, "query - %s", query.String()) err = errors.Wrapf(err, "query - %s", query.String())
err = errors.WithMessage(err, "create user failed") err = errors.WithMessage(err, "create user failed")
@ -426,15 +425,15 @@ func CreateInvite(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req
} }
// ReadByID gets the specified user by ID from the database. // ReadByID gets the specified user by ID from the database.
func ReadByID(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, id string) (*User, error) { func (repo *Repository) ReadByID(ctx context.Context, claims auth.Claims, id string) (*User, error) {
return Read(ctx, claims, dbConn, UserReadRequest{ return repo.Read(ctx, claims, UserReadRequest{
ID: id, ID: id,
IncludeArchived: false, IncludeArchived: false,
}) })
} }
// Read gets the specified user from the database. // Read gets the specified user from the database.
func Read(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserReadRequest) (*User, error) { func (repo *Repository) Read(ctx context.Context, claims auth.Claims, req UserReadRequest) (*User, error) {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.Read") span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.Read")
defer span.Finish() defer span.Finish()
@ -449,7 +448,7 @@ func Read(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserRead
query := selectQuery() query := selectQuery()
query.Where(query.Equal("id", req.ID)) query.Where(query.Equal("id", req.ID))
res, err := find(ctx, claims, dbConn, query, []interface{}{}, req.IncludeArchived) res, err := repo.find(ctx, claims, query, []interface{}{}, req.IncludeArchived)
if err != nil { if err != nil {
return nil, err return nil, err
} else if res == nil || len(res) == 0 { } else if res == nil || len(res) == 0 {
@ -462,7 +461,7 @@ func Read(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserRead
} }
// ReadByEmail gets the specified user from the database. // ReadByEmail gets the specified user from the database.
func ReadByEmail(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, email string, includedArchived bool) (*User, error) { func (repo *Repository) ReadByEmail(ctx context.Context, claims auth.Claims, email string, includedArchived bool) (*User, error) {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.ReadByEmail") span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.ReadByEmail")
defer span.Finish() defer span.Finish()
@ -470,7 +469,7 @@ func ReadByEmail(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, email
query := selectQuery() query := selectQuery()
query.Where(query.Equal("email", email)) query.Where(query.Equal("email", email))
res, err := find(ctx, claims, dbConn, query, []interface{}{}, includedArchived) res, err := repo.find(ctx, claims, query, []interface{}{}, includedArchived)
if err != nil { if err != nil {
return nil, err return nil, err
} else if res == nil || len(res) == 0 { } else if res == nil || len(res) == 0 {
@ -483,14 +482,14 @@ func ReadByEmail(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, email
} }
// Update replaces a user in the database. // Update replaces a user in the database.
func Update(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserUpdateRequest, now time.Time) error { func (repo *Repository) Update(ctx context.Context, claims auth.Claims, req UserUpdateRequest, now time.Time) error {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.Update") span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.Update")
defer span.Finish() defer span.Finish()
// Validation email address is unique in the database. // Validation email address is unique in the database.
if req.Email != nil { if req.Email != nil {
// Validation email address is unique in the database. // Validation email address is unique in the database.
uniq, err := UniqueEmail(ctx, dbConn, *req.Email, req.ID) uniq, err := repo.UniqueEmail(ctx, *req.Email, req.ID)
if err != nil { if err != nil {
return err return err
} }
@ -507,7 +506,7 @@ func Update(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserUp
} }
// Ensure the claims can modify the user specified in the request. // Ensure the claims can modify the user specified in the request.
err = CanModifyUser(ctx, claims, dbConn, req.ID) err = repo.CanModifyUser(ctx, claims, req.ID)
if err != nil { if err != nil {
return err return err
} }
@ -555,8 +554,8 @@ func Update(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserUp
// Execute the query with the provided context. // Execute the query with the provided context.
sql, args := query.Build() sql, args := query.Build()
sql = dbConn.Rebind(sql) sql = repo.DbConn.Rebind(sql)
_, err = dbConn.ExecContext(ctx, sql, args...) _, err = repo.DbConn.ExecContext(ctx, sql, args...)
if err != nil { if err != nil {
err = errors.Wrapf(err, "query - %s", query.String()) err = errors.Wrapf(err, "query - %s", query.String())
err = errors.WithMessagef(err, "update user %s failed", req.ID) err = errors.WithMessagef(err, "update user %s failed", req.ID)
@ -567,7 +566,7 @@ func Update(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserUp
} }
// Update changes the password for a user in the database. // Update changes the password for a user in the database.
func UpdatePassword(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserUpdatePasswordRequest, now time.Time) error { func (repo *Repository) UpdatePassword(ctx context.Context, claims auth.Claims, req UserUpdatePasswordRequest, now time.Time) error {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.UpdatePassword") span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.UpdatePassword")
defer span.Finish() defer span.Finish()
@ -579,7 +578,7 @@ func UpdatePassword(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, re
} }
// Ensure the claims can modify the user specified in the request. // Ensure the claims can modify the user specified in the request.
err = CanModifyUser(ctx, claims, dbConn, req.ID) err = repo.CanModifyUser(ctx, claims, req.ID)
if err != nil { if err != nil {
return err return err
} }
@ -616,8 +615,8 @@ func UpdatePassword(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, re
// Execute the query with the provided context. // Execute the query with the provided context.
sql, args := query.Build() sql, args := query.Build()
sql = dbConn.Rebind(sql) sql = repo.DbConn.Rebind(sql)
_, err = dbConn.ExecContext(ctx, sql, args...) _, err = repo.DbConn.ExecContext(ctx, sql, args...)
if err != nil { if err != nil {
err = errors.Wrapf(err, "query - %s", query.String()) err = errors.Wrapf(err, "query - %s", query.String())
err = errors.WithMessagef(err, "update password for user %s failed", req.ID) err = errors.WithMessagef(err, "update password for user %s failed", req.ID)
@ -628,7 +627,7 @@ func UpdatePassword(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, re
} }
// Archive soft deleted the user from the database. // Archive soft deleted the user from the database.
func Archive(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserArchiveRequest, now time.Time) error { func (repo *Repository) Archive(ctx context.Context, claims auth.Claims, req UserArchiveRequest, now time.Time) error {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.Archive") span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.Archive")
defer span.Finish() defer span.Finish()
@ -640,7 +639,7 @@ func Archive(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserA
} }
// Ensure the claims can modify the user specified in the request. // Ensure the claims can modify the user specified in the request.
err = CanModifyUser(ctx, claims, dbConn, req.ID) err = repo.CanModifyUser(ctx, claims, req.ID)
if err != nil { if err != nil {
return err return err
} else if claims.Subject != "" && claims.Subject == req.ID && !req.force { } else if claims.Subject != "" && claims.Subject == req.ID && !req.force {
@ -669,8 +668,8 @@ func Archive(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserA
// Execute the query with the provided context. // Execute the query with the provided context.
sql, args := query.Build() sql, args := query.Build()
sql = dbConn.Rebind(sql) sql = repo.DbConn.Rebind(sql)
_, err = dbConn.ExecContext(ctx, sql, args...) _, err = repo.DbConn.ExecContext(ctx, sql, args...)
if err != nil { if err != nil {
err = errors.Wrapf(err, "query - %s", query.String()) err = errors.Wrapf(err, "query - %s", query.String())
err = errors.WithMessagef(err, "archive user %s failed", req.ID) err = errors.WithMessagef(err, "archive user %s failed", req.ID)
@ -689,8 +688,8 @@ func Archive(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserA
// Execute the query with the provided context. // Execute the query with the provided context.
sql, args := query.Build() sql, args := query.Build()
sql = dbConn.Rebind(sql) sql = repo.DbConn.Rebind(sql)
_, err = dbConn.ExecContext(ctx, sql, args...) _, err = repo.DbConn.ExecContext(ctx, sql, args...)
if err != nil { if err != nil {
err = errors.Wrapf(err, "query - %s", query.String()) err = errors.Wrapf(err, "query - %s", query.String())
err = errors.WithMessagef(err, "archive accounts for user %s failed", req.ID) err = errors.WithMessagef(err, "archive accounts for user %s failed", req.ID)
@ -702,7 +701,7 @@ func Archive(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserA
} }
// Restore undeletes the user from the database. // Restore undeletes the user from the database.
func Restore(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserRestoreRequest, now time.Time) error { func (repo *Repository) Restore(ctx context.Context, claims auth.Claims, req UserRestoreRequest, now time.Time) error {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.Restore") span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.Restore")
defer span.Finish() defer span.Finish()
@ -714,7 +713,7 @@ func Restore(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserR
} }
// Ensure the claims can modify the user specified in the request. // Ensure the claims can modify the user specified in the request.
err = CanModifyUser(ctx, claims, dbConn, req.ID) err = repo.CanModifyUser(ctx, claims, req.ID)
if err != nil { if err != nil {
return err return err
} }
@ -741,8 +740,8 @@ func Restore(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserR
// Execute the query with the provided context. // Execute the query with the provided context.
sql, args := query.Build() sql, args := query.Build()
sql = dbConn.Rebind(sql) sql = repo.DbConn.Rebind(sql)
_, err = dbConn.ExecContext(ctx, sql, args...) _, err = repo.DbConn.ExecContext(ctx, sql, args...)
if err != nil { if err != nil {
err = errors.Wrapf(err, "query - %s", query.String()) err = errors.Wrapf(err, "query - %s", query.String())
err = errors.WithMessagef(err, "unarchive user %s failed", req.ID) err = errors.WithMessagef(err, "unarchive user %s failed", req.ID)
@ -753,7 +752,7 @@ func Restore(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserR
} }
// Delete removes a user from the database. // Delete removes a user from the database.
func Delete(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserDeleteRequest) error { func (repo *Repository) Delete(ctx context.Context, claims auth.Claims, req UserDeleteRequest) error {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.Delete") span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.Delete")
defer span.Finish() defer span.Finish()
@ -765,7 +764,7 @@ func Delete(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserDe
} }
// Ensure the claims can modify the user specified in the request. // Ensure the claims can modify the user specified in the request.
err = CanModifyUser(ctx, claims, dbConn, req.ID) err = repo.CanModifyUser(ctx, claims, req.ID)
if err != nil { if err != nil {
return err return err
} else if claims.Subject != "" && claims.Subject == req.ID && !req.force { } else if claims.Subject != "" && claims.Subject == req.ID && !req.force {
@ -773,7 +772,7 @@ func Delete(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserDe
} }
// Start a new transaction to handle rollbacks on error. // Start a new transaction to handle rollbacks on error.
tx, err := dbConn.Begin() tx, err := repo.DbConn.Begin()
if err != nil { if err != nil {
return errors.WithStack(err) return errors.WithStack(err)
} }
@ -790,7 +789,7 @@ func Delete(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserDe
// Execute the query with the provided context. // Execute the query with the provided context.
sql, args := query.Build() sql, args := query.Build()
sql = dbConn.Rebind(sql) sql = repo.DbConn.Rebind(sql)
_, err = tx.ExecContext(ctx, sql, args...) _, err = tx.ExecContext(ctx, sql, args...)
if err != nil { if err != nil {
tx.Rollback() tx.Rollback()
@ -808,7 +807,7 @@ func Delete(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserDe
// Execute the query with the provided context. // Execute the query with the provided context.
sql, args := query.Build() sql, args := query.Build()
sql = dbConn.Rebind(sql) sql = repo.DbConn.Rebind(sql)
_, err = tx.ExecContext(ctx, sql, args...) _, err = tx.ExecContext(ctx, sql, args...)
if err != nil { if err != nil {
tx.Rollback() tx.Rollback()
@ -827,7 +826,7 @@ func Delete(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserDe
} }
// ResetPassword sends en email to the user to allow them to reset their password. // ResetPassword sends en email to the user to allow them to reset their password.
func ResetPassword(ctx context.Context, dbConn *sqlx.DB, resetUrl func(string) string, notify notify.Email, req UserResetPasswordRequest, secretKey string, now time.Time) (string, error) { func (repo *Repository) ResetPassword(ctx context.Context, req UserResetPasswordRequest, now time.Time) (string, error) {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.ResetPassword") span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.ResetPassword")
defer span.Finish() defer span.Finish()
@ -845,7 +844,7 @@ func ResetPassword(ctx context.Context, dbConn *sqlx.DB, resetUrl func(string) s
query := selectQuery() query := selectQuery()
query.Where(query.Equal("email", req.Email)) query.Where(query.Equal("email", req.Email))
res, err := find(ctx, auth.Claims{}, dbConn, query, []interface{}{}, false) res, err := repo.find(ctx, auth.Claims{}, query, []interface{}{}, false)
if err != nil { if err != nil {
return "", err return "", err
} else if res == nil || len(res) == 0 { } else if res == nil || len(res) == 0 {
@ -876,8 +875,8 @@ func ResetPassword(ctx context.Context, dbConn *sqlx.DB, resetUrl func(string) s
// Execute the query with the provided context. // Execute the query with the provided context.
sql, args := query.Build() sql, args := query.Build()
sql = dbConn.Rebind(sql) sql = repo.DbConn.Rebind(sql)
_, err = dbConn.ExecContext(ctx, sql, args...) _, err = repo.DbConn.ExecContext(ctx, sql, args...)
if err != nil { if err != nil {
err = errors.Wrapf(err, "query - %s", query.String()) err = errors.Wrapf(err, "query - %s", query.String())
err = errors.WithMessagef(err, "Update user %s failed.", u.ID) err = errors.WithMessagef(err, "Update user %s failed.", u.ID)
@ -895,18 +894,18 @@ func ResetPassword(ctx context.Context, dbConn *sqlx.DB, resetUrl func(string) s
requestIp = vals.RequestIP requestIp = vals.RequestIP
} }
encrypted, err := NewResetHash(ctx, secretKey, resetId, requestIp, req.TTL, now) encrypted, err := NewResetHash(ctx, repo.SecretKey, resetId, requestIp, req.TTL, now)
if err != nil { if err != nil {
return "", err return "", err
} }
data := map[string]interface{}{ data := map[string]interface{}{
"Name": u.FirstName, "Name": u.FirstName,
"Url": resetUrl(encrypted), "Url": repo.ResetUrl(encrypted),
"Minutes": req.TTL.Minutes(), "Minutes": req.TTL.Minutes(),
} }
err = notify.Send(ctx, u.Email, "Reset your Password", "user_reset_password", data) err = repo.Notify.Send(ctx, u.Email, "Reset your Password", "user_reset_password", data)
if err != nil { if err != nil {
err = errors.WithMessagef(err, "Send password reset email to %s failed.", u.Email) err = errors.WithMessagef(err, "Send password reset email to %s failed.", u.Email)
return "", err return "", err
@ -916,7 +915,7 @@ func ResetPassword(ctx context.Context, dbConn *sqlx.DB, resetUrl func(string) s
} }
// ResetConfirm updates the password for a user using the provided reset password ID. // ResetConfirm updates the password for a user using the provided reset password ID.
func ResetConfirm(ctx context.Context, dbConn *sqlx.DB, req UserResetConfirmRequest, secretKey string, now time.Time) (*User, error) { func (repo *Repository) ResetConfirm(ctx context.Context, req UserResetConfirmRequest, now time.Time) (*User, error) {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.ResetConfirm") span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.ResetConfirm")
defer span.Finish() defer span.Finish()
@ -928,7 +927,7 @@ func ResetConfirm(ctx context.Context, dbConn *sqlx.DB, req UserResetConfirmRequ
return nil, err return nil, err
} }
hash, err := ParseResetHash(ctx, secretKey, req.ResetHash, now) hash, err := ParseResetHash(ctx, repo.SecretKey, req.ResetHash, now)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -939,7 +938,7 @@ func ResetConfirm(ctx context.Context, dbConn *sqlx.DB, req UserResetConfirmRequ
query := selectQuery() query := selectQuery()
query.Where(query.Equal("password_reset", hash.ResetID)) query.Where(query.Equal("password_reset", hash.ResetID))
res, err := find(ctx, auth.Claims{}, dbConn, query, []interface{}{}, false) res, err := repo.find(ctx, auth.Claims{}, query, []interface{}{}, false)
if err != nil { if err != nil {
return nil, err return nil, err
} else if res == nil || len(res) == 0 { } else if res == nil || len(res) == 0 {
@ -979,8 +978,8 @@ func ResetConfirm(ctx context.Context, dbConn *sqlx.DB, req UserResetConfirmRequ
// Execute the query with the provided context. // Execute the query with the provided context.
sql, args := query.Build() sql, args := query.Build()
sql = dbConn.Rebind(sql) sql = repo.DbConn.Rebind(sql)
_, err = dbConn.ExecContext(ctx, sql, args...) _, err = repo.DbConn.ExecContext(ctx, sql, args...)
if err != nil { if err != nil {
err = errors.Wrapf(err, "query - %s", query.String()) err = errors.Wrapf(err, "query - %s", query.String())
err = errors.WithMessagef(err, "update password for user %s failed", u.ID) err = errors.WithMessagef(err, "update password for user %s failed", u.ID)
@ -1000,6 +999,10 @@ type MockUserResponse struct {
func MockUser(ctx context.Context, dbConn *sqlx.DB, now time.Time) (*MockUserResponse, error) { func MockUser(ctx context.Context, dbConn *sqlx.DB, now time.Time) (*MockUserResponse, error) {
pass := uuid.NewRandom().String() pass := uuid.NewRandom().String()
repo := &Repository{
DbConn: dbConn,
}
req := UserCreateRequest{ req := UserCreateRequest{
FirstName: "Lee", FirstName: "Lee",
LastName: "Brown", LastName: "Brown",
@ -1007,7 +1010,7 @@ func MockUser(ctx context.Context, dbConn *sqlx.DB, now time.Time) (*MockUserRes
Password: pass, Password: pass,
PasswordConfirm: pass, PasswordConfirm: pass,
} }
u, err := Create(ctx, auth.Claims{}, dbConn, req, now) u, err := repo.Create(ctx, auth.Claims{}, req, now)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -18,7 +18,10 @@ import (
"github.com/pkg/errors" "github.com/pkg/errors"
) )
var test *tests.Test var (
test *tests.Test
repo *Repository
)
// TestMain is the entry point for testing. // TestMain is the entry point for testing.
func TestMain(m *testing.M) { func TestMain(m *testing.M) {
@ -28,6 +31,16 @@ func TestMain(m *testing.M) {
func testMain(m *testing.M) int { func testMain(m *testing.M) int {
test = tests.New() test = tests.New()
defer test.TearDown() defer test.TearDown()
// Mock the methods needed to make a password reset.
resetUrl := func(string) string {
return ""
}
notify := &notify.MockEmail{}
secretKey := "6368616e676520746869732070617373"
repo = NewRepository(test.MasterDB, resetUrl, notify, secretKey)
return m.Run() return m.Run()
} }
@ -219,7 +232,7 @@ func TestCreateValidation(t *testing.T) {
{ {
ctx := tests.Context() ctx := tests.Context()
res, err := Create(ctx, auth.Claims{}, test.MasterDB, tt.req, now) res, err := repo.Create(ctx, auth.Claims{}, tt.req, now)
if err != tt.error { if err != tt.error {
// TODO: need a better way to handle validation errors as they are // TODO: need a better way to handle validation errors as they are
// of type interface validator.ValidationErrorsTranslations // of type interface validator.ValidationErrorsTranslations
@ -272,7 +285,7 @@ func TestCreateValidationEmailUnique(t *testing.T) {
Password: "akTechFr0n!ier", Password: "akTechFr0n!ier",
PasswordConfirm: "akTechFr0n!ier", PasswordConfirm: "akTechFr0n!ier",
} }
user1, err := Create(ctx, auth.Claims{}, test.MasterDB, req1, now) user1, err := repo.Create(ctx, auth.Claims{}, req1, now)
if err != nil { if err != nil {
t.Log("\t\tGot :", err) t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tCreate failed.", tests.Failed) t.Fatalf("\t%s\tCreate failed.", tests.Failed)
@ -286,7 +299,7 @@ func TestCreateValidationEmailUnique(t *testing.T) {
PasswordConfirm: "W0rkL1fe#", PasswordConfirm: "W0rkL1fe#",
} }
expectedErr := errors.New("Key: 'UserCreateRequest.email' Error:Field validation for 'email' failed on the 'unique' tag") expectedErr := errors.New("Key: 'UserCreateRequest.email' Error:Field validation for 'email' failed on the 'unique' tag")
_, err = Create(ctx, auth.Claims{}, test.MasterDB, req2, now) _, err = repo.Create(ctx, auth.Claims{}, req2, now)
if err == nil { if err == nil {
t.Logf("\t\tWant: %+v", expectedErr) t.Logf("\t\tWant: %+v", expectedErr)
t.Fatalf("\t%s\tCreate failed.", tests.Failed) t.Fatalf("\t%s\tCreate failed.", tests.Failed)
@ -374,7 +387,7 @@ func TestCreateClaims(t *testing.T) {
{ {
ctx := tests.Context() ctx := tests.Context()
_, err := Create(ctx, tt.claims, test.MasterDB, tt.req, now) _, err := repo.Create(ctx, tt.claims, tt.req, now)
if errors.Cause(err) != tt.error { if errors.Cause(err) != tt.error {
t.Logf("\t\tGot : %+v", err) t.Logf("\t\tGot : %+v", err)
t.Logf("\t\tWant: %+v", tt.error) t.Logf("\t\tWant: %+v", tt.error)
@ -421,7 +434,7 @@ func TestUpdateValidation(t *testing.T) {
{ {
ctx := tests.Context() ctx := tests.Context()
err := Update(ctx, auth.Claims{}, test.MasterDB, tt.req, now) err := repo.Update(ctx, auth.Claims{}, tt.req, now)
if err != tt.error { if err != tt.error {
// TODO: need a better way to handle validation errors as they are // TODO: need a better way to handle validation errors as they are
// of type interface validator.ValidationErrorsTranslations // of type interface validator.ValidationErrorsTranslations
@ -463,7 +476,7 @@ func TestUpdateValidationEmailUnique(t *testing.T) {
Password: "akTechFr0n!ier", Password: "akTechFr0n!ier",
PasswordConfirm: "akTechFr0n!ier", PasswordConfirm: "akTechFr0n!ier",
} }
user1, err := Create(ctx, auth.Claims{}, test.MasterDB, req1, now) user1, err := repo.Create(ctx, auth.Claims{}, req1, now)
if err != nil { if err != nil {
t.Log("\t\tGot :", err) t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tCreate failed.", tests.Failed) t.Fatalf("\t%s\tCreate failed.", tests.Failed)
@ -476,7 +489,7 @@ func TestUpdateValidationEmailUnique(t *testing.T) {
Password: "W0rkL1fe#", Password: "W0rkL1fe#",
PasswordConfirm: "W0rkL1fe#", PasswordConfirm: "W0rkL1fe#",
} }
user2, err := Create(ctx, auth.Claims{}, test.MasterDB, req2, now) user2, err := repo.Create(ctx, auth.Claims{}, req2, now)
if err != nil { if err != nil {
t.Log("\t\tGot :", err) t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tCreate failed.", tests.Failed) t.Fatalf("\t%s\tCreate failed.", tests.Failed)
@ -488,7 +501,7 @@ func TestUpdateValidationEmailUnique(t *testing.T) {
Email: &user1.Email, Email: &user1.Email,
} }
expectedErr := errors.New("Key: 'UserUpdateRequest.email' Error:Field validation for 'email' failed on the 'unique' tag") expectedErr := errors.New("Key: 'UserUpdateRequest.email' Error:Field validation for 'email' failed on the 'unique' tag")
err = Update(ctx, auth.Claims{}, test.MasterDB, updateReq, now) err = repo.Update(ctx, auth.Claims{}, updateReq, now)
if err == nil { if err == nil {
t.Logf("\t\tWant: %+v", expectedErr) t.Logf("\t\tWant: %+v", expectedErr)
t.Fatalf("\t%s\tUpdate failed.", tests.Failed) t.Fatalf("\t%s\tUpdate failed.", tests.Failed)
@ -518,7 +531,7 @@ func TestUpdatePassword(t *testing.T) {
// Create a new user for testing. // Create a new user for testing.
initPass := uuid.NewRandom().String() initPass := uuid.NewRandom().String()
user, err := Create(ctx, auth.Claims{}, test.MasterDB, UserCreateRequest{ user, err := repo.Create(ctx, auth.Claims{}, UserCreateRequest{
FirstName: "Lee", FirstName: "Lee",
LastName: "Brown", LastName: "Brown",
Email: uuid.NewRandom().String() + "@geeksinthewoods.com", Email: uuid.NewRandom().String() + "@geeksinthewoods.com",
@ -549,7 +562,7 @@ func TestUpdatePassword(t *testing.T) {
expectedErr := errors.New("Key: 'UserUpdatePasswordRequest.id' Error:Field validation for 'id' failed on the 'required' tag\n" + expectedErr := errors.New("Key: 'UserUpdatePasswordRequest.id' Error:Field validation for 'id' failed on the 'required' tag\n" +
"Key: 'UserUpdatePasswordRequest.password' Error:Field validation for 'password' failed on the 'required' tag\n" + "Key: 'UserUpdatePasswordRequest.password' Error:Field validation for 'password' failed on the 'required' tag\n" +
"Key: 'UserUpdatePasswordRequest.password_confirm' Error:Field validation for 'password_confirm' failed on the 'required' tag") "Key: 'UserUpdatePasswordRequest.password_confirm' Error:Field validation for 'password_confirm' failed on the 'required' tag")
err = UpdatePassword(ctx, auth.Claims{}, test.MasterDB, UserUpdatePasswordRequest{}, now) err = repo.UpdatePassword(ctx, auth.Claims{}, UserUpdatePasswordRequest{}, now)
if err == nil { if err == nil {
t.Logf("\t\tWant: %+v", expectedErr) t.Logf("\t\tWant: %+v", expectedErr)
t.Fatalf("\t%s\tUpdate failed.", tests.Failed) t.Fatalf("\t%s\tUpdate failed.", tests.Failed)
@ -567,7 +580,7 @@ func TestUpdatePassword(t *testing.T) {
// Update the users password. // Update the users password.
newPass := uuid.NewRandom().String() newPass := uuid.NewRandom().String()
err = UpdatePassword(ctx, auth.Claims{}, test.MasterDB, UserUpdatePasswordRequest{ err = repo.UpdatePassword(ctx, auth.Claims{}, UserUpdatePasswordRequest{
ID: user.ID, ID: user.ID,
Password: newPass, Password: newPass,
PasswordConfirm: newPass, PasswordConfirm: newPass,
@ -800,7 +813,7 @@ func TestCrud(t *testing.T) {
// Always create the new user with empty claims, testing claims for create user // Always create the new user with empty claims, testing claims for create user
// will be handled separately. // will be handled separately.
user, err := Create(tests.Context(), auth.Claims{}, test.MasterDB, tt.create, now) user, err := repo.Create(tests.Context(), auth.Claims{}, tt.create, now)
if err != nil { if err != nil {
t.Log("\t\tGot :", err) t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tCreate user failed.", tests.Failed) t.Fatalf("\t%s\tCreate user failed.", tests.Failed)
@ -823,7 +836,7 @@ func TestCrud(t *testing.T) {
// Update the user. // Update the user.
updateReq := tt.update(user) updateReq := tt.update(user)
err = Update(ctx, tt.claims(user, accountId), test.MasterDB, updateReq, now) err = repo.Update(ctx, tt.claims(user, accountId), updateReq, now)
if err != nil && errors.Cause(err) != tt.updateErr { if err != nil && errors.Cause(err) != tt.updateErr {
t.Logf("\t\tGot : %+v", err) t.Logf("\t\tGot : %+v", err)
t.Logf("\t\tWant: %+v", tt.updateErr) t.Logf("\t\tWant: %+v", tt.updateErr)
@ -832,7 +845,7 @@ func TestCrud(t *testing.T) {
t.Logf("\t%s\tUpdate ok.", tests.Success) t.Logf("\t%s\tUpdate ok.", tests.Success)
// Find the user and make sure the updates where made. // Find the user and make sure the updates where made.
findRes, err := ReadByID(ctx, tt.claims(user, accountId), test.MasterDB, user.ID) findRes, err := repo.ReadByID(ctx, tt.claims(user, accountId), user.ID)
if err != nil && errors.Cause(err) != tt.findErr { if err != nil && errors.Cause(err) != tt.findErr {
t.Logf("\t\tGot : %+v", err) t.Logf("\t\tGot : %+v", err)
t.Logf("\t\tWant: %+v", tt.findErr) t.Logf("\t\tWant: %+v", tt.findErr)
@ -846,14 +859,14 @@ func TestCrud(t *testing.T) {
} }
// Archive (soft-delete) the user. // Archive (soft-delete) the user.
err = Archive(ctx, tt.claims(user, accountId), test.MasterDB, UserArchiveRequest{ID: user.ID, force: true}, now) err = repo.Archive(ctx, tt.claims(user, accountId), UserArchiveRequest{ID: user.ID, force: true}, now)
if err != nil && errors.Cause(err) != tt.updateErr { if err != nil && errors.Cause(err) != tt.updateErr {
t.Logf("\t\tGot : %+v", err) t.Logf("\t\tGot : %+v", err)
t.Logf("\t\tWant: %+v", tt.updateErr) t.Logf("\t\tWant: %+v", tt.updateErr)
t.Fatalf("\t%s\tArchive failed.", tests.Failed) t.Fatalf("\t%s\tArchive failed.", tests.Failed)
} else if tt.updateErr == nil { } else if tt.updateErr == nil {
// Trying to find the archived user with the includeArchived false should result in not found. // Trying to find the archived user with the includeArchived false should result in not found.
_, err = ReadByID(ctx, tt.claims(user, accountId), test.MasterDB, user.ID) _, err = repo.ReadByID(ctx, tt.claims(user, accountId), user.ID)
if err != nil && errors.Cause(err) != ErrNotFound { if err != nil && errors.Cause(err) != ErrNotFound {
t.Logf("\t\tGot : %+v", err) t.Logf("\t\tGot : %+v", err)
t.Logf("\t\tWant: %+v", ErrNotFound) t.Logf("\t\tWant: %+v", ErrNotFound)
@ -861,7 +874,7 @@ func TestCrud(t *testing.T) {
} }
// Trying to find the archived user with the includeArchived true should result no error. // Trying to find the archived user with the includeArchived true should result no error.
_, err = Read(ctx, tt.claims(user, accountId), test.MasterDB, _, err = repo.Read(ctx, tt.claims(user, accountId),
UserReadRequest{ID: user.ID, IncludeArchived: true}) UserReadRequest{ID: user.ID, IncludeArchived: true})
if err != nil { if err != nil {
t.Log("\t\tGot :", err) t.Log("\t\tGot :", err)
@ -871,14 +884,14 @@ func TestCrud(t *testing.T) {
t.Logf("\t%s\tArchive ok.", tests.Success) t.Logf("\t%s\tArchive ok.", tests.Success)
// Restore (un-delete) the user. // Restore (un-delete) the user.
err = Restore(ctx, tt.claims(user, accountId), test.MasterDB, UserRestoreRequest{ID: user.ID}, now) err = repo.Restore(ctx, tt.claims(user, accountId), UserRestoreRequest{ID: user.ID}, now)
if err != nil && errors.Cause(err) != tt.updateErr { if err != nil && errors.Cause(err) != tt.updateErr {
t.Logf("\t\tGot : %+v", err) t.Logf("\t\tGot : %+v", err)
t.Logf("\t\tWant: %+v", tt.updateErr) t.Logf("\t\tWant: %+v", tt.updateErr)
t.Fatalf("\t%s\tUnarchive failed.", tests.Failed) t.Fatalf("\t%s\tUnarchive failed.", tests.Failed)
} else if tt.updateErr == nil { } else if tt.updateErr == nil {
// Trying to find the archived user with the includeArchived false should result no error. // Trying to find the archived user with the includeArchived false should result no error.
_, err = ReadByID(ctx, tt.claims(user, accountId), test.MasterDB, user.ID) _, err = repo.ReadByID(ctx, tt.claims(user, accountId), user.ID)
if err != nil { if err != nil {
t.Log("\t\tGot :", err) t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tUnarchive Read failed.", tests.Failed) t.Fatalf("\t%s\tUnarchive Read failed.", tests.Failed)
@ -887,14 +900,14 @@ func TestCrud(t *testing.T) {
t.Logf("\t%s\tUnarchive ok.", tests.Success) t.Logf("\t%s\tUnarchive ok.", tests.Success)
// Delete (hard-delete) the user. // Delete (hard-delete) the user.
err = Delete(ctx, tt.claims(user, accountId), test.MasterDB, UserDeleteRequest{ID: user.ID, force: true}) err = repo.Delete(ctx, tt.claims(user, accountId), UserDeleteRequest{ID: user.ID, force: true})
if err != nil && errors.Cause(err) != tt.updateErr { if err != nil && errors.Cause(err) != tt.updateErr {
t.Logf("\t\tGot : %+v", err) t.Logf("\t\tGot : %+v", err)
t.Logf("\t\tWant: %+v", tt.updateErr) t.Logf("\t\tWant: %+v", tt.updateErr)
t.Fatalf("\t%s\tUpdate failed.", tests.Failed) t.Fatalf("\t%s\tUpdate failed.", tests.Failed)
} else if tt.updateErr == nil { } else if tt.updateErr == nil {
// Trying to find the deleted user with the includeArchived true should result in not found. // Trying to find the deleted user with the includeArchived true should result in not found.
_, err = ReadByID(ctx, tt.claims(user, accountId), test.MasterDB, user.ID) _, err = repo.ReadByID(ctx, tt.claims(user, accountId), user.ID)
if errors.Cause(err) != ErrNotFound { if errors.Cause(err) != ErrNotFound {
t.Logf("\t\tGot : %+v", err) t.Logf("\t\tGot : %+v", err)
t.Logf("\t\tWant: %+v", ErrNotFound) t.Logf("\t\tWant: %+v", ErrNotFound)
@ -917,7 +930,7 @@ func TestFind(t *testing.T) {
var users []*User var users []*User
for i := 0; i <= 4; i++ { for i := 0; i <= 4; i++ {
user, err := Create(tests.Context(), auth.Claims{}, test.MasterDB, UserCreateRequest{ user, err := repo.Create(tests.Context(), auth.Claims{}, UserCreateRequest{
FirstName: "Lee", FirstName: "Lee",
LastName: "Brown", LastName: "Brown",
Email: uuid.NewRandom().String() + "@geeksinthewoods.com", Email: uuid.NewRandom().String() + "@geeksinthewoods.com",
@ -1029,7 +1042,7 @@ func TestFind(t *testing.T) {
{ {
ctx := tests.Context() ctx := tests.Context()
res, err := Find(ctx, auth.Claims{}, test.MasterDB, tt.req) res, err := repo.Find(ctx, auth.Claims{}, tt.req)
if errors.Cause(err) != tt.error { if errors.Cause(err) != tt.error {
t.Logf("\t\tGot : %+v", err) t.Logf("\t\tGot : %+v", err)
t.Logf("\t\tWant: %+v", tt.error) t.Logf("\t\tWant: %+v", tt.error)
@ -1064,7 +1077,7 @@ func TestResetPassword(t *testing.T) {
// Create a new user for testing. // Create a new user for testing.
initPass := uuid.NewRandom().String() initPass := uuid.NewRandom().String()
user, err := Create(ctx, auth.Claims{}, test.MasterDB, UserCreateRequest{ user, err := repo.Create(ctx, auth.Claims{}, UserCreateRequest{
FirstName: "Lee", FirstName: "Lee",
LastName: "Brown", LastName: "Brown",
Email: uuid.NewRandom().String() + "@geeksinthewoods.com", Email: uuid.NewRandom().String() + "@geeksinthewoods.com",
@ -1091,18 +1104,10 @@ func TestResetPassword(t *testing.T) {
t.Fatalf("\t%s\tCreate user account failed.", tests.Failed) t.Fatalf("\t%s\tCreate user account failed.", tests.Failed)
} }
// Mock the methods needed to make a password reset.
resetUrl := func(string) string {
return ""
}
notify := &notify.MockEmail{}
secretKey := "6368616e676520746869732070617373"
// Ensure validation is working by trying ResetPassword with an empty request. // Ensure validation is working by trying ResetPassword with an empty request.
{ {
expectedErr := errors.New("Key: 'UserResetPasswordRequest.email' Error:Field validation for 'email' failed on the 'required' tag") expectedErr := errors.New("Key: 'UserResetPasswordRequest.email' Error:Field validation for 'email' failed on the 'required' tag")
_, err = ResetPassword(ctx, test.MasterDB, resetUrl, notify, UserResetPasswordRequest{}, secretKey, now) _, err = repo.ResetPassword(ctx, UserResetPasswordRequest{}, now)
if err == nil { if err == nil {
t.Logf("\t\tWant: %+v", expectedErr) t.Logf("\t\tWant: %+v", expectedErr)
t.Fatalf("\t%s\tResetPassword failed.", tests.Failed) t.Fatalf("\t%s\tResetPassword failed.", tests.Failed)
@ -1122,10 +1127,10 @@ func TestResetPassword(t *testing.T) {
ttl := time.Hour ttl := time.Hour
// Make the reset password request. // Make the reset password request.
resetHash, err := ResetPassword(ctx, test.MasterDB, resetUrl, notify, UserResetPasswordRequest{ resetHash, err := repo.ResetPassword(ctx, UserResetPasswordRequest{
Email: user.Email, Email: user.Email,
TTL: ttl, TTL: ttl,
}, secretKey, now) }, now)
if err != nil { if err != nil {
t.Log("\t\tGot :", err) t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tResetPassword failed.", tests.Failed) t.Fatalf("\t%s\tResetPassword failed.", tests.Failed)
@ -1133,7 +1138,7 @@ func TestResetPassword(t *testing.T) {
t.Logf("\t%s\tResetPassword ok.", tests.Success) t.Logf("\t%s\tResetPassword ok.", tests.Success)
// Read the user to ensure the password_reset field was set. // Read the user to ensure the password_reset field was set.
user, err = ReadByID(ctx, auth.Claims{}, test.MasterDB, user.ID) user, err = repo.ReadByID(ctx, auth.Claims{}, user.ID)
if err != nil { if err != nil {
t.Log("\t\tGot :", err) t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tRead failed.", tests.Failed) t.Fatalf("\t%s\tRead failed.", tests.Failed)
@ -1146,7 +1151,7 @@ func TestResetPassword(t *testing.T) {
expectedErr := errors.New("Key: 'UserResetConfirmRequest.reset_hash' Error:Field validation for 'reset_hash' failed on the 'required' tag\n" + expectedErr := errors.New("Key: 'UserResetConfirmRequest.reset_hash' Error:Field validation for 'reset_hash' failed on the 'required' tag\n" +
"Key: 'UserResetConfirmRequest.password' Error:Field validation for 'password' failed on the 'required' tag\n" + "Key: 'UserResetConfirmRequest.password' Error:Field validation for 'password' failed on the 'required' tag\n" +
"Key: 'UserResetConfirmRequest.password_confirm' Error:Field validation for 'password_confirm' failed on the 'required' tag") "Key: 'UserResetConfirmRequest.password_confirm' Error:Field validation for 'password_confirm' failed on the 'required' tag")
_, err = ResetConfirm(ctx, test.MasterDB, UserResetConfirmRequest{}, secretKey, now) _, err = repo.ResetConfirm(ctx, UserResetConfirmRequest{}, now)
if err == nil { if err == nil {
t.Logf("\t\tWant: %+v", expectedErr) t.Logf("\t\tWant: %+v", expectedErr)
t.Fatalf("\t%s\tResetConfirm failed.", tests.Failed) t.Fatalf("\t%s\tResetConfirm failed.", tests.Failed)
@ -1166,11 +1171,11 @@ func TestResetPassword(t *testing.T) {
// Ensure the TTL is enforced. // Ensure the TTL is enforced.
{ {
newPass := uuid.NewRandom().String() newPass := uuid.NewRandom().String()
_, err = ResetConfirm(ctx, test.MasterDB, UserResetConfirmRequest{ _, err = repo.ResetConfirm(ctx, UserResetConfirmRequest{
ResetHash: resetHash, ResetHash: resetHash,
Password: newPass, Password: newPass,
PasswordConfirm: newPass, PasswordConfirm: newPass,
}, secretKey, now.UTC().Add(ttl*2)) }, now.UTC().Add(ttl*2))
if errors.Cause(err) != ErrResetExpired { if errors.Cause(err) != ErrResetExpired {
t.Logf("\t\tGot : %+v", errors.Cause(err)) t.Logf("\t\tGot : %+v", errors.Cause(err))
t.Logf("\t\tWant: %+v", ErrResetExpired) t.Logf("\t\tWant: %+v", ErrResetExpired)
@ -1181,11 +1186,11 @@ func TestResetPassword(t *testing.T) {
// Assuming we have received the email and clicked the link, we now can ensure confirm works. // Assuming we have received the email and clicked the link, we now can ensure confirm works.
newPass := uuid.NewRandom().String() newPass := uuid.NewRandom().String()
reset, err := ResetConfirm(ctx, test.MasterDB, UserResetConfirmRequest{ reset, err := repo.ResetConfirm(ctx, UserResetConfirmRequest{
ResetHash: resetHash, ResetHash: resetHash,
Password: newPass, Password: newPass,
PasswordConfirm: newPass, PasswordConfirm: newPass,
}, secretKey, now) }, now)
if err != nil { if err != nil {
t.Log("\t\tGot :", err) t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tResetConfirm failed.", tests.Failed) t.Fatalf("\t%s\tResetConfirm failed.", tests.Failed)
@ -1199,11 +1204,11 @@ func TestResetPassword(t *testing.T) {
// Ensure the reset hash does not work after its used. // Ensure the reset hash does not work after its used.
{ {
newPass := uuid.NewRandom().String() newPass := uuid.NewRandom().String()
_, err = ResetConfirm(ctx, test.MasterDB, UserResetConfirmRequest{ _, err = repo.ResetConfirm(ctx, UserResetConfirmRequest{
ResetHash: resetHash, ResetHash: resetHash,
Password: newPass, Password: newPass,
PasswordConfirm: newPass, PasswordConfirm: newPass,
}, secretKey, now) }, now)
if errors.Cause(err) != ErrNotFound { if errors.Cause(err) != ErrNotFound {
t.Logf("\t\tGot : %+v", errors.Cause(err)) t.Logf("\t\tGot : %+v", errors.Cause(err))
t.Logf("\t\tWant: %+v", ErrNotFound) t.Logf("\t\tWant: %+v", ErrNotFound)

View File

@ -3,6 +3,7 @@ package user_account
import ( import (
"context" "context"
"database/sql/driver" "database/sql/driver"
"github.com/jmoiron/sqlx"
"strings" "strings"
"time" "time"
@ -13,6 +14,18 @@ import (
"gopkg.in/go-playground/validator.v9" "gopkg.in/go-playground/validator.v9"
) )
// Repository defines the required dependencies for UserAccount.
type Repository struct {
DbConn *sqlx.DB
}
// NewRepository creates a new Repository that defines dependencies for UserAccount.
func NewRepository(db *sqlx.DB) *Repository {
return &Repository{
DbConn: db,
}
}
// UserAccount defines the one to many relationship of an user to an account. This // UserAccount defines the one to many relationship of an user to an account. This
// will enable a single user access to multiple accounts without having duplicate // will enable a single user access to multiple accounts without having duplicate
// users. Each association of a user to an account has a set of roles and a status // users. Each association of a user to an account has a set of roles and a status

View File

@ -6,13 +6,12 @@ import (
"geeks-accelerator/oss/saas-starter-kit/internal/platform/auth" "geeks-accelerator/oss/saas-starter-kit/internal/platform/auth"
"geeks-accelerator/oss/saas-starter-kit/internal/platform/web/webcontext" "geeks-accelerator/oss/saas-starter-kit/internal/platform/web/webcontext"
"github.com/huandu/go-sqlbuilder" "github.com/huandu/go-sqlbuilder"
"github.com/jmoiron/sqlx"
"github.com/pkg/errors" "github.com/pkg/errors"
"gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer" "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer"
) )
// UserFindByAccount lists all the users for a given account ID. // UserFindByAccount lists all the users for a given account ID.
func UserFindByAccount(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserFindByAccountRequest) (Users, error) { func (repo *Repository) UserFindByAccount(ctx context.Context, claims auth.Claims, req UserFindByAccountRequest) (Users, error) {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user_account.UserFindByAccount") span, ctx := tracer.StartSpanFromContext(ctx, "internal.user_account.UserFindByAccount")
defer span.Finish() defer span.Finish()
@ -113,12 +112,12 @@ func UserFindByAccount(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB,
} }
queryStr, moreQueryArgs := query.Build() queryStr, moreQueryArgs := query.Build()
queryStr = dbConn.Rebind(queryStr) queryStr = repo.DbConn.Rebind(queryStr)
queryArgs = append(queryArgs, moreQueryArgs...) queryArgs = append(queryArgs, moreQueryArgs...)
// fetch all places from the db // fetch all places from the db
rows, err := dbConn.QueryContext(ctx, queryStr, queryArgs...) rows, err := repo.DbConn.QueryContext(ctx, queryStr, queryArgs...)
if err != nil { if err != nil {
err = errors.Wrapf(err, "query - %s", query.String()) err = errors.Wrapf(err, "query - %s", query.String())
err = errors.WithMessage(err, "find users failed") err = errors.WithMessage(err, "find users failed")

View File

@ -49,14 +49,14 @@ func mapRowsToUserAccount(rows *sql.Rows) (*UserAccount, error) {
} }
// CanReadAccount determines if claims has the authority to access the specified user account by user ID. // CanReadAccount determines if claims has the authority to access the specified user account by user ID.
func CanReadAccount(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, accountID string) error { func (repo *Repository) CanReadAccount(ctx context.Context, claims auth.Claims, accountID string) error {
err := account.CanReadAccount(ctx, claims, dbConn, accountID) err := account.CanReadAccount(ctx, claims, accountID)
return mapAccountError(err) return mapAccountError(err)
} }
// CanModifyAccount determines if claims has the authority to modify the specified user ID. // CanModifyAccount determines if claims has the authority to modify the specified user ID.
func CanModifyAccount(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, accountID string) error { func (repo *Repository) CanModifyAccount(ctx context.Context, claims auth.Claims, accountID string) error {
err := account.CanModifyAccount(ctx, claims, dbConn, accountID) err := account.CanModifyAccount(ctx, claims, accountID)
return mapAccountError(err) return mapAccountError(err)
} }

View File

@ -17,7 +17,10 @@ import (
"github.com/pkg/errors" "github.com/pkg/errors"
) )
var test *tests.Test var (
test *tests.Test
repo *Repository
)
// TestMain is the entry point for testing. // TestMain is the entry point for testing.
func TestMain(m *testing.M) { func TestMain(m *testing.M) {
@ -27,6 +30,9 @@ func TestMain(m *testing.M) {
func testMain(m *testing.M) int { func testMain(m *testing.M) int {
test = tests.New() test = tests.New()
defer test.TearDown() defer test.TearDown()
repo = NewRepository(test.MasterDB)
return m.Run() return m.Run()
} }