1
0
mirror of https://github.com/raseels-repos/golang-saas-starter-kit.git synced 2025-06-15 00:15:15 +02:00

finish internal/user user and auth unittests

This commit is contained in:
Lee Brown
2019-05-28 04:44:01 -05:00
parent 895128bbbe
commit c121b7d289
11 changed files with 913 additions and 471 deletions

View File

@ -3,6 +3,7 @@ package user
import (
"context"
"database/sql"
"fmt"
"time"
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/auth"
@ -43,7 +44,7 @@ func mapRowsToUser(rows *sql.Rows) (*User, error) {
u User
err error
)
err = rows.Scan(&u.ID, &u.Email, &u.PasswordSalt, &u.PasswordHash, &u.PasswordReset, &u.Status, &u.Timezone, &u.CreatedAt, &u.UpdatedAt, &u.ArchivedAt)
err = rows.Scan(&u.ID, &u.Name, &u.Email, &u.PasswordSalt, &u.PasswordHash, &u.PasswordReset, &u.Status, &u.Timezone, &u.CreatedAt, &u.UpdatedAt, &u.ArchivedAt)
if err != nil {
return nil, errors.WithStack(err)
}
@ -59,18 +60,17 @@ func CanReadUserId(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, use
// When the claims Subject - UserId - does not match the requested user, the
// claims audience - AccountId - should have a record.
if claims.Subject != userID {
query := sqlbuilder.NewSelectBuilder().Select("id").From(usersAccountsTableName)
query.Where(query.And(
query.Where(query.Or(
query.Equal("account_id", claims.Audience),
query.Equal("user_id", userID),
))
sql, args := query.Build()
sql = dbConn.Rebind(sql)
queryStr, args := query.Build()
queryStr = dbConn.Rebind(queryStr)
var userAccountId string
err := dbConn.QueryRowContext(ctx, sql, args...).Scan(&userAccountId)
if err != nil {
err := dbConn.QueryRowContext(ctx, queryStr, args...).Scan(&userAccountId)
if err != nil && err != sql.ErrNoRows {
err = errors.Wrapf(err, "query - %s", query.String())
return err
}
@ -130,7 +130,7 @@ func applyClaimsUserSelect(ctx context.Context, claims auth.Claims, query *sqlbu
if claims.Subject != "" {
or = append(or, subQuery.Equal("user_id", claims.Subject))
}
subQuery.Where(or...)
subQuery.Where(subQuery.Or(or...))
// Append sub query
query.Where(query.In("id", subQuery))
@ -147,10 +147,12 @@ func selectQuery() *sqlbuilder.SelectBuilder {
}
// userFindRequestQuery generates the select query for the given find request.
func userFindRequestQuery(req UserFindRequest) *sqlbuilder.SelectBuilder {
// TODO: Need to figure out why can't parse the args when appending the where
// to the query.
func userFindRequestQuery(req UserFindRequest) (*sqlbuilder.SelectBuilder, []interface{}) {
query := selectQuery()
if req.Where != nil {
query.Where(*req.Where)
query.Where(query.And(*req.Where))
}
if len(req.Order) > 0 {
query.OrderBy(req.Order...)
@ -162,90 +164,17 @@ func userFindRequestQuery(req UserFindRequest) *sqlbuilder.SelectBuilder {
query.Offset(int(*req.Offset))
}
b := sqlbuilder.Buildf(query.String(), req.Args...)
query.BuilderAs(b, usersTableName)
return query
return query, req.Args
}
// List enables streaming retrieval of Users from the database. The query results
// will be written to the interface{} resultReceiver channel enabling processing the results while
// they're still being fetched. After all pages have been processed the channel is closed
// Possible types sent to the channel are limited to:
// - error
// - User
//
// rr := make(chan interface{})
//
// go List(ctx, claims, db, rr)
//
// for r := range rr {
// switch v := r.(type) {
// case User:
// // v is of type User
// // process the user here
// case error:
// // v is of type error
// // handle the error here
// }
// }
func List(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserFindRequest, results chan<- interface{}) {
query := userFindRequestQuery(req)
list(ctx, claims, dbConn, query, req.IncludedArchived, results)
}
// List enables streaming retrieval of Users from the database for the supplied query.
func list(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, query *sqlbuilder.SelectBuilder, includedArchived bool, results chan<- interface{}) {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.List")
defer span.Finish()
// Close the channel on complete
defer close(results)
query.Select(usersMapColumns)
query.From(usersTableName)
if !includedArchived {
query.Where(query.IsNull("archived_at"))
}
// Check to see if a sub query needs to be applied for the claims
err := applyClaimsUserSelect(ctx, claims, query)
if err != nil {
results <- err
return
}
sql, args := query.Build()
sql = dbConn.Rebind(sql)
// fetch all places from the db
rows, err := dbConn.QueryContext(ctx, sql, args...)
if err != nil {
err = errors.Wrapf(err, "query - %s", query.String())
results <- errors.WithMessage(err, "list users failed")
return
}
// iterate over each row
for rows.Next() {
u, err := mapRowsToUser(rows)
if err != nil {
err = errors.Wrapf(err, "query - %s", query.String())
results <- err
return
}
results <- u
}
}
// Find gets all the users from the database based on the request params
// Find gets all the users from the database based on the request params.
func Find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserFindRequest) ([]*User, error) {
query := userFindRequestQuery(req)
return find(ctx, claims, dbConn, query, req.IncludedArchived)
query, args := userFindRequestQuery(req)
return find(ctx, claims, dbConn, query, args, req.IncludedArchived)
}
// find gets all the users from the database based on the query
func find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, query *sqlbuilder.SelectBuilder, includedArchived bool) ([]*User, error) {
// find internal method for getting all the users from the database using a select query.
func find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, query *sqlbuilder.SelectBuilder, args []interface{}, includedArchived bool) ([]*User, error) {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.Find")
defer span.Finish()
@ -261,11 +190,15 @@ func find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, query *sqlbu
if err != nil {
return nil, err
}
sql, args := query.Build()
sql = dbConn.Rebind(sql)
queryStr, queryArgs := query.Build()
queryStr = dbConn.Rebind(queryStr)
args = append(args, queryArgs...)
fmt.Println(queryStr)
fmt.Println(args)
// fetch all places from the db
rows, err := dbConn.QueryContext(ctx, sql, args...)
rows, err := dbConn.QueryContext(ctx, queryStr, args...)
if err != nil {
err = errors.Wrapf(err, "query - %s", query.String())
err = errors.WithMessage(err, "find users failed")
@ -283,6 +216,8 @@ func find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, query *sqlbu
resp = append(resp, u)
}
fmt.Println("len", len(resp))
return resp, nil
}
@ -295,7 +230,7 @@ func FindById(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, id strin
query := selectQuery()
query.Where(query.Equal("id", id))
res, err := find(ctx, claims, dbConn, query, includedArchived)
res, err := find(ctx, claims, dbConn, query, []interface{}{}, includedArchived)
if err != nil {
return nil, err
} else if res == nil || len(res) == 0 {
@ -309,7 +244,6 @@ func FindById(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, id strin
// Validation an email address is unique excluding the current user ID.
func uniqueEmail(ctx context.Context, dbConn *sqlx.DB, email, userId string) (bool, error) {
query := sqlbuilder.NewSelectBuilder().Select("id").From(usersTableName)
query.Where(query.And(
query.Equal("email", email),