You've already forked golang-saas-starter-kit
mirror of
https://github.com/raseels-repos/golang-saas-starter-kit.git
synced 2025-08-08 22:36:41 +02:00
Completed implimentation of forgot password
This commit is contained in:
271
internal/user_account/invite/invite.go
Normal file
271
internal/user_account/invite/invite.go
Normal file
@ -0,0 +1,271 @@
|
||||
package invite
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"geeks-accelerator/oss/saas-starter-kit/internal/account"
|
||||
"geeks-accelerator/oss/saas-starter-kit/internal/platform/auth"
|
||||
"geeks-accelerator/oss/saas-starter-kit/internal/platform/notify"
|
||||
"geeks-accelerator/oss/saas-starter-kit/internal/platform/web/webcontext"
|
||||
"geeks-accelerator/oss/saas-starter-kit/internal/user"
|
||||
"geeks-accelerator/oss/saas-starter-kit/internal/user_account"
|
||||
"github.com/jmoiron/sqlx"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/sudo-suhas/symcrypto"
|
||||
"gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrInviteExpired occurs when the the reset hash exceeds the expiration.
|
||||
ErrInviteExpired = errors.New("Invite expired")
|
||||
|
||||
// ErrInviteUserPasswordSet occurs when the the reset hash exceeds the expiration.
|
||||
ErrInviteUserPasswordSet = errors.New("User password set")
|
||||
)
|
||||
|
||||
// InviteUsers sends emails to the users inviting them to join an account.
|
||||
func InviteUsers(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, resetUrl func(string) string, notify notify.Email, req InviteUsersRequest, secretKey string, now time.Time) ([]string, error) {
|
||||
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user_account.invite.InviteUsers")
|
||||
defer span.Finish()
|
||||
|
||||
v := webcontext.Validator()
|
||||
|
||||
// Validate the request.
|
||||
err := v.StructCtx(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Ensure the claims can modify the account specified in the request.
|
||||
err = user_account.CanModifyAccount(ctx, claims, dbConn, req.AccountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Find all the users by email address.
|
||||
emailUserIDs := make(map[string]string)
|
||||
{
|
||||
// Find all users without passing in claims to search all users.
|
||||
where := fmt.Sprintf("email in ('%s')", strings.Join(req.Emails, "','"))
|
||||
users, err := user.Find(ctx, auth.Claims{}, dbConn, user.UserFindRequest{
|
||||
Where: &where,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, u := range users {
|
||||
emailUserIDs[u.Email] = u.ID
|
||||
}
|
||||
}
|
||||
|
||||
// Find users that are already active for this account.
|
||||
activelUserIDs := make(map[string]bool)
|
||||
{
|
||||
var args []string
|
||||
for _, userID := range emailUserIDs {
|
||||
args = append(args, userID)
|
||||
}
|
||||
|
||||
where := fmt.Sprintf("user_id in ('%s') and status = '%s'", strings.Join(args, "','"), user_account.UserAccountStatus_Active.String())
|
||||
userAccs, err := user_account.Find(ctx, claims, dbConn, user_account.UserAccountFindRequest{
|
||||
Where: &where,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, userAcc := range userAccs {
|
||||
activelUserIDs[userAcc.UserID] = true
|
||||
}
|
||||
}
|
||||
|
||||
// Always store the time as UTC.
|
||||
now = now.UTC()
|
||||
|
||||
// Postgres truncates times to milliseconds when storing. We and do the same
|
||||
// here so the value we return is consistent with what we store.
|
||||
now = now.Truncate(time.Millisecond)
|
||||
|
||||
// Create any users that don't already exist.
|
||||
for _, email := range req.Emails {
|
||||
if uId, ok := emailUserIDs[email]; ok && uId != "" {
|
||||
continue
|
||||
}
|
||||
|
||||
u, err := user.CreateInvite(ctx, claims, dbConn, user.UserCreateInviteRequest{
|
||||
Email: email,
|
||||
}, now)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
emailUserIDs[email] = u.ID
|
||||
}
|
||||
|
||||
// Loop through all the existing users who either do not have an user_account record or
|
||||
// have an existing record, but the status is disabled.
|
||||
for _, userID := range emailUserIDs {
|
||||
// User already is active, skip.
|
||||
if activelUserIDs[userID] {
|
||||
continue
|
||||
}
|
||||
|
||||
status := user_account.UserAccountStatus_Invited
|
||||
_, err = user_account.Create(ctx, claims, dbConn, user_account.UserAccountCreateRequest{
|
||||
UserID: userID,
|
||||
AccountID: req.AccountID,
|
||||
Roles: req.Roles,
|
||||
Status: &status,
|
||||
}, now)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if req.TTL.Seconds() == 0 {
|
||||
req.TTL = time.Minute * 90
|
||||
}
|
||||
|
||||
fromUser, err := user.Read(ctx, claims, dbConn, req.UserID, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
account, err := account.Read(ctx, claims, dbConn, req.AccountID, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Load the current IP makings the request.
|
||||
var requestIp string
|
||||
if vals, _ := webcontext.ContextValues(ctx); vals != nil {
|
||||
requestIp = vals.RequestIP
|
||||
}
|
||||
|
||||
var inviteHashes []string
|
||||
for email, userID := range emailUserIDs {
|
||||
|
||||
// Generate a string that embeds additional information.
|
||||
hashPts := []string{
|
||||
userID,
|
||||
strconv.Itoa(int(now.UTC().Unix())),
|
||||
strconv.Itoa(int(now.UTC().Add(req.TTL).Unix())),
|
||||
requestIp,
|
||||
}
|
||||
hashStr := strings.Join(hashPts, "|")
|
||||
|
||||
// This returns the nonce appended with the encrypted string.
|
||||
crypto, err := symcrypto.New(secretKey)
|
||||
if err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
encrypted, err := crypto.Encrypt(hashStr)
|
||||
if err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
|
||||
data := map[string]interface{}{
|
||||
"FromUser": fromUser.Response(ctx),
|
||||
"Account": account.Response(ctx),
|
||||
"Url": resetUrl(encrypted),
|
||||
"Minutes": req.TTL.Minutes(),
|
||||
}
|
||||
|
||||
subject := fmt.Sprintf("%s %s has invited you to %s", fromUser.FirstName, fromUser.LastName, account.Name)
|
||||
|
||||
err = notify.Send(ctx, email, subject, "user_invite", data)
|
||||
if err != nil {
|
||||
err = errors.WithMessagef(err, "Send invite to %s failed.", email)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
inviteHashes = append(inviteHashes, encrypted)
|
||||
}
|
||||
|
||||
return inviteHashes, nil
|
||||
}
|
||||
|
||||
// InviteAccept updates the password for a user using the provided reset password ID.
|
||||
func InviteAccept(ctx context.Context, dbConn *sqlx.DB, req InviteAcceptRequest, secretKey string, now time.Time) error {
|
||||
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user_account.invite.InviteAccept")
|
||||
defer span.Finish()
|
||||
|
||||
v := webcontext.Validator()
|
||||
|
||||
// Validate the request.
|
||||
err := v.StructCtx(ctx, req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
crypto, err := symcrypto.New(secretKey)
|
||||
if err != nil {
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
hashStr, err := crypto.Decrypt(req.InviteHash)
|
||||
if err != nil {
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
hashPts := strings.Split(hashStr, "|")
|
||||
|
||||
var hash InviteHash
|
||||
if len(hashPts) == 4 {
|
||||
hash.UserID = hashPts[0]
|
||||
hash.CreatedAt, _ = strconv.Atoi(hashPts[1])
|
||||
hash.ExpiresAt, _ = strconv.Atoi(hashPts[2])
|
||||
hash.RequestIP = hashPts[3]
|
||||
}
|
||||
|
||||
// Validate the hash.
|
||||
err = v.StructCtx(ctx, hash)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if int64(hash.ExpiresAt) < now.UTC().Unix() {
|
||||
err = errors.WithMessage(ErrInviteExpired, "Invite has expired.")
|
||||
return err
|
||||
}
|
||||
|
||||
u, err := user.Read(ctx, auth.Claims{}, dbConn, hash.UserID, true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if u.ArchivedAt != nil && !u.ArchivedAt.Time.IsZero() {
|
||||
err = user.Unarchive(ctx, auth.Claims{}, dbConn, user.UserUnarchiveRequest{ID: hash.UserID}, now)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else if len(u.PasswordHash) > 0 {
|
||||
// Do not update the password for a user that already has a password set.
|
||||
err = errors.WithMessage(ErrInviteUserPasswordSet, "Invite user already has a password set.")
|
||||
return err
|
||||
}
|
||||
|
||||
err = user.Update(ctx, auth.Claims{}, dbConn, user.UserUpdateRequest{
|
||||
ID: hash.UserID,
|
||||
FirstName: &req.FirstName,
|
||||
LastName: &req.LastName,
|
||||
Timezone: req.Timezone,
|
||||
}, now)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = user.UpdatePassword(ctx, auth.Claims{}, dbConn, user.UserUpdatePasswordRequest{
|
||||
ID: hash.UserID,
|
||||
Password: req.Password,
|
||||
PasswordConfirm: req.PasswordConfirm,
|
||||
}, now)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
Reference in New Issue
Block a user