1
0
mirror of https://github.com/raseels-repos/golang-saas-starter-kit.git synced 2025-06-17 00:17:59 +02:00

Splitting our users and accounts, creating new user_account package

This commit is contained in:
Lee Brown
2019-06-22 10:05:03 -08:00
parent d1ceb71265
commit abccc1f658
11 changed files with 222 additions and 1803 deletions

View File

@ -12,7 +12,7 @@ This project should not be considered a web framework. It is a starter toolkit t
There are five areas of expertise that an engineer or her engineering team must do for a project to grow and scale. Based on our experience, a few core decisions were made for each of these areas that help you focus initially on writing the business logic. There are five areas of expertise that an engineer or her engineering team must do for a project to grow and scale. Based on our experience, a few core decisions were made for each of these areas that help you focus initially on writing the business logic.
1. Micro level - Since SaaS requires transactions, project implements Postgres. Implementation facilitates the data semantics that define the data being captured and their relationships. 1. Micro level - Since SaaS requires transactions, project implements Postgres. Implementation facilitates the data semantics that define the data being captured and their relationships.
2. Macro level - Uses POD architecture and design that provides the project foundation. 2. Macro level - The project architecture and design, provides basic project structure and foundation for development.
3. Business logic - Defines an example Golang package that helps illustrate where value generating activities should reside and how the code will be delivered to clients. 3. Business logic - Defines an example Golang package that helps illustrate where value generating activities should reside and how the code will be delivered to clients.
4. Deployment and Operations - Integrates with GitLab for CI/CD and AWS for serverless deployments with AWS Fargate. 4. Deployment and Operations - Integrates with GitLab for CI/CD and AWS for serverless deployments with AWS Fargate.
5. Observability - Implements Datadog to facilitate exposing metrics, logs and request tracing that ensure stable and responsive service for clients. 5. Observability - Implements Datadog to facilitate exposing metrics, logs and request tracing that ensure stable and responsive service for clients.

View File

@ -1,35 +0,0 @@
# This is the official list of people who can contribute
# (and typically have contributed) code to the gotraining repository.
#
# Names should be added to this file only after verifying that
# the individual or the individual's organization has agreed to
# the appropriate Contributor License Agreement, found here:
#
# http://code.google.com/legal/individual-cla-v1.0.html
# http://code.google.com/legal/corporate-cla-v1.0.html
#
# The agreement for individuals can be filled out on the web.
# Names should be added to this file like so:
# Name <email address>
#
# An entry with two email addresses specifies that the
# first address should be used in the submit logs and
# that the second address should be recognized as the
# same person when interacting with Rietveld.
# Please keep the list sorted.
Arash Bina <arash@arash.io>
Askar Sagyndyk <superwhykz@gmail.com>
Bob Cao <3308031+bobintornado@users.noreply.github.com>
Ed Gonzo <Ed@ardanstudios.com>
Farrukh Kurbanov <farrukhkurbanov@Administrators-MacBook-Pro.local>
Jacob Walker <jacob@ardanlabs.com>
Jeremy Stone <slycrel@gmail.com>
Nick Stogner <nstogner@users.noreply.github.com>
William Kennedy <bill@ardanlabs.com>
Wyatt Johnson <wyattjoh@gmail.com>
Zachary Johnson <zachjohnsondev@gmail.com>
Lee Brown <lee@geeksinthewoods.com>
Lucas Brown <lucas@geeksinthewoods.com>

View File

@ -0,0 +1,127 @@
package account
import (
"database/sql"
"database/sql/driver"
"time"
"github.com/lib/pq"
"gopkg.in/go-playground/validator.v9"
"github.com/pkg/errors"
)
// Account represents someone with access to our system.
type Account struct {
ID string `json:"id"`
Name string `json:"name"`
Address1 string `json:"address1"`
Address2 string `json:"address2"`
City string `json:"city"`
Region string `json:"region"`
Country string `json:"country"`
Zipcode string `json:"zipcode"`
Status AccountStatus `json:"status"`
Timezone string `json:"timezone"`
SignupUserID sql.NullString `json:"signup_user_id"`
BillingUserID sql.NullString `json:"billing_user_id"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
ArchivedAt pq.NullTime `json:"archived_at"`
}
// CreateAccountRequest contains information needed to create a new Account.
type CreateAccountRequest struct {
Name string `json:"name" validate:"required,unique"`
Address1 string `json:"address1" validate:"required"`
Address2 string `json:"address2" validate:"omitempty"`
City string `json:"city" validate:"required"`
Region string `json:"region" validate:"required"`
Country string `json:"country" validate:"required"`
Zipcode string `json:"zipcode" validate:"required"`
Status *AccountStatus `json:"status" validate:"omitempty,oneof=active pending disabled"`
Timezone *string `json:"timezone" validate:"omitempty"`
SignupUserID *string `json:"signup_user_id" validate:"omitempty,uuid"`
BillingUserID *string `json:"billing_user_id" validate:"omitempty,uuid"`
}
// UpdateAccountRequest defines what information may be provided to modify an existing
// Account. All fields are optional so clients can send just the fields they want
// changed. It uses pointer fields so we can differentiate between a field that
// was not provided and a field that was provided as explicitly blank. Normally
// we do not want to use pointers to basic types but we make exceptions around
// marshalling/unmarshalling.
type UpdateAccountRequest struct {
ID string `validate:"required,uuid"`
Name *string `json:"name" validate:"omitempty,unique"`
Address1 *string `json:"address1" validate:"omitempty"`
Address2 *string `json:"address2" validate:"omitempty"`
City *string `json:"city" validate:"omitempty"`
Region *string `json:"region" validate:"omitempty"`
Country *string `json:"country" validate:"omitempty"`
Zipcode *string `json:"zipcode" validate:"omitempty"`
Status *AccountStatus `json:"status" validate:"omitempty,oneof=active pending disabled"`
Timezone *string `json:"timezone" validate:"omitempty"`
SignupUserID *string `json:"signup_user_id" validate:"omitempty,uuid"`
BillingUserID *string `json:"billing_user_id" validate:"omitempty,uuid"`
}
// AccountFindRequest defines the possible options to search for accounts. By default
// archived accounts will be excluded from response.
type AccountFindRequest struct {
Where *string
Args []interface{}
Order []string
Limit *uint
Offset *uint
IncludedArchived bool
}
// AccountStatus represents the status of an account.
type AccountStatus string
// AccountStatus values define the status field of a user account.
const (
// AccountStatus_Active defines the state when a user can access an account.
AccountStatus_Active AccountStatus = "active"
// AccountStatus_Pending defined the state when an account was created but
// not activated.
AccountStatus_Pending AccountStatus = "pending"
// AccountStatus_Disabled defines the state when a user has been disabled from
// accessing an account.
AccountStatus_Disabled AccountStatus = "disabled"
)
// AccountStatus_Values provides list of valid AccountStatus values.
var AccountStatus_Values = []AccountStatus{
AccountStatus_Active,
AccountStatus_Pending,
AccountStatus_Disabled,
}
// Scan supports reading the AccountStatus value from the database.
func (s *AccountStatus) Scan(value interface{}) error {
asBytes, ok := value.([]byte)
if !ok {
return errors.New("Scan source is not []byte")
}
*s = AccountStatus(string(asBytes))
return nil
}
// Value converts the AccountStatus value to be stored in the database.
func (s AccountStatus) Value() (driver.Value, error) {
v := validator.New()
errs := v.Var(s, "required,oneof=active invited disabled")
if errs != nil {
return nil, errs
}
return string(s), nil
}
// String converts the AccountStatus value to a string.
func (s AccountStatus) String() string {
return string(s)
}

View File

@ -50,6 +50,7 @@ type Authenticator struct {
parser *jwt.Parser parser *jwt.Parser
} }
// NewAuthenticator creates an *Authenticator for use. // NewAuthenticator creates an *Authenticator for use.
// key expiration is optional to filter out old keys // key expiration is optional to filter out old keys
// It will error if: // It will error if:

View File

@ -1,638 +0,0 @@
package user
import (
"context"
"database/sql"
"time"
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/auth"
"github.com/huandu/go-sqlbuilder"
"github.com/jmoiron/sqlx"
"github.com/pborman/uuid"
"github.com/pkg/errors"
"golang.org/x/crypto/bcrypt"
"gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer"
"gopkg.in/go-playground/validator.v9"
)
// The database table for User
const usersTableName = "users"
var (
// ErrNotFound abstracts the mgo not found error.
ErrNotFound = errors.New("Entity not found")
// ErrInvalidID occurs when an ID is not in a valid form.
ErrInvalidID = errors.New("ID is not in its proper form")
// ErrAuthenticationFailure occurs when a user attempts to authenticate but
// anything goes wrong.
ErrAuthenticationFailure = errors.New("Authentication failed")
// ErrForbidden occurs when a user tries to do something that is forbidden to them according to our access control policies.
ErrForbidden = errors.New("Attempted action is not allowed")
)
// usersMapColumns is the list of columns needed for mapRowsToUser
var usersMapColumns = "id,name,email,password_salt,password_hash,password_reset,timezone,created_at,updated_at,archived_at"
// mapRowsToUser takes the SQL rows and maps it to the UserAccount struct
// with the columns defined by usersMapColumns
func mapRowsToUser(rows *sql.Rows) (*User, error) {
var (
u User
err error
)
err = rows.Scan(&u.ID, &u.Name, &u.Email, &u.PasswordSalt, &u.PasswordHash, &u.PasswordReset, &u.Timezone, &u.CreatedAt, &u.UpdatedAt, &u.ArchivedAt)
if err != nil {
return nil, errors.WithStack(err)
}
return &u, nil
}
// CanReadUser determines if claims has the authority to access the specified user ID.
func CanReadUser(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, userID string) error {
// If the request has claims from a specific user, ensure that the user
// has the correct access to the user.
if claims.Subject != "" {
// When the claims Subject - UserId - does not match the requested user, the
// claims audience - AccountId - should have a record.
if claims.Subject != userID {
query := sqlbuilder.NewSelectBuilder().Select("id").From(usersAccountsTableName)
query.Where(query.Or(
query.Equal("account_id", claims.Audience),
query.Equal("user_id", userID),
))
queryStr, args := query.Build()
queryStr = dbConn.Rebind(queryStr)
var userAccountId string
err := dbConn.QueryRowContext(ctx, queryStr, args...).Scan(&userAccountId)
if err != nil && err != sql.ErrNoRows {
err = errors.Wrapf(err, "query - %s", query.String())
return err
}
// When there is now userAccount ID returned, then the current user does not have access
// to the specified user.
if userAccountId == "" {
return errors.WithStack(ErrForbidden)
}
}
}
return nil
}
// CanModifyUser determines if claims has the authority to modify the specified user ID.
func CanModifyUser(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, userID string) error {
// First check to see if claims can read the user ID
err := CanReadUser(ctx, claims, dbConn, userID)
if err != nil {
return err
}
// If the request has claims from a specific user, ensure that the user
// has the correct role for updating an existing user.
if claims.Subject != "" {
if claims.Subject == userID {
// All users are allowed to update their own record
} else if claims.HasRole(auth.RoleAdmin) {
// Admin users can update users they have access to.
} else {
return errors.WithStack(ErrForbidden)
}
}
return nil
}
// claimsSql applies a sub-query to the provided query to enforce ACL based on
// the claims provided.
// 1. All role types can access their user ID
// 2. Any user with the same account ID
// 3. No claims, request is internal, no ACL applied
func applyClaimsUserSelect(ctx context.Context, claims auth.Claims, query *sqlbuilder.SelectBuilder) error {
// Claims are empty, don't apply any ACL
if claims.Audience == "" && claims.Subject == "" {
return nil
}
// Build select statement for users_accounts table
subQuery := sqlbuilder.NewSelectBuilder().Select("user_id").From(usersAccountsTableName)
var or []string
if claims.Audience != "" {
or = append(or, subQuery.Equal("account_id", claims.Audience))
}
if claims.Subject != "" {
or = append(or, subQuery.Equal("user_id", claims.Subject))
}
subQuery.Where(subQuery.Or(or...))
// Append sub query
query.Where(query.In("id", subQuery))
return nil
}
// selectQuery constructs a base select query for User
func selectQuery() *sqlbuilder.SelectBuilder {
query := sqlbuilder.NewSelectBuilder()
query.Select(usersMapColumns)
query.From(usersTableName)
return query
}
// userFindRequestQuery generates the select query for the given find request.
// TODO: Need to figure out why can't parse the args when appending the where
// to the query.
func userFindRequestQuery(req UserFindRequest) (*sqlbuilder.SelectBuilder, []interface{}) {
query := selectQuery()
if req.Where != nil {
query.Where(query.And(*req.Where))
}
if len(req.Order) > 0 {
query.OrderBy(req.Order...)
}
if req.Limit != nil {
query.Limit(int(*req.Limit))
}
if req.Offset != nil {
query.Offset(int(*req.Offset))
}
return query, req.Args
}
// Find gets all the users from the database based on the request params.
func Find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserFindRequest) ([]*User, error) {
query, args := userFindRequestQuery(req)
return find(ctx, claims, dbConn, query, args, req.IncludedArchived)
}
// find internal method for getting all the users from the database using a select query.
func find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, query *sqlbuilder.SelectBuilder, args []interface{}, includedArchived bool) ([]*User, error) {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.Find")
defer span.Finish()
query.Select(usersMapColumns)
query.From(usersTableName)
if !includedArchived {
query.Where(query.IsNull("archived_at"))
}
// Check to see if a sub query needs to be applied for the claims
err := applyClaimsUserSelect(ctx, claims, query)
if err != nil {
return nil, err
}
queryStr, queryArgs := query.Build()
queryStr = dbConn.Rebind(queryStr)
args = append(args, queryArgs...)
// fetch all places from the db
rows, err := dbConn.QueryContext(ctx, queryStr, args...)
if err != nil {
err = errors.Wrapf(err, "query - %s", query.String())
err = errors.WithMessage(err, "find users failed")
return nil, err
}
// iterate over each row
resp := []*User{}
for rows.Next() {
u, err := mapRowsToUser(rows)
if err != nil {
err = errors.Wrapf(err, "query - %s", query.String())
return nil, err
}
resp = append(resp, u)
}
return resp, nil
}
// Retrieve gets the specified user from the database.
func FindById(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, id string, includedArchived bool) (*User, error) {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.FindById")
defer span.Finish()
// Filter base select query by ID
query := selectQuery()
query.Where(query.Equal("id", id))
res, err := find(ctx, claims, dbConn, query, []interface{}{}, includedArchived)
if err != nil {
return nil, err
} else if res == nil || len(res) == 0 {
err = errors.WithMessagef(ErrNotFound, "user %s not found", id)
return nil, err
}
u := res[0]
return u, nil
}
// Validation an email address is unique excluding the current user ID.
func uniqueEmail(ctx context.Context, dbConn *sqlx.DB, email, userId string) (bool, error) {
query := sqlbuilder.NewSelectBuilder().Select("id").From(usersTableName)
query.Where(query.And(
query.Equal("email", email),
query.NotEqual("id", userId),
))
queryStr, args := query.Build()
queryStr = dbConn.Rebind(queryStr)
var existingId string
err := dbConn.QueryRowContext(ctx, queryStr, args...).Scan(&existingId)
if err != nil && err != sql.ErrNoRows {
err = errors.Wrapf(err, "query - %s", query.String())
return false, err
}
// When an ID was found in the db, the email is not unique.
if existingId != "" {
return false, nil
}
return true, nil
}
// Create inserts a new user into the database.
func Create(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req CreateUserRequest, now time.Time) (*User, error) {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.Create")
defer span.Finish()
v := validator.New()
// Validation email address is unique in the database.
uniq, err := uniqueEmail(ctx, dbConn, req.Email, "")
if err != nil {
return nil, err
}
f := func(fl validator.FieldLevel) bool {
if fl.Field().String() == "invalid" {
return false
}
return uniq
}
v.RegisterValidation("unique", f)
// Validate the request.
err = v.Struct(req)
if err != nil {
return nil, err
}
// If the request has claims from a specific user, ensure that the user
// has the correct role for creating a new user.
if claims.Subject != "" {
// Users with the role of admin are ony allows to create users.
if !claims.HasRole(auth.RoleAdmin) {
err = errors.WithStack(ErrForbidden)
return nil, err
}
}
// If now empty set it to the current time.
if now.IsZero() {
now = time.Now()
}
// Always store the time as UTC.
now = now.UTC()
// Postgres truncates times to milliseconds when storing. We and do the same
// here so the value we return is consistent with what we store.
now = now.Truncate(time.Millisecond)
passwordSalt := uuid.NewRandom().String()
saltedPassword := req.Password + passwordSalt
passwordHash, err := bcrypt.GenerateFromPassword([]byte(saltedPassword), bcrypt.DefaultCost)
if err != nil {
return nil, errors.Wrap(err, "generating password hash")
}
u := User{
ID: uuid.NewRandom().String(),
Name: req.Name,
Email: req.Email,
PasswordHash: passwordHash,
PasswordSalt: passwordSalt,
Timezone: "America/Anchorage",
CreatedAt: now,
UpdatedAt: now,
}
if req.Timezone != nil {
u.Timezone = *req.Timezone
}
// Build the insert SQL statement.
query := sqlbuilder.NewInsertBuilder()
query.InsertInto(usersTableName)
query.Cols("id", "name", "email", "password_hash", "password_salt", "timezone", "created_at", "updated_at")
query.Values(u.ID, u.Name, u.Email, u.PasswordHash, u.PasswordSalt, u.Timezone, u.CreatedAt, u.UpdatedAt)
// Execute the query with the provided context.
sql, args := query.Build()
sql = dbConn.Rebind(sql)
_, err = dbConn.ExecContext(ctx, sql, args...)
if err != nil {
err = errors.Wrapf(err, "query - %s", query.String())
err = errors.WithMessage(err, "create user failed")
return nil, err
}
return &u, nil
}
// Update replaces a user in the database.
func Update(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UpdateUserRequest, now time.Time) error {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.Update")
defer span.Finish()
v := validator.New()
// Validation email address is unique in the database.
if req.Email != nil {
uniq, err := uniqueEmail(ctx, dbConn, *req.Email, req.ID)
if err != nil {
return err
}
f := func(fl validator.FieldLevel) bool {
if fl.Field().String() == "invalid" {
return false
}
return uniq
}
v.RegisterValidation("unique", f)
}
// Validate the request.
err := v.Struct(req)
if err != nil {
return err
}
// Ensure the claims can modify the user specified in the request.
err = CanModifyUser(ctx, claims, dbConn, req.ID)
if err != nil {
err = errors.WithMessagef(err, "Update %s failed", usersTableName)
return err
}
// If now empty set it to the current time.
if now.IsZero() {
now = time.Now()
}
// Always store the time as UTC.
now = now.UTC()
// Postgres truncates times to milliseconds when storing. We and do the same
// here so the value we return is consistent with what we store.
now = now.Truncate(time.Millisecond)
// Build the update SQL statement.
query := sqlbuilder.NewUpdateBuilder()
query.Update(usersTableName)
var fields []string
if req.Name != nil {
fields = append(fields, query.Assign("name", req.Name))
}
if req.Email != nil {
fields = append(fields, query.Assign("email", req.Email))
}
if req.Timezone != nil {
fields = append(fields, query.Assign("timezone", req.Timezone))
}
// If there's nothing to update we can quit early.
if len(fields) == 0 {
return nil
}
// Append the updated_at field
fields = append(fields, query.Assign("updated_at", now))
query.Set(fields...)
query.Where(query.Equal("id", req.ID))
// Execute the query with the provided context.
sql, args := query.Build()
sql = dbConn.Rebind(sql)
_, err = dbConn.ExecContext(ctx, sql, args...)
if err != nil {
err = errors.Wrapf(err, "query - %s", query.String())
err = errors.WithMessagef(err, "update user %s failed", req.ID)
return err
}
return nil
}
// Update replaces a user in the database.
func UpdatePassword(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UpdatePasswordRequest, now time.Time) error {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.Update")
defer span.Finish()
// Validate the request.
err := validator.New().Struct(req)
if err != nil {
return err
}
// Ensure the claims can modify the user specified in the request.
err = CanModifyUser(ctx, claims, dbConn, req.ID)
if err != nil {
return err
}
// If now empty set it to the current time.
if now.IsZero() {
now = time.Now()
}
// Always store the time as UTC.
now = now.UTC()
// Postgres truncates times to milliseconds when storing. We and do the same
// here so the value we return is consistent with what we store.
now = now.Truncate(time.Millisecond)
// Generate new password hash for the provided password.
passwordSalt := uuid.NewRandom()
saltedPassword := req.Password + passwordSalt.String()
passwordHash, err := bcrypt.GenerateFromPassword([]byte(saltedPassword), bcrypt.DefaultCost)
if err != nil {
return errors.Wrap(err, "generating password hash")
}
// Build the update SQL statement.
query := sqlbuilder.NewUpdateBuilder()
query.Update(usersTableName)
query.Set(
query.Assign("password_hash", passwordHash),
query.Assign("password_salt", passwordSalt),
query.Assign("updated_at", now),
)
query.Where(query.Equal("id", req.ID))
// Execute the query with the provided context.
sql, args := query.Build()
sql = dbConn.Rebind(sql)
_, err = dbConn.ExecContext(ctx, sql, args...)
if err != nil {
err = errors.Wrapf(err, "query - %s", query.String())
err = errors.WithMessagef(err, "update password for user %s failed", req.ID)
return err
}
return nil
}
// Archive soft deleted the user from the database.
func Archive(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, userID string, now time.Time) error {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.Archive")
defer span.Finish()
// Defines the struct to apply validation
req := struct {
ID string `validate:"required,uuid"`
}{
ID: userID,
}
// Validate the request.
err := validator.New().Struct(req)
if err != nil {
return err
}
// Ensure the claims can modify the user specified in the request.
err = CanModifyUser(ctx, claims, dbConn, req.ID)
if err != nil {
return err
}
// If now empty set it to the current time.
if now.IsZero() {
now = time.Now()
}
// Always store the time as UTC.
now = now.UTC()
// Postgres truncates times to milliseconds when storing. We and do the same
// here so the value we return is consistent with what we store.
now = now.Truncate(time.Millisecond)
// Build the update SQL statement.
query := sqlbuilder.NewUpdateBuilder()
query.Update(usersTableName)
query.Set(
query.Assign("archived_at", now),
)
query.Where(query.Equal("id", req.ID))
// Execute the query with the provided context.
sql, args := query.Build()
sql = dbConn.Rebind(sql)
_, err = dbConn.ExecContext(ctx, sql, args...)
if err != nil {
err = errors.Wrapf(err, "query - %s", query.String())
err = errors.WithMessagef(err, "archive user %s failed", req.ID)
return err
}
// Archive all the associated user accounts
{
// Build the update SQL statement.
query := sqlbuilder.NewUpdateBuilder()
query.Update(usersAccountsTableName)
query.Set(query.Assign("archived_at", now))
query.Where(query.And(
query.Equal("user_id", req.ID),
))
// Execute the query with the provided context.
sql, args := query.Build()
sql = dbConn.Rebind(sql)
_, err = dbConn.ExecContext(ctx, sql, args...)
if err != nil {
err = errors.Wrapf(err, "query - %s", query.String())
err = errors.WithMessagef(err, "archive accounts for user %s failed", req.ID)
return err
}
}
return nil
}
// Delete removes a user from the database.
func Delete(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, userID string) error {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.Delete")
defer span.Finish()
// Defines the struct to apply validation
req := struct {
ID string `validate:"required,uuid"`
}{
ID: userID,
}
// Validate the request.
err := validator.New().Struct(req)
if err != nil {
return err
}
// Ensure the claims can modify the user specified in the request.
err = CanModifyUser(ctx, claims, dbConn, req.ID)
if err != nil {
return err
}
// Build the delete SQL statement.
query := sqlbuilder.NewDeleteBuilder()
query.DeleteFrom(usersTableName)
query.Where(query.Equal("id", req.ID))
// Execute the query with the provided context.
sql, args := query.Build()
sql = dbConn.Rebind(sql)
_, err = dbConn.ExecContext(ctx, sql, args...)
if err != nil {
err = errors.Wrapf(err, "query - %s", query.String())
err = errors.WithMessagef(err, "delete user %s failed", req.ID)
return err
}
// Delete all the associated user accounts
{
// Build the delete SQL statement.
query := sqlbuilder.NewDeleteBuilder()
query.DeleteFrom(usersAccountsTableName)
query.Where(query.And(
query.Equal("user_id", req.ID),
))
// Execute the query with the provided context.
sql, args := query.Build()
sql = dbConn.Rebind(sql)
_, err = dbConn.ExecContext(ctx, sql, args...)
if err != nil {
err = errors.Wrapf(err, "query - %s", query.String())
err = errors.WithMessagef(err, "delete accounts for user %s failed", req.ID)
return err
}
}
return nil
}

File diff suppressed because it is too large Load Diff

View File

@ -1,7 +1,6 @@
package user package user
import ( import (
"database/sql"
"database/sql/driver" "database/sql/driver"
"time" "time"
@ -11,63 +10,6 @@ import (
"gopkg.in/go-playground/validator.v9" "gopkg.in/go-playground/validator.v9"
) )
// User represents someone with access to our system.
type User struct {
ID string `json:"id"`
Name string `json:"name"`
Email string `json:"email"`
PasswordSalt string `json:"-"`
PasswordHash []byte `json:"-"`
PasswordReset sql.NullString `json:"-"`
Timezone string `json:"timezone"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
ArchivedAt pq.NullTime `json:"archived_at"`
}
// CreateUserRequest contains information needed to create a new User.
type CreateUserRequest struct {
Name string `json:"name" validate:"required"`
Email string `json:"email" validate:"required,email,unique"`
Password string `json:"password" validate:"required"`
PasswordConfirm string `json:"password_confirm" validate:"eqfield=Password"`
Timezone *string `json:"timezone" validate:"omitempty"`
}
// UpdateUserRequest defines what information may be provided to modify an existing
// User. All fields are optional so clients can send just the fields they want
// changed. It uses pointer fields so we can differentiate between a field that
// was not provided and a field that was provided as explicitly blank. Normally
// we do not want to use pointers to basic types but we make exceptions around
// marshalling/unmarshalling.
type UpdateUserRequest struct {
ID string `validate:"required,uuid"`
Name *string `json:"name" validate:"omitempty"`
Email *string `json:"email" validate:"omitempty,email,unique"`
Timezone *string `json:"timezone" validate:"omitempty"`
}
// UpdatePassword defines what information is required to update a user password.
type UpdatePasswordRequest struct {
ID string `validate:"required,uuid"`
Password string `json:"password" validate:"required"`
PasswordConfirm string `json:"password_confirm" validate:"omitempty,eqfield=Password"`
}
// UserFindRequest defines the possible options to search for users. By default
// archived users will be excluded from response.
type UserFindRequest struct {
Where *string
Args []interface{}
Order []string
Limit *uint
Offset *uint
IncludedArchived bool
}
// UserAccount defines the one to many relationship of an user to an account. This // UserAccount defines the one to many relationship of an user to an account. This
// will enable a single user access to multiple accounts without having duplicate // will enable a single user access to multiple accounts without having duplicate
// users. Each association of a user to an account has a set of roles and a status // users. Each association of a user to an account has a set of roles and a status
@ -85,20 +27,20 @@ type UserAccount struct {
ArchivedAt pq.NullTime `json:"archived_at"` ArchivedAt pq.NullTime `json:"archived_at"`
} }
// AddAccountRequest defines the information is needed to associate a user to an // CreateUserAccountRequest defines the information is needed to associate a user to an
// account. Users are global to the application and each users access can be managed // account. Users are global to the application and each users access can be managed
// on an account level. If a current entry exists in the database but is archived, // on an account level. If a current entry exists in the database but is archived,
// it will be un-archived. // it will be un-archived.
type AddAccountRequest struct { type CreateUserAccountRequest struct {
UserID string `validate:"required,uuid"` UserID string `validate:"required,uuid"`
AccountID string `validate:"required,uuid"` AccountID string `validate:"required,uuid"`
Roles UserAccountRoles `json:"roles" validate:"required,dive,oneof=admin user"` Roles UserAccountRoles `json:"roles" validate:"required,dive,oneof=admin user"`
Status *UserAccountStatus `json:"status" validate:"omitempty,oneof=active invited disabled"` Status *UserAccountStatus `json:"status" validate:"omitempty,oneof=active invited disabled"`
} }
// UpdateAccountRequest defines the information needed to update the roles or the // UpdateUserAccountRequest defines the information needed to update the roles or the
// status for an existing user account. // status for an existing user account.
type UpdateAccountRequest struct { type UpdateUserAccountRequest struct {
UserID string `validate:"required,uuid"` UserID string `validate:"required,uuid"`
AccountID string `validate:"required,uuid"` AccountID string `validate:"required,uuid"`
Roles *UserAccountRoles `json:"roles" validate:"required,dive,oneof=admin user"` Roles *UserAccountRoles `json:"roles" validate:"required,dive,oneof=admin user"`
@ -106,16 +48,16 @@ type UpdateAccountRequest struct {
unArchive bool `json:"-"` // Internal use only. unArchive bool `json:"-"` // Internal use only.
} }
// RemoveAccountRequest defines the information needed to remove an existing account // ArchiveUserAccountRequest defines the information needed to remove an existing account
// for a user. This will archive (soft-delete) the existing database entry. // for a user. This will archive (soft-delete) the existing database entry.
type RemoveAccountRequest struct { type ArchiveUserAccountRequest struct {
UserID string `validate:"required,uuid"` UserID string `validate:"required,uuid"`
AccountID string `validate:"required,uuid"` AccountID string `validate:"required,uuid"`
} }
// DeleteAccountRequest defines the information needed to delete an existing account // DeleteUserAccountRequest defines the information needed to delete an existing account
// for a user. This will hard delete the existing database entry. // for a user. This will hard delete the existing database entry.
type DeleteAccountRequest struct { type DeleteUserAccountRequest struct {
UserID string `validate:"required,uuid"` UserID string `validate:"required,uuid"`
AccountID string `validate:"required,uuid"` AccountID string `validate:"required,uuid"`
} }
@ -238,9 +180,3 @@ func (s UserAccountRoles) Value() (driver.Value, error) {
return arr.Value() return arr.Value()
} }
// Token is the payload we deliver to users when they authenticate.
type Token struct {
Token string `json:"token"`
claims auth.Claims `json:"-"`
}

View File

@ -3,6 +3,8 @@ package user
import ( import (
"context" "context"
"database/sql" "database/sql"
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/account"
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/user"
"github.com/lib/pq" "github.com/lib/pq"
"time" "time"
@ -15,14 +17,29 @@ import (
"gopkg.in/go-playground/validator.v9" "gopkg.in/go-playground/validator.v9"
) )
var (
// ErrNotFound abstracts the mgo not found error.
ErrNotFound = errors.New("Entity not found")
// ErrInvalidID occurs when an ID is not in a valid form.
ErrInvalidID = errors.New("ID is not in its proper form")
// ErrAuthenticationFailure occurs when a user attempts to authenticate but
// anything goes wrong.
ErrAuthenticationFailure = errors.New("Authentication failed")
// ErrForbidden occurs when a user tries to do something that is forbidden to them according to our access control policies.
ErrForbidden = errors.New("Attempted action is not allowed")
)
// The database table for UserAccount // The database table for UserAccount
const usersAccountsTableName = "users_accounts" const userAccountTableName = "users_accounts"
// The list of columns needed for mapRowsToUserAccount // The list of columns needed for mapRowsToUserAccount
var usersAccountsMapColumns = "id,user_id,account_id,roles,status,created_at,updated_at,archived_at" var userAccountMapColumns = "id,user_id,account_id,roles,status,created_at,updated_at,archived_at"
// mapRowsToUserAccount takes the SQL rows and maps it to the UserAccount struct // mapRowsToUserAccount takes the SQL rows and maps it to the UserAccount struct
// with the columns defined by usersAccountsMapColumns // with the columns defined by userAccountMapColumns
func mapRowsToUserAccount(rows *sql.Rows) (*UserAccount, error) { func mapRowsToUserAccount(rows *sql.Rows) (*UserAccount, error) {
var ( var (
ua UserAccount ua UserAccount
@ -36,10 +53,31 @@ func mapRowsToUserAccount(rows *sql.Rows) (*UserAccount, error) {
return &ua, nil return &ua, nil
} }
// CanReadUserAccount determines if claims has the authority to access the specified user account by user ID.
func CanReadUserAccount(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, userID, accountID string) error {
// First check to see if claims can read the user ID
err := user.CanReadUser(ctx, claims, dbConn, userID)
if err != nil {
if claims.Audience != accountID {
return err
}
}
// Second check to see if claims can read the account ID
err = account.CanReadAccount(ctx, claims, dbConn, accountID)
if err != nil {
if claims.Audience != accountID {
return err
}
}
return nil
}
// CanModifyUserAccount determines if claims has the authority to modify the specified user ID. // CanModifyUserAccount determines if claims has the authority to modify the specified user ID.
func CanModifyUserAccount(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, userID, accountID string) error { func CanModifyUserAccount(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, userID, accountID string) error {
// First check to see if claims can read the user ID // First check to see if claims can read the user ID
err := CanReadUser(ctx, claims, dbConn, userID) err := CanReadUserAccount(ctx, claims, dbConn, userID, accountID)
if err != nil { if err != nil {
if claims.Audience != accountID { if claims.Audience != accountID {
return err return err
@ -71,7 +109,7 @@ func applyClaimsUserAccountSelect(ctx context.Context, claims auth.Claims, query
} }
// Build select statement for users_accounts table // Build select statement for users_accounts table
subQuery := sqlbuilder.NewSelectBuilder().Select("user_id").From(usersAccountsTableName) subQuery := sqlbuilder.NewSelectBuilder().Select("user_id").From(userAccountTableName)
var or []string var or []string
if claims.Audience != "" { if claims.Audience != "" {
@ -89,18 +127,18 @@ func applyClaimsUserAccountSelect(ctx context.Context, claims auth.Claims, query
} }
// AccountSelectQuery // AccountSelectQuery
func accountSelectQuery() *sqlbuilder.SelectBuilder { func userAccountSelectQuery() *sqlbuilder.SelectBuilder {
query := sqlbuilder.NewSelectBuilder() query := sqlbuilder.NewSelectBuilder()
query.Select(usersAccountsMapColumns) query.Select(userAccountMapColumns)
query.From(usersAccountsTableName) query.From(userAccountTableName)
return query return query
} }
// userFindRequestQuery generates the select query for the given find request. // userFindRequestQuery generates the select query for the given find request.
// TODO: Need to figure out why can't parse the args when appending the where // TODO: Need to figure out why can't parse the args when appending the where
// to the query. // to the query.
func accountFindRequestQuery(req UserAccountFindRequest) (*sqlbuilder.SelectBuilder, []interface{}) { func userAccountFindRequestQuery(req UserAccountFindRequest) (*sqlbuilder.SelectBuilder, []interface{}) {
query := accountSelectQuery() query := userAccountSelectQuery()
if req.Where != nil { if req.Where != nil {
query.Where(query.And(*req.Where)) query.Where(query.And(*req.Where))
} }
@ -117,19 +155,19 @@ func accountFindRequestQuery(req UserAccountFindRequest) (*sqlbuilder.SelectBuil
return query, req.Args return query, req.Args
} }
// Find gets all the users from the database based on the request params // Find gets all the user accounts from the database based on the request params
func FindAccounts(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserAccountFindRequest) ([]*UserAccount, error) { func Find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserAccountFindRequest) ([]*UserAccount, error) {
query, args := accountFindRequestQuery(req) query, args := userAccountFindRequestQuery(req)
return findAccounts(ctx, claims, dbConn, query, args, req.IncludedArchived) return find(ctx, claims, dbConn, query, args, req.IncludedArchived)
} }
// Find gets all the users from the database based on the select query // Find gets all the user accounts from the database based on the select query
func findAccounts(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, query *sqlbuilder.SelectBuilder, args []interface{}, includedArchived bool) ([]*UserAccount, error) { func find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, query *sqlbuilder.SelectBuilder, args []interface{}, includedArchived bool) ([]*UserAccount, error) {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.FindAccounts") span, ctx := tracer.StartSpanFromContext(ctx, "internal.user_account.Find")
defer span.Finish() defer span.Finish()
query.Select(usersAccountsMapColumns) query.Select(userAccountMapColumns)
query.From(usersAccountsTableName) query.From(userAccountTableName)
if !includedArchived { if !includedArchived {
query.Where(query.IsNull("archived_at")) query.Where(query.IsNull("archived_at"))
@ -148,7 +186,7 @@ func findAccounts(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, quer
rows, err := dbConn.QueryContext(ctx, queryStr, args...) rows, err := dbConn.QueryContext(ctx, queryStr, args...)
if err != nil { if err != nil {
err = errors.Wrapf(err, "query - %s", query.String()) err = errors.Wrapf(err, "query - %s", query.String())
err = errors.WithMessage(err, "find accounts failed") err = errors.WithMessage(err, "find user accounts failed")
return nil, err return nil, err
} }
@ -167,8 +205,8 @@ func findAccounts(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, quer
} }
// Retrieve gets the specified user from the database. // Retrieve gets the specified user from the database.
func FindAccountsByUserID(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, userID string, includedArchived bool) ([]*UserAccount, error) { func FindByUserID(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, userID string, includedArchived bool) ([]*UserAccount, error) {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.FindAccountsByUserId") span, ctx := tracer.StartSpanFromContext(ctx, "internal.user_account.FindByUserID")
defer span.Finish() defer span.Finish()
// Filter base select query by ID // Filter base select query by ID
@ -177,7 +215,7 @@ func FindAccountsByUserID(ctx context.Context, claims auth.Claims, dbConn *sqlx.
query.OrderBy("created_at") query.OrderBy("created_at")
// Execute the find accounts method. // Execute the find accounts method.
res, err := findAccounts(ctx, claims, dbConn, query, []interface{}{}, includedArchived) res, err := find(ctx, claims, dbConn, query, []interface{}{}, includedArchived)
if err != nil { if err != nil {
return nil, err return nil, err
} else if res == nil || len(res) == 0 { } else if res == nil || len(res) == 0 {
@ -189,8 +227,8 @@ func FindAccountsByUserID(ctx context.Context, claims auth.Claims, dbConn *sqlx.
} }
// AddAccount an account for a given user with specified roles. // AddAccount an account for a given user with specified roles.
func AddAccount(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req AddAccountRequest, now time.Time) (*UserAccount, error) { func Create(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req CreateUserAccountRequest, now time.Time) (*UserAccount, error) {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.AddAccount") span, ctx := tracer.StartSpanFromContext(ctx, "internal.user_account.Create")
defer span.Finish() defer span.Finish()
// Validate the request. // Validate the request.
@ -218,25 +256,25 @@ func AddAccount(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req Ad
now = now.Truncate(time.Millisecond) now = now.Truncate(time.Millisecond)
// Check to see if there is an existing user account, including archived. // Check to see if there is an existing user account, including archived.
existQuery := accountSelectQuery() existQuery := userAccountSelectQuery()
existQuery.Where(existQuery.And( existQuery.Where(existQuery.And(
existQuery.Equal("account_id", req.AccountID), existQuery.Equal("account_id", req.AccountID),
existQuery.Equal("user_id", req.UserID), existQuery.Equal("user_id", req.UserID),
)) ))
existing, err := findAccounts(ctx, claims, dbConn, existQuery, []interface{}{}, true) existing, err := find(ctx, claims, dbConn, existQuery, []interface{}{}, true)
if err != nil { if err != nil {
return nil, err return nil, err
} }
// If there is an existing entry, then update instead of insert. // If there is an existing entry, then update instead of insert.
if len(existing) > 0 { if len(existing) > 0 {
upReq := UpdateAccountRequest{ upReq := UpdateUserAccountRequest{
UserID: req.UserID, UserID: req.UserID,
AccountID: req.AccountID, AccountID: req.AccountID,
Roles: &req.Roles, Roles: &req.Roles,
unArchive: true, unArchive: true,
} }
err = UpdateAccount(ctx, claims, dbConn, upReq, now) err = Update(ctx, claims, dbConn, upReq, now)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -265,7 +303,7 @@ func AddAccount(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req Ad
// Build the insert SQL statement. // Build the insert SQL statement.
query := sqlbuilder.NewInsertBuilder() query := sqlbuilder.NewInsertBuilder()
query.InsertInto(usersAccountsTableName) query.InsertInto(userAccountTableName)
query.Cols("id", "user_id", "account_id", "roles", "status", "created_at", "updated_at") query.Cols("id", "user_id", "account_id", "roles", "status", "created_at", "updated_at")
query.Values(ua.ID, ua.UserID, ua.AccountID, ua.Roles, ua.Status.String(), ua.CreatedAt, ua.UpdatedAt) query.Values(ua.ID, ua.UserID, ua.AccountID, ua.Roles, ua.Status.String(), ua.CreatedAt, ua.UpdatedAt)
@ -283,8 +321,8 @@ func AddAccount(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req Ad
} }
// UpdateAccount... // UpdateAccount...
func UpdateAccount(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UpdateAccountRequest, now time.Time) error { func Update(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UpdateUserAccountRequest, now time.Time) error {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.Update") span, ctx := tracer.StartSpanFromContext(ctx, "internal.user_account.Update")
defer span.Finish() defer span.Finish()
// Validate the request. // Validate the request.
@ -313,7 +351,7 @@ func UpdateAccount(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req
// Build the update SQL statement. // Build the update SQL statement.
query := sqlbuilder.NewUpdateBuilder() query := sqlbuilder.NewUpdateBuilder()
query.Update(usersAccountsTableName) query.Update(userAccountTableName)
fields := []string{} fields := []string{}
if req.Roles != nil { if req.Roles != nil {
@ -351,9 +389,9 @@ func UpdateAccount(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req
return nil return nil
} }
// RemoveAccount soft deleted the user account from the database. // Archive soft deleted the user account from the database.
func RemoveAccount(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req RemoveAccountRequest, now time.Time) error { func Archive(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req ArchiveUserAccountRequest, now time.Time) error {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.RemoveAccount") span, ctx := tracer.StartSpanFromContext(ctx, "internal.user_account.Archive")
defer span.Finish() defer span.Finish()
// Validate the request. // Validate the request.
@ -382,7 +420,7 @@ func RemoveAccount(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req
// Build the update SQL statement. // Build the update SQL statement.
query := sqlbuilder.NewUpdateBuilder() query := sqlbuilder.NewUpdateBuilder()
query.Update(usersAccountsTableName) query.Update(userAccountTableName)
query.Set(query.Assign("archived_at", now)) query.Set(query.Assign("archived_at", now))
query.Where(query.And( query.Where(query.And(
query.Equal("user_id", req.UserID), query.Equal("user_id", req.UserID),
@ -402,9 +440,9 @@ func RemoveAccount(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req
return nil return nil
} }
// DeleteAccount removes a user account from the database. // Delete removes a user account from the database.
func DeleteAccount(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req DeleteAccountRequest) error { func Delete(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req DeleteUserAccountRequest) error {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.RemoveAccount") span, ctx := tracer.StartSpanFromContext(ctx, "internal.user_account.Delete")
defer span.Finish() defer span.Finish()
// Validate the request. // Validate the request.
@ -421,7 +459,7 @@ func DeleteAccount(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req
// Build the delete SQL statement. // Build the delete SQL statement.
query := sqlbuilder.NewDeleteBuilder() query := sqlbuilder.NewDeleteBuilder()
query.DeleteFrom(usersAccountsTableName) query.DeleteFrom(userAccountTableName)
query.Where(query.And( query.Where(query.And(
query.Equal("user_id", req.UserID), query.Equal("user_id", req.UserID),
query.Equal("account_id", req.AccountID), query.Equal("account_id", req.AccountID),

View File

@ -37,9 +37,9 @@ func TestAccountFindRequestQuery(t *testing.T) {
Limit: &limit, Limit: &limit,
Offset: &offset, Offset: &offset,
} }
expected := "SELECT " + usersAccountsMapColumns + " FROM " + usersAccountsTableName + " WHERE (account_id = ? or user_id = ?) ORDER BY id asc, created_at desc LIMIT 12 OFFSET 34" expected := "SELECT " + usersAccountsMapColumns + " FROM " + userAccountTableName + " WHERE (account_id = ? or user_id = ?) ORDER BY id asc, created_at desc LIMIT 12 OFFSET 34"
res, args := accountFindRequestQuery(req) res, args := userAccountFindRequestQuery(req)
if diff := cmp.Diff(res.String(), expected); diff != "" { if diff := cmp.Diff(res.String(), expected); diff != "" {
t.Fatalf("\t%s\tExpected result query to match. Diff:\n%s", tests.Failed, diff) t.Fatalf("\t%s\tExpected result query to match. Diff:\n%s", tests.Failed, diff)
@ -59,7 +59,7 @@ func TestApplyClaimsUserAccountSelect(t *testing.T) {
}{ }{
{"EmptyClaims", {"EmptyClaims",
auth.Claims{}, auth.Claims{},
"SELECT " + usersAccountsMapColumns + " FROM " + usersAccountsTableName, "SELECT " + usersAccountsMapColumns + " FROM " + userAccountTableName,
nil, nil,
}, },
{"RoleUser", {"RoleUser",
@ -70,7 +70,7 @@ func TestApplyClaimsUserAccountSelect(t *testing.T) {
Audience: "acc1", Audience: "acc1",
}, },
}, },
"SELECT " + usersAccountsMapColumns + " FROM " + usersAccountsTableName + " WHERE user_id IN (SELECT user_id FROM " + usersAccountsTableName + " WHERE (account_id = 'acc1' OR user_id = 'user1'))", "SELECT " + usersAccountsMapColumns + " FROM " + userAccountTableName + " WHERE user_id IN (SELECT user_id FROM " + userAccountTableName + " WHERE (account_id = 'acc1' OR user_id = 'user1'))",
nil, nil,
}, },
{"RoleAdmin", {"RoleAdmin",
@ -81,7 +81,7 @@ func TestApplyClaimsUserAccountSelect(t *testing.T) {
Audience: "acc1", Audience: "acc1",
}, },
}, },
"SELECT " + usersAccountsMapColumns + " FROM " + usersAccountsTableName + " WHERE user_id IN (SELECT user_id FROM " + usersAccountsTableName + " WHERE (account_id = 'acc1' OR user_id = 'user1'))", "SELECT " + usersAccountsMapColumns + " FROM " + userAccountTableName + " WHERE user_id IN (SELECT user_id FROM " + userAccountTableName + " WHERE (account_id = 'acc1' OR user_id = 'user1'))",
nil, nil,
}, },
} }
@ -93,7 +93,7 @@ func TestApplyClaimsUserAccountSelect(t *testing.T) {
{ {
ctx := tests.Context() ctx := tests.Context()
query := accountSelectQuery() query := userAccountSelectQuery()
err := applyClaimsUserAccountSelect(ctx, tt.claims, query) err := applyClaimsUserAccountSelect(ctx, tt.claims, query)
if err != tt.error { if err != tt.error {