1
0
mirror of https://github.com/raseels-repos/golang-saas-starter-kit.git synced 2025-06-08 23:56:37 +02:00

finish internal/user user and auth unittests

This commit is contained in:
Lee Brown 2019-05-28 04:44:01 -05:00
parent 895128bbbe
commit c121b7d289
11 changed files with 913 additions and 471 deletions

View File

@ -6,6 +6,7 @@ require (
github.com/aws/aws-sdk-go v1.19.33 github.com/aws/aws-sdk-go v1.19.33
github.com/dgrijalva/jwt-go v3.2.0+incompatible github.com/dgrijalva/jwt-go v3.2.0+incompatible
github.com/dimfeld/httptreemux v5.0.1+incompatible github.com/dimfeld/httptreemux v5.0.1+incompatible
github.com/geeks-accelerator/sqlxmigrate v0.0.0-20190527223850-4a863a2d30db
github.com/gitwak/gondolier v0.0.0-20190521205431-504d297a6c42 // indirect github.com/gitwak/gondolier v0.0.0-20190521205431-504d297a6c42 // indirect
github.com/gitwak/sqlxmigrate v0.0.0-20190527063335-e98d5d44fc0b github.com/gitwak/sqlxmigrate v0.0.0-20190527063335-e98d5d44fc0b
github.com/go-playground/locales v0.12.1 github.com/go-playground/locales v0.12.1

View File

@ -13,6 +13,8 @@ github.com/dgrijalva/jwt-go v3.2.0+incompatible h1:7qlOGliEKZXTDg6OTjfoBKDXWrumC
github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ= github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ=
github.com/dimfeld/httptreemux v5.0.1+incompatible h1:Qj3gVcDNoOthBAqftuD596rm4wg/adLLz5xh5CmpiCA= github.com/dimfeld/httptreemux v5.0.1+incompatible h1:Qj3gVcDNoOthBAqftuD596rm4wg/adLLz5xh5CmpiCA=
github.com/dimfeld/httptreemux v5.0.1+incompatible/go.mod h1:rbUlSV+CCpv/SuqUTP/8Bk2O3LyUV436/yaRGkhP6Z0= github.com/dimfeld/httptreemux v5.0.1+incompatible/go.mod h1:rbUlSV+CCpv/SuqUTP/8Bk2O3LyUV436/yaRGkhP6Z0=
github.com/geeks-accelerator/sqlxmigrate v0.0.0-20190527223850-4a863a2d30db h1:mjErP7mTFHQ3cw/ibAkW3CvQ8gM4k19EkfzRzRINDAE=
github.com/geeks-accelerator/sqlxmigrate v0.0.0-20190527223850-4a863a2d30db/go.mod h1:dzpCjo4q7chhMVuHDzs/odROkieZ5Wjp70rNDuX83jU=
github.com/gitwak/gondolier v0.0.0-20190521205431-504d297a6c42 h1:+lo4HFeG6LlcgwvsvQC8H5FG8yr/kDn89E51BTw3loE= github.com/gitwak/gondolier v0.0.0-20190521205431-504d297a6c42 h1:+lo4HFeG6LlcgwvsvQC8H5FG8yr/kDn89E51BTw3loE=
github.com/gitwak/gondolier v0.0.0-20190521205431-504d297a6c42/go.mod h1:ecEQ8e4eHeWKPf+g6ByatPM7l4QZgR3G5ZIZKvEAdCE= github.com/gitwak/gondolier v0.0.0-20190521205431-504d297a6c42/go.mod h1:ecEQ8e4eHeWKPf+g6ByatPM7l4QZgR3G5ZIZKvEAdCE=
github.com/gitwak/sqlxmigrate v0.0.0-20190522211042-9625063dea5d h1:oaUPMY0F+lNUkyB5tzsQS3EC0m9Cxdglesp63i3UPso= github.com/gitwak/sqlxmigrate v0.0.0-20190522211042-9625063dea5d h1:oaUPMY0F+lNUkyB5tzsQS3EC0m9Cxdglesp63i3UPso=

View File

@ -185,7 +185,7 @@ func NewAuthenticator(awsSession *session.Session, awsSecretID string, now time.
// refreshed on instance launch. Could store keys in a kv store and update that value // refreshed on instance launch. Could store keys in a kv store and update that value
// when new keys are generated // when new keys are generated
if len(keyContents) == 0 || curKeyId == "" { if len(keyContents) == 0 || curKeyId == "" {
privateKey, err := keygen() privateKey, err := Keygen()
if err != nil { if err != nil {
return nil, errors.Wrap(err, "failed to generate new private key") return nil, errors.Wrap(err, "failed to generate new private key")
} }
@ -307,8 +307,8 @@ func (a *Authenticator) ParseClaims(tknStr string) (Claims, error) {
return claims, nil return claims, nil
} }
// keygen creates an x509 private key for signing auth tokens. // Keygen creates an x509 private key for signing auth tokens.
func keygen() ([]byte, error) { func Keygen() ([]byte, error) {
key, err := rsa.GenerateKey(rand.Reader, 2048) key, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil { if err != nil {
return []byte{}, errors.Wrap(err, "generating keys") return []byte{}, errors.Wrap(err, "generating keys")

View File

@ -57,8 +57,6 @@ func New() *Test {
dbHost := fmt.Sprintf("postgres://%s:%s@127.0.0.1:%s/%s?timezone=UTC&sslmode=disable", container.User, container.Pass, container.Port, container.Database) dbHost := fmt.Sprintf("postgres://%s:%s@127.0.0.1:%s/%s?timezone=UTC&sslmode=disable", container.User, container.Pass, container.Port, container.Database)
fmt.Println(dbHost)
// ============================================================ // ============================================================
// Start Postgres // Start Postgres

View File

@ -4,7 +4,7 @@ import (
"database/sql" "database/sql"
"log" "log"
"github.com/gitwak/sqlxmigrate" "github.com/geeks-accelerator/sqlxmigrate"
"github.com/jmoiron/sqlx" "github.com/jmoiron/sqlx"
_ "github.com/lib/pq" _ "github.com/lib/pq"
"github.com/pkg/errors" "github.com/pkg/errors"
@ -106,7 +106,7 @@ func migrationList(db *sqlx.DB, log *log.Logger) []*sqlxmigrate.Migration {
{ {
ID: "20190522-01c", ID: "20190522-01c",
Migrate: func(tx *sql.Tx) error { Migrate: func(tx *sql.Tx) error {
q1 := `CREATE TYPE user_account_role_t as enum('admin', 'user')` q1 := `CREATE TYPE user_account_role_t as enum('ADMIN', 'USER')`
if _, err := tx.Exec(q1); err != nil { if _, err := tx.Exec(q1); err != nil {
return errors.WithMessagef(err, "Query failed %s", q1) return errors.WithMessagef(err, "Query failed %s", q1)
} }

View File

@ -3,7 +3,7 @@ package schema
import ( import (
"log" "log"
"github.com/gitwak/sqlxmigrate" "github.com/geeks-accelerator/sqlxmigrate"
"github.com/jmoiron/sqlx" "github.com/jmoiron/sqlx"
) )

View File

@ -29,7 +29,7 @@ func Authenticate(ctx context.Context, dbConn *sqlx.DB, tknGen TokenGenerator, n
query.Where(query.Equal("email", email)) query.Where(query.Equal("email", email))
// Run the find, use empty claims to bypass ACLs // Run the find, use empty claims to bypass ACLs
res, err := find(ctx, auth.Claims{}, dbConn, query, false) res, err := find(ctx, auth.Claims{}, dbConn, query, []interface{}{}, false)
if err != nil { if err != nil {
return Token{}, err return Token{}, err
} else if res == nil || len(res) == 0 { } else if res == nil || len(res) == 0 {
@ -39,7 +39,7 @@ func Authenticate(ctx context.Context, dbConn *sqlx.DB, tknGen TokenGenerator, n
u := res[0] u := res[0]
// Append the salt from the user record to the supplied password. // Append the salt from the user record to the supplied password.
saltedPassword := password + string(u.PasswordHash) saltedPassword := password + string(u.PasswordSalt)
// Compare the provided password with the saved hash. Use the bcrypt // Compare the provided password with the saved hash. Use the bcrypt
// comparison function so it is cryptographically secure. // comparison function so it is cryptographically secure.
@ -61,15 +61,17 @@ func Authenticate(ctx context.Context, dbConn *sqlx.DB, tknGen TokenGenerator, n
roles []string roles []string
) )
if len(accounts) > 0 { if len(accounts) > 0 {
accountId = accounts[0].ID accountId = accounts[0].AccountID
roles = accounts[0].Roles for _, r := range accounts[0].Roles {
roles = append(roles, r.String())
}
} }
// Generate a list of all the account IDs associated with the user so // Generate a list of all the account IDs associated with the user so
// the use has the ability to switch between accounts. // the use has the ability to switch between accounts.
accountIds := []string{} accountIds := []string{}
for _, a := range accounts { for _, a := range accounts {
accountIds = append(accountIds, a.ID) accountIds = append(accountIds, a.AccountID)
} }
// If we are this far the request is valid. Create some claims for the user. // If we are this far the request is valid. Create some claims for the user.
@ -81,5 +83,5 @@ func Authenticate(ctx context.Context, dbConn *sqlx.DB, tknGen TokenGenerator, n
return Token{}, errors.Wrap(err, "generating token") return Token{}, errors.Wrap(err, "generating token")
} }
return Token{Token: tkn}, nil return Token{Token: tkn, claims: claims}, nil
} }

View File

@ -5,6 +5,7 @@ import (
"database/sql/driver" "database/sql/driver"
"time" "time"
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/auth"
"github.com/lib/pq" "github.com/lib/pq"
"github.com/pkg/errors" "github.com/pkg/errors"
"gopkg.in/go-playground/validator.v9" "gopkg.in/go-playground/validator.v9"
@ -34,8 +35,8 @@ type CreateUserRequest struct {
Email string `json:"email" validate:"required,email,unique"` Email string `json:"email" validate:"required,email,unique"`
Password string `json:"password" validate:"required"` Password string `json:"password" validate:"required"`
PasswordConfirm string `json:"password_confirm" validate:"eqfield=Password"` PasswordConfirm string `json:"password_confirm" validate:"eqfield=Password"`
Status *UserStatus `json:"status" validate:"oneof=active disabled"` Status *UserStatus `json:"status" validate:"omitempty,oneof=active disabled"`
Timezone *string `json:"timezone"` Timezone *string `json:"timezone" validate:"omitempty"`
} }
// UpdateUserRequest defines what information may be provided to modify an existing // UpdateUserRequest defines what information may be provided to modify an existing
@ -46,10 +47,10 @@ type CreateUserRequest struct {
// marshalling/unmarshalling. // marshalling/unmarshalling.
type UpdateUserRequest struct { type UpdateUserRequest struct {
ID string `validate:"required,uuid"` ID string `validate:"required,uuid"`
Name *string `json:"name"` Name *string `json:"name" validate:"omitempty"`
Email *string `json:"email" validate:"email,unique"` Email *string `json:"email" validate:"omitempty,email,unique"`
Status *UserStatus `json:"status" validate:"oneof=active disabled"` Status *UserStatus `json:"status" validate:"omitempty,oneof=active disabled"`
Timezone *string `json:"timezone"` Timezone *string `json:"timezone" validate:"omitempty"`
} }
// UpdatePassword defines what information may be provided to update user password. // UpdatePassword defines what information may be provided to update user password.
@ -76,7 +77,7 @@ type UserAccount struct {
ID string `db:"id" json:"id"` ID string `db:"id" json:"id"`
UserID string `db:"user_id" json:"user_id"` UserID string `db:"user_id" json:"user_id"`
AccountID string `db:"account_id" json:"account_id"` AccountID string `db:"account_id" json:"account_id"`
Roles []string `db:"roles" json:"roles"` Roles UserAccountRoles `db:"roles" json:"roles"`
CreatedAt time.Time `db:"created_at" json:"created_at"` CreatedAt time.Time `db:"created_at" json:"created_at"`
UpdatedAt time.Time `db:"updated_at" json:"updated_at"` UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
ArchivedAt pq.NullTime `db:"archived_at" json:"archived_at"` ArchivedAt pq.NullTime `db:"archived_at" json:"archived_at"`
@ -86,7 +87,7 @@ type UserAccount struct {
type AddAccountRequest struct { type AddAccountRequest struct {
UserID string `validate:"required,uuid"` UserID string `validate:"required,uuid"`
AccountID string `validate:"required,uuid"` AccountID string `validate:"required,uuid"`
Roles []string `json:"roles" validate:"oneof=admin user"` Roles UserAccountRoles `json:"roles" validate:"required,dive,oneof=ADMIN USER"`
} }
// UpdateAccountRequest defines the information needed to update the roles for // UpdateAccountRequest defines the information needed to update the roles for
@ -94,7 +95,7 @@ type AddAccountRequest struct {
type UpdateAccountRequest struct { type UpdateAccountRequest struct {
UserID string `validate:"required,uuid"` UserID string `validate:"required,uuid"`
AccountID string `validate:"required,uuid"` AccountID string `validate:"required,uuid"`
Roles []string `json:"roles" validate:"oneof=admin user"` Roles UserAccountRoles `json:"roles" validate:"oneof=ADMIN USER"`
unArchive bool unArchive bool
} }
@ -122,6 +123,9 @@ type UserAccountFindRequest struct {
IncludedArchived bool IncludedArchived bool
} }
// UserStatus represents the status of a user.
type UserStatus string
// UserStatus values // UserStatus values
const ( const (
UserStatus_Active UserStatus = "active" UserStatus_Active UserStatus = "active"
@ -134,9 +138,6 @@ var UserStatus_Values = []UserStatus{
UserStatus_Disabled, UserStatus_Disabled,
} }
// UserStatus represents the status of a user.
type UserStatus string
// Scan supports reading the UserStatus value from the database. // Scan supports reading the UserStatus value from the database.
func (s *UserStatus) Scan(value interface{}) error { func (s *UserStatus) Scan(value interface{}) error {
asBytes, ok := value.([]byte) asBytes, ok := value.([]byte)
@ -156,7 +157,6 @@ func (s UserStatus) Value() (driver.Value, error) {
return nil, errs return nil, errs
} }
// validation would go here
return string(s), nil return string(s), nil
} }
@ -165,7 +165,61 @@ func (s UserStatus) String() string {
return string(s) return string(s)
} }
// UserAccountRole represents the role of a user for an account.
type UserAccountRole string
// UserAccountRole values
const (
UserAccountRole_Admin UserAccountRole = auth.RoleAdmin
UserAccountRole_User UserAccountRole = auth.RoleUser
)
// UserAccountRole_Values provides list of valid UserAccountRole values
var UserAccountRole_Values = []UserAccountRole{
UserAccountRole_Admin,
UserAccountRole_User,
}
// String converts the UserAccountRole value to a string.
func (s UserAccountRole) String() string {
return string(s)
}
// UserAccountRoles represents a set of roles for a user for an account.
type UserAccountRoles []UserAccountRole
// Scan supports reading the UserAccountRole value from the database.
func (s *UserAccountRoles) Scan(value interface{}) error {
arr := &pq.StringArray{}
if err := arr.Scan(value); err != nil {
return err
}
for _, v := range *arr {
*s = append(*s, UserAccountRole(v))
}
return nil
}
// Value converts the UserAccountRole value to be stored in the database.
func (s UserAccountRoles) Value() (driver.Value, error) {
v := validator.New()
var arr pq.StringArray
for _, r := range s {
errs := v.Var(r, "required,oneof=ADMIN USER")
if errs != nil {
return nil, errs
}
arr = append(arr, r.String())
}
return arr.Value()
}
// Token is the payload we deliver to users when they authenticate. // Token is the payload we deliver to users when they authenticate.
type Token struct { type Token struct {
Token string `json:"token"` Token string `json:"token"`
claims auth.Claims `json:"-"`
} }

View File

@ -3,6 +3,7 @@ package user
import ( import (
"context" "context"
"database/sql" "database/sql"
"fmt"
"time" "time"
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/auth" "geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/auth"
@ -43,7 +44,7 @@ func mapRowsToUser(rows *sql.Rows) (*User, error) {
u User u User
err error err error
) )
err = rows.Scan(&u.ID, &u.Email, &u.PasswordSalt, &u.PasswordHash, &u.PasswordReset, &u.Status, &u.Timezone, &u.CreatedAt, &u.UpdatedAt, &u.ArchivedAt) err = rows.Scan(&u.ID, &u.Name, &u.Email, &u.PasswordSalt, &u.PasswordHash, &u.PasswordReset, &u.Status, &u.Timezone, &u.CreatedAt, &u.UpdatedAt, &u.ArchivedAt)
if err != nil { if err != nil {
return nil, errors.WithStack(err) return nil, errors.WithStack(err)
} }
@ -59,18 +60,17 @@ func CanReadUserId(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, use
// When the claims Subject - UserId - does not match the requested user, the // When the claims Subject - UserId - does not match the requested user, the
// claims audience - AccountId - should have a record. // claims audience - AccountId - should have a record.
if claims.Subject != userID { if claims.Subject != userID {
query := sqlbuilder.NewSelectBuilder().Select("id").From(usersAccountsTableName) query := sqlbuilder.NewSelectBuilder().Select("id").From(usersAccountsTableName)
query.Where(query.And( query.Where(query.Or(
query.Equal("account_id", claims.Audience), query.Equal("account_id", claims.Audience),
query.Equal("user_id", userID), query.Equal("user_id", userID),
)) ))
sql, args := query.Build() queryStr, args := query.Build()
sql = dbConn.Rebind(sql) queryStr = dbConn.Rebind(queryStr)
var userAccountId string var userAccountId string
err := dbConn.QueryRowContext(ctx, sql, args...).Scan(&userAccountId) err := dbConn.QueryRowContext(ctx, queryStr, args...).Scan(&userAccountId)
if err != nil { if err != nil && err != sql.ErrNoRows {
err = errors.Wrapf(err, "query - %s", query.String()) err = errors.Wrapf(err, "query - %s", query.String())
return err return err
} }
@ -130,7 +130,7 @@ func applyClaimsUserSelect(ctx context.Context, claims auth.Claims, query *sqlbu
if claims.Subject != "" { if claims.Subject != "" {
or = append(or, subQuery.Equal("user_id", claims.Subject)) or = append(or, subQuery.Equal("user_id", claims.Subject))
} }
subQuery.Where(or...) subQuery.Where(subQuery.Or(or...))
// Append sub query // Append sub query
query.Where(query.In("id", subQuery)) query.Where(query.In("id", subQuery))
@ -147,10 +147,12 @@ func selectQuery() *sqlbuilder.SelectBuilder {
} }
// userFindRequestQuery generates the select query for the given find request. // userFindRequestQuery generates the select query for the given find request.
func userFindRequestQuery(req UserFindRequest) *sqlbuilder.SelectBuilder { // TODO: Need to figure out why can't parse the args when appending the where
// to the query.
func userFindRequestQuery(req UserFindRequest) (*sqlbuilder.SelectBuilder, []interface{}) {
query := selectQuery() query := selectQuery()
if req.Where != nil { if req.Where != nil {
query.Where(*req.Where) query.Where(query.And(*req.Where))
} }
if len(req.Order) > 0 { if len(req.Order) > 0 {
query.OrderBy(req.Order...) query.OrderBy(req.Order...)
@ -162,90 +164,17 @@ func userFindRequestQuery(req UserFindRequest) *sqlbuilder.SelectBuilder {
query.Offset(int(*req.Offset)) query.Offset(int(*req.Offset))
} }
b := sqlbuilder.Buildf(query.String(), req.Args...) return query, req.Args
query.BuilderAs(b, usersTableName)
return query
} }
// List enables streaming retrieval of Users from the database. The query results // Find gets all the users from the database based on the request params.
// will be written to the interface{} resultReceiver channel enabling processing the results while
// they're still being fetched. After all pages have been processed the channel is closed
// Possible types sent to the channel are limited to:
// - error
// - User
//
// rr := make(chan interface{})
//
// go List(ctx, claims, db, rr)
//
// for r := range rr {
// switch v := r.(type) {
// case User:
// // v is of type User
// // process the user here
// case error:
// // v is of type error
// // handle the error here
// }
// }
func List(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserFindRequest, results chan<- interface{}) {
query := userFindRequestQuery(req)
list(ctx, claims, dbConn, query, req.IncludedArchived, results)
}
// List enables streaming retrieval of Users from the database for the supplied query.
func list(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, query *sqlbuilder.SelectBuilder, includedArchived bool, results chan<- interface{}) {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.List")
defer span.Finish()
// Close the channel on complete
defer close(results)
query.Select(usersMapColumns)
query.From(usersTableName)
if !includedArchived {
query.Where(query.IsNull("archived_at"))
}
// Check to see if a sub query needs to be applied for the claims
err := applyClaimsUserSelect(ctx, claims, query)
if err != nil {
results <- err
return
}
sql, args := query.Build()
sql = dbConn.Rebind(sql)
// fetch all places from the db
rows, err := dbConn.QueryContext(ctx, sql, args...)
if err != nil {
err = errors.Wrapf(err, "query - %s", query.String())
results <- errors.WithMessage(err, "list users failed")
return
}
// iterate over each row
for rows.Next() {
u, err := mapRowsToUser(rows)
if err != nil {
err = errors.Wrapf(err, "query - %s", query.String())
results <- err
return
}
results <- u
}
}
// Find gets all the users from the database based on the request params
func Find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserFindRequest) ([]*User, error) { func Find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserFindRequest) ([]*User, error) {
query := userFindRequestQuery(req) query, args := userFindRequestQuery(req)
return find(ctx, claims, dbConn, query, req.IncludedArchived) return find(ctx, claims, dbConn, query, args, req.IncludedArchived)
} }
// find gets all the users from the database based on the query // find internal method for getting all the users from the database using a select query.
func find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, query *sqlbuilder.SelectBuilder, includedArchived bool) ([]*User, error) { func find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, query *sqlbuilder.SelectBuilder, args []interface{}, includedArchived bool) ([]*User, error) {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.Find") span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.Find")
defer span.Finish() defer span.Finish()
@ -261,11 +190,15 @@ func find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, query *sqlbu
if err != nil { if err != nil {
return nil, err return nil, err
} }
sql, args := query.Build() queryStr, queryArgs := query.Build()
sql = dbConn.Rebind(sql) queryStr = dbConn.Rebind(queryStr)
args = append(args, queryArgs...)
fmt.Println(queryStr)
fmt.Println(args)
// fetch all places from the db // fetch all places from the db
rows, err := dbConn.QueryContext(ctx, sql, args...) rows, err := dbConn.QueryContext(ctx, queryStr, args...)
if err != nil { if err != nil {
err = errors.Wrapf(err, "query - %s", query.String()) err = errors.Wrapf(err, "query - %s", query.String())
err = errors.WithMessage(err, "find users failed") err = errors.WithMessage(err, "find users failed")
@ -283,6 +216,8 @@ func find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, query *sqlbu
resp = append(resp, u) resp = append(resp, u)
} }
fmt.Println("len", len(resp))
return resp, nil return resp, nil
} }
@ -295,7 +230,7 @@ func FindById(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, id strin
query := selectQuery() query := selectQuery()
query.Where(query.Equal("id", id)) query.Where(query.Equal("id", id))
res, err := find(ctx, claims, dbConn, query, includedArchived) res, err := find(ctx, claims, dbConn, query, []interface{}{}, includedArchived)
if err != nil { if err != nil {
return nil, err return nil, err
} else if res == nil || len(res) == 0 { } else if res == nil || len(res) == 0 {
@ -309,7 +244,6 @@ func FindById(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, id strin
// Validation an email address is unique excluding the current user ID. // Validation an email address is unique excluding the current user ID.
func uniqueEmail(ctx context.Context, dbConn *sqlx.DB, email, userId string) (bool, error) { func uniqueEmail(ctx context.Context, dbConn *sqlx.DB, email, userId string) (bool, error) {
query := sqlbuilder.NewSelectBuilder().Select("id").From(usersTableName) query := sqlbuilder.NewSelectBuilder().Select("id").From(usersTableName)
query.Where(query.And( query.Where(query.And(
query.Equal("email", email), query.Equal("email", email),

View File

@ -71,7 +71,9 @@ func accountSelectQuery() *sqlbuilder.SelectBuilder {
} }
// userFindRequestQuery generates the select query for the given find request. // userFindRequestQuery generates the select query for the given find request.
func accountFindRequestQuery(req UserAccountFindRequest) *sqlbuilder.SelectBuilder { // TODO: Need to figure out why can't parse the args when appending the where
// to the query.
func accountFindRequestQuery(req UserAccountFindRequest) (*sqlbuilder.SelectBuilder, []interface{}) {
query := accountSelectQuery() query := accountSelectQuery()
if req.Where != nil { if req.Where != nil {
query.Where(*req.Where) query.Where(*req.Where)
@ -89,17 +91,17 @@ func accountFindRequestQuery(req UserAccountFindRequest) *sqlbuilder.SelectBuild
b := sqlbuilder.Buildf(query.String(), req.Args...) b := sqlbuilder.Buildf(query.String(), req.Args...)
query.BuilderAs(b, usersAccountsMapColumns) query.BuilderAs(b, usersAccountsMapColumns)
return query return query, req.Args
} }
// Find gets all the users from the database based on the request params // Find gets all the users from the database based on the request params
func FindAccounts(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserAccountFindRequest) ([]*UserAccount, error) { func FindAccounts(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserAccountFindRequest) ([]*UserAccount, error) {
query := accountFindRequestQuery(req) query, args := accountFindRequestQuery(req)
return findAccounts(ctx, claims, dbConn, query, req.IncludedArchived) return findAccounts(ctx, claims, dbConn, query, args, req.IncludedArchived)
} }
// Find gets all the users from the database based on the select query // Find gets all the users from the database based on the select query
func findAccounts(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, query *sqlbuilder.SelectBuilder, includedArchived bool) ([]*UserAccount, error) { func findAccounts(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, query *sqlbuilder.SelectBuilder, args []interface{}, includedArchived bool) ([]*UserAccount, error) {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.FindAccounts") span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.FindAccounts")
defer span.Finish() defer span.Finish()
@ -115,11 +117,12 @@ func findAccounts(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, quer
if err != nil { if err != nil {
return nil, err return nil, err
} }
sql, args := query.Build() queryStr, queryArgs := query.Build()
sql = dbConn.Rebind(sql) queryStr = dbConn.Rebind(queryStr)
args = append(args, queryArgs...)
// fetch all places from the db // fetch all places from the db
rows, err := dbConn.QueryContext(ctx, sql, args...) rows, err := dbConn.QueryContext(ctx, queryStr, args...)
if err != nil { if err != nil {
err = errors.Wrapf(err, "query - %s", query.String()) err = errors.Wrapf(err, "query - %s", query.String())
err = errors.WithMessage(err, "find accounts failed") err = errors.WithMessage(err, "find accounts failed")
@ -151,7 +154,7 @@ func FindAccountsByUserID(ctx context.Context, claims auth.Claims, dbConn *sqlx.
query.OrderBy("id") query.OrderBy("id")
// Execute the find accounts method. // Execute the find accounts method.
res, err := findAccounts(ctx, claims, dbConn, query, includedArchived) res, err := findAccounts(ctx, claims, dbConn, query, []interface{}{}, includedArchived)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -194,7 +197,7 @@ func AddAccount(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req Ad
existQuery.Equal("account_id", req.AccountID), existQuery.Equal("account_id", req.AccountID),
existQuery.Equal("user_id", req.UserID), existQuery.Equal("user_id", req.UserID),
)) ))
existing, err := findAccounts(ctx, claims, dbConn, existQuery, true) existing, err := findAccounts(ctx, claims, dbConn, existQuery, []interface{}{}, true)
if err != nil { if err != nil {
return err return err
} }
@ -217,7 +220,7 @@ func AddAccount(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req Ad
query := sqlbuilder.NewInsertBuilder() query := sqlbuilder.NewInsertBuilder()
query.InsertInto(usersAccountsTableName) query.InsertInto(usersAccountsTableName)
query.Cols("id", "user_id", "account_id", "roles", "created_at", "updated_at") query.Cols("id", "user_id", "account_id", "roles", "created_at", "updated_at")
query.Values(1, id, req.UserID, req.AccountID, req.Roles, now, now) query.Values(id, req.UserID, req.AccountID, req.Roles, now, now)
// Execute the query with the provided context. // Execute the query with the provided context.
sql, args := query.Build() sql, args := query.Build()

File diff suppressed because it is too large Load Diff