1
0
mirror of https://github.com/raseels-repos/golang-saas-starter-kit.git synced 2025-08-08 22:36:41 +02:00

WIP: not sure how to solve user_account calling account.CanModifyAccount

This commit is contained in:
Lee Brown
2019-08-13 23:41:06 -08:00
parent 4be0454421
commit 3bc814a01e
8 changed files with 198 additions and 132 deletions

View File

@ -2,6 +2,7 @@ package handlers
import (
"context"
"geeks-accelerator/oss/saas-starter-kit/internal/user"
"log"
"net/http"
"os"
@ -20,33 +21,52 @@ import (
"gopkg.in/DataDog/dd-trace-go.v1/contrib/go-redis/redis"
)
type AppContext struct {
Log *log.Logger
Env webcontext.Env
Repo *user.Repository
MasterDB *sqlx.DB
Redis *redis.Client
Authenticator *auth.Authenticator
PreAppMiddleware []web.Middleware
PostAppMiddleware []web.Middleware
}
// API returns a handler for a set of routes.
func API(shutdown chan os.Signal, log *log.Logger, env webcontext.Env, masterDB *sqlx.DB, redis *redis.Client, authenticator *auth.Authenticator, globalMids ...web.Middleware) http.Handler {
func API(shutdown chan os.Signal, appContext *AppContext ) http.Handler {
// Define base middlewares applied to all requests.
middlewares := []web.Middleware{
mid.Trace(), mid.Logger(log), mid.Errors(log, nil), mid.Metrics(), mid.Panics(),
}
// Include the pre middlewares first.
middlewares := appContext.PreAppMiddleware
// Append any global middlewares if they were included.
if len(globalMids) > 0 {
middlewares = append(middlewares, globalMids...)
// Define app middlewares applied to all requests.
middlewares = append(middlewares,
mid.Trace(),
mid.Logger(appContext.Log),
mid.Errors(appContext.Log, nil),
mid.Metrics(),
mid.Panics())
// Append any global middlewares that should be included after the app middlewares.
if len(appContext.PostAppMiddleware) > 0 {
middlewares = append(middlewares, appContext.PostAppMiddleware...)
}
// Construct the web.App which holds all routes as well as common Middleware.
app := web.NewApp(shutdown, log, env, middlewares...)
app := web.NewApp(shutdown, appContext.Log, appContext.Env, middlewares...)
// Register health check endpoint. This route is not authenticated.
check := Check{
MasterDB: masterDB,
Redis: redis,
MasterDB: appContext.MasterDB,
Redis: appContext.Redis,
}
app.Handle("GET", "/v1/health", check.Health)
app.Handle("GET", "/ping", check.Ping)
// Register user management and authentication endpoints.
u := User{
MasterDB: masterDB,
MasterDB: appContext.MasterDB,
TokenGenerator: authenticator,
}
app.Handle("GET", "/v1/users", u.Find, mid.AuthenticateHeader(authenticator))

View File

@ -4,8 +4,10 @@ import (
"context"
"database/sql"
"encoding/json"
"geeks-accelerator/oss/saas-starter-kit/internal/platform/notify"
"geeks-accelerator/oss/saas-starter-kit/internal/platform/web"
"geeks-accelerator/oss/saas-starter-kit/internal/platform/web/webcontext"
"github.com/jmoiron/sqlx"
"github.com/pkg/errors"
"github.com/sudo-suhas/symcrypto"
"strconv"
@ -15,6 +17,24 @@ import (
"github.com/lib/pq"
)
// Repository defines the required dependencies for User.
type Repository struct {
DbConn *sqlx.DB
ResetUrl func(string) string
Notify notify.Email
SecretKey string
}
// NewRepository creates a new Repository that defines dependencies for User.
func NewRepository(db *sqlx.DB, resetUrl func(string) string, notify notify.Email, secretKey string) *Repository {
return &Repository{
DbConn: db,
ResetUrl: resetUrl,
Notify: notify,
SecretKey: secretKey,
}
}
// User represents someone with access to our system.
type User struct {
ID string `json:"id" validate:"required,uuid" example:"d69bdef7-173f-4d29-b52c-3edc60baf6a2"`

View File

@ -6,7 +6,6 @@ import (
"time"
"geeks-accelerator/oss/saas-starter-kit/internal/platform/auth"
"geeks-accelerator/oss/saas-starter-kit/internal/platform/notify"
"geeks-accelerator/oss/saas-starter-kit/internal/platform/web/webcontext"
"github.com/huandu/go-sqlbuilder"
"github.com/jmoiron/sqlx"
@ -55,7 +54,7 @@ func mapRowsToUser(rows *sql.Rows) (*User, error) {
}
// CanReadUser determines if claims has the authority to access the specified user ID.
func CanReadUser(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, userID string) error {
func (repo *Repository) CanReadUser(ctx context.Context, claims auth.Claims, userID string) error {
// If the request has claims from a specific user, ensure that the user
// has the correct access to the user.
if claims.Subject != "" && claims.Subject != userID {
@ -68,10 +67,10 @@ func CanReadUser(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, userI
query.Equal("user_id", userID),
))
queryStr, args := query.Build()
queryStr = dbConn.Rebind(queryStr)
queryStr = repo.DbConn.Rebind(queryStr)
var userAccountId string
err := dbConn.QueryRowContext(ctx, queryStr, args...).Scan(&userAccountId)
err := repo.DbConn.QueryRowContext(ctx, queryStr, args...).Scan(&userAccountId)
if err != nil && err != sql.ErrNoRows {
err = errors.Wrapf(err, "query - %s", query.String())
return err
@ -88,7 +87,7 @@ func CanReadUser(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, userI
}
// CanModifyUser determines if claims has the authority to modify the specified user ID.
func CanModifyUser(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, userID string) error {
func (repo *Repository) CanModifyUser(ctx context.Context, claims auth.Claims, userID string) error {
// If the request has claims from a specific user, ensure that the user
// has the correct role for creating a new user.
if claims.Subject != "" && claims.Subject != userID {
@ -99,7 +98,7 @@ func CanModifyUser(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, use
}
}
if err := CanReadUser(ctx, claims, dbConn, userID); err != nil {
if err := repo.CanReadUser(ctx, claims, userID); err != nil {
return err
}
@ -118,10 +117,10 @@ func CanModifyUser(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, use
"'"+auth.RoleAdmin+"' = ANY (roles)",
))
queryStr, args := query.Build()
queryStr = dbConn.Rebind(queryStr)
queryStr = repo.DbConn.Rebind(queryStr)
var userAccountId string
err := dbConn.QueryRowContext(ctx, queryStr, args...).Scan(&userAccountId)
err := repo.DbConn.QueryRowContext(ctx, queryStr, args...).Scan(&userAccountId)
if err != nil && err != sql.ErrNoRows {
err = errors.Wrapf(err, "query - %s", query.String())
return err
@ -199,13 +198,13 @@ func findRequestQuery(req UserFindRequest) (*sqlbuilder.SelectBuilder, []interfa
}
// Find gets all the users from the database based on the request params.
func Find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserFindRequest) (Users, error) {
func (repo *Repository) Find(ctx context.Context, claims auth.Claims, req UserFindRequest) (Users, error) {
query, args := findRequestQuery(req)
return find(ctx, claims, dbConn, query, args, req.IncludeArchived)
return repo.find(ctx, claims, query, args, req.IncludeArchived)
}
// find internal method for getting all the users from the database using a select query.
func find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, query *sqlbuilder.SelectBuilder, args []interface{}, includedArchived bool) (Users, error) {
func (repo *Repository) find(ctx context.Context, claims auth.Claims, query *sqlbuilder.SelectBuilder, args []interface{}, includedArchived bool) (Users, error) {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.Find")
defer span.Finish()
@ -222,11 +221,11 @@ func find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, query *sqlbu
return nil, err
}
queryStr, queryArgs := query.Build()
queryStr = dbConn.Rebind(queryStr)
queryStr = repo.DbConn.Rebind(queryStr)
args = append(args, queryArgs...)
// fetch all places from the db
rows, err := dbConn.QueryContext(ctx, queryStr, args...)
rows, err := repo.DbConn.QueryContext(ctx, queryStr, args...)
if err != nil {
err = errors.Wrapf(err, "query - %s", query.String())
err = errors.WithMessage(err, "find users failed")
@ -248,17 +247,17 @@ func find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, query *sqlbu
}
// Validation an email address is unique excluding the current user ID.
func UniqueEmail(ctx context.Context, dbConn *sqlx.DB, email, userId string) (bool, error) {
func (repo *Repository) UniqueEmail(ctx context.Context, email, userId string) (bool, error) {
query := sqlbuilder.NewSelectBuilder().Select("id").From(userTableName)
query.Where(query.And(
query.Equal("email", email),
query.NotEqual("id", userId),
))
queryStr, args := query.Build()
queryStr = dbConn.Rebind(queryStr)
queryStr = repo.DbConn.Rebind(queryStr)
var existingId string
err := dbConn.QueryRowContext(ctx, queryStr, args...).Scan(&existingId)
err := repo.DbConn.QueryRowContext(ctx, queryStr, args...).Scan(&existingId)
if err != nil && err != sql.ErrNoRows {
err = errors.Wrapf(err, "query - %s", query.String())
return false, err
@ -273,7 +272,7 @@ func UniqueEmail(ctx context.Context, dbConn *sqlx.DB, email, userId string) (bo
}
// Create inserts a new user into the database.
func Create(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserCreateRequest, now time.Time) (*User, error) {
func (repo *Repository) Create(ctx context.Context, claims auth.Claims, req UserCreateRequest, now time.Time) (*User, error) {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.Create")
defer span.Finish()
@ -284,7 +283,7 @@ func Create(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserCr
v := webcontext.Validator()
// Validation email address is unique in the database.
uniq, err := UniqueEmail(ctx, dbConn, req.Email, "")
uniq, err := repo.UniqueEmail(ctx, req.Email, "")
if err != nil {
return nil, err
}
@ -346,8 +345,8 @@ func Create(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserCr
// Execute the query with the provided context.
sql, args := query.Build()
sql = dbConn.Rebind(sql)
_, err = dbConn.ExecContext(ctx, sql, args...)
sql = repo.DbConn.Rebind(sql)
_, err = repo.DbConn.ExecContext(ctx, sql, args...)
if err != nil {
err = errors.Wrapf(err, "query - %s", query.String())
err = errors.WithMessage(err, "create user failed")
@ -358,14 +357,14 @@ func Create(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserCr
}
// Create invite inserts a new user into the database.
func CreateInvite(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserCreateInviteRequest, now time.Time) (*User, error) {
func (repo *Repository) CreateInvite(ctx context.Context, claims auth.Claims, req UserCreateInviteRequest, now time.Time) (*User, error) {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.CreateInvite")
defer span.Finish()
v := webcontext.Validator()
// Validation email address is unique in the database.
uniq, err := UniqueEmail(ctx, dbConn, req.Email, "")
uniq, err := repo.UniqueEmail(ctx, req.Email, "")
if err != nil {
return nil, err
}
@ -414,8 +413,8 @@ func CreateInvite(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req
// Execute the query with the provided context.
sql, args := query.Build()
sql = dbConn.Rebind(sql)
_, err = dbConn.ExecContext(ctx, sql, args...)
sql = repo.DbConn.Rebind(sql)
_, err = repo.DbConn.ExecContext(ctx, sql, args...)
if err != nil {
err = errors.Wrapf(err, "query - %s", query.String())
err = errors.WithMessage(err, "create user failed")
@ -426,15 +425,15 @@ func CreateInvite(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req
}
// ReadByID gets the specified user by ID from the database.
func ReadByID(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, id string) (*User, error) {
return Read(ctx, claims, dbConn, UserReadRequest{
func (repo *Repository) ReadByID(ctx context.Context, claims auth.Claims, id string) (*User, error) {
return repo.Read(ctx, claims, UserReadRequest{
ID: id,
IncludeArchived: false,
})
}
// Read gets the specified user from the database.
func Read(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserReadRequest) (*User, error) {
func (repo *Repository) Read(ctx context.Context, claims auth.Claims, req UserReadRequest) (*User, error) {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.Read")
defer span.Finish()
@ -449,7 +448,7 @@ func Read(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserRead
query := selectQuery()
query.Where(query.Equal("id", req.ID))
res, err := find(ctx, claims, dbConn, query, []interface{}{}, req.IncludeArchived)
res, err := repo.find(ctx, claims, query, []interface{}{}, req.IncludeArchived)
if err != nil {
return nil, err
} else if res == nil || len(res) == 0 {
@ -462,7 +461,7 @@ func Read(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserRead
}
// ReadByEmail gets the specified user from the database.
func ReadByEmail(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, email string, includedArchived bool) (*User, error) {
func (repo *Repository) ReadByEmail(ctx context.Context, claims auth.Claims, email string, includedArchived bool) (*User, error) {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.ReadByEmail")
defer span.Finish()
@ -470,7 +469,7 @@ func ReadByEmail(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, email
query := selectQuery()
query.Where(query.Equal("email", email))
res, err := find(ctx, claims, dbConn, query, []interface{}{}, includedArchived)
res, err := repo.find(ctx, claims, query, []interface{}{}, includedArchived)
if err != nil {
return nil, err
} else if res == nil || len(res) == 0 {
@ -483,14 +482,14 @@ func ReadByEmail(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, email
}
// Update replaces a user in the database.
func Update(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserUpdateRequest, now time.Time) error {
func (repo *Repository) Update(ctx context.Context, claims auth.Claims, req UserUpdateRequest, now time.Time) error {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.Update")
defer span.Finish()
// Validation email address is unique in the database.
if req.Email != nil {
// Validation email address is unique in the database.
uniq, err := UniqueEmail(ctx, dbConn, *req.Email, req.ID)
uniq, err := repo.UniqueEmail(ctx, *req.Email, req.ID)
if err != nil {
return err
}
@ -507,7 +506,7 @@ func Update(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserUp
}
// Ensure the claims can modify the user specified in the request.
err = CanModifyUser(ctx, claims, dbConn, req.ID)
err = repo.CanModifyUser(ctx, claims, req.ID)
if err != nil {
return err
}
@ -555,8 +554,8 @@ func Update(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserUp
// Execute the query with the provided context.
sql, args := query.Build()
sql = dbConn.Rebind(sql)
_, err = dbConn.ExecContext(ctx, sql, args...)
sql = repo.DbConn.Rebind(sql)
_, err = repo.DbConn.ExecContext(ctx, sql, args...)
if err != nil {
err = errors.Wrapf(err, "query - %s", query.String())
err = errors.WithMessagef(err, "update user %s failed", req.ID)
@ -567,7 +566,7 @@ func Update(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserUp
}
// Update changes the password for a user in the database.
func UpdatePassword(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserUpdatePasswordRequest, now time.Time) error {
func (repo *Repository) UpdatePassword(ctx context.Context, claims auth.Claims, req UserUpdatePasswordRequest, now time.Time) error {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.UpdatePassword")
defer span.Finish()
@ -579,7 +578,7 @@ func UpdatePassword(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, re
}
// Ensure the claims can modify the user specified in the request.
err = CanModifyUser(ctx, claims, dbConn, req.ID)
err = repo.CanModifyUser(ctx, claims, req.ID)
if err != nil {
return err
}
@ -616,8 +615,8 @@ func UpdatePassword(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, re
// Execute the query with the provided context.
sql, args := query.Build()
sql = dbConn.Rebind(sql)
_, err = dbConn.ExecContext(ctx, sql, args...)
sql = repo.DbConn.Rebind(sql)
_, err = repo.DbConn.ExecContext(ctx, sql, args...)
if err != nil {
err = errors.Wrapf(err, "query - %s", query.String())
err = errors.WithMessagef(err, "update password for user %s failed", req.ID)
@ -628,7 +627,7 @@ func UpdatePassword(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, re
}
// Archive soft deleted the user from the database.
func Archive(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserArchiveRequest, now time.Time) error {
func (repo *Repository) Archive(ctx context.Context, claims auth.Claims, req UserArchiveRequest, now time.Time) error {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.Archive")
defer span.Finish()
@ -640,7 +639,7 @@ func Archive(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserA
}
// Ensure the claims can modify the user specified in the request.
err = CanModifyUser(ctx, claims, dbConn, req.ID)
err = repo.CanModifyUser(ctx, claims, req.ID)
if err != nil {
return err
} else if claims.Subject != "" && claims.Subject == req.ID && !req.force {
@ -669,8 +668,8 @@ func Archive(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserA
// Execute the query with the provided context.
sql, args := query.Build()
sql = dbConn.Rebind(sql)
_, err = dbConn.ExecContext(ctx, sql, args...)
sql = repo.DbConn.Rebind(sql)
_, err = repo.DbConn.ExecContext(ctx, sql, args...)
if err != nil {
err = errors.Wrapf(err, "query - %s", query.String())
err = errors.WithMessagef(err, "archive user %s failed", req.ID)
@ -689,8 +688,8 @@ func Archive(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserA
// Execute the query with the provided context.
sql, args := query.Build()
sql = dbConn.Rebind(sql)
_, err = dbConn.ExecContext(ctx, sql, args...)
sql = repo.DbConn.Rebind(sql)
_, err = repo.DbConn.ExecContext(ctx, sql, args...)
if err != nil {
err = errors.Wrapf(err, "query - %s", query.String())
err = errors.WithMessagef(err, "archive accounts for user %s failed", req.ID)
@ -702,7 +701,7 @@ func Archive(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserA
}
// Restore undeletes the user from the database.
func Restore(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserRestoreRequest, now time.Time) error {
func (repo *Repository) Restore(ctx context.Context, claims auth.Claims, req UserRestoreRequest, now time.Time) error {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.Restore")
defer span.Finish()
@ -714,7 +713,7 @@ func Restore(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserR
}
// Ensure the claims can modify the user specified in the request.
err = CanModifyUser(ctx, claims, dbConn, req.ID)
err = repo.CanModifyUser(ctx, claims, req.ID)
if err != nil {
return err
}
@ -741,8 +740,8 @@ func Restore(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserR
// Execute the query with the provided context.
sql, args := query.Build()
sql = dbConn.Rebind(sql)
_, err = dbConn.ExecContext(ctx, sql, args...)
sql = repo.DbConn.Rebind(sql)
_, err = repo.DbConn.ExecContext(ctx, sql, args...)
if err != nil {
err = errors.Wrapf(err, "query - %s", query.String())
err = errors.WithMessagef(err, "unarchive user %s failed", req.ID)
@ -753,7 +752,7 @@ func Restore(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserR
}
// Delete removes a user from the database.
func Delete(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserDeleteRequest) error {
func (repo *Repository) Delete(ctx context.Context, claims auth.Claims, req UserDeleteRequest) error {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.Delete")
defer span.Finish()
@ -765,7 +764,7 @@ func Delete(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserDe
}
// Ensure the claims can modify the user specified in the request.
err = CanModifyUser(ctx, claims, dbConn, req.ID)
err = repo.CanModifyUser(ctx, claims, req.ID)
if err != nil {
return err
} else if claims.Subject != "" && claims.Subject == req.ID && !req.force {
@ -773,7 +772,7 @@ func Delete(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserDe
}
// Start a new transaction to handle rollbacks on error.
tx, err := dbConn.Begin()
tx, err := repo.DbConn.Begin()
if err != nil {
return errors.WithStack(err)
}
@ -790,7 +789,7 @@ func Delete(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserDe
// Execute the query with the provided context.
sql, args := query.Build()
sql = dbConn.Rebind(sql)
sql = repo.DbConn.Rebind(sql)
_, err = tx.ExecContext(ctx, sql, args...)
if err != nil {
tx.Rollback()
@ -808,7 +807,7 @@ func Delete(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserDe
// Execute the query with the provided context.
sql, args := query.Build()
sql = dbConn.Rebind(sql)
sql = repo.DbConn.Rebind(sql)
_, err = tx.ExecContext(ctx, sql, args...)
if err != nil {
tx.Rollback()
@ -827,7 +826,7 @@ func Delete(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserDe
}
// ResetPassword sends en email to the user to allow them to reset their password.
func ResetPassword(ctx context.Context, dbConn *sqlx.DB, resetUrl func(string) string, notify notify.Email, req UserResetPasswordRequest, secretKey string, now time.Time) (string, error) {
func (repo *Repository) ResetPassword(ctx context.Context, req UserResetPasswordRequest, now time.Time) (string, error) {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.ResetPassword")
defer span.Finish()
@ -845,7 +844,7 @@ func ResetPassword(ctx context.Context, dbConn *sqlx.DB, resetUrl func(string) s
query := selectQuery()
query.Where(query.Equal("email", req.Email))
res, err := find(ctx, auth.Claims{}, dbConn, query, []interface{}{}, false)
res, err := repo.find(ctx, auth.Claims{}, query, []interface{}{}, false)
if err != nil {
return "", err
} else if res == nil || len(res) == 0 {
@ -876,8 +875,8 @@ func ResetPassword(ctx context.Context, dbConn *sqlx.DB, resetUrl func(string) s
// Execute the query with the provided context.
sql, args := query.Build()
sql = dbConn.Rebind(sql)
_, err = dbConn.ExecContext(ctx, sql, args...)
sql = repo.DbConn.Rebind(sql)
_, err = repo.DbConn.ExecContext(ctx, sql, args...)
if err != nil {
err = errors.Wrapf(err, "query - %s", query.String())
err = errors.WithMessagef(err, "Update user %s failed.", u.ID)
@ -895,18 +894,18 @@ func ResetPassword(ctx context.Context, dbConn *sqlx.DB, resetUrl func(string) s
requestIp = vals.RequestIP
}
encrypted, err := NewResetHash(ctx, secretKey, resetId, requestIp, req.TTL, now)
encrypted, err := NewResetHash(ctx, repo.SecretKey, resetId, requestIp, req.TTL, now)
if err != nil {
return "", err
}
data := map[string]interface{}{
"Name": u.FirstName,
"Url": resetUrl(encrypted),
"Url": repo.ResetUrl(encrypted),
"Minutes": req.TTL.Minutes(),
}
err = notify.Send(ctx, u.Email, "Reset your Password", "user_reset_password", data)
err = repo.Notify.Send(ctx, u.Email, "Reset your Password", "user_reset_password", data)
if err != nil {
err = errors.WithMessagef(err, "Send password reset email to %s failed.", u.Email)
return "", err
@ -916,7 +915,7 @@ func ResetPassword(ctx context.Context, dbConn *sqlx.DB, resetUrl func(string) s
}
// ResetConfirm updates the password for a user using the provided reset password ID.
func ResetConfirm(ctx context.Context, dbConn *sqlx.DB, req UserResetConfirmRequest, secretKey string, now time.Time) (*User, error) {
func (repo *Repository) ResetConfirm(ctx context.Context, req UserResetConfirmRequest, now time.Time) (*User, error) {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.ResetConfirm")
defer span.Finish()
@ -928,7 +927,7 @@ func ResetConfirm(ctx context.Context, dbConn *sqlx.DB, req UserResetConfirmRequ
return nil, err
}
hash, err := ParseResetHash(ctx, secretKey, req.ResetHash, now)
hash, err := ParseResetHash(ctx, repo.SecretKey, req.ResetHash, now)
if err != nil {
return nil, err
}
@ -939,7 +938,7 @@ func ResetConfirm(ctx context.Context, dbConn *sqlx.DB, req UserResetConfirmRequ
query := selectQuery()
query.Where(query.Equal("password_reset", hash.ResetID))
res, err := find(ctx, auth.Claims{}, dbConn, query, []interface{}{}, false)
res, err := repo.find(ctx, auth.Claims{}, query, []interface{}{}, false)
if err != nil {
return nil, err
} else if res == nil || len(res) == 0 {
@ -979,8 +978,8 @@ func ResetConfirm(ctx context.Context, dbConn *sqlx.DB, req UserResetConfirmRequ
// Execute the query with the provided context.
sql, args := query.Build()
sql = dbConn.Rebind(sql)
_, err = dbConn.ExecContext(ctx, sql, args...)
sql = repo.DbConn.Rebind(sql)
_, err = repo.DbConn.ExecContext(ctx, sql, args...)
if err != nil {
err = errors.Wrapf(err, "query - %s", query.String())
err = errors.WithMessagef(err, "update password for user %s failed", u.ID)
@ -1000,6 +999,10 @@ type MockUserResponse struct {
func MockUser(ctx context.Context, dbConn *sqlx.DB, now time.Time) (*MockUserResponse, error) {
pass := uuid.NewRandom().String()
repo := &Repository{
DbConn: dbConn,
}
req := UserCreateRequest{
FirstName: "Lee",
LastName: "Brown",
@ -1007,7 +1010,7 @@ func MockUser(ctx context.Context, dbConn *sqlx.DB, now time.Time) (*MockUserRes
Password: pass,
PasswordConfirm: pass,
}
u, err := Create(ctx, auth.Claims{}, dbConn, req, now)
u, err := repo.Create(ctx, auth.Claims{}, req, now)
if err != nil {
return nil, err
}

View File

@ -18,7 +18,10 @@ import (
"github.com/pkg/errors"
)
var test *tests.Test
var (
test *tests.Test
repo *Repository
)
// TestMain is the entry point for testing.
func TestMain(m *testing.M) {
@ -28,6 +31,16 @@ func TestMain(m *testing.M) {
func testMain(m *testing.M) int {
test = tests.New()
defer test.TearDown()
// Mock the methods needed to make a password reset.
resetUrl := func(string) string {
return ""
}
notify := &notify.MockEmail{}
secretKey := "6368616e676520746869732070617373"
repo = NewRepository(test.MasterDB, resetUrl, notify, secretKey)
return m.Run()
}
@ -219,7 +232,7 @@ func TestCreateValidation(t *testing.T) {
{
ctx := tests.Context()
res, err := Create(ctx, auth.Claims{}, test.MasterDB, tt.req, now)
res, err := repo.Create(ctx, auth.Claims{}, tt.req, now)
if err != tt.error {
// TODO: need a better way to handle validation errors as they are
// of type interface validator.ValidationErrorsTranslations
@ -272,7 +285,7 @@ func TestCreateValidationEmailUnique(t *testing.T) {
Password: "akTechFr0n!ier",
PasswordConfirm: "akTechFr0n!ier",
}
user1, err := Create(ctx, auth.Claims{}, test.MasterDB, req1, now)
user1, err := repo.Create(ctx, auth.Claims{}, req1, now)
if err != nil {
t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tCreate failed.", tests.Failed)
@ -286,7 +299,7 @@ func TestCreateValidationEmailUnique(t *testing.T) {
PasswordConfirm: "W0rkL1fe#",
}
expectedErr := errors.New("Key: 'UserCreateRequest.email' Error:Field validation for 'email' failed on the 'unique' tag")
_, err = Create(ctx, auth.Claims{}, test.MasterDB, req2, now)
_, err = repo.Create(ctx, auth.Claims{}, req2, now)
if err == nil {
t.Logf("\t\tWant: %+v", expectedErr)
t.Fatalf("\t%s\tCreate failed.", tests.Failed)
@ -374,7 +387,7 @@ func TestCreateClaims(t *testing.T) {
{
ctx := tests.Context()
_, err := Create(ctx, tt.claims, test.MasterDB, tt.req, now)
_, err := repo.Create(ctx, tt.claims, tt.req, now)
if errors.Cause(err) != tt.error {
t.Logf("\t\tGot : %+v", err)
t.Logf("\t\tWant: %+v", tt.error)
@ -421,7 +434,7 @@ func TestUpdateValidation(t *testing.T) {
{
ctx := tests.Context()
err := Update(ctx, auth.Claims{}, test.MasterDB, tt.req, now)
err := repo.Update(ctx, auth.Claims{}, tt.req, now)
if err != tt.error {
// TODO: need a better way to handle validation errors as they are
// of type interface validator.ValidationErrorsTranslations
@ -463,7 +476,7 @@ func TestUpdateValidationEmailUnique(t *testing.T) {
Password: "akTechFr0n!ier",
PasswordConfirm: "akTechFr0n!ier",
}
user1, err := Create(ctx, auth.Claims{}, test.MasterDB, req1, now)
user1, err := repo.Create(ctx, auth.Claims{}, req1, now)
if err != nil {
t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tCreate failed.", tests.Failed)
@ -476,7 +489,7 @@ func TestUpdateValidationEmailUnique(t *testing.T) {
Password: "W0rkL1fe#",
PasswordConfirm: "W0rkL1fe#",
}
user2, err := Create(ctx, auth.Claims{}, test.MasterDB, req2, now)
user2, err := repo.Create(ctx, auth.Claims{}, req2, now)
if err != nil {
t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tCreate failed.", tests.Failed)
@ -488,7 +501,7 @@ func TestUpdateValidationEmailUnique(t *testing.T) {
Email: &user1.Email,
}
expectedErr := errors.New("Key: 'UserUpdateRequest.email' Error:Field validation for 'email' failed on the 'unique' tag")
err = Update(ctx, auth.Claims{}, test.MasterDB, updateReq, now)
err = repo.Update(ctx, auth.Claims{}, updateReq, now)
if err == nil {
t.Logf("\t\tWant: %+v", expectedErr)
t.Fatalf("\t%s\tUpdate failed.", tests.Failed)
@ -518,7 +531,7 @@ func TestUpdatePassword(t *testing.T) {
// Create a new user for testing.
initPass := uuid.NewRandom().String()
user, err := Create(ctx, auth.Claims{}, test.MasterDB, UserCreateRequest{
user, err := repo.Create(ctx, auth.Claims{}, UserCreateRequest{
FirstName: "Lee",
LastName: "Brown",
Email: uuid.NewRandom().String() + "@geeksinthewoods.com",
@ -549,7 +562,7 @@ func TestUpdatePassword(t *testing.T) {
expectedErr := errors.New("Key: 'UserUpdatePasswordRequest.id' Error:Field validation for 'id' failed on the 'required' tag\n" +
"Key: 'UserUpdatePasswordRequest.password' Error:Field validation for 'password' failed on the 'required' tag\n" +
"Key: 'UserUpdatePasswordRequest.password_confirm' Error:Field validation for 'password_confirm' failed on the 'required' tag")
err = UpdatePassword(ctx, auth.Claims{}, test.MasterDB, UserUpdatePasswordRequest{}, now)
err = repo.UpdatePassword(ctx, auth.Claims{}, UserUpdatePasswordRequest{}, now)
if err == nil {
t.Logf("\t\tWant: %+v", expectedErr)
t.Fatalf("\t%s\tUpdate failed.", tests.Failed)
@ -567,7 +580,7 @@ func TestUpdatePassword(t *testing.T) {
// Update the users password.
newPass := uuid.NewRandom().String()
err = UpdatePassword(ctx, auth.Claims{}, test.MasterDB, UserUpdatePasswordRequest{
err = repo.UpdatePassword(ctx, auth.Claims{}, UserUpdatePasswordRequest{
ID: user.ID,
Password: newPass,
PasswordConfirm: newPass,
@ -800,7 +813,7 @@ func TestCrud(t *testing.T) {
// Always create the new user with empty claims, testing claims for create user
// will be handled separately.
user, err := Create(tests.Context(), auth.Claims{}, test.MasterDB, tt.create, now)
user, err := repo.Create(tests.Context(), auth.Claims{}, tt.create, now)
if err != nil {
t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tCreate user failed.", tests.Failed)
@ -823,7 +836,7 @@ func TestCrud(t *testing.T) {
// Update the user.
updateReq := tt.update(user)
err = Update(ctx, tt.claims(user, accountId), test.MasterDB, updateReq, now)
err = repo.Update(ctx, tt.claims(user, accountId), updateReq, now)
if err != nil && errors.Cause(err) != tt.updateErr {
t.Logf("\t\tGot : %+v", err)
t.Logf("\t\tWant: %+v", tt.updateErr)
@ -832,7 +845,7 @@ func TestCrud(t *testing.T) {
t.Logf("\t%s\tUpdate ok.", tests.Success)
// Find the user and make sure the updates where made.
findRes, err := ReadByID(ctx, tt.claims(user, accountId), test.MasterDB, user.ID)
findRes, err := repo.ReadByID(ctx, tt.claims(user, accountId), user.ID)
if err != nil && errors.Cause(err) != tt.findErr {
t.Logf("\t\tGot : %+v", err)
t.Logf("\t\tWant: %+v", tt.findErr)
@ -846,14 +859,14 @@ func TestCrud(t *testing.T) {
}
// Archive (soft-delete) the user.
err = Archive(ctx, tt.claims(user, accountId), test.MasterDB, UserArchiveRequest{ID: user.ID, force: true}, now)
err = repo.Archive(ctx, tt.claims(user, accountId), UserArchiveRequest{ID: user.ID, force: true}, now)
if err != nil && errors.Cause(err) != tt.updateErr {
t.Logf("\t\tGot : %+v", err)
t.Logf("\t\tWant: %+v", tt.updateErr)
t.Fatalf("\t%s\tArchive failed.", tests.Failed)
} else if tt.updateErr == nil {
// Trying to find the archived user with the includeArchived false should result in not found.
_, err = ReadByID(ctx, tt.claims(user, accountId), test.MasterDB, user.ID)
_, err = repo.ReadByID(ctx, tt.claims(user, accountId), user.ID)
if err != nil && errors.Cause(err) != ErrNotFound {
t.Logf("\t\tGot : %+v", err)
t.Logf("\t\tWant: %+v", ErrNotFound)
@ -861,7 +874,7 @@ func TestCrud(t *testing.T) {
}
// Trying to find the archived user with the includeArchived true should result no error.
_, err = Read(ctx, tt.claims(user, accountId), test.MasterDB,
_, err = repo.Read(ctx, tt.claims(user, accountId),
UserReadRequest{ID: user.ID, IncludeArchived: true})
if err != nil {
t.Log("\t\tGot :", err)
@ -871,14 +884,14 @@ func TestCrud(t *testing.T) {
t.Logf("\t%s\tArchive ok.", tests.Success)
// Restore (un-delete) the user.
err = Restore(ctx, tt.claims(user, accountId), test.MasterDB, UserRestoreRequest{ID: user.ID}, now)
err = repo.Restore(ctx, tt.claims(user, accountId), UserRestoreRequest{ID: user.ID}, now)
if err != nil && errors.Cause(err) != tt.updateErr {
t.Logf("\t\tGot : %+v", err)
t.Logf("\t\tWant: %+v", tt.updateErr)
t.Fatalf("\t%s\tUnarchive failed.", tests.Failed)
} else if tt.updateErr == nil {
// Trying to find the archived user with the includeArchived false should result no error.
_, err = ReadByID(ctx, tt.claims(user, accountId), test.MasterDB, user.ID)
_, err = repo.ReadByID(ctx, tt.claims(user, accountId), user.ID)
if err != nil {
t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tUnarchive Read failed.", tests.Failed)
@ -887,14 +900,14 @@ func TestCrud(t *testing.T) {
t.Logf("\t%s\tUnarchive ok.", tests.Success)
// Delete (hard-delete) the user.
err = Delete(ctx, tt.claims(user, accountId), test.MasterDB, UserDeleteRequest{ID: user.ID, force: true})
err = repo.Delete(ctx, tt.claims(user, accountId), UserDeleteRequest{ID: user.ID, force: true})
if err != nil && errors.Cause(err) != tt.updateErr {
t.Logf("\t\tGot : %+v", err)
t.Logf("\t\tWant: %+v", tt.updateErr)
t.Fatalf("\t%s\tUpdate failed.", tests.Failed)
} else if tt.updateErr == nil {
// Trying to find the deleted user with the includeArchived true should result in not found.
_, err = ReadByID(ctx, tt.claims(user, accountId), test.MasterDB, user.ID)
_, err = repo.ReadByID(ctx, tt.claims(user, accountId), user.ID)
if errors.Cause(err) != ErrNotFound {
t.Logf("\t\tGot : %+v", err)
t.Logf("\t\tWant: %+v", ErrNotFound)
@ -917,7 +930,7 @@ func TestFind(t *testing.T) {
var users []*User
for i := 0; i <= 4; i++ {
user, err := Create(tests.Context(), auth.Claims{}, test.MasterDB, UserCreateRequest{
user, err := repo.Create(tests.Context(), auth.Claims{}, UserCreateRequest{
FirstName: "Lee",
LastName: "Brown",
Email: uuid.NewRandom().String() + "@geeksinthewoods.com",
@ -1029,7 +1042,7 @@ func TestFind(t *testing.T) {
{
ctx := tests.Context()
res, err := Find(ctx, auth.Claims{}, test.MasterDB, tt.req)
res, err := repo.Find(ctx, auth.Claims{}, tt.req)
if errors.Cause(err) != tt.error {
t.Logf("\t\tGot : %+v", err)
t.Logf("\t\tWant: %+v", tt.error)
@ -1064,7 +1077,7 @@ func TestResetPassword(t *testing.T) {
// Create a new user for testing.
initPass := uuid.NewRandom().String()
user, err := Create(ctx, auth.Claims{}, test.MasterDB, UserCreateRequest{
user, err := repo.Create(ctx, auth.Claims{}, UserCreateRequest{
FirstName: "Lee",
LastName: "Brown",
Email: uuid.NewRandom().String() + "@geeksinthewoods.com",
@ -1091,18 +1104,10 @@ func TestResetPassword(t *testing.T) {
t.Fatalf("\t%s\tCreate user account failed.", tests.Failed)
}
// Mock the methods needed to make a password reset.
resetUrl := func(string) string {
return ""
}
notify := &notify.MockEmail{}
secretKey := "6368616e676520746869732070617373"
// Ensure validation is working by trying ResetPassword with an empty request.
{
expectedErr := errors.New("Key: 'UserResetPasswordRequest.email' Error:Field validation for 'email' failed on the 'required' tag")
_, err = ResetPassword(ctx, test.MasterDB, resetUrl, notify, UserResetPasswordRequest{}, secretKey, now)
_, err = repo.ResetPassword(ctx, UserResetPasswordRequest{}, now)
if err == nil {
t.Logf("\t\tWant: %+v", expectedErr)
t.Fatalf("\t%s\tResetPassword failed.", tests.Failed)
@ -1122,10 +1127,10 @@ func TestResetPassword(t *testing.T) {
ttl := time.Hour
// Make the reset password request.
resetHash, err := ResetPassword(ctx, test.MasterDB, resetUrl, notify, UserResetPasswordRequest{
resetHash, err := repo.ResetPassword(ctx, UserResetPasswordRequest{
Email: user.Email,
TTL: ttl,
}, secretKey, now)
}, now)
if err != nil {
t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tResetPassword failed.", tests.Failed)
@ -1133,7 +1138,7 @@ func TestResetPassword(t *testing.T) {
t.Logf("\t%s\tResetPassword ok.", tests.Success)
// Read the user to ensure the password_reset field was set.
user, err = ReadByID(ctx, auth.Claims{}, test.MasterDB, user.ID)
user, err = repo.ReadByID(ctx, auth.Claims{}, user.ID)
if err != nil {
t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tRead failed.", tests.Failed)
@ -1146,7 +1151,7 @@ func TestResetPassword(t *testing.T) {
expectedErr := errors.New("Key: 'UserResetConfirmRequest.reset_hash' Error:Field validation for 'reset_hash' failed on the 'required' tag\n" +
"Key: 'UserResetConfirmRequest.password' Error:Field validation for 'password' failed on the 'required' tag\n" +
"Key: 'UserResetConfirmRequest.password_confirm' Error:Field validation for 'password_confirm' failed on the 'required' tag")
_, err = ResetConfirm(ctx, test.MasterDB, UserResetConfirmRequest{}, secretKey, now)
_, err = repo.ResetConfirm(ctx, UserResetConfirmRequest{}, now)
if err == nil {
t.Logf("\t\tWant: %+v", expectedErr)
t.Fatalf("\t%s\tResetConfirm failed.", tests.Failed)
@ -1166,11 +1171,11 @@ func TestResetPassword(t *testing.T) {
// Ensure the TTL is enforced.
{
newPass := uuid.NewRandom().String()
_, err = ResetConfirm(ctx, test.MasterDB, UserResetConfirmRequest{
_, err = repo.ResetConfirm(ctx, UserResetConfirmRequest{
ResetHash: resetHash,
Password: newPass,
PasswordConfirm: newPass,
}, secretKey, now.UTC().Add(ttl*2))
}, now.UTC().Add(ttl*2))
if errors.Cause(err) != ErrResetExpired {
t.Logf("\t\tGot : %+v", errors.Cause(err))
t.Logf("\t\tWant: %+v", ErrResetExpired)
@ -1181,11 +1186,11 @@ func TestResetPassword(t *testing.T) {
// Assuming we have received the email and clicked the link, we now can ensure confirm works.
newPass := uuid.NewRandom().String()
reset, err := ResetConfirm(ctx, test.MasterDB, UserResetConfirmRequest{
reset, err := repo.ResetConfirm(ctx, UserResetConfirmRequest{
ResetHash: resetHash,
Password: newPass,
PasswordConfirm: newPass,
}, secretKey, now)
}, now)
if err != nil {
t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tResetConfirm failed.", tests.Failed)
@ -1199,11 +1204,11 @@ func TestResetPassword(t *testing.T) {
// Ensure the reset hash does not work after its used.
{
newPass := uuid.NewRandom().String()
_, err = ResetConfirm(ctx, test.MasterDB, UserResetConfirmRequest{
_, err = repo.ResetConfirm(ctx, UserResetConfirmRequest{
ResetHash: resetHash,
Password: newPass,
PasswordConfirm: newPass,
}, secretKey, now)
}, now)
if errors.Cause(err) != ErrNotFound {
t.Logf("\t\tGot : %+v", errors.Cause(err))
t.Logf("\t\tWant: %+v", ErrNotFound)

View File

@ -3,6 +3,7 @@ package user_account
import (
"context"
"database/sql/driver"
"github.com/jmoiron/sqlx"
"strings"
"time"
@ -13,6 +14,18 @@ import (
"gopkg.in/go-playground/validator.v9"
)
// Repository defines the required dependencies for UserAccount.
type Repository struct {
DbConn *sqlx.DB
}
// NewRepository creates a new Repository that defines dependencies for UserAccount.
func NewRepository(db *sqlx.DB) *Repository {
return &Repository{
DbConn: db,
}
}
// UserAccount defines the one to many relationship of an user to an account. This
// will enable a single user access to multiple accounts without having duplicate
// users. Each association of a user to an account has a set of roles and a status

View File

@ -6,13 +6,12 @@ import (
"geeks-accelerator/oss/saas-starter-kit/internal/platform/auth"
"geeks-accelerator/oss/saas-starter-kit/internal/platform/web/webcontext"
"github.com/huandu/go-sqlbuilder"
"github.com/jmoiron/sqlx"
"github.com/pkg/errors"
"gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer"
)
// UserFindByAccount lists all the users for a given account ID.
func UserFindByAccount(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserFindByAccountRequest) (Users, error) {
func (repo *Repository) UserFindByAccount(ctx context.Context, claims auth.Claims, req UserFindByAccountRequest) (Users, error) {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user_account.UserFindByAccount")
defer span.Finish()
@ -113,12 +112,12 @@ func UserFindByAccount(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB,
}
queryStr, moreQueryArgs := query.Build()
queryStr = dbConn.Rebind(queryStr)
queryStr = repo.DbConn.Rebind(queryStr)
queryArgs = append(queryArgs, moreQueryArgs...)
// fetch all places from the db
rows, err := dbConn.QueryContext(ctx, queryStr, queryArgs...)
rows, err := repo.DbConn.QueryContext(ctx, queryStr, queryArgs...)
if err != nil {
err = errors.Wrapf(err, "query - %s", query.String())
err = errors.WithMessage(err, "find users failed")

View File

@ -49,14 +49,14 @@ func mapRowsToUserAccount(rows *sql.Rows) (*UserAccount, error) {
}
// CanReadAccount determines if claims has the authority to access the specified user account by user ID.
func CanReadAccount(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, accountID string) error {
err := account.CanReadAccount(ctx, claims, dbConn, accountID)
func (repo *Repository) CanReadAccount(ctx context.Context, claims auth.Claims, accountID string) error {
err := account.CanReadAccount(ctx, claims, accountID)
return mapAccountError(err)
}
// CanModifyAccount determines if claims has the authority to modify the specified user ID.
func CanModifyAccount(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, accountID string) error {
err := account.CanModifyAccount(ctx, claims, dbConn, accountID)
func (repo *Repository) CanModifyAccount(ctx context.Context, claims auth.Claims, accountID string) error {
err := account.CanModifyAccount(ctx, claims, accountID)
return mapAccountError(err)
}

View File

@ -17,7 +17,10 @@ import (
"github.com/pkg/errors"
)
var test *tests.Test
var (
test *tests.Test
repo *Repository
)
// TestMain is the entry point for testing.
func TestMain(m *testing.M) {
@ -27,6 +30,9 @@ func TestMain(m *testing.M) {
func testMain(m *testing.M) int {
test = tests.New()
defer test.TearDown()
repo = NewRepository(test.MasterDB)
return m.Run()
}