1
0
mirror of https://github.com/raseels-repos/golang-saas-starter-kit.git synced 2025-06-06 23:46:29 +02:00

718 lines
20 KiB
Go
Raw Normal View History

package user
import (
"context"
"database/sql"
"time"
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/auth"
"github.com/huandu/go-sqlbuilder"
2019-05-23 14:32:24 -05:00
"github.com/jmoiron/sqlx"
"github.com/pborman/uuid"
"github.com/pkg/errors"
"golang.org/x/crypto/bcrypt"
2019-05-23 14:32:24 -05:00
"gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer"
"gopkg.in/go-playground/validator.v9"
)
// The database table for User
const usersTableName = "users"
var (
// ErrNotFound abstracts the mgo not found error.
ErrNotFound = errors.New("Entity not found")
// ErrInvalidID occurs when an ID is not in a valid form.
ErrInvalidID = errors.New("ID is not in its proper form")
// ErrAuthenticationFailure occurs when a user attempts to authenticate but
// anything goes wrong.
ErrAuthenticationFailure = errors.New("Authentication failed")
// ErrForbidden occurs when a user tries to do something that is forbidden to them according to our access control policies.
ErrForbidden = errors.New("Attempted action is not allowed")
)
// usersMapColumns is the list of columns needed for mapRowsToUser
var usersMapColumns = "id,name,email,password_salt,password_hash,password_reset,status,timezone,created_at,updated_at,archived_at"
// mapRowsToUser takes the SQL rows and maps it to the UserAccount struct
// with the columns defined by usersMapColumns
func mapRowsToUser(rows *sql.Rows) (*User, error) {
var (
u User
err error
)
err = rows.Scan(&u.ID, &u.Email, &u.PasswordSalt, &u.PasswordHash, &u.PasswordReset, &u.Status, &u.Timezone, &u.CreatedAt, &u.UpdatedAt, &u.ArchivedAt)
if err != nil {
return nil, errors.WithStack(err)
}
return &u, nil
}
// CanReadUserId determines if claims has the authority to access the specified user ID.
func CanReadUserId(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, userID string) error {
// If the request has claims from a specific user, ensure that the user
// has the correct access to the user.
if claims.Subject != "" {
// When the claims Subject - UserId - does not match the requested user, the
// claims audience - AccountId - should have a record.
if claims.Subject != userID {
query := sqlbuilder.NewSelectBuilder().Select("id").From(usersAccountsTableName)
query.Where(query.And(
query.Equal("account_id", claims.Audience),
query.Equal("user_id", userID),
))
sql, args := query.Build()
sql = dbConn.Rebind(sql)
var userAccountId string
err := dbConn.QueryRowContext(ctx, sql, args...).Scan(&userAccountId)
if err != nil {
err = errors.Wrapf(err, "query - %s", query.String())
return err
}
// When there is now userAccount ID returned, then the current user does not have access
// to the specified user.
if userAccountId == "" {
return errors.WithStack(ErrForbidden)
}
}
}
return nil
}
// CanModifyUserId determines if claims has the authority to modify the specified user ID.
func CanModifyUserId(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, userID string) error {
// First check to see if claims can read the user ID
err := CanReadUserId(ctx, claims, dbConn, userID)
if err != nil {
return err
}
// If the request has claims from a specific user, ensure that the user
// has the correct role for updating an existing user.
if claims.Subject != "" {
if claims.Subject == userID {
// All users are allowed to update their own record
} else if claims.HasRole(auth.RoleAdmin) {
// Admin users can update users they have access to.
} else {
return errors.WithStack(ErrForbidden)
}
}
return nil
}
// claimsSql applies a sub-query to the provided query to enforce ACL based on
// the claims provided.
// 1. All role types can access their user ID
// 2. Any user with the same account ID
// 3. No claims, request is internal, no ACL applied
func applyClaimsUserSelect(ctx context.Context, claims auth.Claims, query *sqlbuilder.SelectBuilder) error {
// Claims are empty, don't apply any ACL
if claims.Audience == "" && claims.Subject == "" {
return nil
}
// Build select statement for users_accounts table
subQuery := sqlbuilder.NewSelectBuilder().Select("user_id").From(usersAccountsTableName)
var or []string
if claims.Audience != "" {
or = append(or, subQuery.Equal("account_id", claims.Audience))
}
if claims.Subject != "" {
or = append(or, subQuery.Equal("user_id", claims.Subject))
}
subQuery.Where(or...)
// Append sub query
query.Where(query.In("id", subQuery))
return nil
}
// selectQuery constructs a base select query for User
func selectQuery() *sqlbuilder.SelectBuilder {
query := sqlbuilder.NewSelectBuilder()
query.Select(usersMapColumns)
query.From(usersTableName)
return query
}
// userFindRequestQuery generates the select query for the given find request.
func userFindRequestQuery(req UserFindRequest) *sqlbuilder.SelectBuilder {
query := selectQuery()
if req.Where != nil {
query.Where(*req.Where)
}
if len(req.Order) > 0 {
query.OrderBy(req.Order...)
}
if req.Limit != nil {
query.Limit(int(*req.Limit))
}
if req.Offset != nil {
query.Offset(int(*req.Offset))
}
b := sqlbuilder.Buildf(query.String(), req.Args...)
query.BuilderAs(b, usersTableName)
return query
}
// List enables streaming retrieval of Users from the database. The query results
// will be written to the interface{} resultReceiver channel enabling processing the results while
// they're still being fetched. After all pages have been processed the channel is closed
// Possible types sent to the channel are limited to:
// - error
// - User
//
// rr := make(chan interface{})
//
// go List(ctx, claims, db, rr)
//
// for r := range rr {
// switch v := r.(type) {
// case User:
// // v is of type User
// // process the user here
// case error:
// // v is of type error
// // handle the error here
// }
// }
func List(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserFindRequest, results chan<- interface{}) {
query := userFindRequestQuery(req)
list(ctx, claims, dbConn, query, req.IncludedArchived, results)
}
// List enables streaming retrieval of Users from the database for the supplied query.
func list(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, query *sqlbuilder.SelectBuilder, includedArchived bool, results chan<- interface{}) {
2019-05-23 14:32:24 -05:00
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.List")
defer span.Finish()
// Close the channel on complete
defer close(results)
query.Select(usersMapColumns)
query.From(usersTableName)
if !includedArchived {
query.Where(query.IsNull("archived_at"))
}
// Check to see if a sub query needs to be applied for the claims
err := applyClaimsUserSelect(ctx, claims, query)
if err != nil {
results <- err
return
}
sql, args := query.Build()
sql = dbConn.Rebind(sql)
// fetch all places from the db
rows, err := dbConn.QueryContext(ctx, sql, args...)
if err != nil {
err = errors.Wrapf(err, "query - %s", query.String())
results <- errors.WithMessage(err, "list users failed")
return
}
// iterate over each row
for rows.Next() {
u, err := mapRowsToUser(rows)
if err != nil {
err = errors.Wrapf(err, "query - %s", query.String())
results <- err
return
}
results <- u
}
}
// Find gets all the users from the database based on the request params
func Find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserFindRequest) ([]*User, error) {
query := userFindRequestQuery(req)
return find(ctx, claims, dbConn, query, req.IncludedArchived)
}
// find gets all the users from the database based on the query
func find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, query *sqlbuilder.SelectBuilder, includedArchived bool) ([]*User, error) {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.Find")
2019-05-23 14:32:24 -05:00
defer span.Finish()
query.Select(usersMapColumns)
query.From(usersTableName)
if !includedArchived {
query.Where(query.IsNull("archived_at"))
}
// Check to see if a sub query needs to be applied for the claims
err := applyClaimsUserSelect(ctx, claims, query)
if err != nil {
return nil, err
}
sql, args := query.Build()
sql = dbConn.Rebind(sql)
// fetch all places from the db
rows, err := dbConn.QueryContext(ctx, sql, args...)
if err != nil {
err = errors.Wrapf(err, "query - %s", query.String())
err = errors.WithMessage(err, "find users failed")
return nil, err
}
// iterate over each row
resp := []*User{}
for rows.Next() {
u, err := mapRowsToUser(rows)
if err != nil {
err = errors.Wrapf(err, "query - %s", query.String())
return nil, err
}
resp = append(resp, u)
}
return resp, nil
}
// Retrieve gets the specified user from the database.
func FindById(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, id string, includedArchived bool) (*User, error) {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.FindById")
defer span.Finish()
// Filter base select query by ID
query := selectQuery()
query.Where(query.Equal("id", id))
res, err := find(ctx, claims, dbConn, query, includedArchived)
if err != nil {
return nil, err
} else if res == nil || len(res) == 0 {
err = errors.WithMessagef(ErrNotFound, "user %s not found", id)
return nil, err
}
u := res[0]
return u, nil
}
// Validation an email address is unique excluding the current user ID.
func uniqueEmail(ctx context.Context, dbConn *sqlx.DB, email, userId string) (bool, error) {
query := sqlbuilder.NewSelectBuilder().Select("id").From(usersTableName)
query.Where(query.And(
query.Equal("email", email),
query.NotEqual("id", userId),
))
queryStr, args := query.Build()
queryStr = dbConn.Rebind(queryStr)
var existingId string
err := dbConn.QueryRowContext(ctx, queryStr, args...).Scan(&existingId)
if err != nil && err != sql.ErrNoRows {
err = errors.Wrapf(err, "query - %s", query.String())
return false, err
}
// When an ID was found in the db, the email is not unique.
if existingId != "" {
return false, nil
}
return true, nil
}
// Create inserts a new user into the database.
func Create(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req CreateUserRequest, now time.Time) (*User, error) {
2019-05-23 14:32:24 -05:00
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.Create")
defer span.Finish()
v := validator.New()
// Validation email address is unique in the database.
uniq, err := uniqueEmail(ctx, dbConn, req.Email, "")
if err != nil {
return nil, err
}
f := func(fl validator.FieldLevel) bool {
if fl.Field().String() == "invalid" {
return false
}
return uniq
}
v.RegisterValidation("unique", f)
// Validate the request.
err = v.Struct(req)
if err != nil {
return nil, err
}
// If the request has claims from a specific user, ensure that the user
// has the correct role for creating a new user.
if claims.Subject != "" {
// Users with the role of admin are ony allows to create users.
if !claims.HasRole(auth.RoleAdmin) {
err = errors.WithStack(ErrForbidden)
return nil, err
}
}
// If now empty set it to the current time.
if now.IsZero() {
now = time.Now()
}
// Always store the time as UTC.
now = now.UTC()
// Postgres truncates times to milliseconds when storing. We and do the same
// here so the value we return is consistent with what we store.
now = now.Truncate(time.Millisecond)
passwordSalt := uuid.NewRandom().String()
saltedPassword := req.Password + passwordSalt
passwordHash, err := bcrypt.GenerateFromPassword([]byte(saltedPassword), bcrypt.DefaultCost)
if err != nil {
return nil, errors.Wrap(err, "generating password hash")
}
u := User{
ID: uuid.NewRandom().String(),
Name: req.Name,
Email: req.Email,
PasswordHash: passwordHash,
PasswordSalt: passwordSalt,
Status: UserStatus_Active,
Timezone: "America/Anchorage",
CreatedAt: now,
UpdatedAt: now,
}
if req.Status != nil {
u.Status = *req.Status
}
if req.Timezone != nil {
u.Timezone = *req.Timezone
}
// Build the insert SQL statement.
query := sqlbuilder.NewInsertBuilder()
query.InsertInto(usersTableName)
query.Cols("id", "name", "email", "password_hash", "password_salt", "status", "timezone", "created_at", "updated_at")
query.Values(u.ID, u.Name, u.Email, u.PasswordHash, u.PasswordSalt, u.Status.String(), u.Timezone, u.CreatedAt, u.UpdatedAt)
// Execute the query with the provided context.
sql, args := query.Build()
sql = dbConn.Rebind(sql)
_, err = dbConn.ExecContext(ctx, sql, args...)
if err != nil {
err = errors.Wrapf(err, "query - %s", query.String())
err = errors.WithMessage(err, "create user failed")
return nil, err
}
return &u, nil
}
// Update replaces a user in the database.
func Update(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UpdateUserRequest, now time.Time) error {
2019-05-23 14:32:24 -05:00
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.Update")
defer span.Finish()
v := validator.New()
// Validation email address is unique in the database.
if req.Email != nil {
uniq, err := uniqueEmail(ctx, dbConn, *req.Email, req.ID)
if err != nil {
return err
}
f := func(fl validator.FieldLevel) bool {
if fl.Field().String() == "invalid" {
return false
}
return uniq
}
v.RegisterValidation("unique", f)
}
// Validate the request.
err := v.Struct(req)
if err != nil {
return err
}
// Ensure the claims can modify the user specified in the request.
err = CanModifyUserId(ctx, claims, dbConn, req.ID)
if err != nil {
err = errors.WithMessagef(err, "Update %s failed", usersTableName)
return err
}
// If now empty set it to the current time.
if now.IsZero() {
now = time.Now()
}
// Always store the time as UTC.
now = now.UTC()
// Postgres truncates times to milliseconds when storing. We and do the same
// here so the value we return is consistent with what we store.
now = now.Truncate(time.Millisecond)
// Build the update SQL statement.
query := sqlbuilder.NewUpdateBuilder()
query.Update(usersTableName)
fields := []string{}
if req.Name != nil {
fields = append(fields, query.Assign("name", req.Name))
}
if req.Email != nil {
fields = append(fields, query.Assign("email", req.Email))
}
if req.Status != nil {
fields = append(fields, query.Assign("status", req.Status))
}
if req.Timezone != nil {
fields = append(fields, query.Assign("timezone", req.Timezone))
}
// If there's nothing to update we can quit early.
if len(fields) == 0 {
return nil
}
// Append the updated_at field
fields = append(fields, query.Assign("updated_at", now))
query.Set(fields...)
query.Where(query.Equal("id", req.ID))
// Execute the query with the provided context.
sql, args := query.Build()
sql = dbConn.Rebind(sql)
_, err = dbConn.ExecContext(ctx, sql, args...)
if err != nil {
err = errors.Wrapf(err, "query - %s", query.String())
err = errors.WithMessagef(err, "update user %s failed", req.ID)
return err
}
return nil
}
// Update replaces a user in the database.
func UpdatePassword(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UpdatePasswordRequest, now time.Time) error {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.Update")
2019-05-23 14:32:24 -05:00
defer span.Finish()
// Validate the request.
err := validator.New().Struct(req)
if err != nil {
return err
}
// Ensure the claims can modify the user specified in the request.
err = CanModifyUserId(ctx, claims, dbConn, req.ID)
if err != nil {
return err
}
// If now empty set it to the current time.
if now.IsZero() {
now = time.Now()
}
// Always store the time as UTC.
now = now.UTC()
// Postgres truncates times to milliseconds when storing. We and do the same
// here so the value we return is consistent with what we store.
now = now.Truncate(time.Millisecond)
// Generate new password hash for the provided password.
passwordSalt := uuid.NewRandom()
saltedPassword := req.Password + passwordSalt.String()
passwordHash, err := bcrypt.GenerateFromPassword([]byte(saltedPassword), bcrypt.DefaultCost)
if err != nil {
return errors.Wrap(err, "generating password hash")
}
// Build the update SQL statement.
query := sqlbuilder.NewUpdateBuilder()
query.Update(usersTableName)
query.Set(
query.Assign("password_hash", passwordHash),
query.Assign("password_salt", passwordSalt),
query.Assign("updated_at", now),
)
query.Where(query.Equal("id", req.ID))
// Execute the query with the provided context.
sql, args := query.Build()
sql = dbConn.Rebind(sql)
_, err = dbConn.ExecContext(ctx, sql, args...)
if err != nil {
err = errors.Wrapf(err, "query - %s", query.String())
err = errors.WithMessagef(err, "update password for user %s failed", req.ID)
return err
}
return nil
}
// Archive soft deleted the user from the database.
func Archive(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, userID string, now time.Time) error {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.Archive")
2019-05-23 14:32:24 -05:00
defer span.Finish()
// Defines the struct to apply validation
req := struct {
ID string `validate:"required,uuid"`
}{
ID: userID,
}
// Validate the request.
err := validator.New().Struct(req)
if err != nil {
return err
}
// Ensure the claims can modify the user specified in the request.
err = CanModifyUserId(ctx, claims, dbConn, req.ID)
if err != nil {
return err
}
// If now empty set it to the current time.
if now.IsZero() {
now = time.Now()
}
// Always store the time as UTC.
now = now.UTC()
// Postgres truncates times to milliseconds when storing. We and do the same
// here so the value we return is consistent with what we store.
now = now.Truncate(time.Millisecond)
// Build the update SQL statement.
query := sqlbuilder.NewUpdateBuilder()
query.Update(usersTableName)
query.Set(
query.Assign("archived_at", now),
)
query.Where(query.Equal("id", req.ID))
// Execute the query with the provided context.
sql, args := query.Build()
sql = dbConn.Rebind(sql)
_, err = dbConn.ExecContext(ctx, sql, args...)
if err != nil {
err = errors.Wrapf(err, "query - %s", query.String())
err = errors.WithMessagef(err, "archive user %s failed", req.ID)
return err
}
// Archive all the associated user accounts
{
// Build the update SQL statement.
query := sqlbuilder.NewUpdateBuilder()
query.Update(usersAccountsTableName)
query.Set(query.Assign("archived_at", now))
query.Where(query.And(
query.Equal("user_id", req.ID),
))
// Execute the query with the provided context.
sql, args := query.Build()
sql = dbConn.Rebind(sql)
_, err = dbConn.ExecContext(ctx, sql, args...)
if err != nil {
err = errors.Wrapf(err, "query - %s", query.String())
err = errors.WithMessagef(err, "archive accounts for user %s failed", req.ID)
return err
}
}
return nil
}
// Delete removes a user from the database.
func Delete(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, userID string) error {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.Delete")
defer span.Finish()
// Defines the struct to apply validation
req := struct {
ID string `validate:"required,uuid"`
}{
ID: userID,
}
// Validate the request.
err := validator.New().Struct(req)
if err != nil {
return err
}
// Ensure the claims can modify the user specified in the request.
err = CanModifyUserId(ctx, claims, dbConn, req.ID)
if err != nil {
return err
}
// Build the delete SQL statement.
query := sqlbuilder.NewDeleteBuilder()
query.DeleteFrom(usersTableName)
query.Where(query.Equal("id", req.ID))
// Execute the query with the provided context.
sql, args := query.Build()
sql = dbConn.Rebind(sql)
_, err = dbConn.ExecContext(ctx, sql, args...)
if err != nil {
err = errors.Wrapf(err, "query - %s", query.String())
err = errors.WithMessagef(err, "delete user %s failed", req.ID)
return err
}
// Delete all the associated user accounts
{
// Build the delete SQL statement.
query := sqlbuilder.NewDeleteBuilder()
query.DeleteFrom(usersAccountsTableName)
query.Where(query.And(
query.Equal("user_id", req.ID),
))
// Execute the query with the provided context.
sql, args := query.Build()
sql = dbConn.Rebind(sql)
_, err = dbConn.ExecContext(ctx, sql, args...)
if err != nil {
err = errors.Wrapf(err, "query - %s", query.String())
err = errors.WithMessagef(err, "delete accounts for user %s failed", req.ID)
return err
}
}
return nil
}