1
0
mirror of https://github.com/raseels-repos/golang-saas-starter-kit.git synced 2025-06-17 00:17:59 +02:00

Completed user and account packages, extracted user_account code to independent package

This commit is contained in:
Lee Brown
2019-06-22 17:48:44 -08:00
parent 1886bf1570
commit efaeeb7103
13 changed files with 3897 additions and 322 deletions

2
.gitignore vendored
View File

@ -1,2 +1,4 @@
.idea
go.mod
aws.lee aws.lee
aws.* aws.*

35
CONTRIBUTORS Normal file
View File

@ -0,0 +1,35 @@
# This is the official list of people who can contribute
# (and typically have contributed) code to the gotraining repository.
#
# Names should be added to this file only after verifying that
# the individual or the individual's organization has agreed to
# the appropriate Contributor License Agreement, found here:
#
# http://code.google.com/legal/individual-cla-v1.0.html
# http://code.google.com/legal/corporate-cla-v1.0.html
#
# The agreement for individuals can be filled out on the web.
# Names should be added to this file like so:
# Name <email address>
#
# An entry with two email addresses specifies that the
# first address should be used in the submit logs and
# that the second address should be recognized as the
# same person when interacting with Rietveld.
# Please keep the list sorted.
Arash Bina <arash@arash.io>
Askar Sagyndyk <superwhykz@gmail.com>
Bob Cao <3308031+bobintornado@users.noreply.github.com>
Ed Gonzo <Ed@ardanstudios.com>
Farrukh Kurbanov <farrukhkurbanov@Administrators-MacBook-Pro.local>
Jacob Walker <jacob@ardanlabs.com>
Jeremy Stone <slycrel@gmail.com>
Nick Stogner <nstogner@users.noreply.github.com>
William Kennedy <bill@ardanlabs.com>
Wyatt Johnson <wyattjoh@gmail.com>
Zachary Johnson <zachjohnsondev@gmail.com>
Lee Brown <lee@geeksinthewoods.com>
Lucas Brown <lucas@geeksinthewoods.com>

View File

@ -0,0 +1,622 @@
package account
import (
"context"
"database/sql"
"time"
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/auth"
"github.com/huandu/go-sqlbuilder"
"github.com/jmoiron/sqlx"
"github.com/pborman/uuid"
"github.com/pkg/errors"
"gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer"
"gopkg.in/go-playground/validator.v9"
)
const (
// The database table for Account
accountTableName = "accounts"
// The database table for User Account
userAccountTableName = "users_accounts"
)
var (
// ErrNotFound abstracts the mgo not found error.
ErrNotFound = errors.New("Entity not found")
// ErrInvalidID occurs when an ID is not in a valid form.
ErrInvalidID = errors.New("ID is not in its proper form")
// ErrForbidden occurs when a user tries to do something that is forbidden to them according to our access control policies.
ErrForbidden = errors.New("Attempted action is not allowed")
)
// accountMapColumns is the list of columns needed for mapRowsToAccount
var accountMapColumns = "id,name,address1,address2,city,region,country,zipcode,status,timezone,signup_user_id,billing_user_id,created_at,updated_at,archived_at"
// mapRowsToAccount takes the SQL rows and maps it to the Account struct
// with the columns defined by accountMapColumns
func mapRowsToAccount(rows *sql.Rows) (*Account, error) {
var (
a Account
err error
)
err = rows.Scan(&a.ID, &a.Name, &a.Address1, &a.Address2, &a.City, &a.Region, &a.Country, &a.Zipcode, &a.Status, &a.Timezone, &a.SignupUserID, &a.BillingUserID, &a.CreatedAt, &a.UpdatedAt, &a.ArchivedAt)
if err != nil {
return nil, errors.WithStack(err)
}
return &a, nil
}
// CanReadAccount determines if claims has the authority to access the specified account ID.
func CanReadAccount(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, accountID string) error {
// If the request has claims from a specific account, ensure that the claims
// has the correct access to the account.
if claims.Audience != "" && claims.Audience != accountID {
// When the claims Audience/AccountID does not match the requested account, the
// claims Audience/AccountID - should have a record for the claims user.
// select id from users_accounts where account_id = [accountID] and user_id = [claims.Subject]
query := sqlbuilder.NewSelectBuilder().Select("id").From(userAccountTableName)
query.Where(query.And(
query.Equal("account_id", accountID),
query.Equal("user_id", claims.Subject),
))
queryStr, args := query.Build()
queryStr = dbConn.Rebind(queryStr)
var userAccountId string
err := dbConn.QueryRowContext(ctx, queryStr, args...).Scan(&userAccountId)
if err != nil && err != sql.ErrNoRows {
err = errors.Wrapf(err, "query - %s", query.String())
return err
}
// When there is no userAccount ID returned, then the current claim user does not have access
// to the specified account.
if userAccountId == "" {
return errors.WithStack(ErrForbidden)
}
}
return nil
}
// CanModifyAccount determines if claims has the authority to modify the specified account ID.
func CanModifyAccount(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, accountID string) error {
// If the request has claims from a specific account, ensure that the claims
// has the correct access to the account.
if claims.Audience != "" {
if claims.Audience == accountID {
// Admin users can update accounts they have access to.
if !claims.HasRole(auth.RoleAdmin) {
return errors.WithStack(ErrForbidden)
}
} else {
// When the claims Audience/AccountID does not match the requested account, the
// claims Audience/AccountID should have a record with an admin role.
// select id from users_accounts where account_id = [accountID] and user_id = [claims.Subject] and any (roles) = 'admin'
query := sqlbuilder.NewSelectBuilder().Select("id").From(userAccountTableName)
query.Where(query.And(
query.Equal("account_id", accountID),
query.Equal("user_id", claims.Subject),
"'"+auth.RoleAdmin+"' = ANY (roles)",
))
queryStr, args := query.Build()
queryStr = dbConn.Rebind(queryStr)
var userAccountId string
err := dbConn.QueryRowContext(ctx, queryStr, args...).Scan(&userAccountId)
if err != nil && err != sql.ErrNoRows {
err = errors.Wrapf(err, "query - %s", query.String())
return err
}
// When there is no userAccount ID returned, then the current claim user does not have access
// to the specified account.
if userAccountId == "" {
return errors.WithStack(ErrForbidden)
}
}
}
return nil
}
// applyClaimsSelect applies a sub-query to the provided query to enforce ACL based on
// the claims provided.
// 1. All role types can access their user ID
// 2. Any user with the same account ID
// 3. No claims, request is internal, no ACL applied
func applyClaimsSelect(ctx context.Context, claims auth.Claims, query *sqlbuilder.SelectBuilder) error {
// Claims are empty, don't apply any ACL
if claims.Audience == "" && claims.Subject == "" {
return nil
}
// Build select statement for users_accounts table
subQuery := sqlbuilder.NewSelectBuilder().Select("account_id").From(userAccountTableName)
var or []string
if claims.Audience != "" {
or = append(or, subQuery.Equal("account_id", claims.Audience))
}
if claims.Subject != "" {
or = append(or, subQuery.Equal("user_id", claims.Subject))
}
subQuery.Where(subQuery.Or(or...))
// Append sub query
query.Where(query.In("id", subQuery))
return nil
}
// selectQuery constructs a base select query for Account
func selectQuery() *sqlbuilder.SelectBuilder {
query := sqlbuilder.NewSelectBuilder()
query.Select(accountMapColumns)
query.From(accountTableName)
return query
}
// findRequestQuery generates the select query for the given find request.
// TODO: Need to figure out why can't parse the args when appending the where
// to the query.
func findRequestQuery(req AccountFindRequest) (*sqlbuilder.SelectBuilder, []interface{}) {
query := selectQuery()
if req.Where != nil {
query.Where(query.And(*req.Where))
}
if len(req.Order) > 0 {
query.OrderBy(req.Order...)
}
if req.Limit != nil {
query.Limit(int(*req.Limit))
}
if req.Offset != nil {
query.Offset(int(*req.Offset))
}
return query, req.Args
}
// Find gets all the accounts from the database based on the request params.
func Find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req AccountFindRequest) ([]*Account, error) {
query, args := findRequestQuery(req)
return find(ctx, claims, dbConn, query, args, req.IncludedArchived)
}
// find internal method for getting all the accounts from the database using a select query.
func find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, query *sqlbuilder.SelectBuilder, args []interface{}, includedArchived bool) ([]*Account, error) {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.account.Find")
defer span.Finish()
query.Select(accountMapColumns)
query.From(accountTableName)
if !includedArchived {
query.Where(query.IsNull("archived_at"))
}
// Check to see if a sub query needs to be applied for the claims
err := applyClaimsSelect(ctx, claims, query)
if err != nil {
return nil, err
}
queryStr, queryArgs := query.Build()
queryStr = dbConn.Rebind(queryStr)
args = append(args, queryArgs...)
// fetch all places from the db
rows, err := dbConn.QueryContext(ctx, queryStr, args...)
if err != nil {
err = errors.Wrapf(err, "query - %s", query.String())
err = errors.WithMessage(err, "find accounts failed")
return nil, err
}
// iterate over each row
resp := []*Account{}
for rows.Next() {
u, err := mapRowsToAccount(rows)
if err != nil {
err = errors.Wrapf(err, "query - %s", query.String())
return nil, err
}
resp = append(resp, u)
}
return resp, nil
}
// Validation an name is unique excluding the current account ID.
func uniqueName(ctx context.Context, dbConn *sqlx.DB, name, accountId string) (bool, error) {
query := sqlbuilder.NewSelectBuilder().Select("id").From(accountTableName)
query.Where(query.And(
query.Equal("name", name),
query.NotEqual("id", accountId),
))
queryStr, args := query.Build()
queryStr = dbConn.Rebind(queryStr)
var existingId string
err := dbConn.QueryRowContext(ctx, queryStr, args...).Scan(&existingId)
if err != nil && err != sql.ErrNoRows {
err = errors.Wrapf(err, "query - %s", query.String())
return false, err
}
// When an ID was found in the db, the name is not unique.
if existingId != "" {
return false, nil
}
return true, nil
}
// Create inserts a new account into the database.
func Create(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req CreateAccountRequest, now time.Time) (*Account, error) {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.account.Create")
defer span.Finish()
v := validator.New()
// Validation email address is unique in the database.
uniq, err := uniqueName(ctx, dbConn, req.Name, "")
if err != nil {
return nil, err
}
f := func(fl validator.FieldLevel) bool {
if fl.Field().String() == "invalid" {
return false
}
return uniq
}
v.RegisterValidation("unique", f)
// Validate the request.
err = v.Struct(req)
if err != nil {
return nil, err
}
// If now empty set it to the current time.
if now.IsZero() {
now = time.Now()
}
// Always store the time as UTC.
now = now.UTC()
// Postgres truncates times to milliseconds when storing. We and do the same
// here so the value we return is consistent with what we store.
now = now.Truncate(time.Millisecond)
a := Account{
ID: uuid.NewRandom().String(),
Name: req.Name,
Address1: req.Address1,
Address2: req.Address2,
City: req.City,
Region: req.Region,
Country: req.Country,
Zipcode: req.Zipcode,
Status: AccountStatus_Pending,
Timezone: "America/Anchorage",
CreatedAt: now,
UpdatedAt: now,
}
if req.Status != nil {
a.Status = *req.Status
}
if req.Timezone != nil {
a.Timezone = *req.Timezone
}
if req.SignupUserID != nil {
a.SignupUserID = sql.NullString{String: *req.SignupUserID, Valid: true}
}
if req.BillingUserID != nil {
a.BillingUserID = sql.NullString{String: *req.BillingUserID, Valid: true}
}
// Build the insert SQL statement.
query := sqlbuilder.NewInsertBuilder()
query.InsertInto(accountTableName)
query.Cols("id", "name", "address1", "address2", "city", "region", "country", "zipcode", "status", "timezone", "signup_user_id", "billing_user_id", "created_at", "updated_at")
query.Values(a.ID, a.Name, a.Address1, a.Address2, a.City, a.Region, a.Country, a.Zipcode, a.Status.String(), a.Timezone, a.SignupUserID, a.BillingUserID, a.CreatedAt, a.UpdatedAt)
// Execute the query with the provided context.
sql, args := query.Build()
sql = dbConn.Rebind(sql)
_, err = dbConn.ExecContext(ctx, sql, args...)
if err != nil {
err = errors.Wrapf(err, "query - %s", query.String())
err = errors.WithMessage(err, "create account failed")
return nil, err
}
return &a, nil
}
// Read gets the specified account from the database.
func Read(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, id string, includedArchived bool) (*Account, error) {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.account.Read")
defer span.Finish()
// Filter base select query by ID
query := selectQuery()
query.Where(query.Equal("id", id))
res, err := find(ctx, claims, dbConn, query, []interface{}{}, includedArchived)
if err != nil {
return nil, err
} else if res == nil || len(res) == 0 {
err = errors.WithMessagef(ErrNotFound, "account %s not found", id)
return nil, err
}
u := res[0]
return u, nil
}
// Update replaces an account in the database.
func Update(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UpdateAccountRequest, now time.Time) error {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.account.Update")
defer span.Finish()
v := validator.New()
// Validation name is unique in the database.
if req.Name != nil {
uniq, err := uniqueName(ctx, dbConn, *req.Name, req.ID)
if err != nil {
return err
}
f := func(fl validator.FieldLevel) bool {
if fl.Field().String() == "invalid" {
return false
}
return uniq
}
v.RegisterValidation("unique", f)
}
// Validate the request.
err := v.Struct(req)
if err != nil {
return err
}
// Ensure the claims can modify the account specified in the request.
err = CanModifyAccount(ctx, claims, dbConn, req.ID)
if err != nil {
return err
}
// If now empty set it to the current time.
if now.IsZero() {
now = time.Now()
}
// Always store the time as UTC.
now = now.UTC()
// Postgres truncates times to milliseconds when storing. We and do the same
// here so the value we return is consistent with what we store.
now = now.Truncate(time.Millisecond)
// Build the update SQL statement.
query := sqlbuilder.NewUpdateBuilder()
query.Update(accountTableName)
var fields []string
if req.Name != nil {
fields = append(fields, query.Assign("name", req.Name))
}
if req.Address1 != nil {
fields = append(fields, query.Assign("address1", req.Address1))
}
if req.Address2 != nil {
fields = append(fields, query.Assign("address2", req.Address2))
}
if req.City != nil {
fields = append(fields, query.Assign("city", req.City))
}
if req.Region != nil {
fields = append(fields, query.Assign("region", req.Region))
}
if req.Country != nil {
fields = append(fields, query.Assign("country", req.Country))
}
if req.Zipcode != nil {
fields = append(fields, query.Assign("zipcode", req.Zipcode))
}
if req.Status != nil {
fields = append(fields, query.Assign("status", req.Status))
}
if req.Timezone != nil {
fields = append(fields, query.Assign("timezone", req.Timezone))
}
if req.SignupUserID != nil {
if *req.SignupUserID != "" {
fields = append(fields, query.Assign("signup_user_id", req.SignupUserID))
} else {
fields = append(fields, query.Assign("signup_user_id", nil))
}
}
if req.BillingUserID != nil {
if *req.BillingUserID != "" {
fields = append(fields, query.Assign("billing_user_id", req.BillingUserID))
} else {
fields = append(fields, query.Assign("billing_user_id", nil))
}
}
// If there's nothing to update we can quit early.
if len(fields) == 0 {
return nil
}
// Append the updated_at field
fields = append(fields, query.Assign("updated_at", now))
query.Set(fields...)
query.Where(query.Equal("id", req.ID))
// Execute the query with the provided context.
sql, args := query.Build()
sql = dbConn.Rebind(sql)
_, err = dbConn.ExecContext(ctx, sql, args...)
if err != nil {
err = errors.Wrapf(err, "query - %s", query.String())
err = errors.WithMessagef(err, "update account %s failed", req.ID)
return err
}
return nil
}
// Archive soft deleted the account from the database.
func Archive(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, accountID string, now time.Time) error {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.account.Archive")
defer span.Finish()
// Defines the struct to apply validation
req := struct {
ID string `validate:"required,uuid"`
}{
ID: accountID,
}
// Validate the request.
err := validator.New().Struct(req)
if err != nil {
return err
}
// Ensure the claims can modify the account specified in the request.
err = CanModifyAccount(ctx, claims, dbConn, req.ID)
if err != nil {
return err
}
// If now empty set it to the current time.
if now.IsZero() {
now = time.Now()
}
// Always store the time as UTC.
now = now.UTC()
// Postgres truncates times to milliseconds when storing. We and do the same
// here so the value we return is consistent with what we store.
now = now.Truncate(time.Millisecond)
// Build the update SQL statement.
query := sqlbuilder.NewUpdateBuilder()
query.Update(accountTableName)
query.Set(
query.Assign("archived_at", now),
)
query.Where(query.Equal("id", req.ID))
// Execute the query with the provided context.
sql, args := query.Build()
sql = dbConn.Rebind(sql)
_, err = dbConn.ExecContext(ctx, sql, args...)
if err != nil {
err = errors.Wrapf(err, "query - %s", query.String())
err = errors.WithMessagef(err, "archive account %s failed", req.ID)
return err
}
// Archive all the associated user accounts
{
// Build the update SQL statement.
query := sqlbuilder.NewUpdateBuilder()
query.Update(userAccountTableName)
query.Set(query.Assign("archived_at", now))
query.Where(query.And(
query.Equal("account_id", req.ID),
))
// Execute the query with the provided context.
sql, args := query.Build()
sql = dbConn.Rebind(sql)
_, err = dbConn.ExecContext(ctx, sql, args...)
if err != nil {
err = errors.Wrapf(err, "query - %s", query.String())
err = errors.WithMessagef(err, "archive users for account %s failed", req.ID)
return err
}
}
return nil
}
// Delete removes an account from the database.
func Delete(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, accountID string) error {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.account.Delete")
defer span.Finish()
// Defines the struct to apply validation
req := struct {
ID string `validate:"required,uuid"`
}{
ID: accountID,
}
// Validate the request.
err := validator.New().Struct(req)
if err != nil {
return err
}
// Ensure the claims can modify the account specified in the request.
err = CanModifyAccount(ctx, claims, dbConn, req.ID)
if err != nil {
return err
}
// Build the delete SQL statement.
query := sqlbuilder.NewDeleteBuilder()
query.DeleteFrom(accountTableName)
query.Where(query.Equal("id", req.ID))
// Execute the query with the provided context.
sql, args := query.Build()
sql = dbConn.Rebind(sql)
_, err = dbConn.ExecContext(ctx, sql, args...)
if err != nil {
err = errors.Wrapf(err, "query - %s", query.String())
err = errors.WithMessagef(err, "delete account %s failed", req.ID)
return err
}
// Delete all the associated user accounts
{
// Build the delete SQL statement.
query := sqlbuilder.NewDeleteBuilder()
query.DeleteFrom(userAccountTableName)
query.Where(query.And(
query.Equal("account_id", req.ID),
))
// Execute the query with the provided context.
sql, args := query.Build()
sql = dbConn.Rebind(sql)
_, err = dbConn.ExecContext(ctx, sql, args...)
if err != nil {
err = errors.Wrapf(err, "query - %s", query.String())
err = errors.WithMessagef(err, "delete users for account %s failed", req.ID)
return err
}
}
return nil
}

File diff suppressed because it is too large Load Diff

View File

@ -6,8 +6,8 @@ import (
"time" "time"
"github.com/lib/pq" "github.com/lib/pq"
"gopkg.in/go-playground/validator.v9"
"github.com/pkg/errors" "github.com/pkg/errors"
"gopkg.in/go-playground/validator.v9"
) )
// Account represents someone with access to our system. // Account represents someone with access to our system.
@ -124,4 +124,3 @@ func (s AccountStatus) Value() (driver.Value, error) {
func (s AccountStatus) String() string { func (s AccountStatus) String() string {
return string(s) return string(s)
} }

View File

@ -50,7 +50,6 @@ type Authenticator struct {
parser *jwt.Parser parser *jwt.Parser
} }
// NewAuthenticator creates an *Authenticator for use. // NewAuthenticator creates an *Authenticator for use.
// key expiration is optional to filter out old keys // key expiration is optional to filter out old keys
// It will error if: // It will error if:

View File

@ -2,15 +2,16 @@ package user
import ( import (
"context" "context"
"gopkg.in/go-playground/validator.v9"
"time" "time"
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/auth" "geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/auth"
"github.com/huandu/go-sqlbuilder" "github.com/huandu/go-sqlbuilder"
"github.com/jmoiron/sqlx" "github.com/jmoiron/sqlx"
"github.com/lib/pq"
"github.com/pkg/errors" "github.com/pkg/errors"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
"gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer" "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer"
"gopkg.in/go-playground/validator.v9"
) )
// TokenGenerator is the behavior we need in our Authenticate to generate tokens for // TokenGenerator is the behavior we need in our Authenticate to generate tokens for
@ -89,19 +90,84 @@ func SwitchAccount(ctx context.Context, dbConn *sqlx.DB, tknGen TokenGenerator,
// generateToken generates claims for the supplied user ID and account ID and then // generateToken generates claims for the supplied user ID and account ID and then
// returns the token for the generated claims used for authentication. // returns the token for the generated claims used for authentication.
func generateToken(ctx context.Context, dbConn *sqlx.DB, tknGen TokenGenerator, claims auth.Claims, userID, accountID string, expires time.Duration, now time.Time) (Token, error) { func generateToken(ctx context.Context, dbConn *sqlx.DB, tknGen TokenGenerator, claims auth.Claims, userID, accountID string, expires time.Duration, now time.Time) (Token, error) {
// Get a list of all the accounts associated with the user.
accounts, err := FindAccountsByUserID(ctx, auth.Claims{}, dbConn, userID, false) type userAccount struct {
AccountID string
Roles pq.StringArray
UserStatus string
UserArchived pq.NullTime
AccountStatus string
AccountArchived pq.NullTime
}
// Build select statement for users_accounts table to find all the user accounts for the user
f := func() ([]userAccount, error) {
query := sqlbuilder.NewSelectBuilder().Select("ua.account_id, ua.roles, ua.status as userStatus, ua.archived_at userArchived, a.status as accountStatus, a.archived_at as accountArchived").
From(userAccountTableName+" ua").
Join(accountTableName+" a", "a.id = ua.account_id")
query.Where(query.And(
query.Equal("ua.user_id", userID),
))
query.OrderBy("ua.status, a.status, ua.created_at")
// fetch all places from the db
queryStr, queryArgs := query.Build()
queryStr = dbConn.Rebind(queryStr)
rows, err := dbConn.QueryContext(ctx, queryStr, queryArgs...)
if err != nil { if err != nil {
err = errors.Wrapf(err, "query - %s", query.String())
return nil, err
}
// iterate over each row
var resp []userAccount
for rows.Next() {
var ua userAccount
err = rows.Scan(&ua.AccountID, &ua.Roles, &ua.UserStatus, &ua.UserArchived, &ua.AccountStatus, &ua.AccountArchived)
if err != nil {
return nil, errors.WithStack(err)
}
if err != nil {
err = errors.Wrapf(err, "query - %s", query.String())
return nil, err
}
resp = append(resp, ua)
}
return resp, nil
}
accounts, err := f()
if err != nil {
err = errors.WithStack(ErrAuthenticationFailure)
return Token{}, err return Token{}, err
} }
// Load the user account entry for the specifed account ID. If none provided, // Load the user account entry for the specified account ID. If none provided,
// choose the first. // choose the first.
var account *UserAccount var account userAccount
if accountID == "" { if accountID == "" {
// Try to choose the first active user account that has not been archived.
for _, a := range accounts {
if a.AccountArchived.Valid && !a.AccountArchived.Time.IsZero() {
continue
} else if a.UserArchived.Valid && !a.UserArchived.Time.IsZero() {
continue
} else if a.AccountStatus != "active" {
continue
} else if a.UserStatus != "active" {
continue
}
account = accounts[0]
accountID = account.AccountID
break
}
// Select the first account associated with the user. For the login flow, // Select the first account associated with the user. For the login flow,
// users could be forced to select a specific account to override this. // users could be forced to select a specific account to override this.
if len(accounts) > 0 { if accountID == "" && len(accounts) > 0 {
account = accounts[0] account = accounts[0]
accountID = account.AccountID accountID = account.AccountID
} }
@ -116,18 +182,25 @@ func generateToken(ctx context.Context, dbConn *sqlx.DB, tknGen TokenGenerator,
} }
// If no matching entry was found for the specified account ID throw an error. // If no matching entry was found for the specified account ID throw an error.
if account == nil { if account.AccountID == "" {
err = errors.WithStack(ErrAuthenticationFailure) err = errors.WithStack(ErrAuthenticationFailure)
return Token{}, err return Token{}, err
} }
} }
// Generate list of user defined roles for accessing the account. // Validate the user account is completely active.
var roles []string if account.AccountArchived.Valid && !account.AccountArchived.Time.IsZero() {
if account != nil { err = errors.WithMessage(ErrAuthenticationFailure, "account is archived")
for _, r := range account.Roles { return Token{}, err
roles = append(roles, r.String()) } else if account.UserArchived.Valid && !account.UserArchived.Time.IsZero() {
} err = errors.WithMessage(ErrAuthenticationFailure, "user account is archived")
return Token{}, err
} else if account.AccountStatus != "active" {
err = errors.WithMessagef(ErrAuthenticationFailure, "account is not active with status of %s", account.AccountStatus)
return Token{}, err
} else if account.UserStatus != "active" {
err = errors.WithMessagef(ErrAuthenticationFailure, "user account is not active with status of %s", account.UserStatus)
return Token{}, err
} }
// Generate a list of all the account IDs associated with the user so the use // Generate a list of all the account IDs associated with the user so the use
@ -141,7 +214,7 @@ func generateToken(ctx context.Context, dbConn *sqlx.DB, tknGen TokenGenerator,
// Subject: The ID of the user authenticated. // Subject: The ID of the user authenticated.
// Audience: The ID of the account the user is accessing. A list of account IDs // Audience: The ID of the account the user is accessing. A list of account IDs
// will also be included to support the user switching between them. // will also be included to support the user switching between them.
claims = auth.NewClaims(userID, accountID, accountIds, roles, now, expires) claims = auth.NewClaims(userID, accountID, accountIds, account.Roles, now, expires)
// Generate a token for the user with the defined claims. // Generate a token for the user with the defined claims.
tkn, err := tknGen.GenerateToken(claims) tkn, err := tknGen.GenerateToken(claims)

View File

@ -116,31 +116,36 @@ func TestAuthenticate(t *testing.T) {
} }
t.Logf("\t%s\tCreate user ok.", tests.Success) t.Logf("\t%s\tCreate user ok.", tests.Success)
// Create a new random account and associate that with the user. // Create a new random account.
// This defined role should be the claims.
account1Id := uuid.NewRandom().String() account1Id := uuid.NewRandom().String()
account1Role := UserAccountRole_Admin err = mockAccount(account1Id, user.CreatedAt)
_, err = AddAccount(tests.Context(), auth.Claims{}, test.MasterDB, AddAccountRequest{
UserID: user.ID,
AccountID: account1Id,
Roles: []UserAccountRole{account1Role},
}, now)
if err != nil { if err != nil {
t.Log("\t\tGot :", err) t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tAddAccount failed.", tests.Failed) t.Fatalf("\t%s\tCreate account failed.", tests.Failed)
} }
// Create a second new random account and associate that with the user. // Associate new account with user user. This defined role should be the claims.
account2Id := uuid.NewRandom().String() account1Role := auth.RoleAdmin
account2Role := UserAccountRole_User err = mockUserAccount(user.ID, account1Id, user.CreatedAt, account1Role)
_, err = AddAccount(tests.Context(), auth.Claims{}, test.MasterDB, AddAccountRequest{
UserID: user.ID,
AccountID: account2Id,
Roles: []UserAccountRole{account2Role},
}, now.Add(time.Second))
if err != nil { if err != nil {
t.Log("\t\tGot :", err) t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tAddAccount failed.", tests.Failed) t.Fatalf("\t%s\tCreate user account failed.", tests.Failed)
}
// Create a second new random account.
account2Id := uuid.NewRandom().String()
err = mockAccount(account2Id, user.CreatedAt)
if err != nil {
t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tCreate account failed.", tests.Failed)
}
// Associate secoend new account with user user.
account2Role := auth.RoleUser
err = mockUserAccount(user.ID, account2Id, user.CreatedAt, account2Role)
if err != nil {
t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tCreate user account failed.", tests.Failed)
} }
// Add 30 minutes to now to simulate time passing. // Add 30 minutes to now to simulate time passing.
@ -170,7 +175,7 @@ func TestAuthenticate(t *testing.T) {
t.Fatalf("\t%s\tParse claims from token failed.", tests.Failed) t.Fatalf("\t%s\tParse claims from token failed.", tests.Failed)
} else if diff := cmp.Diff(claims1, tkn1.claims); diff != "" { } else if diff := cmp.Diff(claims1, tkn1.claims); diff != "" {
t.Fatalf("\t%s\tExpected parsed claims to match from token. Diff:\n%s", tests.Failed, diff) t.Fatalf("\t%s\tExpected parsed claims to match from token. Diff:\n%s", tests.Failed, diff)
} else if diff := cmp.Diff(claims1.Roles, []string{account1Role.String()}); diff != "" { } else if diff := cmp.Diff(claims1.Roles, []string{account1Role}); diff != "" {
t.Fatalf("\t%s\tExpected parsed claims roles to match user account. Diff:\n%s", tests.Failed, diff) t.Fatalf("\t%s\tExpected parsed claims roles to match user account. Diff:\n%s", tests.Failed, diff)
} else if diff := cmp.Diff(claims1.AccountIds, []string{account1Id, account2Id}); diff != "" { } else if diff := cmp.Diff(claims1.AccountIds, []string{account1Id, account2Id}); diff != "" {
t.Fatalf("\t%s\tExpected parsed claims account IDs to match the single user account. Diff:\n%s", tests.Failed, diff) t.Fatalf("\t%s\tExpected parsed claims account IDs to match the single user account. Diff:\n%s", tests.Failed, diff)
@ -192,7 +197,7 @@ func TestAuthenticate(t *testing.T) {
t.Fatalf("\t%s\tParse claims from token failed.", tests.Failed) t.Fatalf("\t%s\tParse claims from token failed.", tests.Failed)
} else if diff := cmp.Diff(claims2, tkn2.claims); diff != "" { } else if diff := cmp.Diff(claims2, tkn2.claims); diff != "" {
t.Fatalf("\t%s\tExpected parsed claims to match from token. Diff:\n%s", tests.Failed, diff) t.Fatalf("\t%s\tExpected parsed claims to match from token. Diff:\n%s", tests.Failed, diff)
} else if diff := cmp.Diff(claims2.Roles, []string{account2Role.String()}); diff != "" { } else if diff := cmp.Diff(claims2.Roles, []string{account2Role}); diff != "" {
t.Fatalf("\t%s\tExpected parsed claims roles to match user account. Diff:\n%s", tests.Failed, diff) t.Fatalf("\t%s\tExpected parsed claims roles to match user account. Diff:\n%s", tests.Failed, diff)
} else if diff := cmp.Diff(claims2.AccountIds, []string{account1Id, account2Id}); diff != "" { } else if diff := cmp.Diff(claims2.AccountIds, []string{account1Id, account2Id}); diff != "" {
t.Fatalf("\t%s\tExpected parsed claims account IDs to match the single user account. Diff:\n%s", tests.Failed, diff) t.Fatalf("\t%s\tExpected parsed claims account IDs to match the single user account. Diff:\n%s", tests.Failed, diff)

View File

@ -0,0 +1,72 @@
package user
import (
"database/sql"
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/auth"
"time"
"github.com/lib/pq"
)
// User represents someone with access to our system.
type User struct {
ID string `json:"id"`
Name string `json:"name"`
Email string `json:"email"`
PasswordSalt string `json:"-"`
PasswordHash []byte `json:"-"`
PasswordReset sql.NullString `json:"-"`
Timezone string `json:"timezone"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
ArchivedAt pq.NullTime `json:"archived_at"`
}
// CreateUserRequest contains information needed to create a new User.
type CreateUserRequest struct {
Name string `json:"name" validate:"required"`
Email string `json:"email" validate:"required,email,unique"`
Password string `json:"password" validate:"required"`
PasswordConfirm string `json:"password_confirm" validate:"eqfield=Password"`
Timezone *string `json:"timezone" validate:"omitempty"`
}
// UpdateUserRequest defines what information may be provided to modify an existing
// User. All fields are optional so clients can send just the fields they want
// changed. It uses pointer fields so we can differentiate between a field that
// was not provided and a field that was provided as explicitly blank. Normally
// we do not want to use pointers to basic types but we make exceptions around
// marshalling/unmarshalling.
type UpdateUserRequest struct {
ID string `validate:"required,uuid"`
Name *string `json:"name" validate:"omitempty"`
Email *string `json:"email" validate:"omitempty,email,unique"`
Timezone *string `json:"timezone" validate:"omitempty"`
}
// UpdatePassword defines what information is required to update a user password.
type UpdatePasswordRequest struct {
ID string `validate:"required,uuid"`
Password string `json:"password" validate:"required"`
PasswordConfirm string `json:"password_confirm" validate:"omitempty,eqfield=Password"`
}
// UserFindRequest defines the possible options to search for users. By default
// archived users will be excluded from response.
type UserFindRequest struct {
Where *string
Args []interface{}
Order []string
Limit *uint
Offset *uint
IncludedArchived bool
}
// Token is the payload we deliver to users when they authenticate.
type Token struct {
Token string `json:"token"`
claims auth.Claims `json:"-"`
}

View File

@ -0,0 +1,653 @@
package user
import (
"context"
"database/sql"
"time"
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/auth"
"github.com/huandu/go-sqlbuilder"
"github.com/jmoiron/sqlx"
"github.com/pborman/uuid"
"github.com/pkg/errors"
"golang.org/x/crypto/bcrypt"
"gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer"
"gopkg.in/go-playground/validator.v9"
)
const (
// The database table for User
userTableName = "users"
// The database table for Account
accountTableName = "accounts"
// The database table for User Account
userAccountTableName = "users_accounts"
)
var (
// ErrNotFound abstracts the mgo not found error.
ErrNotFound = errors.New("Entity not found")
// ErrInvalidID occurs when an ID is not in a valid form.
ErrInvalidID = errors.New("ID is not in its proper form")
// ErrForbidden occurs when a user tries to do something that is forbidden to them according to our access control policies.
ErrForbidden = errors.New("Attempted action is not allowed")
// ErrAuthenticationFailure occurs when a user attempts to authenticate but
// anything goes wrong.
ErrAuthenticationFailure = errors.New("Authentication failed")
)
// userMapColumns is the list of columns needed for mapRowsToUser
var userMapColumns = "id,name,email,password_salt,password_hash,password_reset,timezone,created_at,updated_at,archived_at"
// mapRowsToUser takes the SQL rows and maps it to the UserAccount struct
// with the columns defined by userMapColumns
func mapRowsToUser(rows *sql.Rows) (*User, error) {
var (
u User
err error
)
err = rows.Scan(&u.ID, &u.Name, &u.Email, &u.PasswordSalt, &u.PasswordHash, &u.PasswordReset, &u.Timezone, &u.CreatedAt, &u.UpdatedAt, &u.ArchivedAt)
if err != nil {
return nil, errors.WithStack(err)
}
return &u, nil
}
// CanReadUser determines if claims has the authority to access the specified user ID.
func CanReadUser(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, userID string) error {
// If the request has claims from a specific user, ensure that the user
// has the correct access to the user.
if claims.Subject != "" && claims.Subject != userID {
// When the claims Subject/UserId - does not match the requested user, the
// claims audience - AccountId - should have a record.
// select id from users_accounts where account_id = [claims.Audience] and user_id = [userID]
query := sqlbuilder.NewSelectBuilder().Select("id").From(userAccountTableName)
query.Where(query.And(
query.Equal("account_id", claims.Audience),
query.Equal("user_id", userID),
))
queryStr, args := query.Build()
queryStr = dbConn.Rebind(queryStr)
var userAccountId string
err := dbConn.QueryRowContext(ctx, queryStr, args...).Scan(&userAccountId)
if err != nil && err != sql.ErrNoRows {
err = errors.Wrapf(err, "query - %s", query.String())
return err
}
// When there is no userAccount ID returned, then the current user does not have access
// to the specified user.
if userAccountId == "" {
return errors.WithStack(ErrForbidden)
}
}
return nil
}
// CanModifyUser determines if claims has the authority to modify the specified user ID.
func CanModifyUser(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, userID string) error {
// If the request has claims from a specific account, ensure that the user
// has the correct access to the account.
if claims.Subject != "" && claims.Subject != userID {
// When the claims Audience - AccountID - does not match the requested account, the
// claims Audience - AccountID - should have a record with an admin role.
// select id from users_accounts where account_id = [claims.Audience] and user_id = [userID] and any (roles) = 'admin'
query := sqlbuilder.NewSelectBuilder().Select("id").From(userAccountTableName)
query.Where(query.And(
query.Equal("account_id", claims.Audience),
query.Equal("user_id", userID),
"'"+auth.RoleAdmin+"' = ANY (roles)",
))
queryStr, args := query.Build()
queryStr = dbConn.Rebind(queryStr)
var userAccountId string
err := dbConn.QueryRowContext(ctx, queryStr, args...).Scan(&userAccountId)
if err != nil && err != sql.ErrNoRows {
err = errors.Wrapf(err, "query - %s", query.String())
return err
}
// When there is no userAccount ID returned, then the current user does not have access
// to the specified account.
if userAccountId == "" {
return errors.WithStack(ErrForbidden)
}
}
return nil
}
// applyClaimsSelect applies a sub-query to the provided query to enforce ACL based on
// the claims provided.
// 1. All role types can access their user ID
// 2. Any user with the same account ID
// 3. No claims, request is internal, no ACL applied
func applyClaimsSelect(ctx context.Context, claims auth.Claims, query *sqlbuilder.SelectBuilder) error {
// Claims are empty, don't apply any ACL
if claims.Audience == "" && claims.Subject == "" {
return nil
}
// Build select statement for users_accounts table
subQuery := sqlbuilder.NewSelectBuilder().Select("user_id").From(userAccountTableName)
var or []string
if claims.Audience != "" {
or = append(or, subQuery.Equal("account_id", claims.Audience))
}
if claims.Subject != "" {
or = append(or, subQuery.Equal("user_id", claims.Subject))
}
subQuery.Where(subQuery.Or(or...))
// Append sub query
query.Where(query.In("id", subQuery))
return nil
}
// selectQuery constructs a base select query for User
func selectQuery() *sqlbuilder.SelectBuilder {
query := sqlbuilder.NewSelectBuilder()
query.Select(userMapColumns)
query.From(userTableName)
return query
}
// findRequestQuery generates the select query for the given find request.
// TODO: Need to figure out why can't parse the args when appending the where
// to the query.
func findRequestQuery(req UserFindRequest) (*sqlbuilder.SelectBuilder, []interface{}) {
query := selectQuery()
if req.Where != nil {
query.Where(query.And(*req.Where))
}
if len(req.Order) > 0 {
query.OrderBy(req.Order...)
}
if req.Limit != nil {
query.Limit(int(*req.Limit))
}
if req.Offset != nil {
query.Offset(int(*req.Offset))
}
return query, req.Args
}
// Find gets all the users from the database based on the request params.
func Find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserFindRequest) ([]*User, error) {
query, args := findRequestQuery(req)
return find(ctx, claims, dbConn, query, args, req.IncludedArchived)
}
// find internal method for getting all the users from the database using a select query.
func find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, query *sqlbuilder.SelectBuilder, args []interface{}, includedArchived bool) ([]*User, error) {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.Find")
defer span.Finish()
query.Select(userMapColumns)
query.From(userTableName)
if !includedArchived {
query.Where(query.IsNull("archived_at"))
}
// Check to see if a sub query needs to be applied for the claims
err := applyClaimsSelect(ctx, claims, query)
if err != nil {
return nil, err
}
queryStr, queryArgs := query.Build()
queryStr = dbConn.Rebind(queryStr)
args = append(args, queryArgs...)
// fetch all places from the db
rows, err := dbConn.QueryContext(ctx, queryStr, args...)
if err != nil {
err = errors.Wrapf(err, "query - %s", query.String())
err = errors.WithMessage(err, "find users failed")
return nil, err
}
// iterate over each row
resp := []*User{}
for rows.Next() {
u, err := mapRowsToUser(rows)
if err != nil {
err = errors.Wrapf(err, "query - %s", query.String())
return nil, err
}
resp = append(resp, u)
}
return resp, nil
}
// Validation an email address is unique excluding the current user ID.
func uniqueEmail(ctx context.Context, dbConn *sqlx.DB, email, userId string) (bool, error) {
query := sqlbuilder.NewSelectBuilder().Select("id").From(userTableName)
query.Where(query.And(
query.Equal("email", email),
query.NotEqual("id", userId),
))
queryStr, args := query.Build()
queryStr = dbConn.Rebind(queryStr)
var existingId string
err := dbConn.QueryRowContext(ctx, queryStr, args...).Scan(&existingId)
if err != nil && err != sql.ErrNoRows {
err = errors.Wrapf(err, "query - %s", query.String())
return false, err
}
// When an ID was found in the db, the email is not unique.
if existingId != "" {
return false, nil
}
return true, nil
}
// Create inserts a new user into the database.
func Create(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req CreateUserRequest, now time.Time) (*User, error) {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.Create")
defer span.Finish()
v := validator.New()
// Validation email address is unique in the database.
uniq, err := uniqueEmail(ctx, dbConn, req.Email, "")
if err != nil {
return nil, err
}
f := func(fl validator.FieldLevel) bool {
if fl.Field().String() == "invalid" {
return false
}
return uniq
}
v.RegisterValidation("unique", f)
// Validate the request.
err = v.Struct(req)
if err != nil {
return nil, err
}
// If the request has claims from a specific user, ensure that the user
// has the correct role for creating a new user.
if claims.Subject != "" {
// Users with the role of admin are ony allows to create users.
if !claims.HasRole(auth.RoleAdmin) {
err = errors.WithStack(ErrForbidden)
return nil, err
}
}
// If now empty set it to the current time.
if now.IsZero() {
now = time.Now()
}
// Always store the time as UTC.
now = now.UTC()
// Postgres truncates times to milliseconds when storing. We and do the same
// here so the value we return is consistent with what we store.
now = now.Truncate(time.Millisecond)
passwordSalt := uuid.NewRandom().String()
saltedPassword := req.Password + passwordSalt
passwordHash, err := bcrypt.GenerateFromPassword([]byte(saltedPassword), bcrypt.DefaultCost)
if err != nil {
return nil, errors.Wrap(err, "generating password hash")
}
u := User{
ID: uuid.NewRandom().String(),
Name: req.Name,
Email: req.Email,
PasswordHash: passwordHash,
PasswordSalt: passwordSalt,
Timezone: "America/Anchorage",
CreatedAt: now,
UpdatedAt: now,
}
if req.Timezone != nil {
u.Timezone = *req.Timezone
}
// Build the insert SQL statement.
query := sqlbuilder.NewInsertBuilder()
query.InsertInto(userTableName)
query.Cols("id", "name", "email", "password_hash", "password_salt", "timezone", "created_at", "updated_at")
query.Values(u.ID, u.Name, u.Email, u.PasswordHash, u.PasswordSalt, u.Timezone, u.CreatedAt, u.UpdatedAt)
// Execute the query with the provided context.
sql, args := query.Build()
sql = dbConn.Rebind(sql)
_, err = dbConn.ExecContext(ctx, sql, args...)
if err != nil {
err = errors.Wrapf(err, "query - %s", query.String())
err = errors.WithMessage(err, "create user failed")
return nil, err
}
return &u, nil
}
// Read gets the specified user from the database.
func Read(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, id string, includedArchived bool) (*User, error) {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.Read")
defer span.Finish()
// Filter base select query by ID
query := selectQuery()
query.Where(query.Equal("id", id))
res, err := find(ctx, claims, dbConn, query, []interface{}{}, includedArchived)
if err != nil {
return nil, err
} else if res == nil || len(res) == 0 {
err = errors.WithMessagef(ErrNotFound, "user %s not found", id)
return nil, err
}
u := res[0]
return u, nil
}
// Update replaces a user in the database.
func Update(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UpdateUserRequest, now time.Time) error {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.Update")
defer span.Finish()
v := validator.New()
// Validation email address is unique in the database.
if req.Email != nil {
uniq, err := uniqueEmail(ctx, dbConn, *req.Email, req.ID)
if err != nil {
return err
}
f := func(fl validator.FieldLevel) bool {
if fl.Field().String() == "invalid" {
return false
}
return uniq
}
v.RegisterValidation("unique", f)
}
// Validate the request.
err := v.Struct(req)
if err != nil {
return err
}
// Ensure the claims can modify the user specified in the request.
err = CanModifyUser(ctx, claims, dbConn, req.ID)
if err != nil {
return err
}
// If now empty set it to the current time.
if now.IsZero() {
now = time.Now()
}
// Always store the time as UTC.
now = now.UTC()
// Postgres truncates times to milliseconds when storing. We and do the same
// here so the value we return is consistent with what we store.
now = now.Truncate(time.Millisecond)
// Build the update SQL statement.
query := sqlbuilder.NewUpdateBuilder()
query.Update(userTableName)
var fields []string
if req.Name != nil {
fields = append(fields, query.Assign("name", req.Name))
}
if req.Email != nil {
fields = append(fields, query.Assign("email", req.Email))
}
if req.Timezone != nil {
fields = append(fields, query.Assign("timezone", req.Timezone))
}
// If there's nothing to update we can quit early.
if len(fields) == 0 {
return nil
}
// Append the updated_at field
fields = append(fields, query.Assign("updated_at", now))
query.Set(fields...)
query.Where(query.Equal("id", req.ID))
// Execute the query with the provided context.
sql, args := query.Build()
sql = dbConn.Rebind(sql)
_, err = dbConn.ExecContext(ctx, sql, args...)
if err != nil {
err = errors.Wrapf(err, "query - %s", query.String())
err = errors.WithMessagef(err, "update user %s failed", req.ID)
return err
}
return nil
}
// Update replaces a user in the database.
func UpdatePassword(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UpdatePasswordRequest, now time.Time) error {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.Update")
defer span.Finish()
// Validate the request.
err := validator.New().Struct(req)
if err != nil {
return err
}
// Ensure the claims can modify the user specified in the request.
err = CanModifyUser(ctx, claims, dbConn, req.ID)
if err != nil {
return err
}
// If now empty set it to the current time.
if now.IsZero() {
now = time.Now()
}
// Always store the time as UTC.
now = now.UTC()
// Postgres truncates times to milliseconds when storing. We and do the same
// here so the value we return is consistent with what we store.
now = now.Truncate(time.Millisecond)
// Generate new password hash for the provided password.
passwordSalt := uuid.NewRandom()
saltedPassword := req.Password + passwordSalt.String()
passwordHash, err := bcrypt.GenerateFromPassword([]byte(saltedPassword), bcrypt.DefaultCost)
if err != nil {
return errors.Wrap(err, "generating password hash")
}
// Build the update SQL statement.
query := sqlbuilder.NewUpdateBuilder()
query.Update(userTableName)
query.Set(
query.Assign("password_hash", passwordHash),
query.Assign("password_salt", passwordSalt),
query.Assign("updated_at", now),
)
query.Where(query.Equal("id", req.ID))
// Execute the query with the provided context.
sql, args := query.Build()
sql = dbConn.Rebind(sql)
_, err = dbConn.ExecContext(ctx, sql, args...)
if err != nil {
err = errors.Wrapf(err, "query - %s", query.String())
err = errors.WithMessagef(err, "update password for user %s failed", req.ID)
return err
}
return nil
}
// Archive soft deleted the user from the database.
func Archive(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, userID string, now time.Time) error {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.Archive")
defer span.Finish()
// Defines the struct to apply validation
req := struct {
ID string `validate:"required,uuid"`
}{
ID: userID,
}
// Validate the request.
err := validator.New().Struct(req)
if err != nil {
return err
}
// Ensure the claims can modify the user specified in the request.
err = CanModifyUser(ctx, claims, dbConn, req.ID)
if err != nil {
return err
}
// If now empty set it to the current time.
if now.IsZero() {
now = time.Now()
}
// Always store the time as UTC.
now = now.UTC()
// Postgres truncates times to milliseconds when storing. We and do the same
// here so the value we return is consistent with what we store.
now = now.Truncate(time.Millisecond)
// Build the update SQL statement.
query := sqlbuilder.NewUpdateBuilder()
query.Update(userTableName)
query.Set(
query.Assign("archived_at", now),
)
query.Where(query.Equal("id", req.ID))
// Execute the query with the provided context.
sql, args := query.Build()
sql = dbConn.Rebind(sql)
_, err = dbConn.ExecContext(ctx, sql, args...)
if err != nil {
err = errors.Wrapf(err, "query - %s", query.String())
err = errors.WithMessagef(err, "archive user %s failed", req.ID)
return err
}
// Archive all the associated user accounts
{
// Build the update SQL statement.
query := sqlbuilder.NewUpdateBuilder()
query.Update(userAccountTableName)
query.Set(query.Assign("archived_at", now))
query.Where(query.And(
query.Equal("user_id", req.ID),
))
// Execute the query with the provided context.
sql, args := query.Build()
sql = dbConn.Rebind(sql)
_, err = dbConn.ExecContext(ctx, sql, args...)
if err != nil {
err = errors.Wrapf(err, "query - %s", query.String())
err = errors.WithMessagef(err, "archive accounts for user %s failed", req.ID)
return err
}
}
return nil
}
// Delete removes a user from the database.
func Delete(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, userID string) error {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.Delete")
defer span.Finish()
// Defines the struct to apply validation
req := struct {
ID string `validate:"required,uuid"`
}{
ID: userID,
}
// Validate the request.
err := validator.New().Struct(req)
if err != nil {
return err
}
// Ensure the claims can modify the user specified in the request.
err = CanModifyUser(ctx, claims, dbConn, req.ID)
if err != nil {
return err
}
// Build the delete SQL statement.
query := sqlbuilder.NewDeleteBuilder()
query.DeleteFrom(userTableName)
query.Where(query.Equal("id", req.ID))
// Execute the query with the provided context.
sql, args := query.Build()
sql = dbConn.Rebind(sql)
_, err = dbConn.ExecContext(ctx, sql, args...)
if err != nil {
err = errors.Wrapf(err, "query - %s", query.String())
err = errors.WithMessagef(err, "delete user %s failed", req.ID)
return err
}
// Delete all the associated user accounts
{
// Build the delete SQL statement.
query := sqlbuilder.NewDeleteBuilder()
query.DeleteFrom(userAccountTableName)
query.Where(query.And(
query.Equal("user_id", req.ID),
))
// Execute the query with the provided context.
sql, args := query.Build()
sql = dbConn.Rebind(sql)
_, err = dbConn.ExecContext(ctx, sql, args...)
if err != nil {
err = errors.Wrapf(err, "query - %s", query.String())
err = errors.WithMessagef(err, "delete accounts for user %s failed", req.ID)
return err
}
}
return nil
}

File diff suppressed because it is too large Load Diff

View File

@ -4,7 +4,6 @@ import (
"context" "context"
"database/sql" "database/sql"
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/account" "geeks-accelerator/oss/saas-starter-kit/example-project/internal/account"
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/user"
"github.com/lib/pq" "github.com/lib/pq"
"time" "time"
@ -24,10 +23,6 @@ var (
// ErrInvalidID occurs when an ID is not in a valid form. // ErrInvalidID occurs when an ID is not in a valid form.
ErrInvalidID = errors.New("ID is not in its proper form") ErrInvalidID = errors.New("ID is not in its proper form")
// ErrAuthenticationFailure occurs when a user attempts to authenticate but
// anything goes wrong.
ErrAuthenticationFailure = errors.New("Authentication failed")
// ErrForbidden occurs when a user tries to do something that is forbidden to them according to our access control policies. // ErrForbidden occurs when a user tries to do something that is forbidden to them according to our access control policies.
ErrForbidden = errors.New("Attempted action is not allowed") ErrForbidden = errors.New("Attempted action is not allowed")
) )
@ -53,63 +48,43 @@ func mapRowsToUserAccount(rows *sql.Rows) (*UserAccount, error) {
return &ua, nil return &ua, nil
} }
// CanReadUserAccount determines if claims has the authority to access the specified user account by user ID. // CanReadAccount determines if claims has the authority to access the specified user account by user ID.
func CanReadUserAccount(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, userID, accountID string) error { func CanReadAccount(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, accountID string) error {
// First check to see if claims can read the user ID err := account.CanReadAccount(ctx, claims, dbConn, accountID)
err := user.CanReadUser(ctx, claims, dbConn, userID) return mapAccountError(err)
if err != nil { }
if claims.Audience != accountID {
// CanModifyAccount determines if claims has the authority to modify the specified user ID.
func CanModifyAccount(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, accountID string) error {
err := account.CanModifyAccount(ctx, claims, dbConn, accountID)
return mapAccountError(err)
}
// mapAccountError maps account errors to local defined errors.
func mapAccountError(err error) error {
switch errors.Cause(err) {
case account.ErrNotFound:
err = ErrNotFound
case account.ErrInvalidID:
err = ErrInvalidID
case account.ErrForbidden:
err = ErrForbidden
}
return err return err
} }
}
// Second check to see if claims can read the account ID // applyClaimsSelect applies a sub-query to the provided query
err = account.CanReadAccount(ctx, claims, dbConn, accountID) // to enforce ACL based on the claims provided.
if err != nil { // 1. All role types can access their user ID
if claims.Audience != accountID { // 2. Any user with the same account ID
return err // 3. No claims, request is internal, no ACL applied
} func applyClaimsSelect(ctx context.Context, claims auth.Claims, query *sqlbuilder.SelectBuilder) error {
}
return nil
}
// CanModifyUserAccount determines if claims has the authority to modify the specified user ID.
func CanModifyUserAccount(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, userID, accountID string) error {
// First check to see if claims can read the user ID
err := CanReadUserAccount(ctx, claims, dbConn, userID, accountID)
if err != nil {
if claims.Audience != accountID {
return err
}
}
// If the request has claims from a specific user, ensure that the user
// has the correct role for updating an existing user.
if claims.Subject != "" {
if claims.Subject == userID {
// All users are allowed to update their own record
} else if claims.HasRole(auth.RoleAdmin) {
// Admin users can update users they have access to.
} else {
return errors.WithStack(ErrForbidden)
}
}
return nil
}
// applyClaimsUserAccountSelect applies a sub query to enforce ACL for
// the supplied claims. If claims is empty then request must be internal and
// no sub-query is applied. Else a list of user IDs is found all associated
// user accounts.
func applyClaimsUserAccountSelect(ctx context.Context, claims auth.Claims, query *sqlbuilder.SelectBuilder) error {
if claims.Audience == "" && claims.Subject == "" { if claims.Audience == "" && claims.Subject == "" {
return nil return nil
} }
// Build select statement for users_accounts table // Build select statement for users_accounts table
subQuery := sqlbuilder.NewSelectBuilder().Select("user_id").From(userAccountTableName) subQuery := sqlbuilder.NewSelectBuilder().Select("id").From(userAccountTableName)
var or []string var or []string
if claims.Audience != "" { if claims.Audience != "" {
@ -121,24 +96,24 @@ func applyClaimsUserAccountSelect(ctx context.Context, claims auth.Claims, query
subQuery.Where(subQuery.Or(or...)) subQuery.Where(subQuery.Or(or...))
// Append sub query // Append sub query
query.Where(query.In("user_id", subQuery)) query.Where(query.In("id", subQuery))
return nil return nil
} }
// AccountSelectQuery // selectQuery constructs a base select query for User Account
func userAccountSelectQuery() *sqlbuilder.SelectBuilder { func selectQuery() *sqlbuilder.SelectBuilder {
query := sqlbuilder.NewSelectBuilder() query := sqlbuilder.NewSelectBuilder()
query.Select(userAccountMapColumns) query.Select(userAccountMapColumns)
query.From(userAccountTableName) query.From(userAccountTableName)
return query return query
} }
// userFindRequestQuery generates the select query for the given find request. // findRequestQuery generates the select query for the given find request.
// TODO: Need to figure out why can't parse the args when appending the where // TODO: Need to figure out why can't parse the args when appending the where
// to the query. // to the query.
func userAccountFindRequestQuery(req UserAccountFindRequest) (*sqlbuilder.SelectBuilder, []interface{}) { func findRequestQuery(req UserAccountFindRequest) (*sqlbuilder.SelectBuilder, []interface{}) {
query := userAccountSelectQuery() query := selectQuery()
if req.Where != nil { if req.Where != nil {
query.Where(query.And(*req.Where)) query.Where(query.And(*req.Where))
} }
@ -155,9 +130,9 @@ func userAccountFindRequestQuery(req UserAccountFindRequest) (*sqlbuilder.Select
return query, req.Args return query, req.Args
} }
// Find gets all the user accounts from the database based on the request params // Find gets all the user accounts from the database based on the request params.
func Find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserAccountFindRequest) ([]*UserAccount, error) { func Find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserAccountFindRequest) ([]*UserAccount, error) {
query, args := userAccountFindRequestQuery(req) query, args := findRequestQuery(req)
return find(ctx, claims, dbConn, query, args, req.IncludedArchived) return find(ctx, claims, dbConn, query, args, req.IncludedArchived)
} }
@ -174,7 +149,7 @@ func find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, query *sqlbu
} }
// Check to see if a sub query needs to be applied for the claims // Check to see if a sub query needs to be applied for the claims
err := applyClaimsUserAccountSelect(ctx, claims, query) err := applyClaimsSelect(ctx, claims, query)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -226,7 +201,7 @@ func FindByUserID(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, user
return res, nil return res, nil
} }
// AddAccount an account for a given user with specified roles. // Create a user account for a given user with specified roles.
func Create(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req CreateUserAccountRequest, now time.Time) (*UserAccount, error) { func Create(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req CreateUserAccountRequest, now time.Time) (*UserAccount, error) {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user_account.Create") span, ctx := tracer.StartSpanFromContext(ctx, "internal.user_account.Create")
defer span.Finish() defer span.Finish()
@ -237,8 +212,8 @@ func Create(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req Create
return nil, err return nil, err
} }
// Ensure the claims can modify the user specified in the request. // Ensure the claims can modify the account specified in the request.
err = CanModifyUserAccount(ctx, claims, dbConn, req.UserID, req.AccountID) err = CanModifyAccount(ctx, claims, dbConn, req.AccountID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -256,7 +231,7 @@ func Create(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req Create
now = now.Truncate(time.Millisecond) now = now.Truncate(time.Millisecond)
// Check to see if there is an existing user account, including archived. // Check to see if there is an existing user account, including archived.
existQuery := userAccountSelectQuery() existQuery := selectQuery()
existQuery.Where(existQuery.And( existQuery.Where(existQuery.And(
existQuery.Equal("account_id", req.AccountID), existQuery.Equal("account_id", req.AccountID),
existQuery.Equal("user_id", req.UserID), existQuery.Equal("user_id", req.UserID),
@ -267,6 +242,7 @@ func Create(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req Create
} }
// If there is an existing entry, then update instead of insert. // If there is an existing entry, then update instead of insert.
var ua UserAccount
if len(existing) > 0 { if len(existing) > 0 {
upReq := UpdateUserAccountRequest{ upReq := UpdateUserAccountRequest{
UserID: req.UserID, UserID: req.UserID,
@ -279,15 +255,12 @@ func Create(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req Create
return nil, err return nil, err
} }
ua := existing[0] ua = *existing[0]
ua.Roles = req.Roles ua.Roles = req.Roles
ua.UpdatedAt = now ua.UpdatedAt = now
ua.ArchivedAt = pq.NullTime{} ua.ArchivedAt = pq.NullTime{}
} else {
return ua, nil ua = UserAccount{
}
ua := UserAccount{
ID: uuid.NewRandom().String(), ID: uuid.NewRandom().String(),
UserID: req.UserID, UserID: req.UserID,
AccountID: req.AccountID, AccountID: req.AccountID,
@ -316,11 +289,33 @@ func Create(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req Create
err = errors.WithMessagef(err, "add account %s to user %s failed", req.AccountID, req.UserID) err = errors.WithMessagef(err, "add account %s to user %s failed", req.AccountID, req.UserID)
return nil, err return nil, err
} }
}
return &ua, nil return &ua, nil
} }
// UpdateAccount... // Read gets the specified user account from the database.
func Read(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, id string, includedArchived bool) (*UserAccount, error) {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user_account.Read")
defer span.Finish()
// Filter base select query by ID
query := selectQuery()
query.Where(query.Equal("id", id))
res, err := find(ctx, claims, dbConn, query, []interface{}{}, includedArchived)
if err != nil {
return nil, err
} else if res == nil || len(res) == 0 {
err = errors.WithMessagef(ErrNotFound, "user account %s not found", id)
return nil, err
}
u := res[0]
return u, nil
}
// Update replaces a user account in the database.
func Update(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UpdateUserAccountRequest, now time.Time) error { func Update(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UpdateUserAccountRequest, now time.Time) error {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user_account.Update") span, ctx := tracer.StartSpanFromContext(ctx, "internal.user_account.Update")
defer span.Finish() defer span.Finish()
@ -332,7 +327,7 @@ func Update(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req Update
} }
// Ensure the claims can modify the user specified in the request. // Ensure the claims can modify the user specified in the request.
err = CanModifyUserAccount(ctx, claims, dbConn, req.UserID, req.AccountID) err = CanModifyAccount(ctx, claims, dbConn, req.AccountID)
if err != nil { if err != nil {
return err return err
} }
@ -360,6 +355,9 @@ func Update(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req Update
if req.Status != nil { if req.Status != nil {
fields = append(fields, query.Assign("status", req.Status)) fields = append(fields, query.Assign("status", req.Status))
} }
if req.unArchive {
fields = append(fields, query.Assign("archived_at", nil))
}
// If there's nothing to update we can quit early. // If there's nothing to update we can quit early.
if len(fields) == 0 { if len(fields) == 0 {
@ -401,7 +399,7 @@ func Archive(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req Archi
} }
// Ensure the claims can modify the user specified in the request. // Ensure the claims can modify the user specified in the request.
err = CanModifyUserAccount(ctx, claims, dbConn, req.UserID, req.AccountID) err = CanModifyAccount(ctx, claims, dbConn, req.AccountID)
if err != nil { if err != nil {
return err return err
} }
@ -433,7 +431,7 @@ func Archive(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req Archi
_, err = dbConn.ExecContext(ctx, sql, args...) _, err = dbConn.ExecContext(ctx, sql, args...)
if err != nil { if err != nil {
err = errors.Wrapf(err, "query - %s", query.String()) err = errors.Wrapf(err, "query - %s", query.String())
err = errors.WithMessagef(err, "remove account %s from user %s failed", req.AccountID, req.UserID) err = errors.WithMessagef(err, "archive account %s from user %s failed", req.AccountID, req.UserID)
return err return err
} }
@ -452,7 +450,7 @@ func Delete(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req Delete
} }
// Ensure the claims can modify the user specified in the request. // Ensure the claims can modify the user specified in the request.
err = CanModifyUserAccount(ctx, claims, dbConn, req.UserID, req.AccountID) err = CanModifyAccount(ctx, claims, dbConn, req.AccountID)
if err != nil { if err != nil {
return err return err
} }

View File

@ -3,6 +3,7 @@ package user
import ( import (
"github.com/lib/pq" "github.com/lib/pq"
"math/rand" "math/rand"
"os"
"strings" "strings"
"testing" "testing"
"time" "time"
@ -16,8 +17,21 @@ import (
"github.com/pkg/errors" "github.com/pkg/errors"
) )
// TestAccountFindRequestQuery validates accountFindRequestQuery var test *tests.Test
func TestAccountFindRequestQuery(t *testing.T) {
// TestMain is the entry point for testing.
func TestMain(m *testing.M) {
os.Exit(testMain(m))
}
func testMain(m *testing.M) int {
test = tests.New()
defer test.TearDown()
return m.Run()
}
// TestFindRequestQuery validates findRequestQuery
func TestFindRequestQuery(t *testing.T) {
where := "account_id = ? or user_id = ?" where := "account_id = ? or user_id = ?"
var ( var (
limit uint = 12 limit uint = 12
@ -37,9 +51,9 @@ func TestAccountFindRequestQuery(t *testing.T) {
Limit: &limit, Limit: &limit,
Offset: &offset, Offset: &offset,
} }
expected := "SELECT " + usersAccountsMapColumns + " FROM " + userAccountTableName + " WHERE (account_id = ? or user_id = ?) ORDER BY id asc, created_at desc LIMIT 12 OFFSET 34" expected := "SELECT " + userAccountMapColumns + " FROM " + userAccountTableName + " WHERE (account_id = ? or user_id = ?) ORDER BY id asc, created_at desc LIMIT 12 OFFSET 34"
res, args := userAccountFindRequestQuery(req) res, args := findRequestQuery(req)
if diff := cmp.Diff(res.String(), expected); diff != "" { if diff := cmp.Diff(res.String(), expected); diff != "" {
t.Fatalf("\t%s\tExpected result query to match. Diff:\n%s", tests.Failed, diff) t.Fatalf("\t%s\tExpected result query to match. Diff:\n%s", tests.Failed, diff)
@ -49,8 +63,8 @@ func TestAccountFindRequestQuery(t *testing.T) {
} }
} }
// TestApplyClaimsUserAccountSelect validates applyClaimsUserAccountSelect // TestApplyClaimsSelectvalidates applyClaimsSelect
func TestApplyClaimsUserAccountSelect(t *testing.T) { func TestApplyClaimsSelectvalidates(t *testing.T) {
var claimTests = []struct { var claimTests = []struct {
name string name string
claims auth.Claims claims auth.Claims
@ -59,7 +73,7 @@ func TestApplyClaimsUserAccountSelect(t *testing.T) {
}{ }{
{"EmptyClaims", {"EmptyClaims",
auth.Claims{}, auth.Claims{},
"SELECT " + usersAccountsMapColumns + " FROM " + userAccountTableName, "SELECT " + userAccountMapColumns + " FROM " + userAccountTableName,
nil, nil,
}, },
{"RoleUser", {"RoleUser",
@ -70,7 +84,7 @@ func TestApplyClaimsUserAccountSelect(t *testing.T) {
Audience: "acc1", Audience: "acc1",
}, },
}, },
"SELECT " + usersAccountsMapColumns + " FROM " + userAccountTableName + " WHERE user_id IN (SELECT user_id FROM " + userAccountTableName + " WHERE (account_id = 'acc1' OR user_id = 'user1'))", "SELECT " + userAccountMapColumns + " FROM " + userAccountTableName + " WHERE id IN (SELECT id FROM " + userAccountTableName + " WHERE (account_id = 'acc1' OR user_id = 'user1'))",
nil, nil,
}, },
{"RoleAdmin", {"RoleAdmin",
@ -81,7 +95,7 @@ func TestApplyClaimsUserAccountSelect(t *testing.T) {
Audience: "acc1", Audience: "acc1",
}, },
}, },
"SELECT " + usersAccountsMapColumns + " FROM " + userAccountTableName + " WHERE user_id IN (SELECT user_id FROM " + userAccountTableName + " WHERE (account_id = 'acc1' OR user_id = 'user1'))", "SELECT " + userAccountMapColumns + " FROM " + userAccountTableName + " WHERE id IN (SELECT id FROM " + userAccountTableName + " WHERE (account_id = 'acc1' OR user_id = 'user1'))",
nil, nil,
}, },
} }
@ -93,13 +107,13 @@ func TestApplyClaimsUserAccountSelect(t *testing.T) {
{ {
ctx := tests.Context() ctx := tests.Context()
query := userAccountSelectQuery() query := selectQuery()
err := applyClaimsUserAccountSelect(ctx, tt.claims, query) err := applyClaimsSelect(ctx, tt.claims, query)
if err != tt.error { if err != tt.error {
t.Logf("\t\tGot : %+v", err) t.Logf("\t\tGot : %+v", err)
t.Logf("\t\tWant: %+v", tt.error) t.Logf("\t\tWant: %+v", tt.error)
t.Fatalf("\t%s\tapplyClaimsUserAccountSelect failed.", tests.Failed) t.Fatalf("\t%s\tapplyClaimsSelect failed.", tests.Failed)
} }
sql, args := query.Build() sql, args := query.Build()
@ -108,70 +122,70 @@ func TestApplyClaimsUserAccountSelect(t *testing.T) {
sql, err = sqlbuilder.MySQL.Interpolate(sql, args) sql, err = sqlbuilder.MySQL.Interpolate(sql, args)
if err != nil { if err != nil {
t.Log("\t\tGot :", err) t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tapplyClaimsUserAccountSelect failed.", tests.Failed) t.Fatalf("\t%s\tapplyClaimsSelect failed.", tests.Failed)
} }
if diff := cmp.Diff(sql, tt.expectedSql); diff != "" { if diff := cmp.Diff(sql, tt.expectedSql); diff != "" {
t.Fatalf("\t%s\tExpected result query to match. Diff:\n%s", tests.Failed, diff) t.Fatalf("\t%s\tExpected result query to match. Diff:\n%s", tests.Failed, diff)
} }
t.Logf("\t%s\tapplyClaimsUserAccountSelect ok.", tests.Success) t.Logf("\t%s\tapplyClaimsSelect ok.", tests.Success)
} }
} }
} }
} }
// TestAddAccountValidation ensures all the validation tags work on account add. // TestCreateValidation ensures all the validation tags work on user account create.
func TestAddAccountValidation(t *testing.T) { func TestCreateValidation(t *testing.T) {
invalidRole := UserAccountRole("moon") invalidRole := UserAccountRole("moon")
invalidStatus := UserAccountStatus("moon") invalidStatus := UserAccountStatus("moon")
var accountTests = []struct { var accountTests = []struct {
name string name string
req AddAccountRequest req CreateUserAccountRequest
expected func(req AddAccountRequest, res *UserAccount) *UserAccount expected func(req CreateUserAccountRequest, res *UserAccount) *UserAccount
error error error error
}{ }{
{"Required Fields", {"Required Fields",
AddAccountRequest{}, CreateUserAccountRequest{},
func(req AddAccountRequest, res *UserAccount) *UserAccount { func(req CreateUserAccountRequest, res *UserAccount) *UserAccount {
return nil return nil
}, },
errors.New("Key: 'AddAccountRequest.UserID' Error:Field validation for 'UserID' failed on the 'required' tag\n" + errors.New("Key: 'CreateUserAccountRequest.UserID' Error:Field validation for 'UserID' failed on the 'required' tag\n" +
"Key: 'AddAccountRequest.AccountID' Error:Field validation for 'AccountID' failed on the 'required' tag\n" + "Key: 'CreateUserAccountRequest.AccountID' Error:Field validation for 'AccountID' failed on the 'required' tag\n" +
"Key: 'AddAccountRequest.Roles' Error:Field validation for 'Roles' failed on the 'required' tag"), "Key: 'CreateUserAccountRequest.Roles' Error:Field validation for 'Roles' failed on the 'required' tag"),
}, },
{"Valid Role", {"Valid Role",
AddAccountRequest{ CreateUserAccountRequest{
UserID: uuid.NewRandom().String(), UserID: uuid.NewRandom().String(),
AccountID: uuid.NewRandom().String(), AccountID: uuid.NewRandom().String(),
Roles: []UserAccountRole{invalidRole}, Roles: []UserAccountRole{invalidRole},
}, },
func(req AddAccountRequest, res *UserAccount) *UserAccount { func(req CreateUserAccountRequest, res *UserAccount) *UserAccount {
return nil return nil
}, },
errors.New("Key: 'AddAccountRequest.Roles[0]' Error:Field validation for 'Roles[0]' failed on the 'oneof' tag"), errors.New("Key: 'CreateUserAccountRequest.Roles[0]' Error:Field validation for 'Roles[0]' failed on the 'oneof' tag"),
}, },
{"Valid Status", {"Valid Status",
AddAccountRequest{ CreateUserAccountRequest{
UserID: uuid.NewRandom().String(), UserID: uuid.NewRandom().String(),
AccountID: uuid.NewRandom().String(), AccountID: uuid.NewRandom().String(),
Roles: []UserAccountRole{UserAccountRole_User}, Roles: []UserAccountRole{UserAccountRole_User},
Status: &invalidStatus, Status: &invalidStatus,
}, },
func(req AddAccountRequest, res *UserAccount) *UserAccount { func(req CreateUserAccountRequest, res *UserAccount) *UserAccount {
return nil return nil
}, },
errors.New("Key: 'AddAccountRequest.Status' Error:Field validation for 'Status' failed on the 'oneof' tag"), errors.New("Key: 'CreateUserAccountRequest.Status' Error:Field validation for 'Status' failed on the 'oneof' tag"),
}, },
{"Default Status", {"Default Status",
AddAccountRequest{ CreateUserAccountRequest{
UserID: uuid.NewRandom().String(), UserID: uuid.NewRandom().String(),
AccountID: uuid.NewRandom().String(), AccountID: uuid.NewRandom().String(),
Roles: []UserAccountRole{UserAccountRole_User}, Roles: []UserAccountRole{UserAccountRole_User},
}, },
func(req AddAccountRequest, res *UserAccount) *UserAccount { func(req CreateUserAccountRequest, res *UserAccount) *UserAccount {
return &UserAccount{ return &UserAccount{
UserID: req.UserID, UserID: req.UserID,
AccountID: req.AccountID, AccountID: req.AccountID,
@ -191,14 +205,14 @@ func TestAddAccountValidation(t *testing.T) {
now := time.Date(2018, time.October, 1, 0, 0, 0, 0, time.UTC) now := time.Date(2018, time.October, 1, 0, 0, 0, 0, time.UTC)
t.Log("Given the need ensure all validation tags are working for add account.") t.Log("Given the need ensure all validation tags are working for create user account.")
{ {
for i, tt := range accountTests { for i, tt := range accountTests {
t.Logf("\tTest: %d\tWhen running test: %s", i, tt.name) t.Logf("\tTest: %d\tWhen running test: %s", i, tt.name)
{ {
ctx := tests.Context() ctx := tests.Context()
res, err := AddAccount(ctx, auth.Claims{}, test.MasterDB, tt.req, now) res, err := Create(ctx, auth.Claims{}, test.MasterDB, tt.req, now)
if err != tt.error { if err != tt.error {
// TODO: need a better way to handle validation errors as they are // TODO: need a better way to handle validation errors as they are
// of type interface validator.ValidationErrorsTranslations // of type interface validator.ValidationErrorsTranslations
@ -213,29 +227,30 @@ func TestAddAccountValidation(t *testing.T) {
if errStr != expectStr { if errStr != expectStr {
t.Logf("\t\tGot : %+v", err) t.Logf("\t\tGot : %+v", err)
t.Logf("\t\tWant: %+v", tt.error) t.Logf("\t\tWant: %+v", tt.error)
t.Fatalf("\t%s\tAddAccount failed.", tests.Failed) t.Fatalf("\t%s\tCreate user account failed.", tests.Failed)
} }
} }
// If there was an error that was expected, then don't go any further // If there was an error that was expected, then don't go any further
if tt.error != nil { if tt.error != nil {
t.Logf("\t%s\tAddAccount ok.", tests.Success) t.Logf("\t%s\tCreate user account ok.", tests.Success)
continue continue
} }
expected := tt.expected(tt.req, res) expected := tt.expected(tt.req, res)
if diff := cmp.Diff(res, expected); diff != "" { if diff := cmp.Diff(res, expected); diff != "" {
t.Fatalf("\t%s\tAddAccount result should match. Diff:\n%s", tests.Failed, diff) t.Fatalf("\t%s\tCreate user account result should match. Diff:\n%s", tests.Failed, diff)
} }
t.Logf("\t%s\tAddAccount ok.", tests.Success) t.Logf("\t%s\tCreate user account ok.", tests.Success)
} }
} }
} }
} }
// TestAddAccountExistingEntry validates emails must be unique on add account. // TestCreateExistingEntry ensures that if an archived user account exist,
func TestAddAccountExistingEntry(t *testing.T) { // the entry is updated rather than erroring on duplicate constraint.
func TestCreateExistingEntry(t *testing.T) {
now := time.Date(2018, time.October, 1, 0, 0, 0, 0, time.UTC) now := time.Date(2018, time.October, 1, 0, 0, 0, 0, time.UTC)
@ -243,87 +258,125 @@ func TestAddAccountExistingEntry(t *testing.T) {
{ {
ctx := tests.Context() ctx := tests.Context()
req1 := AddAccountRequest{ req1 := CreateUserAccountRequest{
UserID: uuid.NewRandom().String(), UserID: uuid.NewRandom().String(),
AccountID: uuid.NewRandom().String(), AccountID: uuid.NewRandom().String(),
Roles: []UserAccountRole{UserAccountRole_User}, Roles: []UserAccountRole{UserAccountRole_User},
} }
ua1, err := AddAccount(ctx, auth.Claims{}, test.MasterDB, req1, now) ua1, err := Create(ctx, auth.Claims{}, test.MasterDB, req1, now)
if err != nil { if err != nil {
t.Log("\t\tGot :", err) t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tAddAccount failed.", tests.Failed) t.Fatalf("\t%s\tCreate user account failed.", tests.Failed)
} else if diff := cmp.Diff(ua1.Roles, req1.Roles); diff != "" {
t.Fatalf("\t%s\tCreate user account roles should match request. Diff:\n%s", tests.Failed, diff)
} }
if diff := cmp.Diff(ua1.Roles, req1.Roles); diff != "" { req2 := CreateUserAccountRequest{
t.Fatalf("\t%s\tAddAccount roles should match request. Diff:\n%s", tests.Failed, diff)
}
req2 := AddAccountRequest{
UserID: req1.UserID, UserID: req1.UserID,
AccountID: req1.AccountID, AccountID: req1.AccountID,
Roles: []UserAccountRole{UserAccountRole_Admin}, Roles: []UserAccountRole{UserAccountRole_Admin},
} }
ua2, err := AddAccount(ctx, auth.Claims{}, test.MasterDB, req2, now) ua2, err := Create(ctx, auth.Claims{}, test.MasterDB, req2, now)
if err != nil { if err != nil {
t.Log("\t\tGot :", err) t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tAddAccount failed.", tests.Failed) t.Fatalf("\t%s\tCreate user account failed.", tests.Failed)
} else if diff := cmp.Diff(ua2.Roles, req2.Roles); diff != "" {
t.Fatalf("\t%s\tCreate user account roles should match request. Diff:\n%s", tests.Failed, diff)
} }
if diff := cmp.Diff(ua2.Roles, req2.Roles); diff != "" { // Now archive the user account to test trying to create a new entry for an archived entry
t.Fatalf("\t%s\tAddAccount roles should match request. Diff:\n%s", tests.Failed, diff) err = Archive(tests.Context(), auth.Claims{}, test.MasterDB, ArchiveUserAccountRequest{
UserID: req1.UserID,
AccountID: req1.AccountID,
}, now)
if err != nil {
t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tArchive user account failed.", tests.Failed)
} }
t.Logf("\t%s\tAddAccount ok.", tests.Success) // Find the archived user account
arcRes, err := Read(tests.Context(), auth.Claims{}, test.MasterDB, ua2.ID, true)
if err != nil || arcRes == nil {
t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tFind user account failed.", tests.Failed)
} else if !arcRes.ArchivedAt.Valid || arcRes.ArchivedAt.Time.IsZero() {
t.Fatalf("\t%s\tExpected user account to have archived_at set", tests.Failed)
}
// Attempt to create the duplicate user account which should set archived_at back to nil
req3 := CreateUserAccountRequest{
UserID: req1.UserID,
AccountID: req1.AccountID,
Roles: []UserAccountRole{UserAccountRole_User},
}
ua3, err := Create(ctx, auth.Claims{}, test.MasterDB, req3, now)
if err != nil {
t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tCreate user account failed.", tests.Failed)
} else if diff := cmp.Diff(ua3.Roles, req3.Roles); diff != "" {
t.Fatalf("\t%s\tCreate user account roles should match request. Diff:\n%s", tests.Failed, diff)
}
// Ensure the user account has archived_at empty
findRes, err := Read(tests.Context(), auth.Claims{}, test.MasterDB, ua3.ID, false)
if err != nil || arcRes == nil {
t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tFind user account failed.", tests.Failed)
} else if findRes.ArchivedAt.Valid && !findRes.ArchivedAt.Time.IsZero() {
t.Fatalf("\t%s\tExpected user account to have archived_at empty", tests.Failed)
}
t.Logf("\t%s\tCreate user account ok.", tests.Success)
} }
} }
// TestUpdateAccountValidation ensures all the validation tags work on account update. // TestUpdateValidation ensures all the validation tags work on user account update.
func TestUpdateAccountValidation(t *testing.T) { func TestUpdateValidation(t *testing.T) {
invalidRole := UserAccountRole("moon") invalidRole := UserAccountRole("moon")
invalidStatus := UserAccountStatus("xxxxxxxxx") invalidStatus := UserAccountStatus("xxxxxxxxx")
var accountTests = []struct { var accountTests = []struct {
name string name string
req UpdateAccountRequest req UpdateUserAccountRequest
error error error error
}{ }{
{"Required Fields", {"Required Fields",
UpdateAccountRequest{}, UpdateUserAccountRequest{},
errors.New("Key: 'UpdateAccountRequest.UserID' Error:Field validation for 'UserID' failed on the 'required' tag\n" + errors.New("Key: 'UpdateUserAccountRequest.UserID' Error:Field validation for 'UserID' failed on the 'required' tag\n" +
"Key: 'UpdateAccountRequest.AccountID' Error:Field validation for 'AccountID' failed on the 'required' tag\n" + "Key: 'UpdateUserAccountRequest.AccountID' Error:Field validation for 'AccountID' failed on the 'required' tag\n" +
"Key: 'UpdateAccountRequest.Roles' Error:Field validation for 'Roles' failed on the 'required' tag"), "Key: 'UpdateUserAccountRequest.Roles' Error:Field validation for 'Roles' failed on the 'required' tag"),
}, },
{"Valid Role", {"Valid Role",
UpdateAccountRequest{ UpdateUserAccountRequest{
UserID: uuid.NewRandom().String(), UserID: uuid.NewRandom().String(),
AccountID: uuid.NewRandom().String(), AccountID: uuid.NewRandom().String(),
Roles: &UserAccountRoles{invalidRole}, Roles: &UserAccountRoles{invalidRole},
}, },
errors.New("Key: 'UpdateAccountRequest.Roles[0]' Error:Field validation for 'Roles[0]' failed on the 'oneof' tag"), errors.New("Key: 'UpdateUserAccountRequest.Roles[0]' Error:Field validation for 'Roles[0]' failed on the 'oneof' tag"),
}, },
{"Valid Status", {"Valid Status",
UpdateAccountRequest{ UpdateUserAccountRequest{
UserID: uuid.NewRandom().String(), UserID: uuid.NewRandom().String(),
AccountID: uuid.NewRandom().String(), AccountID: uuid.NewRandom().String(),
Roles: &UserAccountRoles{UserAccountRole_User}, Roles: &UserAccountRoles{UserAccountRole_User},
Status: &invalidStatus, Status: &invalidStatus,
}, },
errors.New("Key: 'UpdateAccountRequest.Status' Error:Field validation for 'Status' failed on the 'oneof' tag"), errors.New("Key: 'UpdateUserAccountRequest.Status' Error:Field validation for 'Status' failed on the 'oneof' tag"),
}, },
} }
now := time.Date(2018, time.October, 1, 0, 0, 0, 0, time.UTC) now := time.Date(2018, time.October, 1, 0, 0, 0, 0, time.UTC)
t.Log("Given the need ensure all validation tags are working for update account.") t.Log("Given the need ensure all validation tags are working for update user account.")
{ {
for i, tt := range accountTests { for i, tt := range accountTests {
t.Logf("\tTest: %d\tWhen running test: %s", i, tt.name) t.Logf("\tTest: %d\tWhen running test: %s", i, tt.name)
{ {
ctx := tests.Context() ctx := tests.Context()
err := UpdateAccount(ctx, auth.Claims{}, test.MasterDB, tt.req, now) err := Update(ctx, auth.Claims{}, test.MasterDB, tt.req, now)
if err != tt.error { if err != tt.error {
// TODO: need a better way to handle validation errors as they are // TODO: need a better way to handle validation errors as they are
// of type interface validator.ValidationErrorsTranslations // of type interface validator.ValidationErrorsTranslations
@ -338,30 +391,31 @@ func TestUpdateAccountValidation(t *testing.T) {
if errStr != expectStr { if errStr != expectStr {
t.Logf("\t\tGot : %+v", err) t.Logf("\t\tGot : %+v", err)
t.Logf("\t\tWant: %+v", tt.error) t.Logf("\t\tWant: %+v", tt.error)
t.Fatalf("\t%s\tUpdateAccount failed.", tests.Failed) t.Fatalf("\t%s\tUpdate user account failed.", tests.Failed)
} }
} }
// If there was an error that was expected, then don't go any further // If there was an error that was expected, then don't go any further
if tt.error != nil { if tt.error != nil {
t.Logf("\t%s\tUpdateAccount ok.", tests.Success) t.Logf("\t%s\tUpdate user account ok.", tests.Success)
continue continue
} }
t.Logf("\t%s\tUpdateAccount ok.", tests.Success) t.Logf("\t%s\tUpdate user account ok.", tests.Success)
} }
} }
} }
} }
// TestAccountCrud validates the full set of CRUD operations for user accounts and // TestCrud validates the full set of CRUD operations for user accounts and
// ensures ACLs are correctly applied by claims. // ensures ACLs are correctly applied by claims.
func TestAccountCrud(t *testing.T) { func TestCrud(t *testing.T) {
defer tests.Recover(t) defer tests.Recover(t)
type accountTest struct { type accountTest struct {
name string name string
claims func(string, string) auth.Claims claims func(string, string) auth.Claims
createErr error
updateErr error updateErr error
findErr error findErr error
} }
@ -375,9 +429,10 @@ func TestAccountCrud(t *testing.T) {
}, },
nil, nil,
nil, nil,
nil,
}) })
// Role of user but claim user does not match update user so forbidden. // Role of user but claim user does not match update user so forbidden for update.
accountTests = append(accountTests, accountTest{"RoleUserDiffUser", accountTests = append(accountTests, accountTest{"RoleUserDiffUser",
func(userID, accountId string) auth.Claims { func(userID, accountId string) auth.Claims {
return auth.Claims{ return auth.Claims{
@ -389,7 +444,8 @@ func TestAccountCrud(t *testing.T) {
} }
}, },
ErrForbidden, ErrForbidden,
ErrNotFound, ErrForbidden,
ErrForbidden,
}) })
// Role of user AND claim user matches update user so OK. // Role of user AND claim user matches update user so OK.
@ -403,7 +459,8 @@ func TestAccountCrud(t *testing.T) {
}, },
} }
}, },
nil, ErrForbidden,
ErrForbidden,
nil, nil,
}) })
@ -419,6 +476,7 @@ func TestAccountCrud(t *testing.T) {
} }
}, },
ErrForbidden, ErrForbidden,
ErrForbidden,
ErrNotFound, ErrNotFound,
}) })
@ -435,6 +493,7 @@ func TestAccountCrud(t *testing.T) {
}, },
nil, nil,
nil, nil,
nil,
}) })
t.Log("Given the need to validate CRUD functionality for user accounts and ensure claims are applied as ACL.") t.Log("Given the need to validate CRUD functionality for user accounts and ensure claims are applied as ACL.")
@ -444,100 +503,105 @@ func TestAccountCrud(t *testing.T) {
for i, tt := range accountTests { for i, tt := range accountTests {
t.Logf("\tTest: %d\tWhen running test: %s", i, tt.name) t.Logf("\tTest: %d\tWhen running test: %s", i, tt.name)
{ {
// Always create the new user with empty claims, testing claims for create user
// will be handled separately.
user, err := Create(tests.Context(), auth.Claims{}, test.MasterDB, CreateUserRequest{
Name: "Lee Brown",
Email: uuid.NewRandom().String() + "@geeksinthewoods.com",
Password: "akTechFr0n!ier",
PasswordConfirm: "akTechFr0n!ier",
}, now)
if err != nil {
t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tCreate user failed.", tests.Failed)
}
// Create a new random account and associate that with the user. // Create a new random account and associate that with the user.
userID := uuid.NewRandom().String()
accountID := uuid.NewRandom().String() accountID := uuid.NewRandom().String()
createReq := AddAccountRequest{ createReq := CreateUserAccountRequest{
UserID: user.ID, UserID: userID,
AccountID: accountID, AccountID: accountID,
Roles: []UserAccountRole{UserAccountRole_User}, Roles: []UserAccountRole{UserAccountRole_User},
} }
ua, err := AddAccount(tests.Context(), tt.claims(user.ID, accountID), test.MasterDB, createReq, now) ua, err := Create(tests.Context(), tt.claims(userID, accountID), test.MasterDB, createReq, now)
if err != nil && errors.Cause(err) != tt.updateErr { if err != nil && errors.Cause(err) != tt.createErr {
t.Logf("\t\tGot : %+v", err) t.Logf("\t\tGot : %+v", err)
t.Logf("\t\tWant: %+v", tt.updateErr) t.Logf("\t\tWant: %+v", tt.createErr)
t.Fatalf("\t%s\tUpdateAccount failed.", tests.Failed) t.Fatalf("\t%s\tCreate user account failed.", tests.Failed)
} else if tt.updateErr == nil { } else if tt.createErr == nil {
if diff := cmp.Diff(ua.Roles, createReq.Roles); diff != "" { if diff := cmp.Diff(ua.Roles, createReq.Roles); diff != "" {
t.Fatalf("\t%s\tExpected find result to match update. Diff:\n%s", tests.Failed, diff) t.Fatalf("\t%s\tExpected user account roles result to match for create. Diff:\n%s", tests.Failed, diff)
}
t.Logf("\t%s\tCreate user account ok.", tests.Success)
}
if tt.createErr == ErrForbidden {
ua, err = Create(tests.Context(), auth.Claims{}, test.MasterDB, createReq, now)
if err != nil && errors.Cause(err) != tt.createErr {
t.Logf("\t\tGot : %+v", err)
t.Fatalf("\t%s\tCreate user account failed.", tests.Failed)
} }
t.Logf("\t%s\tAddAccount ok.", tests.Success)
} }
// Update the account. // Update the account.
updateReq := UpdateAccountRequest{ updateReq := UpdateUserAccountRequest{
UserID: user.ID, UserID: userID,
AccountID: accountID, AccountID: accountID,
Roles: &UserAccountRoles{UserAccountRole_Admin}, Roles: &UserAccountRoles{UserAccountRole_Admin},
} }
err = UpdateAccount(tests.Context(), tt.claims(user.ID, accountID), test.MasterDB, updateReq, now) err = Update(tests.Context(), tt.claims(userID, accountID), test.MasterDB, updateReq, now)
if err != nil && errors.Cause(err) != tt.updateErr { if err != nil {
if errors.Cause(err) != tt.updateErr {
t.Logf("\t\tGot : %+v", err) t.Logf("\t\tGot : %+v", err)
t.Logf("\t\tWant: %+v", tt.updateErr) t.Logf("\t\tWant: %+v", tt.updateErr)
t.Fatalf("\t%s\tUpdateAccount failed.", tests.Failed) t.Fatalf("\t%s\tUpdate user account failed.", tests.Failed)
} }
t.Logf("\t%s\tUpdateAccount ok.", tests.Success) } else {
ua.Roles = *updateReq.Roles
}
t.Logf("\t%s\tUpdate user account ok.", tests.Success)
// Find the account for the user to verify the updates where made. There should only // Find the account for the user to verify the updates where made. There should only
// be one account associated with the user for this test. // be one account associated with the user for this test.
findRes, err := FindAccountsByUserID(tests.Context(), tt.claims(user.ID, accountID), test.MasterDB, user.ID, false) ff := "user_id = ? or account_id = ?"
findRes, err := Find(tests.Context(), tt.claims(userID, accountID), test.MasterDB, UserAccountFindRequest{
Where: &ff,
Args: []interface{}{userID, accountID},
Order: []string{"created_at"},
})
if err != nil && errors.Cause(err) != tt.findErr { if err != nil && errors.Cause(err) != tt.findErr {
t.Logf("\t\tGot : %+v", err) t.Logf("\t\tGot : %+v", err)
t.Logf("\t\tWant: %+v", tt.findErr) t.Logf("\t\tWant: %+v", tt.findErr)
t.Fatalf("\t%s\tVerify UpdateAccount failed.", tests.Failed) t.Fatalf("\t%s\tVerify update user account failed.", tests.Failed)
} else if tt.findErr == nil { } else if tt.findErr == nil {
expected := []*UserAccount{ expected := []*UserAccount{
&UserAccount{ &UserAccount{
ID: ua.ID, ID: ua.ID,
UserID: ua.UserID, UserID: ua.UserID,
AccountID: ua.AccountID, AccountID: ua.AccountID,
Roles: *updateReq.Roles, Roles: ua.Roles,
Status: ua.Status, Status: ua.Status,
CreatedAt: ua.CreatedAt, CreatedAt: ua.CreatedAt,
UpdatedAt: now, UpdatedAt: now,
}, },
} }
if diff := cmp.Diff(findRes, expected); diff != "" { if diff := cmp.Diff(findRes, expected); diff != "" {
t.Fatalf("\t%s\tExpected find result to match update. Diff:\n%s", tests.Failed, diff) t.Fatalf("\t%s\tExpected user account find result to match update. Diff:\n%s", tests.Failed, diff)
} }
t.Logf("\t%s\tVerify UpdateAccount ok.", tests.Success) t.Logf("\t%s\tVerify update user account ok.", tests.Success)
} }
// Archive (soft-delete) the user account. // Archive (soft-delete) the user account.
err = RemoveAccount(tests.Context(), tt.claims(user.ID, accountID), test.MasterDB, RemoveAccountRequest{ err = Archive(tests.Context(), tt.claims(userID, accountID), test.MasterDB, ArchiveUserAccountRequest{
UserID: user.ID, UserID: userID,
AccountID: accountID, AccountID: accountID,
}, now) }, now)
if err != nil && errors.Cause(err) != tt.updateErr { if err != nil && errors.Cause(err) != tt.updateErr {
t.Logf("\t\tGot : %+v", err) t.Logf("\t\tGot : %+v", err)
t.Logf("\t\tWant: %+v", tt.updateErr) t.Logf("\t\tWant: %+v", tt.updateErr)
t.Fatalf("\t%s\tRemoveAccount failed.", tests.Failed) t.Fatalf("\t%s\tArchive user account failed.", tests.Failed)
} else if tt.updateErr == nil { } else if tt.updateErr == nil {
// Trying to find the archived user with the includeArchived false should result in not found. // Trying to find the archived user with the includeArchived false should result in not found.
_, err = FindAccountsByUserID(tests.Context(), tt.claims(user.ID, accountID), test.MasterDB, user.ID, false) _, err = FindByUserID(tests.Context(), tt.claims(userID, accountID), test.MasterDB, userID, false)
if errors.Cause(err) != ErrNotFound { if errors.Cause(err) != ErrNotFound {
t.Logf("\t\tGot : %+v", err) t.Logf("\t\tGot : %+v", err)
t.Logf("\t\tWant: %+v", ErrNotFound) t.Logf("\t\tWant: %+v", ErrNotFound)
t.Fatalf("\t%s\tVerify RemoveAccount failed when excluding archived.", tests.Failed) t.Fatalf("\t%s\tVerify archive user account failed when excluding archived.", tests.Failed)
} }
// Trying to find the archived user with the includeArchived true should result no error. // Trying to find the archived user with the includeArchived true should result no error.
findRes, err = FindAccountsByUserID(tests.Context(), tt.claims(user.ID, accountID), test.MasterDB, user.ID, true) findRes, err = FindByUserID(tests.Context(), tt.claims(userID, accountID), test.MasterDB, userID, true)
if err != nil { if err != nil {
t.Logf("\t\tGot : %+v", err) t.Logf("\t\tGot : %+v", err)
t.Fatalf("\t%s\tVerify RemoveAccount failed when including archived.", tests.Failed) t.Fatalf("\t%s\tVerify archive user account failed when including archived.", tests.Failed)
} }
expected := []*UserAccount{ expected := []*UserAccount{
@ -553,37 +617,37 @@ func TestAccountCrud(t *testing.T) {
}, },
} }
if diff := cmp.Diff(findRes, expected); diff != "" { if diff := cmp.Diff(findRes, expected); diff != "" {
t.Fatalf("\t%s\tExpected find result to be archived. Diff:\n%s", tests.Failed, diff) t.Fatalf("\t%s\tExpected user account find result to be archived. Diff:\n%s", tests.Failed, diff)
} }
} }
t.Logf("\t%s\tRemoveAccount ok.", tests.Success) t.Logf("\t%s\tArchive user account ok.", tests.Success)
// Delete (hard-delete) the user account. // Delete (hard-delete) the user account.
err = DeleteAccount(tests.Context(), tt.claims(user.ID, accountID), test.MasterDB, DeleteAccountRequest{ err = Delete(tests.Context(), tt.claims(userID, accountID), test.MasterDB, DeleteUserAccountRequest{
UserID: user.ID, UserID: userID,
AccountID: accountID, AccountID: accountID,
}) })
if err != nil && errors.Cause(err) != tt.updateErr { if err != nil && errors.Cause(err) != tt.updateErr {
t.Logf("\t\tGot : %+v", err) t.Logf("\t\tGot : %+v", err)
t.Logf("\t\tWant: %+v", tt.updateErr) t.Logf("\t\tWant: %+v", tt.updateErr)
t.Fatalf("\t%s\tDeleteAccount failed.", tests.Failed) t.Fatalf("\t%s\tDelete user account failed.", tests.Failed)
} else if tt.updateErr == nil { } else if tt.updateErr == nil {
// Trying to find the deleted user with the includeArchived true should result in not found. // Trying to find the deleted user with the includeArchived true should result in not found.
_, err = FindAccountsByUserID(tests.Context(), tt.claims(user.ID, accountID), test.MasterDB, user.ID, true) _, err = FindByUserID(tests.Context(), tt.claims(userID, accountID), test.MasterDB, userID, true)
if errors.Cause(err) != ErrNotFound { if errors.Cause(err) != ErrNotFound {
t.Logf("\t\tGot : %+v", err) t.Logf("\t\tGot : %+v", err)
t.Logf("\t\tWant: %+v", ErrNotFound) t.Logf("\t\tWant: %+v", ErrNotFound)
t.Fatalf("\t%s\tVerify DeleteAccount failed when including archived.", tests.Failed) t.Fatalf("\t%s\tVerify delete user account failed when including archived.", tests.Failed)
} }
} }
t.Logf("\t%s\tDeleteAccount ok.", tests.Success) t.Logf("\t%s\tDelete user account ok.", tests.Success)
} }
} }
} }
} }
// TestAccountFind validates all the request params are correctly parsed into a select query. // TestFind validates all the request params are correctly parsed into a select query.
func TestAccountFind(t *testing.T) { func TestFind(t *testing.T) {
now := time.Now().Add(time.Hour * -2).UTC() now := time.Now().Add(time.Hour * -2).UTC()
@ -592,31 +656,21 @@ func TestAccountFind(t *testing.T) {
var userAccounts []*UserAccount var userAccounts []*UserAccount
for i := 0; i <= 4; i++ { for i := 0; i <= 4; i++ {
user, err := Create(tests.Context(), auth.Claims{}, test.MasterDB, CreateUserRequest{
Name: "Lee Brown",
Email: uuid.NewRandom().String() + "@geeksinthewoods.com",
Password: "akTechFr0n!ier",
PasswordConfirm: "akTechFr0n!ier",
}, now.Add(time.Second*time.Duration(i)))
if err != nil {
t.Logf("\t\tGot : %+v", err)
t.Fatalf("\t%s\tCreate user failed.", tests.Failed)
}
// Create a new random account and associate that with the user. // Create a new random account and associate that with the user.
userID := uuid.NewRandom().String()
accountID := uuid.NewRandom().String() accountID := uuid.NewRandom().String()
ua, err := AddAccount(tests.Context(), auth.Claims{}, test.MasterDB, AddAccountRequest{ ua, err := Create(tests.Context(), auth.Claims{}, test.MasterDB, CreateUserAccountRequest{
UserID: user.ID, UserID: userID,
AccountID: accountID, AccountID: accountID,
Roles: []UserAccountRole{UserAccountRole_User}, Roles: []UserAccountRole{UserAccountRole_User},
}, now.Add(time.Second*time.Duration(i))) }, now.Add(time.Second*time.Duration(i)))
if err != nil { if err != nil {
t.Logf("\t\tGot : %+v", err) t.Logf("\t\tGot : %+v", err)
t.Fatalf("\t%s\tAdd account failed.", tests.Failed) t.Fatalf("\t%s\tCreate user account failed.", tests.Failed)
} }
userAccounts = append(userAccounts, ua) userAccounts = append(userAccounts, ua)
endTime = user.CreatedAt endTime = ua.CreatedAt
} }
type accountTest struct { type accountTest struct {
@ -708,24 +762,24 @@ func TestAccountFind(t *testing.T) {
nil, nil,
}) })
t.Log("Given the need to ensure find users returns the expected results.") t.Log("Given the need to ensure find user accounts returns the expected results.")
{ {
for i, tt := range accountTests { for i, tt := range accountTests {
t.Logf("\tTest: %d\tWhen running test: %s", i, tt.name) t.Logf("\tTest: %d\tWhen running test: %s", i, tt.name)
{ {
ctx := tests.Context() ctx := tests.Context()
res, err := FindAccounts(ctx, auth.Claims{}, test.MasterDB, tt.req) res, err := Find(ctx, auth.Claims{}, test.MasterDB, tt.req)
if err != nil && errors.Cause(err) != tt.error { if errors.Cause(err) != tt.error {
t.Logf("\t\tGot : %+v", err) t.Logf("\t\tGot : %+v", err)
t.Logf("\t\tWant: %+v", tt.error) t.Logf("\t\tWant: %+v", tt.error)
t.Fatalf("\t%s\tFind failed.", tests.Failed) t.Fatalf("\t%s\tFind user account failed.", tests.Failed)
} else if diff := cmp.Diff(res, tt.expected); diff != "" { } else if diff := cmp.Diff(res, tt.expected); diff != "" {
t.Logf("\t\tGot: %d items", len(res)) t.Logf("\t\tGot: %d items", len(res))
t.Logf("\t\tWant: %d items", len(tt.expected)) t.Logf("\t\tWant: %d items", len(tt.expected))
t.Fatalf("\t%s\tExpected find result to match expected. Diff:\n%s", tests.Failed, diff) t.Fatalf("\t%s\tExpected user account find result to match expected. Diff:\n%s", tests.Failed, diff)
} }
t.Logf("\t%s\tFind ok.", tests.Success) t.Logf("\t%s\tFind user account ok.", tests.Success)
} }
} }
} }