mirror of
https://github.com/raseels-repos/golang-saas-starter-kit.git
synced 2025-06-08 23:56:37 +02:00
finish internal/user user and auth unittests
This commit is contained in:
parent
895128bbbe
commit
c121b7d289
@ -6,6 +6,7 @@ require (
|
|||||||
github.com/aws/aws-sdk-go v1.19.33
|
github.com/aws/aws-sdk-go v1.19.33
|
||||||
github.com/dgrijalva/jwt-go v3.2.0+incompatible
|
github.com/dgrijalva/jwt-go v3.2.0+incompatible
|
||||||
github.com/dimfeld/httptreemux v5.0.1+incompatible
|
github.com/dimfeld/httptreemux v5.0.1+incompatible
|
||||||
|
github.com/geeks-accelerator/sqlxmigrate v0.0.0-20190527223850-4a863a2d30db
|
||||||
github.com/gitwak/gondolier v0.0.0-20190521205431-504d297a6c42 // indirect
|
github.com/gitwak/gondolier v0.0.0-20190521205431-504d297a6c42 // indirect
|
||||||
github.com/gitwak/sqlxmigrate v0.0.0-20190527063335-e98d5d44fc0b
|
github.com/gitwak/sqlxmigrate v0.0.0-20190527063335-e98d5d44fc0b
|
||||||
github.com/go-playground/locales v0.12.1
|
github.com/go-playground/locales v0.12.1
|
||||||
|
@ -13,6 +13,8 @@ github.com/dgrijalva/jwt-go v3.2.0+incompatible h1:7qlOGliEKZXTDg6OTjfoBKDXWrumC
|
|||||||
github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ=
|
github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ=
|
||||||
github.com/dimfeld/httptreemux v5.0.1+incompatible h1:Qj3gVcDNoOthBAqftuD596rm4wg/adLLz5xh5CmpiCA=
|
github.com/dimfeld/httptreemux v5.0.1+incompatible h1:Qj3gVcDNoOthBAqftuD596rm4wg/adLLz5xh5CmpiCA=
|
||||||
github.com/dimfeld/httptreemux v5.0.1+incompatible/go.mod h1:rbUlSV+CCpv/SuqUTP/8Bk2O3LyUV436/yaRGkhP6Z0=
|
github.com/dimfeld/httptreemux v5.0.1+incompatible/go.mod h1:rbUlSV+CCpv/SuqUTP/8Bk2O3LyUV436/yaRGkhP6Z0=
|
||||||
|
github.com/geeks-accelerator/sqlxmigrate v0.0.0-20190527223850-4a863a2d30db h1:mjErP7mTFHQ3cw/ibAkW3CvQ8gM4k19EkfzRzRINDAE=
|
||||||
|
github.com/geeks-accelerator/sqlxmigrate v0.0.0-20190527223850-4a863a2d30db/go.mod h1:dzpCjo4q7chhMVuHDzs/odROkieZ5Wjp70rNDuX83jU=
|
||||||
github.com/gitwak/gondolier v0.0.0-20190521205431-504d297a6c42 h1:+lo4HFeG6LlcgwvsvQC8H5FG8yr/kDn89E51BTw3loE=
|
github.com/gitwak/gondolier v0.0.0-20190521205431-504d297a6c42 h1:+lo4HFeG6LlcgwvsvQC8H5FG8yr/kDn89E51BTw3loE=
|
||||||
github.com/gitwak/gondolier v0.0.0-20190521205431-504d297a6c42/go.mod h1:ecEQ8e4eHeWKPf+g6ByatPM7l4QZgR3G5ZIZKvEAdCE=
|
github.com/gitwak/gondolier v0.0.0-20190521205431-504d297a6c42/go.mod h1:ecEQ8e4eHeWKPf+g6ByatPM7l4QZgR3G5ZIZKvEAdCE=
|
||||||
github.com/gitwak/sqlxmigrate v0.0.0-20190522211042-9625063dea5d h1:oaUPMY0F+lNUkyB5tzsQS3EC0m9Cxdglesp63i3UPso=
|
github.com/gitwak/sqlxmigrate v0.0.0-20190522211042-9625063dea5d h1:oaUPMY0F+lNUkyB5tzsQS3EC0m9Cxdglesp63i3UPso=
|
||||||
|
@ -185,7 +185,7 @@ func NewAuthenticator(awsSession *session.Session, awsSecretID string, now time.
|
|||||||
// refreshed on instance launch. Could store keys in a kv store and update that value
|
// refreshed on instance launch. Could store keys in a kv store and update that value
|
||||||
// when new keys are generated
|
// when new keys are generated
|
||||||
if len(keyContents) == 0 || curKeyId == "" {
|
if len(keyContents) == 0 || curKeyId == "" {
|
||||||
privateKey, err := keygen()
|
privateKey, err := Keygen()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "failed to generate new private key")
|
return nil, errors.Wrap(err, "failed to generate new private key")
|
||||||
}
|
}
|
||||||
@ -307,8 +307,8 @@ func (a *Authenticator) ParseClaims(tknStr string) (Claims, error) {
|
|||||||
return claims, nil
|
return claims, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// keygen creates an x509 private key for signing auth tokens.
|
// Keygen creates an x509 private key for signing auth tokens.
|
||||||
func keygen() ([]byte, error) {
|
func Keygen() ([]byte, error) {
|
||||||
key, err := rsa.GenerateKey(rand.Reader, 2048)
|
key, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return []byte{}, errors.Wrap(err, "generating keys")
|
return []byte{}, errors.Wrap(err, "generating keys")
|
||||||
|
@ -57,8 +57,6 @@ func New() *Test {
|
|||||||
|
|
||||||
dbHost := fmt.Sprintf("postgres://%s:%s@127.0.0.1:%s/%s?timezone=UTC&sslmode=disable", container.User, container.Pass, container.Port, container.Database)
|
dbHost := fmt.Sprintf("postgres://%s:%s@127.0.0.1:%s/%s?timezone=UTC&sslmode=disable", container.User, container.Pass, container.Port, container.Database)
|
||||||
|
|
||||||
fmt.Println(dbHost)
|
|
||||||
|
|
||||||
// ============================================================
|
// ============================================================
|
||||||
// Start Postgres
|
// Start Postgres
|
||||||
|
|
||||||
|
@ -4,7 +4,7 @@ import (
|
|||||||
"database/sql"
|
"database/sql"
|
||||||
"log"
|
"log"
|
||||||
|
|
||||||
"github.com/gitwak/sqlxmigrate"
|
"github.com/geeks-accelerator/sqlxmigrate"
|
||||||
"github.com/jmoiron/sqlx"
|
"github.com/jmoiron/sqlx"
|
||||||
_ "github.com/lib/pq"
|
_ "github.com/lib/pq"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
@ -106,7 +106,7 @@ func migrationList(db *sqlx.DB, log *log.Logger) []*sqlxmigrate.Migration {
|
|||||||
{
|
{
|
||||||
ID: "20190522-01c",
|
ID: "20190522-01c",
|
||||||
Migrate: func(tx *sql.Tx) error {
|
Migrate: func(tx *sql.Tx) error {
|
||||||
q1 := `CREATE TYPE user_account_role_t as enum('admin', 'user')`
|
q1 := `CREATE TYPE user_account_role_t as enum('ADMIN', 'USER')`
|
||||||
if _, err := tx.Exec(q1); err != nil {
|
if _, err := tx.Exec(q1); err != nil {
|
||||||
return errors.WithMessagef(err, "Query failed %s", q1)
|
return errors.WithMessagef(err, "Query failed %s", q1)
|
||||||
}
|
}
|
||||||
|
@ -3,7 +3,7 @@ package schema
|
|||||||
import (
|
import (
|
||||||
"log"
|
"log"
|
||||||
|
|
||||||
"github.com/gitwak/sqlxmigrate"
|
"github.com/geeks-accelerator/sqlxmigrate"
|
||||||
"github.com/jmoiron/sqlx"
|
"github.com/jmoiron/sqlx"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -29,7 +29,7 @@ func Authenticate(ctx context.Context, dbConn *sqlx.DB, tknGen TokenGenerator, n
|
|||||||
query.Where(query.Equal("email", email))
|
query.Where(query.Equal("email", email))
|
||||||
|
|
||||||
// Run the find, use empty claims to bypass ACLs
|
// Run the find, use empty claims to bypass ACLs
|
||||||
res, err := find(ctx, auth.Claims{}, dbConn, query, false)
|
res, err := find(ctx, auth.Claims{}, dbConn, query, []interface{}{}, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return Token{}, err
|
return Token{}, err
|
||||||
} else if res == nil || len(res) == 0 {
|
} else if res == nil || len(res) == 0 {
|
||||||
@ -39,7 +39,7 @@ func Authenticate(ctx context.Context, dbConn *sqlx.DB, tknGen TokenGenerator, n
|
|||||||
u := res[0]
|
u := res[0]
|
||||||
|
|
||||||
// Append the salt from the user record to the supplied password.
|
// Append the salt from the user record to the supplied password.
|
||||||
saltedPassword := password + string(u.PasswordHash)
|
saltedPassword := password + string(u.PasswordSalt)
|
||||||
|
|
||||||
// Compare the provided password with the saved hash. Use the bcrypt
|
// Compare the provided password with the saved hash. Use the bcrypt
|
||||||
// comparison function so it is cryptographically secure.
|
// comparison function so it is cryptographically secure.
|
||||||
@ -61,15 +61,17 @@ func Authenticate(ctx context.Context, dbConn *sqlx.DB, tknGen TokenGenerator, n
|
|||||||
roles []string
|
roles []string
|
||||||
)
|
)
|
||||||
if len(accounts) > 0 {
|
if len(accounts) > 0 {
|
||||||
accountId = accounts[0].ID
|
accountId = accounts[0].AccountID
|
||||||
roles = accounts[0].Roles
|
for _, r := range accounts[0].Roles {
|
||||||
|
roles = append(roles, r.String())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Generate a list of all the account IDs associated with the user so
|
// Generate a list of all the account IDs associated with the user so
|
||||||
// the use has the ability to switch between accounts.
|
// the use has the ability to switch between accounts.
|
||||||
accountIds := []string{}
|
accountIds := []string{}
|
||||||
for _, a := range accounts {
|
for _, a := range accounts {
|
||||||
accountIds = append(accountIds, a.ID)
|
accountIds = append(accountIds, a.AccountID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// If we are this far the request is valid. Create some claims for the user.
|
// If we are this far the request is valid. Create some claims for the user.
|
||||||
@ -81,5 +83,5 @@ func Authenticate(ctx context.Context, dbConn *sqlx.DB, tknGen TokenGenerator, n
|
|||||||
return Token{}, errors.Wrap(err, "generating token")
|
return Token{}, errors.Wrap(err, "generating token")
|
||||||
}
|
}
|
||||||
|
|
||||||
return Token{Token: tkn}, nil
|
return Token{Token: tkn, claims: claims}, nil
|
||||||
}
|
}
|
||||||
|
@ -5,6 +5,7 @@ import (
|
|||||||
"database/sql/driver"
|
"database/sql/driver"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/auth"
|
||||||
"github.com/lib/pq"
|
"github.com/lib/pq"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
"gopkg.in/go-playground/validator.v9"
|
"gopkg.in/go-playground/validator.v9"
|
||||||
@ -34,8 +35,8 @@ type CreateUserRequest struct {
|
|||||||
Email string `json:"email" validate:"required,email,unique"`
|
Email string `json:"email" validate:"required,email,unique"`
|
||||||
Password string `json:"password" validate:"required"`
|
Password string `json:"password" validate:"required"`
|
||||||
PasswordConfirm string `json:"password_confirm" validate:"eqfield=Password"`
|
PasswordConfirm string `json:"password_confirm" validate:"eqfield=Password"`
|
||||||
Status *UserStatus `json:"status" validate:"oneof=active disabled"`
|
Status *UserStatus `json:"status" validate:"omitempty,oneof=active disabled"`
|
||||||
Timezone *string `json:"timezone"`
|
Timezone *string `json:"timezone" validate:"omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateUserRequest defines what information may be provided to modify an existing
|
// UpdateUserRequest defines what information may be provided to modify an existing
|
||||||
@ -46,10 +47,10 @@ type CreateUserRequest struct {
|
|||||||
// marshalling/unmarshalling.
|
// marshalling/unmarshalling.
|
||||||
type UpdateUserRequest struct {
|
type UpdateUserRequest struct {
|
||||||
ID string `validate:"required,uuid"`
|
ID string `validate:"required,uuid"`
|
||||||
Name *string `json:"name"`
|
Name *string `json:"name" validate:"omitempty"`
|
||||||
Email *string `json:"email" validate:"email,unique"`
|
Email *string `json:"email" validate:"omitempty,email,unique"`
|
||||||
Status *UserStatus `json:"status" validate:"oneof=active disabled"`
|
Status *UserStatus `json:"status" validate:"omitempty,oneof=active disabled"`
|
||||||
Timezone *string `json:"timezone"`
|
Timezone *string `json:"timezone" validate:"omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdatePassword defines what information may be provided to update user password.
|
// UpdatePassword defines what information may be provided to update user password.
|
||||||
@ -76,7 +77,7 @@ type UserAccount struct {
|
|||||||
ID string `db:"id" json:"id"`
|
ID string `db:"id" json:"id"`
|
||||||
UserID string `db:"user_id" json:"user_id"`
|
UserID string `db:"user_id" json:"user_id"`
|
||||||
AccountID string `db:"account_id" json:"account_id"`
|
AccountID string `db:"account_id" json:"account_id"`
|
||||||
Roles []string `db:"roles" json:"roles"`
|
Roles UserAccountRoles `db:"roles" json:"roles"`
|
||||||
CreatedAt time.Time `db:"created_at" json:"created_at"`
|
CreatedAt time.Time `db:"created_at" json:"created_at"`
|
||||||
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
|
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
|
||||||
ArchivedAt pq.NullTime `db:"archived_at" json:"archived_at"`
|
ArchivedAt pq.NullTime `db:"archived_at" json:"archived_at"`
|
||||||
@ -86,7 +87,7 @@ type UserAccount struct {
|
|||||||
type AddAccountRequest struct {
|
type AddAccountRequest struct {
|
||||||
UserID string `validate:"required,uuid"`
|
UserID string `validate:"required,uuid"`
|
||||||
AccountID string `validate:"required,uuid"`
|
AccountID string `validate:"required,uuid"`
|
||||||
Roles []string `json:"roles" validate:"oneof=admin user"`
|
Roles UserAccountRoles `json:"roles" validate:"required,dive,oneof=ADMIN USER"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateAccountRequest defines the information needed to update the roles for
|
// UpdateAccountRequest defines the information needed to update the roles for
|
||||||
@ -94,7 +95,7 @@ type AddAccountRequest struct {
|
|||||||
type UpdateAccountRequest struct {
|
type UpdateAccountRequest struct {
|
||||||
UserID string `validate:"required,uuid"`
|
UserID string `validate:"required,uuid"`
|
||||||
AccountID string `validate:"required,uuid"`
|
AccountID string `validate:"required,uuid"`
|
||||||
Roles []string `json:"roles" validate:"oneof=admin user"`
|
Roles UserAccountRoles `json:"roles" validate:"oneof=ADMIN USER"`
|
||||||
unArchive bool
|
unArchive bool
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -122,6 +123,9 @@ type UserAccountFindRequest struct {
|
|||||||
IncludedArchived bool
|
IncludedArchived bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// UserStatus represents the status of a user.
|
||||||
|
type UserStatus string
|
||||||
|
|
||||||
// UserStatus values
|
// UserStatus values
|
||||||
const (
|
const (
|
||||||
UserStatus_Active UserStatus = "active"
|
UserStatus_Active UserStatus = "active"
|
||||||
@ -134,9 +138,6 @@ var UserStatus_Values = []UserStatus{
|
|||||||
UserStatus_Disabled,
|
UserStatus_Disabled,
|
||||||
}
|
}
|
||||||
|
|
||||||
// UserStatus represents the status of a user.
|
|
||||||
type UserStatus string
|
|
||||||
|
|
||||||
// Scan supports reading the UserStatus value from the database.
|
// Scan supports reading the UserStatus value from the database.
|
||||||
func (s *UserStatus) Scan(value interface{}) error {
|
func (s *UserStatus) Scan(value interface{}) error {
|
||||||
asBytes, ok := value.([]byte)
|
asBytes, ok := value.([]byte)
|
||||||
@ -156,7 +157,6 @@ func (s UserStatus) Value() (driver.Value, error) {
|
|||||||
return nil, errs
|
return nil, errs
|
||||||
}
|
}
|
||||||
|
|
||||||
// validation would go here
|
|
||||||
return string(s), nil
|
return string(s), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -165,7 +165,61 @@ func (s UserStatus) String() string {
|
|||||||
return string(s)
|
return string(s)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// UserAccountRole represents the role of a user for an account.
|
||||||
|
type UserAccountRole string
|
||||||
|
|
||||||
|
// UserAccountRole values
|
||||||
|
const (
|
||||||
|
UserAccountRole_Admin UserAccountRole = auth.RoleAdmin
|
||||||
|
UserAccountRole_User UserAccountRole = auth.RoleUser
|
||||||
|
)
|
||||||
|
|
||||||
|
// UserAccountRole_Values provides list of valid UserAccountRole values
|
||||||
|
var UserAccountRole_Values = []UserAccountRole{
|
||||||
|
UserAccountRole_Admin,
|
||||||
|
UserAccountRole_User,
|
||||||
|
}
|
||||||
|
|
||||||
|
// String converts the UserAccountRole value to a string.
|
||||||
|
func (s UserAccountRole) String() string {
|
||||||
|
return string(s)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UserAccountRoles represents a set of roles for a user for an account.
|
||||||
|
type UserAccountRoles []UserAccountRole
|
||||||
|
|
||||||
|
// Scan supports reading the UserAccountRole value from the database.
|
||||||
|
func (s *UserAccountRoles) Scan(value interface{}) error {
|
||||||
|
arr := &pq.StringArray{}
|
||||||
|
if err := arr.Scan(value); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, v := range *arr {
|
||||||
|
*s = append(*s, UserAccountRole(v))
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Value converts the UserAccountRole value to be stored in the database.
|
||||||
|
func (s UserAccountRoles) Value() (driver.Value, error) {
|
||||||
|
v := validator.New()
|
||||||
|
|
||||||
|
var arr pq.StringArray
|
||||||
|
for _, r := range s {
|
||||||
|
errs := v.Var(r, "required,oneof=ADMIN USER")
|
||||||
|
if errs != nil {
|
||||||
|
return nil, errs
|
||||||
|
}
|
||||||
|
arr = append(arr, r.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
return arr.Value()
|
||||||
|
}
|
||||||
|
|
||||||
// Token is the payload we deliver to users when they authenticate.
|
// Token is the payload we deliver to users when they authenticate.
|
||||||
type Token struct {
|
type Token struct {
|
||||||
Token string `json:"token"`
|
Token string `json:"token"`
|
||||||
|
claims auth.Claims `json:"-"`
|
||||||
}
|
}
|
||||||
|
@ -3,6 +3,7 @@ package user
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/auth"
|
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/auth"
|
||||||
@ -43,7 +44,7 @@ func mapRowsToUser(rows *sql.Rows) (*User, error) {
|
|||||||
u User
|
u User
|
||||||
err error
|
err error
|
||||||
)
|
)
|
||||||
err = rows.Scan(&u.ID, &u.Email, &u.PasswordSalt, &u.PasswordHash, &u.PasswordReset, &u.Status, &u.Timezone, &u.CreatedAt, &u.UpdatedAt, &u.ArchivedAt)
|
err = rows.Scan(&u.ID, &u.Name, &u.Email, &u.PasswordSalt, &u.PasswordHash, &u.PasswordReset, &u.Status, &u.Timezone, &u.CreatedAt, &u.UpdatedAt, &u.ArchivedAt)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.WithStack(err)
|
return nil, errors.WithStack(err)
|
||||||
}
|
}
|
||||||
@ -59,18 +60,17 @@ func CanReadUserId(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, use
|
|||||||
// When the claims Subject - UserId - does not match the requested user, the
|
// When the claims Subject - UserId - does not match the requested user, the
|
||||||
// claims audience - AccountId - should have a record.
|
// claims audience - AccountId - should have a record.
|
||||||
if claims.Subject != userID {
|
if claims.Subject != userID {
|
||||||
|
|
||||||
query := sqlbuilder.NewSelectBuilder().Select("id").From(usersAccountsTableName)
|
query := sqlbuilder.NewSelectBuilder().Select("id").From(usersAccountsTableName)
|
||||||
query.Where(query.And(
|
query.Where(query.Or(
|
||||||
query.Equal("account_id", claims.Audience),
|
query.Equal("account_id", claims.Audience),
|
||||||
query.Equal("user_id", userID),
|
query.Equal("user_id", userID),
|
||||||
))
|
))
|
||||||
sql, args := query.Build()
|
queryStr, args := query.Build()
|
||||||
sql = dbConn.Rebind(sql)
|
queryStr = dbConn.Rebind(queryStr)
|
||||||
|
|
||||||
var userAccountId string
|
var userAccountId string
|
||||||
err := dbConn.QueryRowContext(ctx, sql, args...).Scan(&userAccountId)
|
err := dbConn.QueryRowContext(ctx, queryStr, args...).Scan(&userAccountId)
|
||||||
if err != nil {
|
if err != nil && err != sql.ErrNoRows {
|
||||||
err = errors.Wrapf(err, "query - %s", query.String())
|
err = errors.Wrapf(err, "query - %s", query.String())
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -130,7 +130,7 @@ func applyClaimsUserSelect(ctx context.Context, claims auth.Claims, query *sqlbu
|
|||||||
if claims.Subject != "" {
|
if claims.Subject != "" {
|
||||||
or = append(or, subQuery.Equal("user_id", claims.Subject))
|
or = append(or, subQuery.Equal("user_id", claims.Subject))
|
||||||
}
|
}
|
||||||
subQuery.Where(or...)
|
subQuery.Where(subQuery.Or(or...))
|
||||||
|
|
||||||
// Append sub query
|
// Append sub query
|
||||||
query.Where(query.In("id", subQuery))
|
query.Where(query.In("id", subQuery))
|
||||||
@ -147,10 +147,12 @@ func selectQuery() *sqlbuilder.SelectBuilder {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// userFindRequestQuery generates the select query for the given find request.
|
// userFindRequestQuery generates the select query for the given find request.
|
||||||
func userFindRequestQuery(req UserFindRequest) *sqlbuilder.SelectBuilder {
|
// TODO: Need to figure out why can't parse the args when appending the where
|
||||||
|
// to the query.
|
||||||
|
func userFindRequestQuery(req UserFindRequest) (*sqlbuilder.SelectBuilder, []interface{}) {
|
||||||
query := selectQuery()
|
query := selectQuery()
|
||||||
if req.Where != nil {
|
if req.Where != nil {
|
||||||
query.Where(*req.Where)
|
query.Where(query.And(*req.Where))
|
||||||
}
|
}
|
||||||
if len(req.Order) > 0 {
|
if len(req.Order) > 0 {
|
||||||
query.OrderBy(req.Order...)
|
query.OrderBy(req.Order...)
|
||||||
@ -162,90 +164,17 @@ func userFindRequestQuery(req UserFindRequest) *sqlbuilder.SelectBuilder {
|
|||||||
query.Offset(int(*req.Offset))
|
query.Offset(int(*req.Offset))
|
||||||
}
|
}
|
||||||
|
|
||||||
b := sqlbuilder.Buildf(query.String(), req.Args...)
|
return query, req.Args
|
||||||
query.BuilderAs(b, usersTableName)
|
|
||||||
|
|
||||||
return query
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// List enables streaming retrieval of Users from the database. The query results
|
// Find gets all the users from the database based on the request params.
|
||||||
// will be written to the interface{} resultReceiver channel enabling processing the results while
|
|
||||||
// they're still being fetched. After all pages have been processed the channel is closed
|
|
||||||
// Possible types sent to the channel are limited to:
|
|
||||||
// - error
|
|
||||||
// - User
|
|
||||||
//
|
|
||||||
// rr := make(chan interface{})
|
|
||||||
//
|
|
||||||
// go List(ctx, claims, db, rr)
|
|
||||||
//
|
|
||||||
// for r := range rr {
|
|
||||||
// switch v := r.(type) {
|
|
||||||
// case User:
|
|
||||||
// // v is of type User
|
|
||||||
// // process the user here
|
|
||||||
// case error:
|
|
||||||
// // v is of type error
|
|
||||||
// // handle the error here
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
func List(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserFindRequest, results chan<- interface{}) {
|
|
||||||
query := userFindRequestQuery(req)
|
|
||||||
list(ctx, claims, dbConn, query, req.IncludedArchived, results)
|
|
||||||
}
|
|
||||||
|
|
||||||
// List enables streaming retrieval of Users from the database for the supplied query.
|
|
||||||
func list(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, query *sqlbuilder.SelectBuilder, includedArchived bool, results chan<- interface{}) {
|
|
||||||
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.List")
|
|
||||||
defer span.Finish()
|
|
||||||
|
|
||||||
// Close the channel on complete
|
|
||||||
defer close(results)
|
|
||||||
|
|
||||||
query.Select(usersMapColumns)
|
|
||||||
query.From(usersTableName)
|
|
||||||
|
|
||||||
if !includedArchived {
|
|
||||||
query.Where(query.IsNull("archived_at"))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check to see if a sub query needs to be applied for the claims
|
|
||||||
err := applyClaimsUserSelect(ctx, claims, query)
|
|
||||||
if err != nil {
|
|
||||||
results <- err
|
|
||||||
return
|
|
||||||
}
|
|
||||||
sql, args := query.Build()
|
|
||||||
sql = dbConn.Rebind(sql)
|
|
||||||
|
|
||||||
// fetch all places from the db
|
|
||||||
rows, err := dbConn.QueryContext(ctx, sql, args...)
|
|
||||||
if err != nil {
|
|
||||||
err = errors.Wrapf(err, "query - %s", query.String())
|
|
||||||
results <- errors.WithMessage(err, "list users failed")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// iterate over each row
|
|
||||||
for rows.Next() {
|
|
||||||
u, err := mapRowsToUser(rows)
|
|
||||||
if err != nil {
|
|
||||||
err = errors.Wrapf(err, "query - %s", query.String())
|
|
||||||
results <- err
|
|
||||||
return
|
|
||||||
}
|
|
||||||
results <- u
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Find gets all the users from the database based on the request params
|
|
||||||
func Find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserFindRequest) ([]*User, error) {
|
func Find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserFindRequest) ([]*User, error) {
|
||||||
query := userFindRequestQuery(req)
|
query, args := userFindRequestQuery(req)
|
||||||
return find(ctx, claims, dbConn, query, req.IncludedArchived)
|
return find(ctx, claims, dbConn, query, args, req.IncludedArchived)
|
||||||
}
|
}
|
||||||
|
|
||||||
// find gets all the users from the database based on the query
|
// find internal method for getting all the users from the database using a select query.
|
||||||
func find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, query *sqlbuilder.SelectBuilder, includedArchived bool) ([]*User, error) {
|
func find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, query *sqlbuilder.SelectBuilder, args []interface{}, includedArchived bool) ([]*User, error) {
|
||||||
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.Find")
|
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.Find")
|
||||||
defer span.Finish()
|
defer span.Finish()
|
||||||
|
|
||||||
@ -261,11 +190,15 @@ func find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, query *sqlbu
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
sql, args := query.Build()
|
queryStr, queryArgs := query.Build()
|
||||||
sql = dbConn.Rebind(sql)
|
queryStr = dbConn.Rebind(queryStr)
|
||||||
|
args = append(args, queryArgs...)
|
||||||
|
|
||||||
|
fmt.Println(queryStr)
|
||||||
|
fmt.Println(args)
|
||||||
|
|
||||||
// fetch all places from the db
|
// fetch all places from the db
|
||||||
rows, err := dbConn.QueryContext(ctx, sql, args...)
|
rows, err := dbConn.QueryContext(ctx, queryStr, args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
err = errors.Wrapf(err, "query - %s", query.String())
|
err = errors.Wrapf(err, "query - %s", query.String())
|
||||||
err = errors.WithMessage(err, "find users failed")
|
err = errors.WithMessage(err, "find users failed")
|
||||||
@ -283,6 +216,8 @@ func find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, query *sqlbu
|
|||||||
resp = append(resp, u)
|
resp = append(resp, u)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fmt.Println("len", len(resp))
|
||||||
|
|
||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -295,7 +230,7 @@ func FindById(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, id strin
|
|||||||
query := selectQuery()
|
query := selectQuery()
|
||||||
query.Where(query.Equal("id", id))
|
query.Where(query.Equal("id", id))
|
||||||
|
|
||||||
res, err := find(ctx, claims, dbConn, query, includedArchived)
|
res, err := find(ctx, claims, dbConn, query, []interface{}{}, includedArchived)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
} else if res == nil || len(res) == 0 {
|
} else if res == nil || len(res) == 0 {
|
||||||
@ -309,7 +244,6 @@ func FindById(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, id strin
|
|||||||
|
|
||||||
// Validation an email address is unique excluding the current user ID.
|
// Validation an email address is unique excluding the current user ID.
|
||||||
func uniqueEmail(ctx context.Context, dbConn *sqlx.DB, email, userId string) (bool, error) {
|
func uniqueEmail(ctx context.Context, dbConn *sqlx.DB, email, userId string) (bool, error) {
|
||||||
|
|
||||||
query := sqlbuilder.NewSelectBuilder().Select("id").From(usersTableName)
|
query := sqlbuilder.NewSelectBuilder().Select("id").From(usersTableName)
|
||||||
query.Where(query.And(
|
query.Where(query.And(
|
||||||
query.Equal("email", email),
|
query.Equal("email", email),
|
||||||
|
@ -71,7 +71,9 @@ func accountSelectQuery() *sqlbuilder.SelectBuilder {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// userFindRequestQuery generates the select query for the given find request.
|
// userFindRequestQuery generates the select query for the given find request.
|
||||||
func accountFindRequestQuery(req UserAccountFindRequest) *sqlbuilder.SelectBuilder {
|
// TODO: Need to figure out why can't parse the args when appending the where
|
||||||
|
// to the query.
|
||||||
|
func accountFindRequestQuery(req UserAccountFindRequest) (*sqlbuilder.SelectBuilder, []interface{}) {
|
||||||
query := accountSelectQuery()
|
query := accountSelectQuery()
|
||||||
if req.Where != nil {
|
if req.Where != nil {
|
||||||
query.Where(*req.Where)
|
query.Where(*req.Where)
|
||||||
@ -89,17 +91,17 @@ func accountFindRequestQuery(req UserAccountFindRequest) *sqlbuilder.SelectBuild
|
|||||||
b := sqlbuilder.Buildf(query.String(), req.Args...)
|
b := sqlbuilder.Buildf(query.String(), req.Args...)
|
||||||
query.BuilderAs(b, usersAccountsMapColumns)
|
query.BuilderAs(b, usersAccountsMapColumns)
|
||||||
|
|
||||||
return query
|
return query, req.Args
|
||||||
}
|
}
|
||||||
|
|
||||||
// Find gets all the users from the database based on the request params
|
// Find gets all the users from the database based on the request params
|
||||||
func FindAccounts(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserAccountFindRequest) ([]*UserAccount, error) {
|
func FindAccounts(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserAccountFindRequest) ([]*UserAccount, error) {
|
||||||
query := accountFindRequestQuery(req)
|
query, args := accountFindRequestQuery(req)
|
||||||
return findAccounts(ctx, claims, dbConn, query, req.IncludedArchived)
|
return findAccounts(ctx, claims, dbConn, query, args, req.IncludedArchived)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Find gets all the users from the database based on the select query
|
// Find gets all the users from the database based on the select query
|
||||||
func findAccounts(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, query *sqlbuilder.SelectBuilder, includedArchived bool) ([]*UserAccount, error) {
|
func findAccounts(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, query *sqlbuilder.SelectBuilder, args []interface{}, includedArchived bool) ([]*UserAccount, error) {
|
||||||
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.FindAccounts")
|
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.FindAccounts")
|
||||||
defer span.Finish()
|
defer span.Finish()
|
||||||
|
|
||||||
@ -115,11 +117,12 @@ func findAccounts(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, quer
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
sql, args := query.Build()
|
queryStr, queryArgs := query.Build()
|
||||||
sql = dbConn.Rebind(sql)
|
queryStr = dbConn.Rebind(queryStr)
|
||||||
|
args = append(args, queryArgs...)
|
||||||
|
|
||||||
// fetch all places from the db
|
// fetch all places from the db
|
||||||
rows, err := dbConn.QueryContext(ctx, sql, args...)
|
rows, err := dbConn.QueryContext(ctx, queryStr, args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
err = errors.Wrapf(err, "query - %s", query.String())
|
err = errors.Wrapf(err, "query - %s", query.String())
|
||||||
err = errors.WithMessage(err, "find accounts failed")
|
err = errors.WithMessage(err, "find accounts failed")
|
||||||
@ -151,7 +154,7 @@ func FindAccountsByUserID(ctx context.Context, claims auth.Claims, dbConn *sqlx.
|
|||||||
query.OrderBy("id")
|
query.OrderBy("id")
|
||||||
|
|
||||||
// Execute the find accounts method.
|
// Execute the find accounts method.
|
||||||
res, err := findAccounts(ctx, claims, dbConn, query, includedArchived)
|
res, err := findAccounts(ctx, claims, dbConn, query, []interface{}{}, includedArchived)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -194,7 +197,7 @@ func AddAccount(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req Ad
|
|||||||
existQuery.Equal("account_id", req.AccountID),
|
existQuery.Equal("account_id", req.AccountID),
|
||||||
existQuery.Equal("user_id", req.UserID),
|
existQuery.Equal("user_id", req.UserID),
|
||||||
))
|
))
|
||||||
existing, err := findAccounts(ctx, claims, dbConn, existQuery, true)
|
existing, err := findAccounts(ctx, claims, dbConn, existQuery, []interface{}{}, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -217,7 +220,7 @@ func AddAccount(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req Ad
|
|||||||
query := sqlbuilder.NewInsertBuilder()
|
query := sqlbuilder.NewInsertBuilder()
|
||||||
query.InsertInto(usersAccountsTableName)
|
query.InsertInto(usersAccountsTableName)
|
||||||
query.Cols("id", "user_id", "account_id", "roles", "created_at", "updated_at")
|
query.Cols("id", "user_id", "account_id", "roles", "created_at", "updated_at")
|
||||||
query.Values(1, id, req.UserID, req.AccountID, req.Roles, now, now)
|
query.Values(id, req.UserID, req.AccountID, req.Roles, now, now)
|
||||||
|
|
||||||
// Execute the query with the provided context.
|
// Execute the query with the provided context.
|
||||||
sql, args := query.Build()
|
sql, args := query.Build()
|
||||||
|
File diff suppressed because it is too large
Load Diff
Loading…
x
Reference in New Issue
Block a user