2019-06-24 22:41:21 -08:00
|
|
|
package user_account
|
2019-05-27 02:44:40 -05:00
|
|
|
|
|
|
|
import (
|
|
|
|
"context"
|
|
|
|
"database/sql"
|
|
|
|
"time"
|
|
|
|
|
2019-08-02 15:03:32 -08:00
|
|
|
"geeks-accelerator/oss/saas-starter-kit/internal/account"
|
2019-07-13 12:16:28 -08:00
|
|
|
"geeks-accelerator/oss/saas-starter-kit/internal/platform/auth"
|
2019-08-02 15:03:32 -08:00
|
|
|
"geeks-accelerator/oss/saas-starter-kit/internal/platform/web/webcontext"
|
2019-08-14 11:40:26 -08:00
|
|
|
"geeks-accelerator/oss/saas-starter-kit/internal/user"
|
2019-05-27 02:44:40 -05:00
|
|
|
"github.com/huandu/go-sqlbuilder"
|
|
|
|
"github.com/jmoiron/sqlx"
|
|
|
|
"github.com/pborman/uuid"
|
|
|
|
"github.com/pkg/errors"
|
|
|
|
"gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer"
|
|
|
|
)
|
|
|
|
|
2019-06-22 10:05:03 -08:00
|
|
|
var (
|
|
|
|
// ErrNotFound abstracts the mgo not found error.
|
|
|
|
ErrNotFound = errors.New("Entity not found")
|
|
|
|
|
|
|
|
// ErrForbidden occurs when a user tries to do something that is forbidden to them according to our access control policies.
|
|
|
|
ErrForbidden = errors.New("Attempted action is not allowed")
|
|
|
|
)
|
|
|
|
|
2019-05-27 02:44:40 -05:00
|
|
|
// The database table for UserAccount
|
2019-06-22 10:05:03 -08:00
|
|
|
const userAccountTableName = "users_accounts"
|
2019-05-27 02:44:40 -05:00
|
|
|
|
2019-08-02 15:03:32 -08:00
|
|
|
// The database table for User
|
|
|
|
const userTableName = "users"
|
|
|
|
|
2019-05-27 02:44:40 -05:00
|
|
|
// The list of columns needed for mapRowsToUserAccount
|
2019-08-04 14:48:43 -08:00
|
|
|
var userAccountMapColumns = "user_id,account_id,roles,status,created_at,updated_at,archived_at"
|
2019-05-27 02:44:40 -05:00
|
|
|
|
|
|
|
// mapRowsToUserAccount takes the SQL rows and maps it to the UserAccount struct
|
2019-06-22 10:05:03 -08:00
|
|
|
// with the columns defined by userAccountMapColumns
|
2019-05-27 02:44:40 -05:00
|
|
|
func mapRowsToUserAccount(rows *sql.Rows) (*UserAccount, error) {
|
|
|
|
var (
|
|
|
|
ua UserAccount
|
|
|
|
err error
|
|
|
|
)
|
2019-08-04 14:48:43 -08:00
|
|
|
err = rows.Scan(&ua.UserID, &ua.AccountID, &ua.Roles, &ua.Status, &ua.CreatedAt, &ua.UpdatedAt, &ua.ArchivedAt)
|
2019-05-27 02:44:40 -05:00
|
|
|
if err != nil {
|
|
|
|
return nil, errors.WithStack(err)
|
|
|
|
}
|
|
|
|
|
|
|
|
return &ua, nil
|
|
|
|
}
|
|
|
|
|
2019-06-22 17:48:44 -08:00
|
|
|
// CanReadAccount determines if claims has the authority to access the specified user account by user ID.
|
2019-08-13 23:41:06 -08:00
|
|
|
func (repo *Repository) CanReadAccount(ctx context.Context, claims auth.Claims, accountID string) error {
|
2019-08-14 11:40:26 -08:00
|
|
|
err := account.CanReadAccount(ctx, claims, repo.DbConn, accountID)
|
2019-06-22 17:48:44 -08:00
|
|
|
return mapAccountError(err)
|
2019-06-22 10:05:03 -08:00
|
|
|
}
|
|
|
|
|
2019-06-22 17:48:44 -08:00
|
|
|
// CanModifyAccount determines if claims has the authority to modify the specified user ID.
|
2019-08-13 23:41:06 -08:00
|
|
|
func (repo *Repository) CanModifyAccount(ctx context.Context, claims auth.Claims, accountID string) error {
|
2019-08-14 11:40:26 -08:00
|
|
|
err := account.CanModifyAccount(ctx, claims, repo.DbConn, accountID)
|
2019-06-22 17:48:44 -08:00
|
|
|
return mapAccountError(err)
|
|
|
|
}
|
2019-05-29 03:35:08 -05:00
|
|
|
|
2019-06-22 17:48:44 -08:00
|
|
|
// mapAccountError maps account errors to local defined errors.
|
|
|
|
func mapAccountError(err error) error {
|
|
|
|
switch errors.Cause(err) {
|
|
|
|
case account.ErrNotFound:
|
|
|
|
err = ErrNotFound
|
|
|
|
case account.ErrForbidden:
|
|
|
|
err = ErrForbidden
|
|
|
|
}
|
|
|
|
return err
|
2019-05-29 03:35:08 -05:00
|
|
|
}
|
|
|
|
|
2019-06-22 17:48:44 -08:00
|
|
|
// applyClaimsSelect applies a sub-query to the provided query
|
|
|
|
// to enforce ACL based on the claims provided.
|
|
|
|
// 1. All role types can access their user ID
|
|
|
|
// 2. Any user with the same account ID
|
|
|
|
// 3. No claims, request is internal, no ACL applied
|
|
|
|
func applyClaimsSelect(ctx context.Context, claims auth.Claims, query *sqlbuilder.SelectBuilder) error {
|
2019-05-27 02:44:40 -05:00
|
|
|
if claims.Audience == "" && claims.Subject == "" {
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
|
|
|
// Build select statement for users_accounts table
|
2019-06-22 17:48:44 -08:00
|
|
|
subQuery := sqlbuilder.NewSelectBuilder().Select("id").From(userAccountTableName)
|
2019-05-27 02:44:40 -05:00
|
|
|
|
|
|
|
var or []string
|
|
|
|
if claims.Audience != "" {
|
|
|
|
or = append(or, subQuery.Equal("account_id", claims.Audience))
|
|
|
|
}
|
|
|
|
if claims.Subject != "" {
|
|
|
|
or = append(or, subQuery.Equal("user_id", claims.Subject))
|
|
|
|
}
|
|
|
|
|
|
|
|
// Append sub query
|
2019-07-15 16:05:02 -08:00
|
|
|
if len(or) > 0 {
|
|
|
|
subQuery.Where(subQuery.Or(or...))
|
|
|
|
query.Where(query.In("id", subQuery))
|
|
|
|
}
|
2019-05-27 02:44:40 -05:00
|
|
|
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
2019-06-22 17:48:44 -08:00
|
|
|
// selectQuery constructs a base select query for User Account
|
|
|
|
func selectQuery() *sqlbuilder.SelectBuilder {
|
2019-05-27 02:44:40 -05:00
|
|
|
query := sqlbuilder.NewSelectBuilder()
|
2019-06-22 10:05:03 -08:00
|
|
|
query.Select(userAccountMapColumns)
|
|
|
|
query.From(userAccountTableName)
|
2019-05-27 02:44:40 -05:00
|
|
|
return query
|
|
|
|
}
|
|
|
|
|
2019-06-22 17:48:44 -08:00
|
|
|
// findRequestQuery generates the select query for the given find request.
|
2019-05-28 04:44:01 -05:00
|
|
|
// TODO: Need to figure out why can't parse the args when appending the where
|
|
|
|
// to the query.
|
2019-06-22 17:48:44 -08:00
|
|
|
func findRequestQuery(req UserAccountFindRequest) (*sqlbuilder.SelectBuilder, []interface{}) {
|
|
|
|
query := selectQuery()
|
2019-08-05 17:12:28 -08:00
|
|
|
if req.Where != "" {
|
|
|
|
query.Where(query.And(req.Where))
|
2019-05-27 02:44:40 -05:00
|
|
|
}
|
|
|
|
if len(req.Order) > 0 {
|
|
|
|
query.OrderBy(req.Order...)
|
|
|
|
}
|
|
|
|
if req.Limit != nil {
|
|
|
|
query.Limit(int(*req.Limit))
|
|
|
|
}
|
|
|
|
if req.Offset != nil {
|
2019-05-29 03:35:08 -05:00
|
|
|
query.Offset(int(*req.Offset))
|
2019-05-27 02:44:40 -05:00
|
|
|
}
|
|
|
|
|
2019-05-28 04:44:01 -05:00
|
|
|
return query, req.Args
|
2019-05-27 02:44:40 -05:00
|
|
|
}
|
|
|
|
|
2019-06-22 17:48:44 -08:00
|
|
|
// Find gets all the user accounts from the database based on the request params.
|
2019-08-14 11:40:26 -08:00
|
|
|
func (repo *Repository) Find(ctx context.Context, claims auth.Claims, req UserAccountFindRequest) (UserAccounts, error) {
|
2019-06-22 17:48:44 -08:00
|
|
|
query, args := findRequestQuery(req)
|
2019-08-14 11:40:26 -08:00
|
|
|
return find(ctx, claims, repo.DbConn, query, args, req.IncludeArchived)
|
2019-05-27 02:44:40 -05:00
|
|
|
}
|
|
|
|
|
2019-06-22 10:05:03 -08:00
|
|
|
// Find gets all the user accounts from the database based on the select query
|
2019-08-04 21:28:02 -08:00
|
|
|
func find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, query *sqlbuilder.SelectBuilder, args []interface{}, includedArchived bool) (UserAccounts, error) {
|
2019-06-22 10:05:03 -08:00
|
|
|
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user_account.Find")
|
2019-05-27 02:44:40 -05:00
|
|
|
defer span.Finish()
|
|
|
|
|
2019-06-22 10:05:03 -08:00
|
|
|
query.Select(userAccountMapColumns)
|
|
|
|
query.From(userAccountTableName)
|
2019-05-27 02:44:40 -05:00
|
|
|
|
|
|
|
if !includedArchived {
|
|
|
|
query.Where(query.IsNull("archived_at"))
|
|
|
|
}
|
|
|
|
|
|
|
|
// Check to see if a sub query needs to be applied for the claims
|
2019-06-22 17:48:44 -08:00
|
|
|
err := applyClaimsSelect(ctx, claims, query)
|
2019-05-27 02:44:40 -05:00
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
2019-05-28 04:44:01 -05:00
|
|
|
queryStr, queryArgs := query.Build()
|
|
|
|
queryStr = dbConn.Rebind(queryStr)
|
|
|
|
args = append(args, queryArgs...)
|
2019-05-27 02:44:40 -05:00
|
|
|
|
|
|
|
// fetch all places from the db
|
2019-05-28 04:44:01 -05:00
|
|
|
rows, err := dbConn.QueryContext(ctx, queryStr, args...)
|
2019-05-27 02:44:40 -05:00
|
|
|
if err != nil {
|
|
|
|
err = errors.Wrapf(err, "query - %s", query.String())
|
2019-06-22 10:05:03 -08:00
|
|
|
err = errors.WithMessage(err, "find user accounts failed")
|
2019-05-27 02:44:40 -05:00
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
|
|
|
|
// iterate over each row
|
|
|
|
resp := []*UserAccount{}
|
|
|
|
for rows.Next() {
|
|
|
|
ua, err := mapRowsToUserAccount(rows)
|
|
|
|
if err != nil {
|
|
|
|
err = errors.Wrapf(err, "query - %s", query.String())
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
resp = append(resp, ua)
|
|
|
|
}
|
|
|
|
|
|
|
|
return resp, nil
|
|
|
|
}
|
|
|
|
|
|
|
|
// Retrieve gets the specified user from the database.
|
2019-08-14 11:40:26 -08:00
|
|
|
func (repo *Repository) FindByUserID(ctx context.Context, claims auth.Claims, userID string, includedArchived bool) (UserAccounts, error) {
|
2019-06-22 10:05:03 -08:00
|
|
|
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user_account.FindByUserID")
|
2019-05-27 02:44:40 -05:00
|
|
|
defer span.Finish()
|
|
|
|
|
|
|
|
// Filter base select query by ID
|
|
|
|
query := sqlbuilder.NewSelectBuilder()
|
|
|
|
query.Where(query.Equal("user_id", userID))
|
2019-05-29 15:05:17 -05:00
|
|
|
query.OrderBy("created_at")
|
2019-05-27 02:44:40 -05:00
|
|
|
|
|
|
|
// Execute the find accounts method.
|
2019-08-14 11:40:26 -08:00
|
|
|
res, err := find(ctx, claims, repo.DbConn, query, []interface{}{}, includedArchived)
|
2019-05-27 02:44:40 -05:00
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
2019-05-29 03:35:08 -05:00
|
|
|
} else if res == nil || len(res) == 0 {
|
|
|
|
err = errors.WithMessagef(ErrNotFound, "no accounts for user %s found", userID)
|
|
|
|
return nil, err
|
2019-05-27 02:44:40 -05:00
|
|
|
}
|
|
|
|
|
|
|
|
return res, nil
|
|
|
|
}
|
|
|
|
|
2019-06-22 17:48:44 -08:00
|
|
|
// Create a user account for a given user with specified roles.
|
2019-08-14 11:40:26 -08:00
|
|
|
func (repo *Repository) Create(ctx context.Context, claims auth.Claims, req UserAccountCreateRequest, now time.Time) (*UserAccount, error) {
|
2019-06-22 10:05:03 -08:00
|
|
|
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user_account.Create")
|
2019-05-27 02:44:40 -05:00
|
|
|
defer span.Finish()
|
|
|
|
|
|
|
|
// Validate the request.
|
2019-07-31 13:47:30 -08:00
|
|
|
v := webcontext.Validator()
|
2019-06-26 20:21:00 -08:00
|
|
|
err := v.Struct(req)
|
2019-05-27 02:44:40 -05:00
|
|
|
if err != nil {
|
2019-05-29 03:35:08 -05:00
|
|
|
return nil, err
|
2019-05-27 02:44:40 -05:00
|
|
|
}
|
|
|
|
|
2019-06-22 17:48:44 -08:00
|
|
|
// Ensure the claims can modify the account specified in the request.
|
2019-08-14 11:40:26 -08:00
|
|
|
err = repo.CanModifyAccount(ctx, claims, req.AccountID)
|
2019-05-27 02:44:40 -05:00
|
|
|
if err != nil {
|
2019-05-29 03:35:08 -05:00
|
|
|
return nil, err
|
2019-05-27 02:44:40 -05:00
|
|
|
}
|
|
|
|
|
|
|
|
// If now empty set it to the current time.
|
|
|
|
if now.IsZero() {
|
|
|
|
now = time.Now()
|
|
|
|
}
|
|
|
|
|
|
|
|
// Always store the time as UTC.
|
|
|
|
now = now.UTC()
|
|
|
|
|
|
|
|
// Postgres truncates times to milliseconds when storing. We and do the same
|
|
|
|
// here so the value we return is consistent with what we store.
|
|
|
|
now = now.Truncate(time.Millisecond)
|
|
|
|
|
|
|
|
// Check to see if there is an existing user account, including archived.
|
2019-06-22 17:48:44 -08:00
|
|
|
existQuery := selectQuery()
|
2019-05-27 02:44:40 -05:00
|
|
|
existQuery.Where(existQuery.And(
|
|
|
|
existQuery.Equal("account_id", req.AccountID),
|
|
|
|
existQuery.Equal("user_id", req.UserID),
|
|
|
|
))
|
2019-08-14 11:40:26 -08:00
|
|
|
existing, err := find(ctx, claims, repo.DbConn, existQuery, []interface{}{}, true)
|
2019-05-27 02:44:40 -05:00
|
|
|
if err != nil {
|
2019-05-29 03:35:08 -05:00
|
|
|
return nil, err
|
2019-05-27 02:44:40 -05:00
|
|
|
}
|
|
|
|
|
|
|
|
// If there is an existing entry, then update instead of insert.
|
2019-06-22 17:48:44 -08:00
|
|
|
var ua UserAccount
|
2019-05-27 02:44:40 -05:00
|
|
|
if len(existing) > 0 {
|
2019-06-26 01:16:57 -08:00
|
|
|
upReq := UserAccountUpdateRequest{
|
2019-05-27 02:44:40 -05:00
|
|
|
UserID: req.UserID,
|
|
|
|
AccountID: req.AccountID,
|
2019-05-29 03:35:08 -05:00
|
|
|
Roles: &req.Roles,
|
2019-05-27 02:44:40 -05:00
|
|
|
unArchive: true,
|
|
|
|
}
|
2019-08-14 11:40:26 -08:00
|
|
|
err = repo.Update(ctx, claims, upReq, now)
|
2019-05-29 03:35:08 -05:00
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
|
2019-06-22 17:48:44 -08:00
|
|
|
ua = *existing[0]
|
2019-05-29 03:35:08 -05:00
|
|
|
ua.Roles = req.Roles
|
|
|
|
ua.UpdatedAt = now
|
2019-06-25 02:40:29 -08:00
|
|
|
ua.ArchivedAt = nil
|
2019-06-22 17:48:44 -08:00
|
|
|
} else {
|
2019-08-04 14:48:43 -08:00
|
|
|
uaID := uuid.NewRandom().String()
|
|
|
|
|
2019-06-22 17:48:44 -08:00
|
|
|
ua = UserAccount{
|
2019-08-04 14:48:43 -08:00
|
|
|
//ID: uaID,
|
2019-06-22 17:48:44 -08:00
|
|
|
UserID: req.UserID,
|
|
|
|
AccountID: req.AccountID,
|
|
|
|
Roles: req.Roles,
|
|
|
|
Status: UserAccountStatus_Active,
|
|
|
|
CreatedAt: now,
|
|
|
|
UpdatedAt: now,
|
|
|
|
}
|
2019-05-29 03:35:08 -05:00
|
|
|
|
2019-06-22 17:48:44 -08:00
|
|
|
if req.Status != nil {
|
|
|
|
ua.Status = *req.Status
|
|
|
|
}
|
2019-05-27 02:44:40 -05:00
|
|
|
|
2019-06-22 17:48:44 -08:00
|
|
|
// Build the insert SQL statement.
|
|
|
|
query := sqlbuilder.NewInsertBuilder()
|
|
|
|
query.InsertInto(userAccountTableName)
|
|
|
|
query.Cols("id", "user_id", "account_id", "roles", "status", "created_at", "updated_at")
|
2019-08-04 14:48:43 -08:00
|
|
|
query.Values(uaID, ua.UserID, ua.AccountID, ua.Roles, ua.Status.String(), ua.CreatedAt, ua.UpdatedAt)
|
2019-05-29 03:35:08 -05:00
|
|
|
|
2019-06-22 17:48:44 -08:00
|
|
|
// Execute the query with the provided context.
|
|
|
|
sql, args := query.Build()
|
2019-08-14 11:40:26 -08:00
|
|
|
sql = repo.DbConn.Rebind(sql)
|
|
|
|
_, err = repo.DbConn.ExecContext(ctx, sql, args...)
|
2019-06-22 17:48:44 -08:00
|
|
|
if err != nil {
|
|
|
|
err = errors.Wrapf(err, "query - %s", query.String())
|
|
|
|
err = errors.WithMessagef(err, "add account %s to user %s failed", req.AccountID, req.UserID)
|
|
|
|
return nil, err
|
|
|
|
}
|
2019-05-29 03:35:08 -05:00
|
|
|
}
|
2019-05-27 02:44:40 -05:00
|
|
|
|
2019-06-22 17:48:44 -08:00
|
|
|
return &ua, nil
|
|
|
|
}
|
2019-05-27 02:44:40 -05:00
|
|
|
|
2019-06-22 17:48:44 -08:00
|
|
|
// Read gets the specified user account from the database.
|
2019-08-14 11:40:26 -08:00
|
|
|
func (repo *Repository) Read(ctx context.Context, claims auth.Claims, req UserAccountReadRequest) (*UserAccount, error) {
|
2019-06-22 17:48:44 -08:00
|
|
|
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user_account.Read")
|
|
|
|
defer span.Finish()
|
|
|
|
|
2019-08-04 14:48:43 -08:00
|
|
|
// Validate the request.
|
|
|
|
v := webcontext.Validator()
|
|
|
|
err := v.Struct(req)
|
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
|
2019-06-22 17:48:44 -08:00
|
|
|
// Filter base select query by ID
|
|
|
|
query := selectQuery()
|
2019-08-04 14:48:43 -08:00
|
|
|
query.Where(query.And(
|
|
|
|
query.Equal("user_id", req.UserID),
|
|
|
|
query.Equal("account_id", req.AccountID)))
|
2019-06-22 17:48:44 -08:00
|
|
|
|
2019-08-14 11:40:26 -08:00
|
|
|
res, err := find(ctx, claims, repo.DbConn, query, []interface{}{}, req.IncludeArchived)
|
2019-08-04 14:48:43 -08:00
|
|
|
if err != nil {
|
2019-06-22 17:48:44 -08:00
|
|
|
return nil, err
|
2019-08-04 14:48:43 -08:00
|
|
|
} else if res == nil || len(res) == 0 {
|
|
|
|
err = errors.WithMessagef(ErrNotFound, "entry for user %s account %s not found", req.UserID, req.AccountID)
|
2019-05-29 03:35:08 -05:00
|
|
|
return nil, err
|
2019-05-27 02:44:40 -05:00
|
|
|
}
|
2019-06-22 17:48:44 -08:00
|
|
|
u := res[0]
|
2019-05-27 02:44:40 -05:00
|
|
|
|
2019-06-22 17:48:44 -08:00
|
|
|
return u, nil
|
2019-05-27 02:44:40 -05:00
|
|
|
}
|
|
|
|
|
2019-06-22 17:48:44 -08:00
|
|
|
// Update replaces a user account in the database.
|
2019-08-14 11:40:26 -08:00
|
|
|
func (repo *Repository) Update(ctx context.Context, claims auth.Claims, req UserAccountUpdateRequest, now time.Time) error {
|
2019-06-22 10:05:03 -08:00
|
|
|
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user_account.Update")
|
2019-05-27 02:44:40 -05:00
|
|
|
defer span.Finish()
|
|
|
|
|
|
|
|
// Validate the request.
|
2019-07-31 13:47:30 -08:00
|
|
|
v := webcontext.Validator()
|
2019-06-26 20:21:00 -08:00
|
|
|
err := v.Struct(req)
|
2019-05-27 02:44:40 -05:00
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
|
|
|
// Ensure the claims can modify the user specified in the request.
|
2019-08-14 11:40:26 -08:00
|
|
|
err = repo.CanModifyAccount(ctx, claims, req.AccountID)
|
2019-05-27 02:44:40 -05:00
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
|
|
|
// If now empty set it to the current time.
|
|
|
|
if now.IsZero() {
|
|
|
|
now = time.Now()
|
|
|
|
}
|
|
|
|
|
|
|
|
// Always store the time as UTC.
|
|
|
|
now = now.UTC()
|
|
|
|
|
|
|
|
// Postgres truncates times to milliseconds when storing. We and do the same
|
|
|
|
// here so the value we return is consistent with what we store.
|
|
|
|
now = now.Truncate(time.Millisecond)
|
|
|
|
|
|
|
|
// Build the update SQL statement.
|
|
|
|
query := sqlbuilder.NewUpdateBuilder()
|
2019-06-22 10:05:03 -08:00
|
|
|
query.Update(userAccountTableName)
|
2019-05-29 03:35:08 -05:00
|
|
|
|
|
|
|
fields := []string{}
|
|
|
|
if req.Roles != nil {
|
|
|
|
fields = append(fields, query.Assign("roles", req.Roles))
|
|
|
|
}
|
|
|
|
if req.Status != nil {
|
|
|
|
fields = append(fields, query.Assign("status", req.Status))
|
|
|
|
}
|
2019-06-22 17:48:44 -08:00
|
|
|
if req.unArchive {
|
|
|
|
fields = append(fields, query.Assign("archived_at", nil))
|
|
|
|
}
|
2019-05-29 03:35:08 -05:00
|
|
|
|
|
|
|
// If there's nothing to update we can quit early.
|
|
|
|
if len(fields) == 0 {
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
|
|
|
// Append the updated_at field
|
|
|
|
fields = append(fields, query.Assign("updated_at", now))
|
|
|
|
|
|
|
|
query.Set(fields...)
|
|
|
|
|
2019-05-27 02:44:40 -05:00
|
|
|
query.Where(query.And(
|
|
|
|
query.Equal("user_id", req.UserID),
|
|
|
|
query.Equal("account_id", req.AccountID),
|
|
|
|
))
|
|
|
|
|
|
|
|
// Execute the query with the provided context.
|
|
|
|
sql, args := query.Build()
|
2019-08-14 11:40:26 -08:00
|
|
|
sql = repo.DbConn.Rebind(sql)
|
|
|
|
_, err = repo.DbConn.ExecContext(ctx, sql, args...)
|
2019-05-27 02:44:40 -05:00
|
|
|
if err != nil {
|
|
|
|
err = errors.Wrapf(err, "query - %s", query.String())
|
|
|
|
err = errors.WithMessagef(err, "update account %s for user %s failed", req.AccountID, req.UserID)
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
2019-06-22 10:05:03 -08:00
|
|
|
// Archive soft deleted the user account from the database.
|
2019-08-14 11:40:26 -08:00
|
|
|
func (repo *Repository) Archive(ctx context.Context, claims auth.Claims, req UserAccountArchiveRequest, now time.Time) error {
|
2019-06-22 10:05:03 -08:00
|
|
|
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user_account.Archive")
|
2019-05-27 02:44:40 -05:00
|
|
|
defer span.Finish()
|
|
|
|
|
|
|
|
// Validate the request.
|
2019-07-31 13:47:30 -08:00
|
|
|
v := webcontext.Validator()
|
2019-06-26 20:21:00 -08:00
|
|
|
err := v.Struct(req)
|
2019-05-27 02:44:40 -05:00
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
|
|
|
// Ensure the claims can modify the user specified in the request.
|
2019-08-14 11:40:26 -08:00
|
|
|
err = repo.CanModifyAccount(ctx, claims, req.AccountID)
|
2019-05-27 02:44:40 -05:00
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
|
|
|
// If now empty set it to the current time.
|
|
|
|
if now.IsZero() {
|
|
|
|
now = time.Now()
|
|
|
|
}
|
|
|
|
|
|
|
|
// Always store the time as UTC.
|
|
|
|
now = now.UTC()
|
|
|
|
|
|
|
|
// Postgres truncates times to milliseconds when storing. We and do the same
|
|
|
|
// here so the value we return is consistent with what we store.
|
|
|
|
now = now.Truncate(time.Millisecond)
|
|
|
|
|
|
|
|
// Build the update SQL statement.
|
|
|
|
query := sqlbuilder.NewUpdateBuilder()
|
2019-06-22 10:05:03 -08:00
|
|
|
query.Update(userAccountTableName)
|
2019-05-27 02:44:40 -05:00
|
|
|
query.Set(query.Assign("archived_at", now))
|
|
|
|
query.Where(query.And(
|
|
|
|
query.Equal("user_id", req.UserID),
|
|
|
|
query.Equal("account_id", req.AccountID),
|
|
|
|
))
|
|
|
|
|
|
|
|
// Execute the query with the provided context.
|
|
|
|
sql, args := query.Build()
|
2019-08-14 11:40:26 -08:00
|
|
|
sql = repo.DbConn.Rebind(sql)
|
|
|
|
_, err = repo.DbConn.ExecContext(ctx, sql, args...)
|
2019-05-27 02:44:40 -05:00
|
|
|
if err != nil {
|
|
|
|
err = errors.Wrapf(err, "query - %s", query.String())
|
2019-06-22 17:48:44 -08:00
|
|
|
err = errors.WithMessagef(err, "archive account %s from user %s failed", req.AccountID, req.UserID)
|
2019-05-27 02:44:40 -05:00
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
2019-06-22 10:05:03 -08:00
|
|
|
// Delete removes a user account from the database.
|
2019-08-14 11:40:26 -08:00
|
|
|
func (repo *Repository) Delete(ctx context.Context, claims auth.Claims, req UserAccountDeleteRequest) error {
|
2019-06-22 10:05:03 -08:00
|
|
|
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user_account.Delete")
|
2019-05-27 02:44:40 -05:00
|
|
|
defer span.Finish()
|
|
|
|
|
|
|
|
// Validate the request.
|
2019-07-31 13:47:30 -08:00
|
|
|
v := webcontext.Validator()
|
2019-06-26 20:21:00 -08:00
|
|
|
err := v.Struct(req)
|
2019-05-27 02:44:40 -05:00
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
|
|
|
// Ensure the claims can modify the user specified in the request.
|
2019-08-14 11:40:26 -08:00
|
|
|
err = repo.CanModifyAccount(ctx, claims, req.AccountID)
|
2019-05-27 02:44:40 -05:00
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
|
|
|
// Build the delete SQL statement.
|
|
|
|
query := sqlbuilder.NewDeleteBuilder()
|
2019-06-22 10:05:03 -08:00
|
|
|
query.DeleteFrom(userAccountTableName)
|
2019-05-27 02:44:40 -05:00
|
|
|
query.Where(query.And(
|
|
|
|
query.Equal("user_id", req.UserID),
|
|
|
|
query.Equal("account_id", req.AccountID),
|
|
|
|
))
|
|
|
|
|
|
|
|
// Execute the query with the provided context.
|
|
|
|
sql, args := query.Build()
|
2019-08-14 11:40:26 -08:00
|
|
|
sql = repo.DbConn.Rebind(sql)
|
|
|
|
_, err = repo.DbConn.ExecContext(ctx, sql, args...)
|
2019-05-27 02:44:40 -05:00
|
|
|
if err != nil {
|
|
|
|
err = errors.Wrapf(err, "query - %s", query.String())
|
|
|
|
err = errors.WithMessagef(err, "delete account %s for user %s failed", req.AccountID, req.UserID)
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
|
|
|
return nil
|
|
|
|
}
|
2019-08-04 14:48:43 -08:00
|
|
|
|
|
|
|
type MockUserAccountResponse struct {
|
|
|
|
*UserAccount
|
|
|
|
User *user.MockUserResponse
|
|
|
|
Account *account.Account
|
|
|
|
}
|
|
|
|
|
|
|
|
// MockUserAccount returns a fake UserAccount for testing.
|
|
|
|
func MockUserAccount(ctx context.Context, dbConn *sqlx.DB, now time.Time, roles ...UserAccountRole) (*MockUserAccountResponse, error) {
|
|
|
|
usr, err := user.MockUser(ctx, dbConn, now)
|
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
|
|
|
|
acc, err := account.MockAccount(ctx, dbConn, now)
|
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
|
2019-08-14 11:40:26 -08:00
|
|
|
repo := &Repository{
|
|
|
|
DbConn: dbConn,
|
|
|
|
}
|
|
|
|
|
2019-08-04 14:48:43 -08:00
|
|
|
status := UserAccountStatus_Active
|
|
|
|
|
|
|
|
req := UserAccountCreateRequest{
|
|
|
|
UserID: usr.ID,
|
|
|
|
AccountID: acc.ID,
|
|
|
|
Status: &status,
|
|
|
|
Roles: roles,
|
|
|
|
}
|
2019-08-14 11:40:26 -08:00
|
|
|
ua, err := repo.Create(ctx, auth.Claims{}, req, now)
|
2019-08-04 14:48:43 -08:00
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
|
|
|
|
return &MockUserAccountResponse{
|
|
|
|
UserAccount: ua,
|
|
|
|
User: usr,
|
|
|
|
Account: acc,
|
|
|
|
}, nil
|
|
|
|
}
|