2019-05-27 02:44:40 -05:00
package user
import (
"context"
2019-06-24 22:41:21 -08:00
"crypto/rsa"
2019-06-25 22:31:54 -08:00
"database/sql"
2019-07-13 12:16:28 -08:00
"geeks-accelerator/oss/saas-starter-kit/internal/platform/web"
2019-06-24 22:41:21 -08:00
"github.com/dgrijalva/jwt-go"
2019-06-25 22:31:54 -08:00
"strings"
2019-05-27 02:44:40 -05:00
"time"
2019-07-13 12:16:28 -08:00
"geeks-accelerator/oss/saas-starter-kit/internal/platform/auth"
2019-05-27 02:44:40 -05:00
"github.com/huandu/go-sqlbuilder"
"github.com/jmoiron/sqlx"
2019-06-22 17:48:44 -08:00
"github.com/lib/pq"
2019-05-27 02:44:40 -05:00
"github.com/pkg/errors"
"golang.org/x/crypto/bcrypt"
"gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer"
)
2019-05-29 15:05:17 -05:00
// TokenGenerator is the behavior we need in our Authenticate to generate tokens for
// authenticated users.
2019-05-27 02:44:40 -05:00
type TokenGenerator interface {
GenerateToken ( auth . Claims ) ( string , error )
2019-05-29 15:05:17 -05:00
ParseClaims ( string ) ( auth . Claims , error )
2019-05-27 02:44:40 -05:00
}
2019-05-29 15:05:17 -05:00
// Authenticate finds a user by their email and verifies their password. On success
// it returns a Token that can be used to authenticate access to the application in
// the future.
2019-06-25 22:31:54 -08:00
func Authenticate ( ctx context . Context , dbConn * sqlx . DB , tknGen TokenGenerator , email , password string , expires time . Duration , now time . Time , scopes ... string ) ( Token , error ) {
2019-05-27 02:44:40 -05:00
span , ctx := tracer . StartSpanFromContext ( ctx , "internal.user.Authenticate" )
defer span . Finish ( )
2019-05-29 15:05:17 -05:00
// Generate sql query to select user by email address.
2019-05-27 02:44:40 -05:00
query := sqlbuilder . NewSelectBuilder ( )
query . Where ( query . Equal ( "email" , email ) )
2019-05-29 15:05:17 -05:00
// Run the find, use empty claims to bypass ACLs since this in an internal request
// and the current user is not authenticated at this point. If the email is
// invalid, return the same error as when an invalid password is supplied.
2019-05-28 04:44:01 -05:00
res , err := find ( ctx , auth . Claims { } , dbConn , query , [ ] interface { } { } , false )
2019-05-27 02:44:40 -05:00
if err != nil {
return Token { } , err
} else if res == nil || len ( res ) == 0 {
err = errors . WithStack ( ErrAuthenticationFailure )
return Token { } , err
}
u := res [ 0 ]
// Append the salt from the user record to the supplied password.
2019-05-29 15:05:17 -05:00
saltedPassword := password + u . PasswordSalt
2019-05-27 02:44:40 -05:00
2019-05-29 15:05:17 -05:00
// Compare the provided password with the saved hash. Use the bcrypt comparison
// function so it is cryptographically secure. Return authentication error for
// invalid password.
2019-05-27 02:44:40 -05:00
if err := bcrypt . CompareHashAndPassword ( u . PasswordHash , [ ] byte ( saltedPassword ) ) ; err != nil {
err = errors . WithStack ( ErrAuthenticationFailure )
return Token { } , err
}
2019-05-29 15:05:17 -05:00
// The user is successfully authenticated with the supplied email and password.
2019-06-25 22:31:54 -08:00
return generateToken ( ctx , dbConn , tknGen , auth . Claims { } , u . ID , "" , expires , now , scopes ... )
2019-05-29 15:05:17 -05:00
}
// Authenticate finds a user by their email and verifies their password. On success
// it returns a Token that can be used to authenticate access to the application in
// the future.
2019-06-25 22:31:54 -08:00
func SwitchAccount ( ctx context . Context , dbConn * sqlx . DB , tknGen TokenGenerator , claims auth . Claims , accountID string , expires time . Duration , now time . Time , scopes ... string ) ( Token , error ) {
2019-05-29 15:05:17 -05:00
span , ctx := tracer . StartSpanFromContext ( ctx , "internal.user.SwitchAccount" )
defer span . Finish ( )
// Defines struct to apply validation for the supplied claims and account ID.
req := struct {
2019-06-27 04:48:18 -08:00
UserID string ` json:"user_id" validate:"required,uuid" `
AccountID string ` json:"account_id" validate:"required,uuid" `
2019-05-29 15:05:17 -05:00
} {
UserID : claims . Subject ,
AccountID : accountID ,
}
// Validate the request.
2019-06-26 20:21:00 -08:00
v := web . NewValidator ( )
err := v . Struct ( req )
2019-05-29 15:05:17 -05:00
if err != nil {
return Token { } , err
}
// Generate a token for the user ID in supplied in claims as the Subject. Pass
// in the supplied claims as well to enforce ACLs when finding the current
// list of accounts for the user.
2019-06-25 22:31:54 -08:00
return generateToken ( ctx , dbConn , tknGen , claims , req . UserID , req . AccountID , expires , now , scopes ... )
2019-05-29 15:05:17 -05:00
}
// generateToken generates claims for the supplied user ID and account ID and then
// returns the token for the generated claims used for authentication.
2019-06-25 22:31:54 -08:00
func generateToken ( ctx context . Context , dbConn * sqlx . DB , tknGen TokenGenerator , claims auth . Claims , userID , accountID string , expires time . Duration , now time . Time , scopes ... string ) ( Token , error ) {
2019-06-22 17:48:44 -08:00
type userAccount struct {
AccountID string
Roles pq . StringArray
UserStatus string
UserArchived pq . NullTime
AccountStatus string
AccountArchived pq . NullTime
2019-06-25 22:31:54 -08:00
AccountTimezone sql . NullString
UserTimezone sql . NullString
2019-06-22 17:48:44 -08:00
}
// Build select statement for users_accounts table to find all the user accounts for the user
f := func ( ) ( [ ] userAccount , error ) {
2019-07-15 18:34:58 -08:00
query := sqlbuilder . NewSelectBuilder ( ) . Select ( "ua.account_id, ua.roles, ua.status as userStatus, ua.archived_at userArchived, a.status as accountStatus, a.archived_at, a.timezone, u.timezone as userTimezone" ) .
2019-06-22 17:48:44 -08:00
From ( userAccountTableName + " ua" ) .
2019-06-25 22:31:54 -08:00
Join ( accountTableName + " a" , "a.id = ua.account_id" ) .
Join ( userTableName + " u" , "u.id = ua.user_id" )
2019-06-22 17:48:44 -08:00
query . Where ( query . And (
query . Equal ( "ua.user_id" , userID ) ,
) )
query . OrderBy ( "ua.status, a.status, ua.created_at" )
// fetch all places from the db
queryStr , queryArgs := query . Build ( )
queryStr = dbConn . Rebind ( queryStr )
rows , err := dbConn . QueryContext ( ctx , queryStr , queryArgs ... )
if err != nil {
err = errors . Wrapf ( err , "query - %s" , query . String ( ) )
return nil , err
}
// iterate over each row
var resp [ ] userAccount
for rows . Next ( ) {
var ua userAccount
2019-06-25 22:31:54 -08:00
err = rows . Scan ( & ua . AccountID , & ua . Roles , & ua . UserStatus , & ua . UserArchived , & ua . AccountStatus , & ua . AccountArchived , & ua . AccountTimezone , & ua . UserTimezone )
2019-06-22 17:48:44 -08:00
if err != nil {
return nil , errors . WithStack ( err )
}
if err != nil {
err = errors . Wrapf ( err , "query - %s" , query . String ( ) )
return nil , err
}
resp = append ( resp , ua )
}
return resp , nil
}
accounts , err := f ( )
2019-05-27 02:44:40 -05:00
if err != nil {
2019-06-22 17:48:44 -08:00
err = errors . WithStack ( ErrAuthenticationFailure )
2019-05-27 02:44:40 -05:00
return Token { } , err
}
2019-06-22 17:48:44 -08:00
// Load the user account entry for the specified account ID. If none provided,
2019-05-29 15:05:17 -05:00
// choose the first.
2019-06-22 17:48:44 -08:00
var account userAccount
2019-05-29 15:05:17 -05:00
if accountID == "" {
2019-06-22 17:48:44 -08:00
// Try to choose the first active user account that has not been archived.
for _ , a := range accounts {
if a . AccountArchived . Valid && ! a . AccountArchived . Time . IsZero ( ) {
continue
} else if a . UserArchived . Valid && ! a . UserArchived . Time . IsZero ( ) {
continue
} else if a . AccountStatus != "active" {
continue
} else if a . UserStatus != "active" {
continue
}
account = accounts [ 0 ]
accountID = account . AccountID
break
}
2019-05-29 15:05:17 -05:00
// Select the first account associated with the user. For the login flow,
// users could be forced to select a specific account to override this.
2019-06-22 17:48:44 -08:00
if accountID == "" && len ( accounts ) > 0 {
2019-05-29 15:05:17 -05:00
account = accounts [ 0 ]
accountID = account . AccountID
}
} else {
// Loop through all the accounts found for the user and select the specified
// account.
for _ , a := range accounts {
if a . AccountID == accountID {
account = a
break
}
}
// If no matching entry was found for the specified account ID throw an error.
2019-06-22 17:48:44 -08:00
if account . AccountID == "" {
2019-05-29 15:05:17 -05:00
err = errors . WithStack ( ErrAuthenticationFailure )
return Token { } , err
}
}
2019-06-22 17:48:44 -08:00
// Validate the user account is completely active.
if account . AccountArchived . Valid && ! account . AccountArchived . Time . IsZero ( ) {
err = errors . WithMessage ( ErrAuthenticationFailure , "account is archived" )
return Token { } , err
} else if account . UserArchived . Valid && ! account . UserArchived . Time . IsZero ( ) {
err = errors . WithMessage ( ErrAuthenticationFailure , "user account is archived" )
return Token { } , err
} else if account . AccountStatus != "active" {
err = errors . WithMessagef ( ErrAuthenticationFailure , "account is not active with status of %s" , account . AccountStatus )
return Token { } , err
} else if account . UserStatus != "active" {
err = errors . WithMessagef ( ErrAuthenticationFailure , "user account is not active with status of %s" , account . UserStatus )
return Token { } , err
2019-05-27 02:44:40 -05:00
}
2019-05-29 15:05:17 -05:00
// Generate a list of all the account IDs associated with the user so the use
// has the ability to switch between accounts.
var accountIds [ ] string
2019-05-27 02:44:40 -05:00
for _ , a := range accounts {
2019-05-28 04:44:01 -05:00
accountIds = append ( accountIds , a . AccountID )
2019-05-27 02:44:40 -05:00
}
2019-06-25 22:31:54 -08:00
// Allow the scope to be defined for the claims. This enables testing via the API when a user has the role of admin
// and would like to limit their role to user.
var roles [ ] string
if len ( scopes ) > 0 && scopes [ 0 ] != "" {
// Parse scopes, handle when one value has a list of scopes
// separated by a space.
var scopeList [ ] string
for _ , vs := range scopes {
for _ , v := range strings . Split ( vs , " " ) {
v = strings . TrimSpace ( v )
if v == "" {
continue
}
scopeList = append ( scopeList , v )
}
}
for _ , s := range scopeList {
var scopeValid bool
for _ , r := range account . Roles {
if r == s || ( s == auth . RoleUser && r == auth . RoleAdmin ) {
scopeValid = true
break
}
}
if scopeValid {
roles = append ( roles , s )
} else {
err := errors . Errorf ( "invalid scope '%s'" , s )
return Token { } , err
}
}
} else {
roles = account . Roles
}
if len ( roles ) == 0 {
err := errors . New ( "no roles defined for user" )
return Token { } , err
}
// Set the timezone if one is specifically set on the user.
var tz * time . Location
if account . UserTimezone . Valid && account . UserTimezone . String != "" {
tz , _ = time . LoadLocation ( account . UserTimezone . String )
}
// If user timezone failed to parse or none is set, check the timezone set on the account.
if tz == nil && account . AccountTimezone . Valid && account . AccountTimezone . String != "" {
tz , _ = time . LoadLocation ( account . AccountTimezone . String )
}
2019-05-29 15:05:17 -05:00
// JWT claims requires both an audience and a subject. For this application:
// Subject: The ID of the user authenticated.
// Audience: The ID of the account the user is accessing. A list of account IDs
// will also be included to support the user switching between them.
2019-06-25 22:31:54 -08:00
claims = auth . NewClaims ( userID , accountID , accountIds , roles , tz , now , expires )
2019-05-27 02:44:40 -05:00
// Generate a token for the user with the defined claims.
2019-06-25 06:25:55 -08:00
tknStr , err := tknGen . GenerateToken ( claims )
2019-05-27 02:44:40 -05:00
if err != nil {
return Token { } , errors . Wrap ( err , "generating token" )
}
2019-06-25 06:25:55 -08:00
tkn := Token {
AccessToken : tknStr ,
TokenType : "Bearer" ,
claims : claims ,
}
if expires . Seconds ( ) > 0 {
tkn . Expiry = now . Add ( expires )
}
return tkn , nil
2019-05-27 02:44:40 -05:00
}
2019-06-24 22:41:21 -08:00
2019-06-26 01:16:57 -08:00
// AuthorizationHeader returns the header authorization value.
func ( t Token ) AuthorizationHeader ( ) string {
return "Bearer " + t . AccessToken
}
2019-06-24 22:41:21 -08:00
// mockTokenGenerator is used for testing that Authenticate calls its provided
// token generator in a specific way.
type MockTokenGenerator struct {
// Private key generated by GenerateToken that is need for ParseClaims
key * rsa . PrivateKey
// algorithm is the method used to generate the private key.
algorithm string
}
// GenerateToken implements the TokenGenerator interface. It returns a "token"
// that includes some information about the claims it was passed.
func ( g * MockTokenGenerator ) GenerateToken ( claims auth . Claims ) ( string , error ) {
privateKey , err := auth . KeyGen ( )
if err != nil {
return "" , err
}
g . key , err = jwt . ParseRSAPrivateKeyFromPEM ( privateKey )
if err != nil {
return "" , err
}
g . algorithm = "RS256"
method := jwt . GetSigningMethod ( g . algorithm )
tkn := jwt . NewWithClaims ( method , claims )
tkn . Header [ "kid" ] = "1"
str , err := tkn . SignedString ( g . key )
if err != nil {
return "" , err
}
return str , nil
}
// ParseClaims recreates the Claims that were used to generate a token. It
// verifies that the token was signed using our key.
func ( g * MockTokenGenerator ) ParseClaims ( tknStr string ) ( auth . Claims , error ) {
parser := jwt . Parser {
ValidMethods : [ ] string { g . algorithm } ,
}
if g . key == nil {
return auth . Claims { } , errors . New ( "Private key is empty." )
}
f := func ( t * jwt . Token ) ( interface { } , error ) {
return g . key . Public ( ) . ( * rsa . PublicKey ) , nil
}
var claims auth . Claims
tkn , err := parser . ParseWithClaims ( tknStr , & claims , f )
if err != nil {
return auth . Claims { } , errors . Wrap ( err , "parsing token" )
}
if ! tkn . Valid {
return auth . Claims { } , errors . New ( "Invalid token" )
}
return claims , nil
}