mirror of
https://github.com/raseels-repos/golang-saas-starter-kit.git
synced 2025-06-04 23:37:49 +02:00
1035 lines
30 KiB
Go
1035 lines
30 KiB
Go
package user
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"time"
|
|
|
|
"geeks-accelerator/oss/saas-starter-kit/internal/platform/auth"
|
|
"geeks-accelerator/oss/saas-starter-kit/internal/platform/notify"
|
|
"geeks-accelerator/oss/saas-starter-kit/internal/platform/web/webcontext"
|
|
"github.com/huandu/go-sqlbuilder"
|
|
"github.com/jmoiron/sqlx"
|
|
"github.com/pborman/uuid"
|
|
"github.com/pkg/errors"
|
|
"golang.org/x/crypto/bcrypt"
|
|
"gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer"
|
|
)
|
|
|
|
const (
|
|
// The database table for User
|
|
userTableName = "users"
|
|
// The database table for Account
|
|
accountTableName = "accounts"
|
|
// The database table for User Account
|
|
userAccountTableName = "users_accounts"
|
|
)
|
|
|
|
var (
|
|
// ErrNotFound abstracts the mgo not found error.
|
|
ErrNotFound = errors.New("Entity not found")
|
|
|
|
// ErrForbidden occurs when a user tries to do something that is forbidden to them according to our access control policies.
|
|
ErrForbidden = errors.New("Attempted action is not allowed")
|
|
|
|
// ErrResetExpired occurs when the the reset hash exceeds the expiration.
|
|
ErrResetExpired = errors.New("Reset expired")
|
|
)
|
|
|
|
// userMapColumns is the list of columns needed for mapRowsToUser
|
|
var userMapColumns = "id,first_name,last_name,email,password_salt,password_hash,password_reset,timezone,created_at,updated_at,archived_at"
|
|
|
|
// mapRowsToUser takes the SQL rows and maps it to the UserAccount struct
|
|
// with the columns defined by userMapColumns
|
|
func mapRowsToUser(rows *sql.Rows) (*User, error) {
|
|
var (
|
|
u User
|
|
err error
|
|
)
|
|
err = rows.Scan(&u.ID, &u.FirstName, &u.LastName, &u.Email, &u.PasswordSalt, &u.PasswordHash, &u.PasswordReset, &u.Timezone, &u.CreatedAt, &u.UpdatedAt, &u.ArchivedAt)
|
|
if err != nil {
|
|
return nil, errors.WithStack(err)
|
|
}
|
|
|
|
return &u, nil
|
|
}
|
|
|
|
// CanReadUser determines if claims has the authority to access the specified user ID.
|
|
func (repo *Repository) CanReadUser(ctx context.Context, claims auth.Claims, userID string) error {
|
|
// If the request has claims from a specific user, ensure that the user
|
|
// has the correct access to the user.
|
|
if claims.Subject != "" && claims.Subject != userID {
|
|
// When the claims Subject/UserId - does not match the requested user, the
|
|
// claims audience - AccountId - should have a record.
|
|
// select id from users_accounts where account_id = [claims.Audience] and user_id = [userID]
|
|
query := sqlbuilder.NewSelectBuilder().Select("id").From(userAccountTableName)
|
|
query.Where(query.And(
|
|
query.Equal("account_id", claims.Audience),
|
|
query.Equal("user_id", userID),
|
|
))
|
|
queryStr, args := query.Build()
|
|
queryStr = repo.DbConn.Rebind(queryStr)
|
|
|
|
var userAccountId string
|
|
err := repo.DbConn.QueryRowContext(ctx, queryStr, args...).Scan(&userAccountId)
|
|
if err != nil && err != sql.ErrNoRows {
|
|
err = errors.Wrapf(err, "query - %s", query.String())
|
|
return err
|
|
}
|
|
|
|
// When there is no userAccount ID returned, then the current user does not have access
|
|
// to the specified user.
|
|
if userAccountId == "" {
|
|
return errors.WithStack(ErrForbidden)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// CanModifyUser determines if claims has the authority to modify the specified user ID.
|
|
func (repo *Repository) CanModifyUser(ctx context.Context, claims auth.Claims, userID string) error {
|
|
// If the request has claims from a specific user, ensure that the user
|
|
// has the correct role for creating a new user.
|
|
if claims.Subject != "" && claims.Subject != userID {
|
|
// Users with the role of admin are ony allows to create users.
|
|
if !claims.HasRole(auth.RoleAdmin) {
|
|
err := errors.WithStack(ErrForbidden)
|
|
return err
|
|
}
|
|
}
|
|
|
|
if err := repo.CanReadUser(ctx, claims, userID); err != nil {
|
|
return err
|
|
}
|
|
|
|
// TODO: Review, this doesn't seem correct, replaced with above.
|
|
/*
|
|
// If the request has claims from a specific account, ensure that the user
|
|
// has the correct access to the account.
|
|
if claims.Subject != "" && claims.Subject != userID {
|
|
// When the claims Audience - AccountID - does not match the requested account, the
|
|
// claims Audience - AccountID - should have a record with an admin role.
|
|
// select id from users_accounts where account_id = [claims.Audience] and user_id = [userID] and any (roles) = 'admin'
|
|
query := sqlbuilder.NewSelectBuilder().Select("id").From(userAccountTableName)
|
|
query.Where(query.And(
|
|
query.Equal("account_id", claims.Audience),
|
|
query.Equal("user_id", userID),
|
|
"'"+auth.RoleAdmin+"' = ANY (roles)",
|
|
))
|
|
queryStr, args := query.Build()
|
|
queryStr = repo.DbConn.Rebind(queryStr)
|
|
|
|
var userAccountId string
|
|
err := repo.DbConn.QueryRowContext(ctx, queryStr, args...).Scan(&userAccountId)
|
|
if err != nil && err != sql.ErrNoRows {
|
|
err = errors.Wrapf(err, "query - %s", query.String())
|
|
return err
|
|
}
|
|
|
|
// When there is no userAccount ID returned, then the current user does not have access
|
|
// to the specified account.
|
|
if userAccountId == "" {
|
|
return errors.WithStack(ErrForbidden)
|
|
}
|
|
}
|
|
*/
|
|
|
|
return nil
|
|
}
|
|
|
|
// applyClaimsSelect applies a sub-query to the provided query to enforce ACL based on
|
|
// the claims provided.
|
|
// 1. All role types can access their user ID
|
|
// 2. Any user with the same account ID
|
|
// 3. No claims, request is internal, no ACL applied
|
|
func applyClaimsSelect(ctx context.Context, claims auth.Claims, query *sqlbuilder.SelectBuilder) error {
|
|
// Claims are empty, don't apply any ACL
|
|
if claims.Audience == "" && claims.Subject == "" {
|
|
return nil
|
|
}
|
|
|
|
// Build select statement for users_accounts table
|
|
subQuery := sqlbuilder.NewSelectBuilder().Select("user_id").From(userAccountTableName)
|
|
|
|
var or []string
|
|
if claims.Audience != "" {
|
|
or = append(or, subQuery.Equal("account_id", claims.Audience))
|
|
}
|
|
if claims.Subject != "" {
|
|
or = append(or, subQuery.Equal("user_id", claims.Subject))
|
|
}
|
|
|
|
// Append sub query
|
|
if len(or) > 0 {
|
|
subQuery.Where(subQuery.Or(or...))
|
|
query.Where(query.In("id", subQuery))
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// selectQuery constructs a base select query for User
|
|
func selectQuery() *sqlbuilder.SelectBuilder {
|
|
query := sqlbuilder.NewSelectBuilder()
|
|
query.Select(userMapColumns)
|
|
query.From(userTableName)
|
|
return query
|
|
}
|
|
|
|
// findRequestQuery generates the select query for the given find request.
|
|
// TODO: Need to figure out why can't parse the args when appending the where
|
|
// to the query.
|
|
func findRequestQuery(req UserFindRequest) (*sqlbuilder.SelectBuilder, []interface{}) {
|
|
query := selectQuery()
|
|
if req.Where != "" {
|
|
query.Where(query.And(req.Where))
|
|
}
|
|
if len(req.Order) > 0 {
|
|
query.OrderBy(req.Order...)
|
|
}
|
|
if req.Limit != nil {
|
|
query.Limit(int(*req.Limit))
|
|
}
|
|
if req.Offset != nil {
|
|
query.Offset(int(*req.Offset))
|
|
}
|
|
|
|
return query, req.Args
|
|
}
|
|
|
|
// Find gets all the users from the database based on the request params.
|
|
func (repo *Repository) Find(ctx context.Context, claims auth.Claims, req UserFindRequest) (Users, error) {
|
|
query, args := findRequestQuery(req)
|
|
return find(ctx, claims, repo.DbConn, query, args, req.IncludeArchived)
|
|
}
|
|
|
|
// find internal method for getting all the users from the database using a select query.
|
|
func find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, query *sqlbuilder.SelectBuilder, args []interface{}, includedArchived bool) (Users, error) {
|
|
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.Find")
|
|
defer span.Finish()
|
|
|
|
query.Select(userMapColumns)
|
|
query.From(userTableName)
|
|
|
|
if !includedArchived {
|
|
query.Where(query.IsNull("archived_at"))
|
|
}
|
|
|
|
// Check to see if a sub query needs to be applied for the claims
|
|
err := applyClaimsSelect(ctx, claims, query)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
queryStr, queryArgs := query.Build()
|
|
queryStr = dbConn.Rebind(queryStr)
|
|
args = append(args, queryArgs...)
|
|
|
|
// fetch all places from the db
|
|
rows, err := dbConn.QueryContext(ctx, queryStr, args...)
|
|
if err != nil {
|
|
err = errors.Wrapf(err, "query - %s", query.String())
|
|
err = errors.WithMessage(err, "find users failed")
|
|
return nil, err
|
|
}
|
|
|
|
// iterate over each row
|
|
resp := []*User{}
|
|
for rows.Next() {
|
|
u, err := mapRowsToUser(rows)
|
|
if err != nil {
|
|
err = errors.Wrapf(err, "query - %s", query.String())
|
|
return nil, err
|
|
}
|
|
resp = append(resp, u)
|
|
}
|
|
|
|
return resp, nil
|
|
}
|
|
|
|
// Validation an email address is unique excluding the current user ID.
|
|
func UniqueEmail(ctx context.Context, dbConn *sqlx.DB, email, userId string) (bool, error) {
|
|
query := sqlbuilder.NewSelectBuilder().Select("id").From(userTableName)
|
|
query.Where(query.And(
|
|
query.Equal("email", email),
|
|
query.NotEqual("id", userId),
|
|
))
|
|
queryStr, args := query.Build()
|
|
queryStr = dbConn.Rebind(queryStr)
|
|
|
|
var existingId string
|
|
err := dbConn.QueryRowContext(ctx, queryStr, args...).Scan(&existingId)
|
|
if err != nil && err != sql.ErrNoRows {
|
|
err = errors.Wrapf(err, "query - %s", query.String())
|
|
return false, err
|
|
}
|
|
|
|
// When an ID was found in the db, the email is not unique.
|
|
if existingId != "" {
|
|
return false, nil
|
|
}
|
|
|
|
return true, nil
|
|
}
|
|
|
|
// Create inserts a new user into the database.
|
|
func (repo *Repository) Create(ctx context.Context, claims auth.Claims, req UserCreateRequest, now time.Time) (*User, error) {
|
|
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.Create")
|
|
defer span.Finish()
|
|
|
|
if req.Timezone != nil && *req.Timezone == "" {
|
|
req.Timezone = nil
|
|
}
|
|
|
|
v := webcontext.Validator()
|
|
|
|
// Validation email address is unique in the database.
|
|
uniq, err := UniqueEmail(ctx, repo.DbConn, req.Email, "")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
ctx = context.WithValue(ctx, webcontext.KeyTagUnique, uniq)
|
|
|
|
// Validate the request.
|
|
err = v.StructCtx(ctx, req)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// If the request has claims from a specific user, ensure that the user
|
|
// has the correct role for creating a new user.
|
|
if claims.Subject != "" {
|
|
// Users with the role of admin are ony allows to create users.
|
|
if !claims.HasRole(auth.RoleAdmin) {
|
|
err = errors.WithStack(ErrForbidden)
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
// If now empty set it to the current time.
|
|
if now.IsZero() {
|
|
now = time.Now()
|
|
}
|
|
|
|
// Always store the time as UTC.
|
|
now = now.UTC()
|
|
|
|
// Postgres truncates times to milliseconds when storing. We and do the same
|
|
// here so the value we return is consistent with what we store.
|
|
now = now.Truncate(time.Millisecond)
|
|
|
|
passwordSalt := uuid.NewRandom().String()
|
|
saltedPassword := req.Password + passwordSalt
|
|
|
|
passwordHash, err := bcrypt.GenerateFromPassword([]byte(saltedPassword), bcrypt.DefaultCost)
|
|
if err != nil {
|
|
return nil, errors.Wrap(err, "generating password hash")
|
|
}
|
|
|
|
u := User{
|
|
ID: uuid.NewRandom().String(),
|
|
FirstName: req.FirstName,
|
|
LastName: req.LastName,
|
|
Email: req.Email,
|
|
Timezone: req.Timezone,
|
|
PasswordHash: passwordHash,
|
|
PasswordSalt: passwordSalt,
|
|
CreatedAt: now,
|
|
UpdatedAt: now,
|
|
}
|
|
|
|
// Build the insert SQL statement.
|
|
query := sqlbuilder.NewInsertBuilder()
|
|
query.InsertInto(userTableName)
|
|
query.Cols("id", "first_name", "last_name", "email", "password_hash", "password_salt", "timezone", "created_at", "updated_at")
|
|
query.Values(u.ID, u.FirstName, u.LastName, u.Email, u.PasswordHash, u.PasswordSalt, u.Timezone, u.CreatedAt, u.UpdatedAt)
|
|
|
|
// Execute the query with the provided context.
|
|
sql, args := query.Build()
|
|
sql = repo.DbConn.Rebind(sql)
|
|
_, err = repo.DbConn.ExecContext(ctx, sql, args...)
|
|
if err != nil {
|
|
err = errors.Wrapf(err, "query - %s", query.String())
|
|
err = errors.WithMessage(err, "create user failed")
|
|
return nil, err
|
|
}
|
|
|
|
return &u, nil
|
|
}
|
|
|
|
// Create invite inserts a new user into the database.
|
|
func (repo *Repository) CreateInvite(ctx context.Context, claims auth.Claims, req UserCreateInviteRequest, now time.Time) (*User, error) {
|
|
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.CreateInvite")
|
|
defer span.Finish()
|
|
|
|
v := webcontext.Validator()
|
|
|
|
// Validation email address is unique in the database.
|
|
uniq, err := UniqueEmail(ctx, repo.DbConn, req.Email, "")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
ctx = context.WithValue(ctx, webcontext.KeyTagUnique, uniq)
|
|
|
|
// Validate the request.
|
|
err = v.StructCtx(ctx, req)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// If the request has claims from a specific user, ensure that the user
|
|
// has the correct role for creating a new user.
|
|
if claims.Subject != "" {
|
|
// Users with the role of admin are ony allows to create users.
|
|
if !claims.HasRole(auth.RoleAdmin) {
|
|
err = errors.WithStack(ErrForbidden)
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
// If now empty set it to the current time.
|
|
if now.IsZero() {
|
|
now = time.Now()
|
|
}
|
|
|
|
// Always store the time as UTC.
|
|
now = now.UTC()
|
|
|
|
// Postgres truncates times to milliseconds when storing. We and do the same
|
|
// here so the value we return is consistent with what we store.
|
|
now = now.Truncate(time.Millisecond)
|
|
|
|
u := User{
|
|
ID: uuid.NewRandom().String(),
|
|
Email: req.Email,
|
|
CreatedAt: now,
|
|
UpdatedAt: now,
|
|
}
|
|
|
|
// Build the insert SQL statement.
|
|
query := sqlbuilder.NewInsertBuilder()
|
|
query.InsertInto(userTableName)
|
|
query.Cols("id", "email", "password_hash", "password_salt", "created_at", "updated_at")
|
|
query.Values(u.ID, u.Email, "", "", u.CreatedAt, u.UpdatedAt)
|
|
|
|
// Execute the query with the provided context.
|
|
sql, args := query.Build()
|
|
sql = repo.DbConn.Rebind(sql)
|
|
_, err = repo.DbConn.ExecContext(ctx, sql, args...)
|
|
if err != nil {
|
|
err = errors.Wrapf(err, "query - %s", query.String())
|
|
err = errors.WithMessage(err, "create user failed")
|
|
return nil, err
|
|
}
|
|
|
|
return &u, nil
|
|
}
|
|
|
|
// ReadByID gets the specified user by ID from the database.
|
|
func (repo *Repository) ReadByID(ctx context.Context, claims auth.Claims, id string) (*User, error) {
|
|
return repo.Read(ctx, claims, UserReadRequest{
|
|
ID: id,
|
|
IncludeArchived: false,
|
|
})
|
|
}
|
|
|
|
// Read gets the specified user from the database.
|
|
func (repo *Repository) Read(ctx context.Context, claims auth.Claims, req UserReadRequest) (*User, error) {
|
|
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.Read")
|
|
defer span.Finish()
|
|
|
|
// Validate the request.
|
|
v := webcontext.Validator()
|
|
err := v.Struct(req)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Filter base select query by ID
|
|
query := selectQuery()
|
|
query.Where(query.Equal("id", req.ID))
|
|
|
|
res, err := find(ctx, claims, repo.DbConn, query, []interface{}{}, req.IncludeArchived)
|
|
if err != nil {
|
|
return nil, err
|
|
} else if res == nil || len(res) == 0 {
|
|
err = errors.WithMessagef(ErrNotFound, "user %s not found", req.ID)
|
|
return nil, err
|
|
}
|
|
u := res[0]
|
|
|
|
return u, nil
|
|
}
|
|
|
|
// ReadByEmail gets the specified user from the database.
|
|
func (repo *Repository) ReadByEmail(ctx context.Context, claims auth.Claims, email string, includedArchived bool) (*User, error) {
|
|
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.ReadByEmail")
|
|
defer span.Finish()
|
|
|
|
// Filter base select query by ID
|
|
query := selectQuery()
|
|
query.Where(query.Equal("email", email))
|
|
|
|
res, err := find(ctx, claims, repo.DbConn, query, []interface{}{}, includedArchived)
|
|
if err != nil {
|
|
return nil, err
|
|
} else if res == nil || len(res) == 0 {
|
|
err = errors.WithMessagef(ErrNotFound, "user %s not found", email)
|
|
return nil, err
|
|
}
|
|
u := res[0]
|
|
|
|
return u, nil
|
|
}
|
|
|
|
// Update replaces a user in the database.
|
|
func (repo *Repository) Update(ctx context.Context, claims auth.Claims, req UserUpdateRequest, now time.Time) error {
|
|
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.Update")
|
|
defer span.Finish()
|
|
|
|
// Validation email address is unique in the database.
|
|
if req.Email != nil {
|
|
// Validation email address is unique in the database.
|
|
uniq, err := UniqueEmail(ctx, repo.DbConn, *req.Email, req.ID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
ctx = context.WithValue(ctx, webcontext.KeyTagUnique, uniq)
|
|
} else {
|
|
ctx = context.WithValue(ctx, webcontext.KeyTagUnique, true)
|
|
}
|
|
|
|
// Validate the request.
|
|
v := webcontext.Validator()
|
|
err := v.StructCtx(ctx, req)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Ensure the claims can modify the user specified in the request.
|
|
err = repo.CanModifyUser(ctx, claims, req.ID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// If now empty set it to the current time.
|
|
if now.IsZero() {
|
|
now = time.Now()
|
|
}
|
|
|
|
// Always store the time as UTC.
|
|
now = now.UTC()
|
|
|
|
// Postgres truncates times to milliseconds when storing. We and do the same
|
|
// here so the value we return is consistent with what we store.
|
|
now = now.Truncate(time.Millisecond)
|
|
|
|
// Build the update SQL statement.
|
|
query := sqlbuilder.NewUpdateBuilder()
|
|
query.Update(userTableName)
|
|
|
|
var fields []string
|
|
if req.FirstName != nil {
|
|
fields = append(fields, query.Assign("first_name", req.FirstName))
|
|
}
|
|
if req.LastName != nil {
|
|
fields = append(fields, query.Assign("last_name", req.LastName))
|
|
}
|
|
if req.Email != nil {
|
|
fields = append(fields, query.Assign("email", req.Email))
|
|
}
|
|
if req.Timezone != nil && *req.Timezone != "" {
|
|
fields = append(fields, query.Assign("timezone", *req.Timezone))
|
|
}
|
|
|
|
// If there's nothing to update we can quit early.
|
|
if len(fields) == 0 {
|
|
return nil
|
|
}
|
|
|
|
// Append the updated_at field
|
|
fields = append(fields, query.Assign("updated_at", now))
|
|
|
|
query.Set(fields...)
|
|
query.Where(query.Equal("id", req.ID))
|
|
|
|
// Execute the query with the provided context.
|
|
sql, args := query.Build()
|
|
sql = repo.DbConn.Rebind(sql)
|
|
_, err = repo.DbConn.ExecContext(ctx, sql, args...)
|
|
if err != nil {
|
|
err = errors.Wrapf(err, "query - %s", query.String())
|
|
err = errors.WithMessagef(err, "update user %s failed", req.ID)
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Update changes the password for a user in the database.
|
|
func (repo *Repository) UpdatePassword(ctx context.Context, claims auth.Claims, req UserUpdatePasswordRequest, now time.Time) error {
|
|
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.UpdatePassword")
|
|
defer span.Finish()
|
|
|
|
// Validate the request.
|
|
v := webcontext.Validator()
|
|
err := v.Struct(req)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Ensure the claims can modify the user specified in the request.
|
|
err = repo.CanModifyUser(ctx, claims, req.ID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// If now empty set it to the current time.
|
|
if now.IsZero() {
|
|
now = time.Now()
|
|
}
|
|
|
|
// Always store the time as UTC.
|
|
now = now.UTC()
|
|
|
|
// Postgres truncates times to milliseconds when storing. We and do the same
|
|
// here so the value we return is consistent with what we store.
|
|
now = now.Truncate(time.Millisecond)
|
|
|
|
// Generate new password hash for the provided password.
|
|
passwordSalt := uuid.NewRandom()
|
|
saltedPassword := req.Password + passwordSalt.String()
|
|
passwordHash, err := bcrypt.GenerateFromPassword([]byte(saltedPassword), bcrypt.DefaultCost)
|
|
if err != nil {
|
|
return errors.Wrap(err, "generating password hash")
|
|
}
|
|
|
|
// Build the update SQL statement.
|
|
query := sqlbuilder.NewUpdateBuilder()
|
|
query.Update(userTableName)
|
|
query.Set(
|
|
query.Assign("password_hash", passwordHash),
|
|
query.Assign("password_salt", passwordSalt),
|
|
query.Assign("updated_at", now),
|
|
)
|
|
query.Where(query.Equal("id", req.ID))
|
|
|
|
// Execute the query with the provided context.
|
|
sql, args := query.Build()
|
|
sql = repo.DbConn.Rebind(sql)
|
|
_, err = repo.DbConn.ExecContext(ctx, sql, args...)
|
|
if err != nil {
|
|
err = errors.Wrapf(err, "query - %s", query.String())
|
|
err = errors.WithMessagef(err, "update password for user %s failed", req.ID)
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Archive soft deleted the user from the database.
|
|
func (repo *Repository) Archive(ctx context.Context, claims auth.Claims, req UserArchiveRequest, now time.Time) error {
|
|
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.Archive")
|
|
defer span.Finish()
|
|
|
|
// Validate the request.
|
|
v := webcontext.Validator()
|
|
err := v.Struct(req)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Ensure the claims can modify the user specified in the request.
|
|
err = repo.CanModifyUser(ctx, claims, req.ID)
|
|
if err != nil {
|
|
return err
|
|
} else if claims.Subject != "" && claims.Subject == req.ID && !req.force {
|
|
return errors.WithStack(ErrForbidden)
|
|
}
|
|
|
|
// If now empty set it to the current time.
|
|
if now.IsZero() {
|
|
now = time.Now()
|
|
}
|
|
|
|
// Always store the time as UTC.
|
|
now = now.UTC()
|
|
|
|
// Postgres truncates times to milliseconds when storing. We and do the same
|
|
// here so the value we return is consistent with what we store.
|
|
now = now.Truncate(time.Millisecond)
|
|
|
|
// Build the update SQL statement.
|
|
query := sqlbuilder.NewUpdateBuilder()
|
|
query.Update(userTableName)
|
|
query.Set(
|
|
query.Assign("archived_at", now),
|
|
)
|
|
query.Where(query.Equal("id", req.ID))
|
|
|
|
// Execute the query with the provided context.
|
|
sql, args := query.Build()
|
|
sql = repo.DbConn.Rebind(sql)
|
|
_, err = repo.DbConn.ExecContext(ctx, sql, args...)
|
|
if err != nil {
|
|
err = errors.Wrapf(err, "query - %s", query.String())
|
|
err = errors.WithMessagef(err, "archive user %s failed", req.ID)
|
|
return err
|
|
}
|
|
|
|
// Archive all the associated user accounts
|
|
{
|
|
// Build the update SQL statement.
|
|
query := sqlbuilder.NewUpdateBuilder()
|
|
query.Update(userAccountTableName)
|
|
query.Set(query.Assign("archived_at", now))
|
|
query.Where(query.And(
|
|
query.Equal("user_id", req.ID),
|
|
))
|
|
|
|
// Execute the query with the provided context.
|
|
sql, args := query.Build()
|
|
sql = repo.DbConn.Rebind(sql)
|
|
_, err = repo.DbConn.ExecContext(ctx, sql, args...)
|
|
if err != nil {
|
|
err = errors.Wrapf(err, "query - %s", query.String())
|
|
err = errors.WithMessagef(err, "archive accounts for user %s failed", req.ID)
|
|
return err
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Restore undeletes the user from the database.
|
|
func (repo *Repository) Restore(ctx context.Context, claims auth.Claims, req UserRestoreRequest, now time.Time) error {
|
|
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.Restore")
|
|
defer span.Finish()
|
|
|
|
// Validate the request.
|
|
v := webcontext.Validator()
|
|
err := v.Struct(req)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Ensure the claims can modify the user specified in the request.
|
|
err = repo.CanModifyUser(ctx, claims, req.ID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// If now empty set it to the current time.
|
|
if now.IsZero() {
|
|
now = time.Now()
|
|
}
|
|
|
|
// Always store the time as UTC.
|
|
now = now.UTC()
|
|
|
|
// Postgres truncates times to milliseconds when storing. We and do the same
|
|
// here so the value we return is consistent with what we store.
|
|
now = now.Truncate(time.Millisecond)
|
|
|
|
// Build the update SQL statement.
|
|
query := sqlbuilder.NewUpdateBuilder()
|
|
query.Update(userTableName)
|
|
query.Set(
|
|
query.Assign("archived_at", nil),
|
|
)
|
|
query.Where(query.Equal("id", req.ID))
|
|
|
|
// Execute the query with the provided context.
|
|
sql, args := query.Build()
|
|
sql = repo.DbConn.Rebind(sql)
|
|
_, err = repo.DbConn.ExecContext(ctx, sql, args...)
|
|
if err != nil {
|
|
err = errors.Wrapf(err, "query - %s", query.String())
|
|
err = errors.WithMessagef(err, "unarchive user %s failed", req.ID)
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Delete removes a user from the database.
|
|
func (repo *Repository) Delete(ctx context.Context, claims auth.Claims, req UserDeleteRequest) error {
|
|
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.Delete")
|
|
defer span.Finish()
|
|
|
|
// Validate the request.
|
|
v := webcontext.Validator()
|
|
err := v.Struct(req)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Ensure the claims can modify the user specified in the request.
|
|
err = repo.CanModifyUser(ctx, claims, req.ID)
|
|
if err != nil {
|
|
return err
|
|
} else if claims.Subject != "" && claims.Subject == req.ID && !req.force {
|
|
return errors.WithStack(ErrForbidden)
|
|
}
|
|
|
|
// Start a new transaction to handle rollbacks on error.
|
|
tx, err := repo.DbConn.Begin()
|
|
if err != nil {
|
|
return errors.WithStack(err)
|
|
}
|
|
|
|
// Delete all the associated user accounts.
|
|
// Required to execute first to avoid foreign key constraints.
|
|
{
|
|
// Build the delete SQL statement.
|
|
query := sqlbuilder.NewDeleteBuilder()
|
|
query.DeleteFrom(userAccountTableName)
|
|
query.Where(query.And(
|
|
query.Equal("user_id", req.ID),
|
|
))
|
|
|
|
// Execute the query with the provided context.
|
|
sql, args := query.Build()
|
|
sql = repo.DbConn.Rebind(sql)
|
|
_, err = tx.ExecContext(ctx, sql, args...)
|
|
if err != nil {
|
|
tx.Rollback()
|
|
|
|
err = errors.Wrapf(err, "query - %s", query.String())
|
|
err = errors.WithMessagef(err, "delete accounts for user %s failed", req.ID)
|
|
return err
|
|
}
|
|
}
|
|
|
|
// Build the delete SQL statement.
|
|
query := sqlbuilder.NewDeleteBuilder()
|
|
query.DeleteFrom(userTableName)
|
|
query.Where(query.Equal("id", req.ID))
|
|
|
|
// Execute the query with the provided context.
|
|
sql, args := query.Build()
|
|
sql = repo.DbConn.Rebind(sql)
|
|
_, err = tx.ExecContext(ctx, sql, args...)
|
|
if err != nil {
|
|
tx.Rollback()
|
|
|
|
err = errors.Wrapf(err, "query - %s", query.String())
|
|
err = errors.WithMessagef(err, "delete user %s failed", req.ID)
|
|
return err
|
|
}
|
|
|
|
err = tx.Commit()
|
|
if err != nil {
|
|
return errors.WithStack(err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// ResetPassword sends en email to the user to allow them to reset their password.
|
|
func (repo *Repository) ResetPassword(ctx context.Context, req UserResetPasswordRequest, now time.Time) (string, error) {
|
|
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.ResetPassword")
|
|
defer span.Finish()
|
|
|
|
v := webcontext.Validator()
|
|
|
|
// Validate the request.
|
|
err := v.StructCtx(ctx, req)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
// Find user by email address.
|
|
var u *User
|
|
{
|
|
query := selectQuery()
|
|
query.Where(query.Equal("email", req.Email))
|
|
|
|
res, err := find(ctx, auth.Claims{}, repo.DbConn, query, []interface{}{}, false)
|
|
if err != nil {
|
|
return "", err
|
|
} else if res == nil || len(res) == 0 {
|
|
err = errors.WithMessagef(ErrNotFound, "No user found using '%s'.", req.Email)
|
|
return "", err
|
|
}
|
|
u = res[0]
|
|
}
|
|
|
|
// Update the user with a random string used to confirm the password reset.
|
|
resetId := uuid.NewRandom().String()
|
|
{
|
|
// Always store the time as UTC.
|
|
now = now.UTC()
|
|
|
|
// Postgres truncates times to milliseconds when storing. We and do the same
|
|
// here so the value we return is consistent with what we store.
|
|
now = now.Truncate(time.Millisecond)
|
|
|
|
// Build the update SQL statement.
|
|
query := sqlbuilder.NewUpdateBuilder()
|
|
query.Update(userTableName)
|
|
query.Set(
|
|
query.Assign("password_reset", resetId),
|
|
query.Assign("updated_at", now),
|
|
)
|
|
query.Where(query.Equal("id", u.ID))
|
|
|
|
// Execute the query with the provided context.
|
|
sql, args := query.Build()
|
|
sql = repo.DbConn.Rebind(sql)
|
|
_, err = repo.DbConn.ExecContext(ctx, sql, args...)
|
|
if err != nil {
|
|
err = errors.Wrapf(err, "query - %s", query.String())
|
|
err = errors.WithMessagef(err, "Update user %s failed.", u.ID)
|
|
return "", err
|
|
}
|
|
}
|
|
|
|
if req.TTL.Seconds() == 0 {
|
|
req.TTL = time.Minute * 90
|
|
}
|
|
|
|
// Load the current IP makings the request.
|
|
var requestIp string
|
|
if vals, _ := webcontext.ContextValues(ctx); vals != nil {
|
|
requestIp = vals.RequestIP
|
|
}
|
|
|
|
encrypted, err := NewResetHash(ctx, repo.secretKey, resetId, requestIp, req.TTL, now)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
data := map[string]interface{}{
|
|
"Name": u.FirstName,
|
|
"Url": repo.ResetUrl(encrypted),
|
|
"Minutes": req.TTL.Minutes(),
|
|
}
|
|
|
|
err = repo.Notify.Send(ctx, u.Email, "Reset your Password", "user_reset_password", data)
|
|
if err != nil {
|
|
err = errors.WithMessagef(err, "Send password reset email to %s failed.", u.Email)
|
|
return "", err
|
|
}
|
|
|
|
return encrypted, nil
|
|
}
|
|
|
|
// ResetConfirm updates the password for a user using the provided reset password ID.
|
|
func (repo *Repository) ResetConfirm(ctx context.Context, req UserResetConfirmRequest, now time.Time) (*User, error) {
|
|
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.ResetConfirm")
|
|
defer span.Finish()
|
|
|
|
v := webcontext.Validator()
|
|
|
|
// Validate the request.
|
|
err := v.StructCtx(ctx, req)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
hash, err := ParseResetHash(ctx, repo.secretKey, req.ResetHash, now)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Find user by password_reset.
|
|
var u *User
|
|
{
|
|
query := selectQuery()
|
|
query.Where(query.Equal("password_reset", hash.ResetID))
|
|
|
|
res, err := find(ctx, auth.Claims{}, repo.DbConn, query, []interface{}{}, false)
|
|
if err != nil {
|
|
return nil, err
|
|
} else if res == nil || len(res) == 0 {
|
|
err = errors.WithMessage(ErrNotFound, "Invalid password reset.")
|
|
return nil, err
|
|
}
|
|
u = res[0]
|
|
}
|
|
|
|
// Save the new password for the user.
|
|
{
|
|
// Always store the time as UTC.
|
|
now = now.UTC()
|
|
|
|
// Postgres truncates times to milliseconds when storing. We and do the same
|
|
// here so the value we return is consistent with what we store.
|
|
now = now.Truncate(time.Millisecond)
|
|
|
|
// Generate new password hash for the provided password.
|
|
passwordSalt := uuid.NewRandom()
|
|
saltedPassword := req.Password + passwordSalt.String()
|
|
passwordHash, err := bcrypt.GenerateFromPassword([]byte(saltedPassword), bcrypt.DefaultCost)
|
|
if err != nil {
|
|
return nil, errors.Wrap(err, "generating password hash")
|
|
}
|
|
|
|
// Build the update SQL statement.
|
|
query := sqlbuilder.NewUpdateBuilder()
|
|
query.Update(userTableName)
|
|
query.Set(
|
|
query.Assign("password_reset", nil),
|
|
query.Assign("password_hash", passwordHash),
|
|
query.Assign("password_salt", passwordSalt),
|
|
query.Assign("updated_at", now),
|
|
)
|
|
query.Where(query.Equal("id", u.ID))
|
|
|
|
// Execute the query with the provided context.
|
|
sql, args := query.Build()
|
|
sql = repo.DbConn.Rebind(sql)
|
|
_, err = repo.DbConn.ExecContext(ctx, sql, args...)
|
|
if err != nil {
|
|
err = errors.Wrapf(err, "query - %s", query.String())
|
|
err = errors.WithMessagef(err, "update password for user %s failed", u.ID)
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
return u, nil
|
|
}
|
|
|
|
type MockUserResponse struct {
|
|
*User
|
|
Password string
|
|
}
|
|
|
|
// MockUser returns a fake User for testing.
|
|
func MockUser(ctx context.Context, dbConn *sqlx.DB, now time.Time) (*MockUserResponse, error) {
|
|
pass := uuid.NewRandom().String()
|
|
|
|
repo := &Repository{
|
|
DbConn: dbConn,
|
|
}
|
|
|
|
req := UserCreateRequest{
|
|
FirstName: "Lee",
|
|
LastName: "Brown",
|
|
Email: uuid.NewRandom().String() + "@geeksinthewoods.com",
|
|
Password: pass,
|
|
PasswordConfirm: pass,
|
|
}
|
|
u, err := repo.Create(ctx, auth.Claims{}, req, now)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &MockUserResponse{
|
|
User: u,
|
|
Password: pass,
|
|
}, nil
|
|
}
|
|
|
|
func MockRepository(dbConn *sqlx.DB) *Repository {
|
|
// Mock the methods needed to make a password reset.
|
|
resetUrl := func(string) string {
|
|
return ""
|
|
}
|
|
notify := ¬ify.MockEmail{}
|
|
secretKey := "6368616e676520746869732070617373"
|
|
|
|
return NewRepository(dbConn, resetUrl, notify, secretKey)
|
|
}
|