mirror of
https://github.com/raseels-repos/golang-saas-starter-kit.git
synced 2025-06-06 23:46:29 +02:00
204 lines
5.5 KiB
Go
204 lines
5.5 KiB
Go
package auth
|
|
|
|
import (
|
|
"crypto/rsa"
|
|
"fmt"
|
|
"time"
|
|
|
|
"github.com/dgrijalva/jwt-go"
|
|
"github.com/pkg/errors"
|
|
)
|
|
|
|
// KeyFunc is used to map a JWT key id (kid) to the corresponding public key.
|
|
// It is a requirement for creating an Authenticator.
|
|
//
|
|
// * Private keys should be rotated. During the transition period, tokens
|
|
// signed with the old and new keys can coexist by looking up the correct
|
|
// public key by key id (kid).
|
|
//
|
|
// * Key-id-to-public-key resolution is usually accomplished via a public JWKS
|
|
// endpoint. See https://auth0.com/docs/jwks for more details.
|
|
type KeyFunc func(keyID string) (*rsa.PublicKey, error)
|
|
|
|
// NewKeyFunc is a multiple implementation of KeyFunc that
|
|
// supports a map of keys.
|
|
func NewKeyFunc(keys map[string]*PrivateKey) KeyFunc {
|
|
return func(kid string) (*rsa.PublicKey, error) {
|
|
key, ok := keys[kid]
|
|
if !ok {
|
|
return nil, fmt.Errorf("unrecognized kid %q", kid)
|
|
}
|
|
return key.Public().(*rsa.PublicKey), nil
|
|
}
|
|
}
|
|
|
|
// Authenticator is used to authenticate clients. It can generate a token for a
|
|
// set of user claims and recreate the claims by parsing the token.
|
|
type Authenticator struct {
|
|
privateKey *PrivateKey
|
|
keyID string
|
|
algorithm string
|
|
kf KeyFunc
|
|
parser *jwt.Parser
|
|
Storage Storage
|
|
}
|
|
|
|
// PrivateKey is used to associate a private key with a keyID and algorithm.
|
|
type PrivateKey struct {
|
|
*rsa.PrivateKey
|
|
keyID string
|
|
algorithm string
|
|
}
|
|
|
|
// NewAuthenticator creates an *Authenticator for use.
|
|
// key expiration is optional to filter out old keys
|
|
// It will error if:
|
|
// - The specified algorithm is unsupported.
|
|
// - No current private key exists.
|
|
func NewAuthenticator(storage Storage, now time.Time) (*Authenticator, error) {
|
|
|
|
// Lookup function to be used by the middleware to validate the kid and
|
|
// Return the associated public key.
|
|
publicKeyLookup := NewKeyFunc(storage.Keys())
|
|
|
|
// Validate the globally defined encryption algorithm is valid.
|
|
if jwt.GetSigningMethod(algorithm) == nil {
|
|
return nil, errors.Errorf("unknown algorithm %v", algorithm)
|
|
}
|
|
|
|
// Create the token parser to use. The algorithm used to sign the JWT must be
|
|
// validated to avoid a critical vulnerability:
|
|
// https://auth0.com/blog/critical-vulnerabilities-in-json-web-token-libraries/
|
|
parser := jwt.Parser{
|
|
ValidMethods: []string{algorithm},
|
|
}
|
|
|
|
// Load the current key from the storage engine.
|
|
curKey := storage.Current()
|
|
if curKey == nil {
|
|
return nil, errors.New("Missing private key")
|
|
}
|
|
|
|
a := Authenticator{
|
|
privateKey: curKey,
|
|
keyID: curKey.keyID,
|
|
algorithm: algorithm,
|
|
kf: publicKeyLookup,
|
|
parser: &parser,
|
|
}
|
|
|
|
return &a, nil
|
|
}
|
|
|
|
// GenerateToken generates a signed JWT token string representing the user Claims.
|
|
func (a *Authenticator) GenerateToken(claims Claims) (string, error) {
|
|
method := jwt.GetSigningMethod(a.algorithm)
|
|
|
|
tkn := jwt.NewWithClaims(method, claims)
|
|
tkn.Header["kid"] = a.keyID
|
|
|
|
str, err := tkn.SignedString(a.privateKey.PrivateKey)
|
|
if err != nil {
|
|
return "", errors.Wrap(err, "signing token")
|
|
}
|
|
|
|
return str, nil
|
|
}
|
|
|
|
// ParseClaims recreates the Claims that were used to generate a token. It
|
|
// verifies that the token was signed using our key.
|
|
func (a *Authenticator) ParseClaims(tknStr string) (Claims, error) {
|
|
|
|
// f is a function that returns the public key for validating a token. We use
|
|
// the parsed (but unverified) token to find the key id. That ID is passed to
|
|
// our KeyFunc to find the public key to use for verification.
|
|
f := func(t *jwt.Token) (interface{}, error) {
|
|
kid, ok := t.Header["kid"]
|
|
if !ok {
|
|
return nil, errors.New("Missing key id (kid) in token header")
|
|
}
|
|
kidStr, ok := kid.(string)
|
|
if !ok {
|
|
return nil, errors.New("Token key id (kid) must be string")
|
|
}
|
|
|
|
return a.kf(kidStr)
|
|
}
|
|
|
|
var claims Claims
|
|
tkn, err := a.parser.ParseWithClaims(tknStr, &claims, f)
|
|
if err != nil {
|
|
return Claims{}, errors.Wrap(err, "parsing token")
|
|
}
|
|
|
|
if !tkn.Valid {
|
|
return Claims{}, errors.New("Invalid token")
|
|
}
|
|
|
|
return claims, nil
|
|
}
|
|
|
|
// mockTokenGenerator is used for testing that Authenticate calls its provided
|
|
// token generator in a specific way.
|
|
type MockTokenGenerator struct {
|
|
// Private key generated by GenerateToken that is need for ParseClaims
|
|
key *rsa.PrivateKey
|
|
// algorithm is the method used to generate the private key.
|
|
algorithm string
|
|
}
|
|
|
|
// GenerateToken implements the TokenGenerator interface. It returns a "token"
|
|
// that includes some information about the claims it was passed.
|
|
func (g *MockTokenGenerator) GenerateToken(claims Claims) (string, error) {
|
|
privateKey, err := KeyGen()
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
g.key, err = jwt.ParseRSAPrivateKeyFromPEM(privateKey)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
g.algorithm = "RS256"
|
|
method := jwt.GetSigningMethod(g.algorithm)
|
|
|
|
tkn := jwt.NewWithClaims(method, claims)
|
|
tkn.Header["kid"] = "1"
|
|
|
|
str, err := tkn.SignedString(g.key)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
return str, nil
|
|
}
|
|
|
|
// ParseClaims recreates the Claims that were used to generate a token. It
|
|
// verifies that the token was signed using our key.
|
|
func (g *MockTokenGenerator) ParseClaims(tknStr string) (Claims, error) {
|
|
parser := jwt.Parser{
|
|
ValidMethods: []string{g.algorithm},
|
|
}
|
|
|
|
if g.key == nil {
|
|
return Claims{}, errors.New("Private key is empty.")
|
|
}
|
|
|
|
f := func(t *jwt.Token) (interface{}, error) {
|
|
return g.key.Public().(*rsa.PublicKey), nil
|
|
}
|
|
|
|
var claims Claims
|
|
tkn, err := parser.ParseWithClaims(tknStr, &claims, f)
|
|
if err != nil {
|
|
return Claims{}, errors.Wrap(err, "parsing token")
|
|
}
|
|
|
|
if !tkn.Valid {
|
|
return Claims{}, errors.New("Invalid token")
|
|
}
|
|
|
|
return claims, nil
|
|
}
|