You've already forked golang-saas-starter-kit
mirror of
https://github.com/raseels-repos/golang-saas-starter-kit.git
synced 2025-08-06 22:32:51 +02:00
moved example-project files back a directory
This commit is contained in:
139
internal/platform/auth/auth.go
Normal file
139
internal/platform/auth/auth.go
Normal file
@ -0,0 +1,139 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"crypto/rsa"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/dgrijalva/jwt-go"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// KeyFunc is used to map a JWT key id (kid) to the corresponding public key.
|
||||
// It is a requirement for creating an Authenticator.
|
||||
//
|
||||
// * Private keys should be rotated. During the transition period, tokens
|
||||
// signed with the old and new keys can coexist by looking up the correct
|
||||
// public key by key id (kid).
|
||||
//
|
||||
// * Key-id-to-public-key resolution is usually accomplished via a public JWKS
|
||||
// endpoint. See https://auth0.com/docs/jwks for more details.
|
||||
type KeyFunc func(keyID string) (*rsa.PublicKey, error)
|
||||
|
||||
// NewKeyFunc is a multiple implementation of KeyFunc that
|
||||
// supports a map of keys.
|
||||
func NewKeyFunc(keys map[string]*PrivateKey) KeyFunc {
|
||||
return func(kid string) (*rsa.PublicKey, error) {
|
||||
key, ok := keys[kid]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unrecognized kid %q", kid)
|
||||
}
|
||||
return key.Public().(*rsa.PublicKey), nil
|
||||
}
|
||||
}
|
||||
|
||||
// Authenticator is used to authenticate clients. It can generate a token for a
|
||||
// set of user claims and recreate the claims by parsing the token.
|
||||
type Authenticator struct {
|
||||
privateKey *PrivateKey
|
||||
keyID string
|
||||
algorithm string
|
||||
kf KeyFunc
|
||||
parser *jwt.Parser
|
||||
Storage Storage
|
||||
}
|
||||
|
||||
// PrivateKey is used to associate a private key with a keyID and algorithm.
|
||||
type PrivateKey struct {
|
||||
*rsa.PrivateKey
|
||||
keyID string
|
||||
algorithm string
|
||||
}
|
||||
|
||||
// NewAuthenticator creates an *Authenticator for use.
|
||||
// key expiration is optional to filter out old keys
|
||||
// It will error if:
|
||||
// - The specified algorithm is unsupported.
|
||||
// - No current private key exists.
|
||||
func NewAuthenticator(storage Storage, now time.Time) (*Authenticator, error) {
|
||||
|
||||
// Lookup function to be used by the middleware to validate the kid and
|
||||
// Return the associated public key.
|
||||
publicKeyLookup := NewKeyFunc(storage.Keys())
|
||||
|
||||
// Validate the globally defined encryption algorithm is valid.
|
||||
if jwt.GetSigningMethod(algorithm) == nil {
|
||||
return nil, errors.Errorf("unknown algorithm %v", algorithm)
|
||||
}
|
||||
|
||||
// Create the token parser to use. The algorithm used to sign the JWT must be
|
||||
// validated to avoid a critical vulnerability:
|
||||
// https://auth0.com/blog/critical-vulnerabilities-in-json-web-token-libraries/
|
||||
parser := jwt.Parser{
|
||||
ValidMethods: []string{algorithm},
|
||||
}
|
||||
|
||||
// Load the current key from the storage engine.
|
||||
curKey := storage.Current()
|
||||
if curKey == nil {
|
||||
return nil, errors.New("Missing private key")
|
||||
}
|
||||
|
||||
a := Authenticator{
|
||||
privateKey: curKey,
|
||||
keyID: curKey.keyID,
|
||||
algorithm: algorithm,
|
||||
kf: publicKeyLookup,
|
||||
parser: &parser,
|
||||
}
|
||||
|
||||
return &a, nil
|
||||
}
|
||||
|
||||
// GenerateToken generates a signed JWT token string representing the user Claims.
|
||||
func (a *Authenticator) GenerateToken(claims Claims) (string, error) {
|
||||
method := jwt.GetSigningMethod(a.algorithm)
|
||||
|
||||
tkn := jwt.NewWithClaims(method, claims)
|
||||
tkn.Header["kid"] = a.keyID
|
||||
|
||||
str, err := tkn.SignedString(a.privateKey.PrivateKey)
|
||||
if err != nil {
|
||||
return "", errors.Wrap(err, "signing token")
|
||||
}
|
||||
|
||||
return str, nil
|
||||
}
|
||||
|
||||
// ParseClaims recreates the Claims that were used to generate a token. It
|
||||
// verifies that the token was signed using our key.
|
||||
func (a *Authenticator) ParseClaims(tknStr string) (Claims, error) {
|
||||
|
||||
// f is a function that returns the public key for validating a token. We use
|
||||
// the parsed (but unverified) token to find the key id. That ID is passed to
|
||||
// our KeyFunc to find the public key to use for verification.
|
||||
f := func(t *jwt.Token) (interface{}, error) {
|
||||
kid, ok := t.Header["kid"]
|
||||
if !ok {
|
||||
return nil, errors.New("Missing key id (kid) in token header")
|
||||
}
|
||||
kidStr, ok := kid.(string)
|
||||
if !ok {
|
||||
return nil, errors.New("Token key id (kid) must be string")
|
||||
}
|
||||
|
||||
return a.kf(kidStr)
|
||||
}
|
||||
|
||||
var claims Claims
|
||||
tkn, err := a.parser.ParseWithClaims(tknStr, &claims, f)
|
||||
if err != nil {
|
||||
return Claims{}, errors.Wrap(err, "parsing token")
|
||||
}
|
||||
|
||||
if !tkn.Valid {
|
||||
return Claims{}, errors.New("Invalid token")
|
||||
}
|
||||
|
||||
return claims, nil
|
||||
}
|
165
internal/platform/auth/auth_test.go
Normal file
165
internal/platform/auth/auth_test.go
Normal file
@ -0,0 +1,165 @@
|
||||
package auth_test
|
||||
|
||||
import (
|
||||
"github.com/aws/aws-sdk-go/aws"
|
||||
"github.com/aws/aws-sdk-go/service/secretsmanager"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/auth"
|
||||
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/tests"
|
||||
"github.com/pborman/uuid"
|
||||
)
|
||||
|
||||
var test *tests.Test
|
||||
|
||||
// TestMain is the entry point for testing.
|
||||
func TestMain(m *testing.M) {
|
||||
os.Exit(testMain(m))
|
||||
}
|
||||
|
||||
func testMain(m *testing.M) int {
|
||||
tests.DisableDb = true
|
||||
|
||||
test = tests.New()
|
||||
defer test.TearDown()
|
||||
|
||||
return m.Run()
|
||||
}
|
||||
|
||||
// TestAuthenticatorFile validates File storage.
|
||||
func TestAuthenticatorFile(t *testing.T) {
|
||||
|
||||
var authTests = []struct {
|
||||
name string
|
||||
now time.Time
|
||||
keyExpiration time.Duration
|
||||
error error
|
||||
}{
|
||||
{"NoKeyExpiration", time.Now(), time.Duration(0), nil},
|
||||
{"KeyExpirationOk", time.Now(), time.Duration(time.Second * 3600), nil},
|
||||
{"KeyExpirationDisabled", time.Now().Add(time.Second * 3600 * 3), time.Duration(time.Second * 3600), nil},
|
||||
}
|
||||
|
||||
// Generate the token.
|
||||
signedClaims := auth.Claims{
|
||||
Roles: []string{auth.RoleAdmin},
|
||||
}
|
||||
|
||||
t.Log("Given the need to validate initiating a new Authenticator using File storage by key expiration.")
|
||||
{
|
||||
for i, tt := range authTests {
|
||||
t.Logf("\tTest: %d\tWhen running test: %s", i, tt.name)
|
||||
{
|
||||
a, err := auth.NewAuthenticatorFile("", tt.now, tt.keyExpiration)
|
||||
if err != tt.error {
|
||||
t.Log("\t\tGot :", err)
|
||||
t.Log("\t\tWant:", tt.error)
|
||||
t.Fatalf("\t%s\tNewAuthenticatorFile failed.", tests.Failed)
|
||||
}
|
||||
|
||||
tknStr, err := a.GenerateToken(signedClaims)
|
||||
if err != nil {
|
||||
t.Log("\t\tGot :", err)
|
||||
t.Fatalf("\t%s\tGenerateToken failed.", tests.Failed)
|
||||
}
|
||||
|
||||
parsedClaims, err := a.ParseClaims(tknStr)
|
||||
if err != nil {
|
||||
t.Log("\t\tGot :", err)
|
||||
t.Fatalf("\t%s\tParseClaims failed.", tests.Failed)
|
||||
}
|
||||
|
||||
// Assert expected claims.
|
||||
if exp, got := len(signedClaims.Roles), len(parsedClaims.Roles); exp != got {
|
||||
t.Log("\t\tGot :", got)
|
||||
t.Log("\t\tWant:", exp)
|
||||
t.Fatalf("\t%s\tShould got the same number of roles.", tests.Failed)
|
||||
}
|
||||
if exp, got := signedClaims.Roles[0], parsedClaims.Roles[0]; exp != got {
|
||||
t.Log("\t\tGot :", got)
|
||||
t.Log("\t\tWant:", exp)
|
||||
t.Fatalf("\t%s\tShould got the same role name.", tests.Failed)
|
||||
}
|
||||
|
||||
t.Logf("\t%s\tNewAuthenticatorFile ok.", tests.Success)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestAuthenticatorAws validates AWS storage.
|
||||
func TestAuthenticatorAws(t *testing.T) {
|
||||
|
||||
awsSecretID := "jwt-key" + uuid.NewRandom().String()
|
||||
|
||||
defer func() {
|
||||
// cleanup the secret after test is complete
|
||||
sm := secretsmanager.New(test.AwsSession)
|
||||
_, err := sm.DeleteSecret(&secretsmanager.DeleteSecretInput{
|
||||
SecretId: aws.String(awsSecretID),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}()
|
||||
|
||||
var authTests = []struct {
|
||||
name string
|
||||
awsSecretID string
|
||||
now time.Time
|
||||
keyExpiration time.Duration
|
||||
error error
|
||||
}{
|
||||
{"NoKeyExpiration", awsSecretID, time.Now(), time.Duration(0), nil},
|
||||
{"KeyExpirationOk", awsSecretID, time.Now(), time.Duration(time.Second * 3600), nil},
|
||||
{"KeyExpirationDisabled", awsSecretID, time.Now().Add(time.Second * 3600 * 3), time.Duration(time.Second * 3600), nil},
|
||||
}
|
||||
|
||||
// Generate the token.
|
||||
signedClaims := auth.Claims{
|
||||
Roles: []string{auth.RoleAdmin},
|
||||
}
|
||||
|
||||
t.Log("Given the need to validate initiating a new Authenticator using AWS storage by key expiration.")
|
||||
{
|
||||
for i, tt := range authTests {
|
||||
t.Logf("\tTest: %d\tWhen running test: %s", i, tt.name)
|
||||
{
|
||||
a, err := auth.NewAuthenticatorAws(test.AwsSession, tt.awsSecretID, tt.now, tt.keyExpiration)
|
||||
if err != tt.error {
|
||||
t.Log("\t\tGot :", err)
|
||||
t.Log("\t\tWant:", tt.error)
|
||||
t.Fatalf("\t%s\tNewAuthenticatorAws failed.", tests.Failed)
|
||||
}
|
||||
|
||||
tknStr, err := a.GenerateToken(signedClaims)
|
||||
if err != nil {
|
||||
t.Log("\t\tGot :", err)
|
||||
t.Fatalf("\t%s\tGenerateToken failed.", tests.Failed)
|
||||
}
|
||||
|
||||
parsedClaims, err := a.ParseClaims(tknStr)
|
||||
if err != nil {
|
||||
t.Log("\t\tGot :", err)
|
||||
t.Fatalf("\t%s\tParseClaims failed.", tests.Failed)
|
||||
}
|
||||
|
||||
// Assert expected claims.
|
||||
if exp, got := len(signedClaims.Roles), len(parsedClaims.Roles); exp != got {
|
||||
t.Log("\t\tGot :", got)
|
||||
t.Log("\t\tWant:", exp)
|
||||
t.Fatalf("\t%s\tShould got the same number of roles.", tests.Failed)
|
||||
}
|
||||
if exp, got := signedClaims.Roles[0], parsedClaims.Roles[0]; exp != got {
|
||||
t.Log("\t\tGot :", got)
|
||||
t.Log("\t\tWant:", exp)
|
||||
t.Fatalf("\t%s\tShould got the same role name.", tests.Failed)
|
||||
}
|
||||
|
||||
t.Logf("\t%s\tNewAuthenticatorAws ok.", tests.Success)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
86
internal/platform/auth/claims.go
Normal file
86
internal/platform/auth/claims.go
Normal file
@ -0,0 +1,86 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
jwt "github.com/dgrijalva/jwt-go"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// These are the expected values for Claims.Roles.
|
||||
const (
|
||||
RoleAdmin = "admin"
|
||||
RoleUser = "user"
|
||||
)
|
||||
|
||||
// ctxKey represents the type of value for the context key.
|
||||
type ctxKey int
|
||||
|
||||
// Key is used to store/retrieve a Claims value from a context.Context.
|
||||
const Key ctxKey = 1
|
||||
|
||||
// Claims represents the authorization claims transmitted via a JWT.
|
||||
type Claims struct {
|
||||
AccountIds []string `json:"accounts"`
|
||||
Roles []string `json:"roles"`
|
||||
Timezone string `json:"timezone"`
|
||||
tz *time.Location
|
||||
jwt.StandardClaims
|
||||
}
|
||||
|
||||
// NewClaims constructs a Claims value for the identified user. The Claims
|
||||
// expire within a specified duration of the provided time. Additional fields
|
||||
// of the Claims can be set after calling NewClaims is desired.
|
||||
func NewClaims(userId, accountId string, accountIds []string, roles []string, userTimezone *time.Location, now time.Time, expires time.Duration) Claims {
|
||||
c := Claims{
|
||||
AccountIds: accountIds,
|
||||
Roles: roles,
|
||||
StandardClaims: jwt.StandardClaims{
|
||||
Subject: userId,
|
||||
Audience: accountId,
|
||||
IssuedAt: now.Unix(),
|
||||
ExpiresAt: now.Add(expires).Unix(),
|
||||
},
|
||||
}
|
||||
|
||||
if userTimezone != nil {
|
||||
c.Timezone = userTimezone.String()
|
||||
}
|
||||
|
||||
return c
|
||||
}
|
||||
|
||||
// Valid is called during the parsing of a token.
|
||||
func (c Claims) Valid() error {
|
||||
for _, r := range c.Roles {
|
||||
switch r {
|
||||
case RoleAdmin, RoleUser: // Role is valid.
|
||||
default:
|
||||
return fmt.Errorf("invalid role %q", r)
|
||||
}
|
||||
}
|
||||
if err := c.StandardClaims.Valid(); err != nil {
|
||||
return errors.Wrap(err, "validating standard claims")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// HasRole returns true if the claims has at least one of the provided roles.
|
||||
func (c Claims) HasRole(roles ...string) bool {
|
||||
for _, has := range c.Roles {
|
||||
for _, want := range roles {
|
||||
if has == want {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (c Claims) TimeLocation() *time.Location {
|
||||
if c.tz == nil && c.Timezone != "" {
|
||||
c.tz, _ = time.LoadLocation(c.Timezone)
|
||||
}
|
||||
return c.tz
|
||||
}
|
33
internal/platform/auth/key_gen.go
Normal file
33
internal/platform/auth/key_gen.go
Normal file
@ -0,0 +1,33 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// Algorithm to be used to for the private key.
|
||||
const algorithm = "RS256"
|
||||
|
||||
// keyGen creates an x509 private key for signing auth tokens.
|
||||
func KeyGen() ([]byte, error) {
|
||||
key, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
return []byte{}, errors.Wrap(err, "generating keys")
|
||||
}
|
||||
|
||||
block := pem.Block{
|
||||
Type: "RSA PRIVATE KEY",
|
||||
Bytes: x509.MarshalPKCS1PrivateKey(key),
|
||||
}
|
||||
|
||||
buf := new(bytes.Buffer)
|
||||
if err := pem.Encode(buf, &block); err != nil {
|
||||
return []byte{}, errors.Wrap(err, "encoding to private file")
|
||||
}
|
||||
|
||||
return buf.Bytes(), nil
|
||||
}
|
73
internal/platform/auth/storage.go
Normal file
73
internal/platform/auth/storage.go
Normal file
@ -0,0 +1,73 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"github.com/dgrijalva/jwt-go"
|
||||
"github.com/pborman/uuid"
|
||||
"github.com/pkg/errors"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Storage provides the ability to persist keys to custom locations.
|
||||
type Storage interface {
|
||||
// Keys returns a map of private keys by kID.
|
||||
Keys() map[string]*PrivateKey
|
||||
// Current returns the most recently generated private key.
|
||||
Current() *PrivateKey
|
||||
}
|
||||
|
||||
// StorageMemory is a storage engine that stores a single private key in memory.
|
||||
type StorageMemory struct {
|
||||
privateKey *PrivateKey
|
||||
}
|
||||
|
||||
// Keys returns a map of private keys by kID.
|
||||
func (s *StorageMemory) Keys() map[string]*PrivateKey {
|
||||
if s == nil || s.privateKey == nil {
|
||||
return map[string]*PrivateKey{}
|
||||
}
|
||||
return map[string]*PrivateKey{
|
||||
s.privateKey.keyID: s.privateKey,
|
||||
}
|
||||
}
|
||||
|
||||
// Current returns the most recently generated private key.
|
||||
func (s *StorageMemory) Current() *PrivateKey {
|
||||
if s == nil {
|
||||
return nil
|
||||
}
|
||||
return s.privateKey
|
||||
}
|
||||
|
||||
// NewAuthenticatorMemory is a help function that inits a new Authenticator with a single key stored in memory.
|
||||
func NewAuthenticatorMemory(now time.Time) (*Authenticator, error) {
|
||||
storage, err := NewStorageMemory()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return NewAuthenticator(storage, now)
|
||||
}
|
||||
|
||||
// NewStorageMemory implements the interface Storage to store a single key in memory.
|
||||
func NewStorageMemory() (*StorageMemory, error) {
|
||||
|
||||
privateKey, err := KeyGen()
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to generate new private key")
|
||||
}
|
||||
|
||||
pk, err := jwt.ParseRSAPrivateKeyFromPEM(privateKey)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "parsing auth private key")
|
||||
}
|
||||
|
||||
storage := &StorageMemory{
|
||||
privateKey: &PrivateKey{
|
||||
PrivateKey: pk,
|
||||
keyID: uuid.NewRandom().String(),
|
||||
algorithm: algorithm,
|
||||
},
|
||||
}
|
||||
|
||||
return storage, nil
|
||||
}
|
236
internal/platform/auth/storage_aws.go
Normal file
236
internal/platform/auth/storage_aws.go
Normal file
@ -0,0 +1,236 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws"
|
||||
"github.com/aws/aws-sdk-go/aws/awserr"
|
||||
"github.com/aws/aws-sdk-go/aws/session"
|
||||
"github.com/aws/aws-sdk-go/service/secretsmanager"
|
||||
"github.com/dgrijalva/jwt-go"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// StorageAws is a storage engine that uses AWS Secrets Manager to persist private keys.
|
||||
type StorageAws struct {
|
||||
keyExpiration time.Duration
|
||||
// Map of keys by kid (version id).
|
||||
keys map[string]*PrivateKey
|
||||
// The current active key to be used.
|
||||
curPrivateKey *PrivateKey
|
||||
}
|
||||
|
||||
// Keys returns a map of private keys by kID.
|
||||
func (s *StorageAws) Keys() map[string]*PrivateKey {
|
||||
if s == nil || s.keys == nil {
|
||||
return map[string]*PrivateKey{}
|
||||
}
|
||||
return s.keys
|
||||
}
|
||||
|
||||
// Current returns the most recently generated private key.
|
||||
func (s *StorageAws) Current() *PrivateKey {
|
||||
if s == nil {
|
||||
return nil
|
||||
}
|
||||
return s.curPrivateKey
|
||||
}
|
||||
|
||||
// NewAuthenticatorAws is a help function that inits a new Authenticator
|
||||
// using the AWS storage.
|
||||
func NewAuthenticatorAws(awsSession *session.Session, awsSecretID string, now time.Time, keyExpiration time.Duration) (*Authenticator, error) {
|
||||
storage, err := NewStorageAws(awsSession, awsSecretID, now, keyExpiration)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return NewAuthenticator(storage, now)
|
||||
}
|
||||
|
||||
// NewStorageAws implements the interface Storage to support persisting private keys
|
||||
// to AWS Secrets Manager.
|
||||
// It will error if:
|
||||
// - The aws session is nil.
|
||||
// - The aws secret id is blank.
|
||||
func NewStorageAws(awsSession *session.Session, awsSecretID string, now time.Time, keyExpiration time.Duration) (*StorageAws, error) {
|
||||
if awsSession == nil {
|
||||
return nil, errors.New("aws session cannot be nil")
|
||||
}
|
||||
|
||||
if awsSecretID == "" {
|
||||
return nil, errors.New("aws secret id cannot be empty")
|
||||
}
|
||||
|
||||
storage := &StorageAws{
|
||||
keyExpiration: keyExpiration,
|
||||
keys: make(map[string]*PrivateKey),
|
||||
}
|
||||
|
||||
if now.IsZero() {
|
||||
now = time.Now().UTC()
|
||||
}
|
||||
|
||||
// Time threshold to stop loading keys, any key with a created date
|
||||
// before this value will not be loaded.
|
||||
var disabledCreatedDate time.Time
|
||||
|
||||
// Time threshold to create a new key. If a current key exists and the
|
||||
// created date of the key is before this value, a new key will be created.
|
||||
var activeCreatedDate time.Time
|
||||
|
||||
// If an expiration duration is included, convert to past time from now.
|
||||
if keyExpiration.Seconds() != 0 {
|
||||
// Ensure the expiration is a time in the past for comparison below.
|
||||
if keyExpiration.Seconds() > 0 {
|
||||
keyExpiration = keyExpiration * -1
|
||||
}
|
||||
// Stop loading keys when the created date exceeds two times the key expiration
|
||||
disabledCreatedDate = now.UTC().Add(keyExpiration * 2)
|
||||
|
||||
// Time used to determine when a new key should be created.
|
||||
activeCreatedDate = now.UTC().Add(keyExpiration)
|
||||
}
|
||||
|
||||
// Init new AWS Secret Manager using provided AWS session.
|
||||
secretManager := secretsmanager.New(awsSession)
|
||||
|
||||
// A List of version ids for the stored secret. All keys will be stored under
|
||||
// the same name in AWS secret manager. We still want to load old keys for a
|
||||
// short period of time to ensure any requests in flight have the opportunity
|
||||
// to be completed.
|
||||
var versionIds []string
|
||||
|
||||
// Exec call to AWS secret manager to return a list of version ids for the
|
||||
// provided secret ID.
|
||||
listParams := &secretsmanager.ListSecretVersionIdsInput{
|
||||
SecretId: aws.String(awsSecretID),
|
||||
}
|
||||
err := secretManager.ListSecretVersionIdsPages(listParams,
|
||||
func(page *secretsmanager.ListSecretVersionIdsOutput, lastPage bool) bool {
|
||||
for _, v := range page.Versions {
|
||||
// When disabled CreatedDate is not empty, compare the created date
|
||||
// for each key version to the disabled cut off time.
|
||||
if !disabledCreatedDate.IsZero() && v.CreatedDate != nil && !v.CreatedDate.IsZero() {
|
||||
// Skip any version ids that are less than the expiration time.
|
||||
if v.CreatedDate.UTC().Unix() < disabledCreatedDate.UTC().Unix() {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
if v.VersionId != nil {
|
||||
versionIds = append(versionIds, *v.VersionId)
|
||||
}
|
||||
}
|
||||
return !lastPage
|
||||
},
|
||||
)
|
||||
|
||||
// Flag whether the secret exists and update needs to be used
|
||||
// instead of create.
|
||||
var awsSecretIDNotFound bool
|
||||
if err != nil {
|
||||
if aerr, ok := err.(awserr.Error); ok {
|
||||
switch aerr.Code() {
|
||||
case secretsmanager.ErrCodeResourceNotFoundException:
|
||||
awsSecretIDNotFound = true
|
||||
}
|
||||
}
|
||||
|
||||
if !awsSecretIDNotFound {
|
||||
return nil, errors.Wrapf(err, "aws list secret version ids for secret ID %s failed", awsSecretID)
|
||||
}
|
||||
}
|
||||
|
||||
// Map of keys stored by version id. version id is kid.
|
||||
keyContents := make(map[string][]byte)
|
||||
|
||||
// The current key id if there is an active one.
|
||||
var curKeyId string
|
||||
|
||||
// If the list of version ids is not empty, load the keys from secret manager.
|
||||
if len(versionIds) > 0 {
|
||||
// The max created data to determine the most recent key.
|
||||
var lastCreatedDate time.Time
|
||||
|
||||
for _, id := range versionIds {
|
||||
res, err := secretManager.GetSecretValue(&secretsmanager.GetSecretValueInput{
|
||||
SecretId: aws.String(awsSecretID),
|
||||
VersionId: aws.String(id),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "aws secret id %s, version id %s value failed", awsSecretID, id)
|
||||
}
|
||||
|
||||
if len(res.SecretBinary) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
keyContents[*res.VersionId] = res.SecretBinary
|
||||
|
||||
if lastCreatedDate.IsZero() || res.CreatedDate.UTC().Unix() > lastCreatedDate.UTC().Unix() {
|
||||
curKeyId = *res.VersionId
|
||||
lastCreatedDate = res.CreatedDate.UTC()
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
if !activeCreatedDate.IsZero() && lastCreatedDate.UTC().Unix() < activeCreatedDate.UTC().Unix() {
|
||||
curKeyId = ""
|
||||
}
|
||||
}
|
||||
|
||||
// If there are no keys stored in secret manager, create a new one or
|
||||
// if the current key needs to be rotated, generate a new key and update the secret.
|
||||
// @TODO: When a new key is generated and there are multiple instances of the service running
|
||||
// its possible based on the key expiration set that requests fail because keys are only
|
||||
// refreshed on instance launch. Could store keys in a kv store and update that value
|
||||
// when new keys are generated
|
||||
if len(keyContents) == 0 || curKeyId == "" {
|
||||
privateKey, err := KeyGen()
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to generate new private key")
|
||||
}
|
||||
|
||||
if awsSecretIDNotFound {
|
||||
res, err := secretManager.CreateSecret(&secretsmanager.CreateSecretInput{
|
||||
Name: aws.String(awsSecretID),
|
||||
SecretBinary: privateKey,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to create new secret with private key")
|
||||
}
|
||||
curKeyId = *res.VersionId
|
||||
} else {
|
||||
res, err := secretManager.UpdateSecret(&secretsmanager.UpdateSecretInput{
|
||||
SecretId: aws.String(awsSecretID),
|
||||
SecretBinary: privateKey,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to create new secret with private key")
|
||||
}
|
||||
curKeyId = *res.VersionId
|
||||
}
|
||||
|
||||
keyContents[curKeyId] = privateKey
|
||||
}
|
||||
|
||||
// Loop through all the key bytes and load the private key.
|
||||
for kid, key := range keyContents {
|
||||
pk, err := jwt.ParseRSAPrivateKeyFromPEM(key)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "parsing auth private key")
|
||||
}
|
||||
|
||||
storage.keys[kid] = &PrivateKey{
|
||||
PrivateKey: pk,
|
||||
keyID: kid,
|
||||
algorithm: algorithm,
|
||||
}
|
||||
|
||||
if kid == curKeyId {
|
||||
storage.curPrivateKey = storage.keys[kid]
|
||||
}
|
||||
}
|
||||
|
||||
return storage, nil
|
||||
}
|
204
internal/platform/auth/storage_file.go
Normal file
204
internal/platform/auth/storage_file.go
Normal file
@ -0,0 +1,204 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/dgrijalva/jwt-go"
|
||||
"github.com/pborman/uuid"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// StorageFile is a storage engine that stores private keys on the local file system.
|
||||
type StorageFile struct {
|
||||
// Local directory for storing private keys.
|
||||
localDir string
|
||||
// Duration for keys to be valid.
|
||||
keyExpiration time.Duration
|
||||
// Map of keys by kid (version id).
|
||||
keys map[string]*PrivateKey
|
||||
// The current active key to be used.
|
||||
curPrivateKey *PrivateKey
|
||||
}
|
||||
|
||||
// Keys returns a map of private keys by kID.
|
||||
func (s *StorageFile) Keys() map[string]*PrivateKey {
|
||||
if s == nil || s.keys == nil {
|
||||
return map[string]*PrivateKey{}
|
||||
}
|
||||
return s.keys
|
||||
}
|
||||
|
||||
// Current returns the most recently generated private key.
|
||||
func (s *StorageFile) Current() *PrivateKey {
|
||||
if s == nil {
|
||||
return nil
|
||||
}
|
||||
return s.curPrivateKey
|
||||
}
|
||||
|
||||
// NewAuthenticatorFile is a help function that inits a new Authenticator
|
||||
// using the file storage.
|
||||
func NewAuthenticatorFile(localDir string, now time.Time, keyExpiration time.Duration) (*Authenticator, error) {
|
||||
storage, err := NewStorageFile(localDir, now, keyExpiration)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return NewAuthenticator(storage, now)
|
||||
}
|
||||
|
||||
// NewStorageFile implements the interface Storage to support persisting private keys
|
||||
// to the local file system.
|
||||
func NewStorageFile(localDir string, now time.Time, keyExpiration time.Duration) (*StorageFile, error) {
|
||||
if localDir == "" {
|
||||
localDir = filepath.Join(os.TempDir(), "auth-private-keys")
|
||||
}
|
||||
|
||||
if _, err := os.Stat(localDir); os.IsNotExist(err) {
|
||||
err = os.MkdirAll(localDir, os.ModePerm)
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "failed to create storage directory %s", localDir)
|
||||
}
|
||||
}
|
||||
|
||||
storage := &StorageFile{
|
||||
localDir: localDir,
|
||||
keyExpiration: keyExpiration,
|
||||
keys: make(map[string]*PrivateKey),
|
||||
}
|
||||
|
||||
if now.IsZero() {
|
||||
now = time.Now().UTC()
|
||||
}
|
||||
|
||||
// Time threshold to stop loading keys, any key with a created date
|
||||
// before this value will not be loaded.
|
||||
var disabledCreatedDate time.Time
|
||||
|
||||
// Time threshold to create a new key. If a current key exists and the
|
||||
// created date of the key is before this value, a new key will be created.
|
||||
var activeCreatedDate time.Time
|
||||
|
||||
// If an expiration duration is included, convert to past time from now.
|
||||
if keyExpiration.Seconds() != 0 {
|
||||
// Ensure the expiration is a time in the past for comparison below.
|
||||
if keyExpiration.Seconds() > 0 {
|
||||
keyExpiration = keyExpiration * -1
|
||||
}
|
||||
// Stop loading keys when the created date exceeds two times the key expiration
|
||||
disabledCreatedDate = now.UTC().Add(keyExpiration * 2)
|
||||
|
||||
// Time used to determine when a new key should be created.
|
||||
activeCreatedDate = now.UTC().Add(keyExpiration)
|
||||
}
|
||||
|
||||
// Values used to format filename.
|
||||
filePrefix := "sassauth_"
|
||||
fileExt := ".privatekey"
|
||||
|
||||
files, err := ioutil.ReadDir(localDir)
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "failed to list files in directory %s", localDir)
|
||||
}
|
||||
|
||||
// Map of keys stored by version id. version id is kid.
|
||||
keyContents := make(map[string][]byte)
|
||||
|
||||
// The current key id if there is an active one.
|
||||
var curKeyId string
|
||||
|
||||
// The max created data to determine the most recent key.
|
||||
var lastCreatedDate time.Time
|
||||
|
||||
for _, f := range files {
|
||||
if !strings.HasPrefix(f.Name(), filePrefix) || !strings.HasSuffix(f.Name(), fileExt) {
|
||||
continue
|
||||
}
|
||||
|
||||
// Extract the created timestamp and kID from the filename.
|
||||
fname := strings.TrimSuffix(f.Name(), fileExt)
|
||||
pts := strings.Split(fname, "_")
|
||||
if len(pts) != 3 {
|
||||
return nil, errors.Errorf("unable to parse filename %s", f.Name())
|
||||
}
|
||||
createdAt := pts[1]
|
||||
kID := pts[2]
|
||||
|
||||
// Covert string timestamp to int.
|
||||
createdAtSecs, err := strconv.Atoi(createdAt)
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "failed parse timestamp from %s", f.Name())
|
||||
}
|
||||
ts := time.Unix(int64(createdAtSecs), 0)
|
||||
|
||||
// If the created time of the key is less than the disabled threshold, skip.
|
||||
if !disabledCreatedDate.IsZero() && ts.UTC().Unix() < disabledCreatedDate.UTC().Unix() {
|
||||
continue
|
||||
}
|
||||
|
||||
filePath := filepath.Join(localDir, f.Name())
|
||||
dat, err := ioutil.ReadFile(filePath)
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "failed read file %s", f.Name())
|
||||
}
|
||||
|
||||
keyContents[kID] = dat
|
||||
|
||||
if lastCreatedDate.IsZero() || ts.UTC().Unix() > lastCreatedDate.UTC().Unix() {
|
||||
curKeyId = kID
|
||||
lastCreatedDate = ts.UTC()
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
if !activeCreatedDate.IsZero() && lastCreatedDate.UTC().Unix() < activeCreatedDate.UTC().Unix() {
|
||||
curKeyId = ""
|
||||
}
|
||||
|
||||
// If there are no keys or the current key needs to be rotated, generate a new key.
|
||||
if len(keyContents) == 0 || curKeyId == "" {
|
||||
privateKey, err := KeyGen()
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to generate new private key")
|
||||
}
|
||||
|
||||
kID := uuid.NewRandom().String()
|
||||
|
||||
fname := fmt.Sprintf("%s%d_%s%s", filePrefix, now.UTC().Unix(), kID, fileExt)
|
||||
|
||||
filePath := filepath.Join(localDir, fname)
|
||||
|
||||
err = ioutil.WriteFile(filePath, privateKey, 0644)
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "failed write file %s", filePath)
|
||||
}
|
||||
|
||||
keyContents[curKeyId] = privateKey
|
||||
}
|
||||
|
||||
// Loop through all the key bytes and load the private key.
|
||||
for kid, key := range keyContents {
|
||||
pk, err := jwt.ParseRSAPrivateKeyFromPEM(key)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "parsing auth private key")
|
||||
}
|
||||
|
||||
storage.keys[kid] = &PrivateKey{
|
||||
PrivateKey: pk,
|
||||
keyID: kid,
|
||||
algorithm: algorithm,
|
||||
}
|
||||
|
||||
if kid == curKeyId {
|
||||
storage.curPrivateKey = storage.keys[kid]
|
||||
}
|
||||
}
|
||||
|
||||
return storage, nil
|
||||
}
|
Reference in New Issue
Block a user