mirror of
https://github.com/oauth2-proxy/oauth2-proxy.git
synced 2025-03-27 22:01:28 +02:00
Merge pull request #539 from grnhse/encryption-efficiency-improvements
Encryption efficiency improvements
This commit is contained in:
commit
a197a17bc3
@ -55,6 +55,7 @@
|
||||
|
||||
## Changes since v5.1.1
|
||||
|
||||
- [#539](https://github.com/oauth2-proxy/oauth2-proxy/pull/539) Refactor encryption ciphers and add AES-GCM support (@NickMeves)
|
||||
- [#601](https://github.com/oauth2-proxy/oauth2-proxy/pull/601) Ensure decrypted user/email are valid UTF8 (@JoelSpeed)
|
||||
- [#560](https://github.com/oauth2-proxy/oauth2-proxy/pull/560) Fallback to UserInfo is User ID claim not present (@JoelSpeed)
|
||||
- [#598](https://github.com/oauth2-proxy/oauth2-proxy/pull/598) acr_values no longer sent to IdP when empty (@ScottGuymer)
|
||||
|
@ -4,9 +4,9 @@ import "github.com/oauth2-proxy/oauth2-proxy/pkg/encryption"
|
||||
|
||||
// SessionOptions contains configuration options for the SessionStore providers.
|
||||
type SessionOptions struct {
|
||||
Type string `flag:"session-store-type" cfg:"session_store_type"`
|
||||
Cipher *encryption.Cipher `cfg:",internal"`
|
||||
Redis RedisStoreOptions `cfg:",squash"`
|
||||
Type string `flag:"session-store-type" cfg:"session_store_type"`
|
||||
Cipher encryption.Cipher `cfg:",internal"`
|
||||
Redis RedisStoreOptions `cfg:",squash"`
|
||||
}
|
||||
|
||||
// CookieSessionStoreType is used to indicate the CookieSessionStore should be
|
||||
|
@ -60,7 +60,7 @@ func (s *SessionState) String() string {
|
||||
}
|
||||
|
||||
// EncodeSessionState returns string representation of the current session
|
||||
func (s *SessionState) EncodeSessionState(c *encryption.Cipher) (string, error) {
|
||||
func (s *SessionState) EncodeSessionState(c encryption.Cipher) (string, error) {
|
||||
var ss SessionState
|
||||
if c == nil {
|
||||
// Store only Email and User when cipher is unavailable
|
||||
@ -77,7 +77,7 @@ func (s *SessionState) EncodeSessionState(c *encryption.Cipher) (string, error)
|
||||
&ss.IDToken,
|
||||
&ss.RefreshToken,
|
||||
} {
|
||||
err := c.EncryptInto(s)
|
||||
err := into(s, c.Encrypt)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
@ -89,7 +89,7 @@ func (s *SessionState) EncodeSessionState(c *encryption.Cipher) (string, error)
|
||||
}
|
||||
|
||||
// DecodeSessionState decodes the session cookie string into a SessionState
|
||||
func DecodeSessionState(v string, c *encryption.Cipher) (*SessionState, error) {
|
||||
func DecodeSessionState(v string, c encryption.Cipher) (*SessionState, error) {
|
||||
var ss SessionState
|
||||
err := json.Unmarshal([]byte(v), &ss)
|
||||
if err != nil {
|
||||
@ -104,24 +104,18 @@ func DecodeSessionState(v string, c *encryption.Cipher) (*SessionState, error) {
|
||||
PreferredUsername: ss.PreferredUsername,
|
||||
}
|
||||
} else {
|
||||
// Backward compatibility with using unencrypted Email
|
||||
if ss.Email != "" {
|
||||
decryptedEmail, errEmail := c.Decrypt(ss.Email)
|
||||
if errEmail == nil {
|
||||
if !utf8.ValidString(decryptedEmail) {
|
||||
return nil, errors.New("invalid value for decrypted email")
|
||||
}
|
||||
ss.Email = decryptedEmail
|
||||
// Backward compatibility with using unencrypted Email or User
|
||||
// Decryption errors will leave original string
|
||||
err = into(&ss.Email, c.Decrypt)
|
||||
if err == nil {
|
||||
if !utf8.ValidString(ss.Email) {
|
||||
return nil, errors.New("invalid value for decrypted email")
|
||||
}
|
||||
}
|
||||
// Backward compatibility with using unencrypted User
|
||||
if ss.User != "" {
|
||||
decryptedUser, errUser := c.Decrypt(ss.User)
|
||||
if errUser == nil {
|
||||
if !utf8.ValidString(decryptedUser) {
|
||||
return nil, errors.New("invalid value for decrypted user")
|
||||
}
|
||||
ss.User = decryptedUser
|
||||
err = into(&ss.User, c.Decrypt)
|
||||
if err == nil {
|
||||
if !utf8.ValidString(ss.User) {
|
||||
return nil, errors.New("invalid value for decrypted user")
|
||||
}
|
||||
}
|
||||
|
||||
@ -131,7 +125,7 @@ func DecodeSessionState(v string, c *encryption.Cipher) (*SessionState, error) {
|
||||
&ss.IDToken,
|
||||
&ss.RefreshToken,
|
||||
} {
|
||||
err := c.DecryptInto(s)
|
||||
err := into(s, c.Decrypt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -139,3 +133,20 @@ func DecodeSessionState(v string, c *encryption.Cipher) (*SessionState, error) {
|
||||
}
|
||||
return &ss, nil
|
||||
}
|
||||
|
||||
// codecFunc is a function that takes a []byte and encodes/decodes it
|
||||
type codecFunc func([]byte) ([]byte, error)
|
||||
|
||||
func into(s *string, f codecFunc) error {
|
||||
// Do not encrypt/decrypt nil or empty strings
|
||||
if s == nil || *s == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
d, err := f([]byte(*s))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*s = string(d)
|
||||
return nil
|
||||
}
|
||||
|
@ -1,11 +1,13 @@
|
||||
package sessions_test
|
||||
package sessions
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
"io"
|
||||
mathrand "math/rand"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/pkg/encryption"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
@ -17,12 +19,16 @@ func timePtr(t time.Time) *time.Time {
|
||||
return &t
|
||||
}
|
||||
|
||||
func newTestCipher(secret []byte) (encryption.Cipher, error) {
|
||||
return encryption.NewBase64Cipher(encryption.NewCFBCipher, secret)
|
||||
}
|
||||
|
||||
func TestSessionStateSerialization(t *testing.T) {
|
||||
c, err := encryption.NewCipher([]byte(secret))
|
||||
c, err := newTestCipher([]byte(secret))
|
||||
assert.Equal(t, nil, err)
|
||||
c2, err := encryption.NewCipher([]byte(altSecret))
|
||||
c2, err := newTestCipher([]byte(altSecret))
|
||||
assert.Equal(t, nil, err)
|
||||
s := &sessions.SessionState{
|
||||
s := &SessionState{
|
||||
Email: "user@domain.com",
|
||||
PreferredUsername: "user",
|
||||
AccessToken: "token1234",
|
||||
@ -34,7 +40,7 @@ func TestSessionStateSerialization(t *testing.T) {
|
||||
encoded, err := s.EncodeSessionState(c)
|
||||
assert.Equal(t, nil, err)
|
||||
|
||||
ss, err := sessions.DecodeSessionState(encoded, c)
|
||||
ss, err := DecodeSessionState(encoded, c)
|
||||
t.Logf("%#v", ss)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.Equal(t, "", ss.User)
|
||||
@ -47,17 +53,17 @@ func TestSessionStateSerialization(t *testing.T) {
|
||||
assert.Equal(t, s.RefreshToken, ss.RefreshToken)
|
||||
|
||||
// ensure a different cipher can't decode properly (ie: it gets gibberish)
|
||||
ss, err = sessions.DecodeSessionState(encoded, c2)
|
||||
ss, err = DecodeSessionState(encoded, c2)
|
||||
t.Logf("%#v", ss)
|
||||
assert.NotEqual(t, nil, err)
|
||||
}
|
||||
|
||||
func TestSessionStateSerializationWithUser(t *testing.T) {
|
||||
c, err := encryption.NewCipher([]byte(secret))
|
||||
c, err := newTestCipher([]byte(secret))
|
||||
assert.Equal(t, nil, err)
|
||||
c2, err := encryption.NewCipher([]byte(altSecret))
|
||||
c2, err := newTestCipher([]byte(altSecret))
|
||||
assert.Equal(t, nil, err)
|
||||
s := &sessions.SessionState{
|
||||
s := &SessionState{
|
||||
User: "just-user",
|
||||
PreferredUsername: "ju",
|
||||
Email: "user@domain.com",
|
||||
@ -69,7 +75,7 @@ func TestSessionStateSerializationWithUser(t *testing.T) {
|
||||
encoded, err := s.EncodeSessionState(c)
|
||||
assert.Equal(t, nil, err)
|
||||
|
||||
ss, err := sessions.DecodeSessionState(encoded, c)
|
||||
ss, err := DecodeSessionState(encoded, c)
|
||||
t.Logf("%#v", ss)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.Equal(t, s.User, ss.User)
|
||||
@ -81,13 +87,13 @@ func TestSessionStateSerializationWithUser(t *testing.T) {
|
||||
assert.Equal(t, s.RefreshToken, ss.RefreshToken)
|
||||
|
||||
// ensure a different cipher can't decode properly (ie: it gets gibberish)
|
||||
ss, err = sessions.DecodeSessionState(encoded, c2)
|
||||
ss, err = DecodeSessionState(encoded, c2)
|
||||
t.Logf("%#v", ss)
|
||||
assert.NotEqual(t, nil, err)
|
||||
}
|
||||
|
||||
func TestSessionStateSerializationNoCipher(t *testing.T) {
|
||||
s := &sessions.SessionState{
|
||||
s := &SessionState{
|
||||
Email: "user@domain.com",
|
||||
PreferredUsername: "user",
|
||||
AccessToken: "token1234",
|
||||
@ -99,7 +105,7 @@ func TestSessionStateSerializationNoCipher(t *testing.T) {
|
||||
assert.Equal(t, nil, err)
|
||||
|
||||
// only email should have been serialized
|
||||
ss, err := sessions.DecodeSessionState(encoded, nil)
|
||||
ss, err := DecodeSessionState(encoded, nil)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.Equal(t, "", ss.User)
|
||||
assert.Equal(t, s.Email, ss.Email)
|
||||
@ -109,7 +115,7 @@ func TestSessionStateSerializationNoCipher(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestSessionStateSerializationNoCipherWithUser(t *testing.T) {
|
||||
s := &sessions.SessionState{
|
||||
s := &SessionState{
|
||||
User: "just-user",
|
||||
Email: "user@domain.com",
|
||||
PreferredUsername: "user",
|
||||
@ -122,7 +128,7 @@ func TestSessionStateSerializationNoCipherWithUser(t *testing.T) {
|
||||
assert.Equal(t, nil, err)
|
||||
|
||||
// only email should have been serialized
|
||||
ss, err := sessions.DecodeSessionState(encoded, nil)
|
||||
ss, err := DecodeSessionState(encoded, nil)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.Equal(t, s.User, ss.User)
|
||||
assert.Equal(t, s.Email, ss.Email)
|
||||
@ -132,20 +138,20 @@ func TestSessionStateSerializationNoCipherWithUser(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestExpired(t *testing.T) {
|
||||
s := &sessions.SessionState{ExpiresOn: timePtr(time.Now().Add(time.Duration(-1) * time.Minute))}
|
||||
s := &SessionState{ExpiresOn: timePtr(time.Now().Add(time.Duration(-1) * time.Minute))}
|
||||
assert.Equal(t, true, s.IsExpired())
|
||||
|
||||
s = &sessions.SessionState{ExpiresOn: timePtr(time.Now().Add(time.Duration(1) * time.Minute))}
|
||||
s = &SessionState{ExpiresOn: timePtr(time.Now().Add(time.Duration(1) * time.Minute))}
|
||||
assert.Equal(t, false, s.IsExpired())
|
||||
|
||||
s = &sessions.SessionState{}
|
||||
s = &SessionState{}
|
||||
assert.Equal(t, false, s.IsExpired())
|
||||
}
|
||||
|
||||
type testCase struct {
|
||||
sessions.SessionState
|
||||
SessionState
|
||||
Encoded string
|
||||
Cipher *encryption.Cipher
|
||||
Cipher encryption.Cipher
|
||||
Error bool
|
||||
}
|
||||
|
||||
@ -159,14 +165,14 @@ func TestEncodeSessionState(t *testing.T) {
|
||||
|
||||
testCases := []testCase{
|
||||
{
|
||||
SessionState: sessions.SessionState{
|
||||
SessionState: SessionState{
|
||||
Email: "user@domain.com",
|
||||
User: "just-user",
|
||||
},
|
||||
Encoded: `{"Email":"user@domain.com","User":"just-user"}`,
|
||||
},
|
||||
{
|
||||
SessionState: sessions.SessionState{
|
||||
SessionState: SessionState{
|
||||
Email: "user@domain.com",
|
||||
User: "just-user",
|
||||
AccessToken: "token1234",
|
||||
@ -181,7 +187,7 @@ func TestEncodeSessionState(t *testing.T) {
|
||||
|
||||
for i, tc := range testCases {
|
||||
encoded, err := tc.EncodeSessionState(tc.Cipher)
|
||||
t.Logf("i:%d Encoded:%#vsessions.SessionState:%#v Error:%#v", i, encoded, tc.SessionState, err)
|
||||
t.Logf("i:%d Encoded:%#vSessionState:%#v Error:%#v", i, encoded, tc.SessionState, err)
|
||||
if tc.Error {
|
||||
assert.Error(t, err)
|
||||
assert.Empty(t, encoded)
|
||||
@ -201,39 +207,39 @@ func TestDecodeSessionState(t *testing.T) {
|
||||
eJSON, _ := e.MarshalJSON()
|
||||
eString := string(eJSON)
|
||||
|
||||
c, err := encryption.NewCipher([]byte(secret))
|
||||
c, err := newTestCipher([]byte(secret))
|
||||
assert.NoError(t, err)
|
||||
|
||||
testCases := []testCase{
|
||||
{
|
||||
SessionState: sessions.SessionState{
|
||||
SessionState: SessionState{
|
||||
Email: "user@domain.com",
|
||||
User: "just-user",
|
||||
},
|
||||
Encoded: `{"Email":"user@domain.com","User":"just-user"}`,
|
||||
},
|
||||
{
|
||||
SessionState: sessions.SessionState{
|
||||
SessionState: SessionState{
|
||||
Email: "user@domain.com",
|
||||
User: "",
|
||||
},
|
||||
Encoded: `{"Email":"user@domain.com"}`,
|
||||
},
|
||||
{
|
||||
SessionState: sessions.SessionState{
|
||||
SessionState: SessionState{
|
||||
User: "just-user",
|
||||
},
|
||||
Encoded: `{"User":"just-user"}`,
|
||||
},
|
||||
{
|
||||
SessionState: sessions.SessionState{
|
||||
SessionState: SessionState{
|
||||
Email: "user@domain.com",
|
||||
User: "just-user",
|
||||
},
|
||||
Encoded: fmt.Sprintf(`{"Email":"user@domain.com","User":"just-user","AccessToken":"I6s+ml+/MldBMgHIiC35BTKTh57skGX24w==","IDToken":"xojNdyyjB1HgYWh6XMtXY/Ph5eCVxa1cNsklJw==","RefreshToken":"qEX0x6RmASxo4dhlBG6YuRs9Syn/e9sHu/+K","CreatedAt":%s,"ExpiresOn":%s}`, createdString, eString),
|
||||
},
|
||||
{
|
||||
SessionState: sessions.SessionState{
|
||||
SessionState: SessionState{
|
||||
Email: "user@domain.com",
|
||||
User: "just-user",
|
||||
AccessToken: "token1234",
|
||||
@ -246,7 +252,7 @@ func TestDecodeSessionState(t *testing.T) {
|
||||
Cipher: c,
|
||||
},
|
||||
{
|
||||
SessionState: sessions.SessionState{
|
||||
SessionState: SessionState{
|
||||
Email: "user@domain.com",
|
||||
User: "just-user",
|
||||
},
|
||||
@ -264,7 +270,7 @@ func TestDecodeSessionState(t *testing.T) {
|
||||
Error: true,
|
||||
},
|
||||
{
|
||||
SessionState: sessions.SessionState{
|
||||
SessionState: SessionState{
|
||||
Email: "user@domain.com",
|
||||
User: "YmFzZTY0LWVuY29kZWQtdXNlcgo=", // Base64 encoding of base64-encoded-user
|
||||
},
|
||||
@ -274,8 +280,8 @@ func TestDecodeSessionState(t *testing.T) {
|
||||
}
|
||||
|
||||
for i, tc := range testCases {
|
||||
ss, err := sessions.DecodeSessionState(tc.Encoded, tc.Cipher)
|
||||
t.Logf("i:%d Encoded:%#vsessions.SessionState:%#v Error:%#v", i, tc.Encoded, ss, err)
|
||||
ss, err := DecodeSessionState(tc.Encoded, tc.Cipher)
|
||||
t.Logf("i:%d Encoded:%#vSessionState:%#v Error:%#v", i, tc.Encoded, ss, err)
|
||||
if tc.Error {
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, ss)
|
||||
@ -297,7 +303,7 @@ func TestDecodeSessionState(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestSessionStateAge(t *testing.T) {
|
||||
ss := &sessions.SessionState{}
|
||||
ss := &SessionState{}
|
||||
|
||||
// Created at unset so should be 0
|
||||
assert.Equal(t, time.Duration(0), ss.Age())
|
||||
@ -306,3 +312,44 @@ func TestSessionStateAge(t *testing.T) {
|
||||
ss.CreatedAt = timePtr(time.Now().Add(-1 * time.Hour))
|
||||
assert.Equal(t, time.Hour, ss.Age().Round(time.Minute))
|
||||
}
|
||||
|
||||
func TestIntoEncryptAndIntoDecrypt(t *testing.T) {
|
||||
const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
|
||||
|
||||
// Test all 3 valid AES sizes
|
||||
for _, secretSize := range []int{16, 24, 32} {
|
||||
t.Run(fmt.Sprintf("%d", secretSize), func(t *testing.T) {
|
||||
secret := make([]byte, secretSize)
|
||||
_, err := io.ReadFull(rand.Reader, secret)
|
||||
assert.Equal(t, nil, err)
|
||||
|
||||
c, err := newTestCipher(secret)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Check no errors with empty or nil strings
|
||||
empty := ""
|
||||
assert.Equal(t, nil, into(&empty, c.Encrypt))
|
||||
assert.Equal(t, nil, into(&empty, c.Decrypt))
|
||||
assert.Equal(t, nil, into(nil, c.Encrypt))
|
||||
assert.Equal(t, nil, into(nil, c.Decrypt))
|
||||
|
||||
// Test various sizes tokens might be
|
||||
for _, dataSize := range []int{10, 100, 1000, 5000, 10000} {
|
||||
t.Run(fmt.Sprintf("%d", dataSize), func(t *testing.T) {
|
||||
b := make([]byte, dataSize)
|
||||
for i := range b {
|
||||
b[i] = charset[mathrand.Intn(len(charset))]
|
||||
}
|
||||
data := string(b)
|
||||
originalData := data
|
||||
|
||||
assert.Equal(t, nil, into(&data, c.Encrypt))
|
||||
assert.NotEqual(t, originalData, data)
|
||||
|
||||
assert.Equal(t, nil, into(&data, c.Decrypt))
|
||||
assert.Equal(t, originalData, data)
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -3,183 +3,134 @@ package encryption
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/hmac"
|
||||
"crypto/rand"
|
||||
"crypto/sha1"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"hash"
|
||||
"io"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// SecretBytes attempts to base64 decode the secret, if that fails it treats the secret as binary
|
||||
func SecretBytes(secret string) []byte {
|
||||
b, err := base64.RawURLEncoding.DecodeString(strings.TrimRight(secret, "="))
|
||||
if err == nil {
|
||||
// Only return decoded form if a valid AES length
|
||||
// Don't want unintentional decoding resulting in invalid lengths confusing a user
|
||||
// that thought they used a 16, 24, 32 length string
|
||||
for _, i := range []int{16, 24, 32} {
|
||||
if len(b) == i {
|
||||
return b
|
||||
}
|
||||
}
|
||||
}
|
||||
// If decoding didn't work or resulted in non-AES compliant length,
|
||||
// assume the raw string was the intended secret
|
||||
return []byte(secret)
|
||||
// Cipher provides methods to encrypt and decrypt
|
||||
type Cipher interface {
|
||||
Encrypt(value []byte) ([]byte, error)
|
||||
Decrypt(ciphertext []byte) ([]byte, error)
|
||||
}
|
||||
|
||||
// cookies are stored in a 3 part (value + timestamp + signature) to enforce that the values are as originally set.
|
||||
// additionally, the 'value' is encrypted so it's opaque to the browser
|
||||
|
||||
// Validate ensures a cookie is properly signed
|
||||
func Validate(cookie *http.Cookie, seed string, expiration time.Duration) (value string, t time.Time, ok bool) {
|
||||
// value, timestamp, sig
|
||||
parts := strings.Split(cookie.Value, "|")
|
||||
if len(parts) != 3 {
|
||||
return
|
||||
}
|
||||
if checkSignature(parts[2], seed, cookie.Name, parts[0], parts[1]) {
|
||||
ts, err := strconv.Atoi(parts[1])
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
// The expiration timestamp set when the cookie was created
|
||||
// isn't sent back by the browser. Hence, we check whether the
|
||||
// creation timestamp stored in the cookie falls within the
|
||||
// window defined by (Now()-expiration, Now()].
|
||||
t = time.Unix(int64(ts), 0)
|
||||
if t.After(time.Now().Add(expiration*-1)) && t.Before(time.Now().Add(time.Minute*5)) {
|
||||
// it's a valid cookie. now get the contents
|
||||
rawValue, err := base64.URLEncoding.DecodeString(parts[0])
|
||||
if err == nil {
|
||||
value = string(rawValue)
|
||||
ok = true
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
type base64Cipher struct {
|
||||
Cipher Cipher
|
||||
}
|
||||
|
||||
// SignedValue returns a cookie that is signed and can later be checked with Validate
|
||||
func SignedValue(seed string, key string, value string, now time.Time) string {
|
||||
encodedValue := base64.URLEncoding.EncodeToString([]byte(value))
|
||||
timeStr := fmt.Sprintf("%d", now.Unix())
|
||||
sig := cookieSignature(sha256.New, seed, key, encodedValue, timeStr)
|
||||
cookieVal := fmt.Sprintf("%s|%s|%s", encodedValue, timeStr, sig)
|
||||
return cookieVal
|
||||
}
|
||||
|
||||
func cookieSignature(signer func() hash.Hash, args ...string) string {
|
||||
h := hmac.New(signer, []byte(args[0]))
|
||||
for _, arg := range args[1:] {
|
||||
h.Write([]byte(arg))
|
||||
// NewBase64Cipher returns a new AES Cipher for encrypting cookie values
|
||||
// and wrapping them in Base64 -- Supports Legacy encryption scheme
|
||||
func NewBase64Cipher(initCipher func([]byte) (Cipher, error), secret []byte) (Cipher, error) {
|
||||
c, err := initCipher(secret)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var b []byte
|
||||
b = h.Sum(b)
|
||||
return base64.URLEncoding.EncodeToString(b)
|
||||
return &base64Cipher{Cipher: c}, nil
|
||||
}
|
||||
|
||||
func checkSignature(signature string, args ...string) bool {
|
||||
checkSig := cookieSignature(sha256.New, args...)
|
||||
if checkHmac(signature, checkSig) {
|
||||
return true
|
||||
// Encrypt encrypts a value with the embedded Cipher & Base64 encodes it
|
||||
func (c *base64Cipher) Encrypt(value []byte) ([]byte, error) {
|
||||
encrypted, err := c.Cipher.Encrypt(value)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// TODO: After appropriate rollout window, remove support for SHA1
|
||||
legacySig := cookieSignature(sha1.New, args...)
|
||||
return checkHmac(signature, legacySig)
|
||||
return []byte(base64.StdEncoding.EncodeToString(encrypted)), nil
|
||||
}
|
||||
|
||||
func checkHmac(input, expected string) bool {
|
||||
inputMAC, err1 := base64.URLEncoding.DecodeString(input)
|
||||
if err1 == nil {
|
||||
expectedMAC, err2 := base64.URLEncoding.DecodeString(expected)
|
||||
if err2 == nil {
|
||||
return hmac.Equal(inputMAC, expectedMAC)
|
||||
}
|
||||
// Decrypt Base64 decodes a value & decrypts it with the embedded Cipher
|
||||
func (c *base64Cipher) Decrypt(ciphertext []byte) ([]byte, error) {
|
||||
encrypted, err := base64.StdEncoding.DecodeString(string(ciphertext))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to base64 decode value %s", err)
|
||||
}
|
||||
return false
|
||||
|
||||
return c.Cipher.Decrypt(encrypted)
|
||||
}
|
||||
|
||||
// Cipher provides methods to encrypt and decrypt cookie values
|
||||
type Cipher struct {
|
||||
type cfbCipher struct {
|
||||
cipher.Block
|
||||
}
|
||||
|
||||
// NewCipher returns a new aes Cipher for encrypting cookie values
|
||||
func NewCipher(secret []byte) (*Cipher, error) {
|
||||
// NewCFBCipher returns a new AES CFB Cipher
|
||||
func NewCFBCipher(secret []byte) (Cipher, error) {
|
||||
c, err := aes.NewCipher(secret)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Cipher{Block: c}, err
|
||||
return &cfbCipher{Block: c}, err
|
||||
}
|
||||
|
||||
// Encrypt a value for use in a cookie
|
||||
func (c *Cipher) Encrypt(value string) (string, error) {
|
||||
// Encrypt with AES CFB
|
||||
func (c *cfbCipher) Encrypt(value []byte) ([]byte, error) {
|
||||
ciphertext := make([]byte, aes.BlockSize+len(value))
|
||||
iv := ciphertext[:aes.BlockSize]
|
||||
if _, err := io.ReadFull(rand.Reader, iv); err != nil {
|
||||
return "", fmt.Errorf("failed to create initialization vector %s", err)
|
||||
return nil, fmt.Errorf("failed to create initialization vector %s", err)
|
||||
}
|
||||
|
||||
stream := cipher.NewCFBEncrypter(c.Block, iv)
|
||||
stream.XORKeyStream(ciphertext[aes.BlockSize:], []byte(value))
|
||||
return base64.StdEncoding.EncodeToString(ciphertext), nil
|
||||
stream.XORKeyStream(ciphertext[aes.BlockSize:], value)
|
||||
return ciphertext, nil
|
||||
}
|
||||
|
||||
// Decrypt a value from a cookie to it's original string
|
||||
func (c *Cipher) Decrypt(s string) (string, error) {
|
||||
encrypted, err := base64.StdEncoding.DecodeString(s)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to decrypt cookie value %s", err)
|
||||
// Decrypt an AES CFB ciphertext
|
||||
func (c *cfbCipher) Decrypt(ciphertext []byte) ([]byte, error) {
|
||||
if len(ciphertext) < aes.BlockSize {
|
||||
return nil, fmt.Errorf("encrypted value should be at least %d bytes, but is only %d bytes", aes.BlockSize, len(ciphertext))
|
||||
}
|
||||
|
||||
if len(encrypted) < aes.BlockSize {
|
||||
return "", fmt.Errorf("encrypted cookie value should be "+
|
||||
"at least %d bytes, but is only %d bytes",
|
||||
aes.BlockSize, len(encrypted))
|
||||
}
|
||||
|
||||
iv := encrypted[:aes.BlockSize]
|
||||
encrypted = encrypted[aes.BlockSize:]
|
||||
iv, ciphertext := ciphertext[:aes.BlockSize], ciphertext[aes.BlockSize:]
|
||||
plaintext := make([]byte, len(ciphertext))
|
||||
stream := cipher.NewCFBDecrypter(c.Block, iv)
|
||||
stream.XORKeyStream(encrypted, encrypted)
|
||||
stream.XORKeyStream(plaintext, ciphertext)
|
||||
|
||||
return string(encrypted), nil
|
||||
return plaintext, nil
|
||||
}
|
||||
|
||||
// EncryptInto encrypts the value and stores it back in the string pointer
|
||||
func (c *Cipher) EncryptInto(s *string) error {
|
||||
return into(c.Encrypt, s)
|
||||
type gcmCipher struct {
|
||||
cipher.Block
|
||||
}
|
||||
|
||||
// DecryptInto decrypts the value and stores it back in the string pointer
|
||||
func (c *Cipher) DecryptInto(s *string) error {
|
||||
return into(c.Decrypt, s)
|
||||
}
|
||||
|
||||
// codecFunc is a function that takes a string and encodes/decodes it
|
||||
type codecFunc func(string) (string, error)
|
||||
|
||||
func into(f codecFunc, s *string) error {
|
||||
// Do not encrypt/decrypt nil or empty strings
|
||||
if s == nil || *s == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
d, err := f(*s)
|
||||
// NewGCMCipher returns a new AES GCM Cipher
|
||||
func NewGCMCipher(secret []byte) (Cipher, error) {
|
||||
c, err := aes.NewCipher(secret)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
*s = d
|
||||
return nil
|
||||
return &gcmCipher{Block: c}, err
|
||||
}
|
||||
|
||||
// Encrypt with AES GCM on raw bytes
|
||||
func (c *gcmCipher) Encrypt(value []byte) ([]byte, error) {
|
||||
gcm, err := cipher.NewGCM(c.Block)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
nonce := make([]byte, gcm.NonceSize())
|
||||
if _, err = io.ReadFull(rand.Reader, nonce); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Using nonce as Seal's dst argument results in it being the first
|
||||
// chunk of bytes in the ciphertext. Decrypt retrieves the nonce/IV from this.
|
||||
ciphertext := gcm.Seal(nonce, nonce, value, nil)
|
||||
return ciphertext, nil
|
||||
}
|
||||
|
||||
// Decrypt an AES GCM ciphertext
|
||||
func (c *gcmCipher) Decrypt(ciphertext []byte) ([]byte, error) {
|
||||
gcm, err := cipher.NewGCM(c.Block)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
nonceSize := gcm.NonceSize()
|
||||
nonce, ciphertext := ciphertext[:nonceSize], ciphertext[nonceSize:]
|
||||
|
||||
plaintext, err := gcm.Open(nil, nonce, ciphertext, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return plaintext, nil
|
||||
}
|
||||
|
@ -2,8 +2,6 @@ package encryption
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/sha1"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
@ -12,107 +10,20 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestSecretBytesEncoded(t *testing.T) {
|
||||
for _, secretSize := range []int{16, 24, 32} {
|
||||
t.Run(fmt.Sprintf("%d", secretSize), func(t *testing.T) {
|
||||
secret := make([]byte, secretSize)
|
||||
_, err := io.ReadFull(rand.Reader, secret)
|
||||
assert.Equal(t, nil, err)
|
||||
|
||||
// We test both padded & raw Base64 to ensure we handle both
|
||||
// potential user input routes for Base64
|
||||
base64Padded := base64.URLEncoding.EncodeToString(secret)
|
||||
sb := SecretBytes(base64Padded)
|
||||
assert.Equal(t, secret, sb)
|
||||
assert.Equal(t, len(sb), secretSize)
|
||||
|
||||
base64Raw := base64.RawURLEncoding.EncodeToString(secret)
|
||||
sb = SecretBytes(base64Raw)
|
||||
assert.Equal(t, secret, sb)
|
||||
assert.Equal(t, len(sb), secretSize)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// A string that isn't intended as Base64 and still decodes (but to unintended length)
|
||||
// will return the original secret as bytes
|
||||
func TestSecretBytesEncodedWrongSize(t *testing.T) {
|
||||
for _, secretSize := range []int{15, 20, 28, 33, 44} {
|
||||
t.Run(fmt.Sprintf("%d", secretSize), func(t *testing.T) {
|
||||
secret := make([]byte, secretSize)
|
||||
_, err := io.ReadFull(rand.Reader, secret)
|
||||
assert.Equal(t, nil, err)
|
||||
|
||||
// We test both padded & raw Base64 to ensure we handle both
|
||||
// potential user input routes for Base64
|
||||
base64Padded := base64.URLEncoding.EncodeToString(secret)
|
||||
sb := SecretBytes(base64Padded)
|
||||
assert.NotEqual(t, secret, sb)
|
||||
assert.NotEqual(t, len(sb), secretSize)
|
||||
// The given secret is returned as []byte
|
||||
assert.Equal(t, base64Padded, string(sb))
|
||||
|
||||
base64Raw := base64.RawURLEncoding.EncodeToString(secret)
|
||||
sb = SecretBytes(base64Raw)
|
||||
assert.NotEqual(t, secret, sb)
|
||||
assert.NotEqual(t, len(sb), secretSize)
|
||||
// The given secret is returned as []byte
|
||||
assert.Equal(t, base64Raw, string(sb))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSecretBytesNonBase64(t *testing.T) {
|
||||
trailer := "equals=========="
|
||||
assert.Equal(t, trailer, string(SecretBytes(trailer)))
|
||||
|
||||
raw16 := "asdflkjhqwer)(*&"
|
||||
sb16 := SecretBytes(raw16)
|
||||
assert.Equal(t, raw16, string(sb16))
|
||||
assert.Equal(t, 16, len(sb16))
|
||||
|
||||
raw24 := "asdflkjhqwer)(*&CJEN#$%^"
|
||||
sb24 := SecretBytes(raw24)
|
||||
assert.Equal(t, raw24, string(sb24))
|
||||
assert.Equal(t, 24, len(sb24))
|
||||
|
||||
raw32 := "asdflkjhqwer)(*&1234lkjhqwer)(*&"
|
||||
sb32 := SecretBytes(raw32)
|
||||
assert.Equal(t, raw32, string(sb32))
|
||||
assert.Equal(t, 32, len(sb32))
|
||||
}
|
||||
|
||||
func TestSignAndValidate(t *testing.T) {
|
||||
seed := "0123456789abcdef"
|
||||
key := "cookie-name"
|
||||
value := base64.URLEncoding.EncodeToString([]byte("I am soooo encoded"))
|
||||
epoch := "123456789"
|
||||
|
||||
sha256sig := cookieSignature(sha256.New, seed, key, value, epoch)
|
||||
sha1sig := cookieSignature(sha1.New, seed, key, value, epoch)
|
||||
|
||||
assert.True(t, checkSignature(sha256sig, seed, key, value, epoch))
|
||||
// This should be switched to False after fully deprecating SHA1
|
||||
assert.True(t, checkSignature(sha1sig, seed, key, value, epoch))
|
||||
|
||||
assert.False(t, checkSignature(sha256sig, seed, key, "tampered", epoch))
|
||||
assert.False(t, checkSignature(sha1sig, seed, key, "tampered", epoch))
|
||||
}
|
||||
|
||||
func TestEncodeAndDecodeAccessToken(t *testing.T) {
|
||||
const secret = "0123456789abcdefghijklmnopqrstuv"
|
||||
const token = "my access token"
|
||||
c, err := NewCipher([]byte(secret))
|
||||
c, err := NewBase64Cipher(NewCFBCipher, []byte(secret))
|
||||
assert.Equal(t, nil, err)
|
||||
|
||||
encoded, err := c.Encrypt(token)
|
||||
encoded, err := c.Encrypt([]byte(token))
|
||||
assert.Equal(t, nil, err)
|
||||
|
||||
decoded, err := c.Decrypt(encoded)
|
||||
assert.Equal(t, nil, err)
|
||||
|
||||
assert.NotEqual(t, token, encoded)
|
||||
assert.Equal(t, token, decoded)
|
||||
assert.NotEqual(t, []byte(token), encoded)
|
||||
assert.Equal(t, []byte(token), decoded)
|
||||
}
|
||||
|
||||
func TestEncodeAndDecodeAccessTokenB64(t *testing.T) {
|
||||
@ -121,37 +32,199 @@ func TestEncodeAndDecodeAccessTokenB64(t *testing.T) {
|
||||
|
||||
secret, err := base64.URLEncoding.DecodeString(secretBase64)
|
||||
assert.Equal(t, nil, err)
|
||||
c, err := NewCipher([]byte(secret))
|
||||
c, err := NewBase64Cipher(NewCFBCipher, []byte(secret))
|
||||
assert.Equal(t, nil, err)
|
||||
|
||||
encoded, err := c.Encrypt(token)
|
||||
encoded, err := c.Encrypt([]byte(token))
|
||||
assert.Equal(t, nil, err)
|
||||
|
||||
decoded, err := c.Decrypt(encoded)
|
||||
assert.Equal(t, nil, err)
|
||||
|
||||
assert.NotEqual(t, token, encoded)
|
||||
assert.Equal(t, token, decoded)
|
||||
assert.NotEqual(t, []byte(token), encoded)
|
||||
assert.Equal(t, []byte(token), decoded)
|
||||
}
|
||||
|
||||
func TestEncodeIntoAndDecodeIntoAccessToken(t *testing.T) {
|
||||
const secret = "0123456789abcdefghijklmnopqrstuv"
|
||||
c, err := NewCipher([]byte(secret))
|
||||
func TestEncryptAndDecrypt(t *testing.T) {
|
||||
// Test our 2 cipher types
|
||||
cipherInits := map[string]func([]byte) (Cipher, error){
|
||||
"CFB": NewCFBCipher,
|
||||
"GCM": NewGCMCipher,
|
||||
}
|
||||
for name, initCipher := range cipherInits {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
// Test all 3 valid AES sizes
|
||||
for _, secretSize := range []int{16, 24, 32} {
|
||||
t.Run(fmt.Sprintf("%d", secretSize), func(t *testing.T) {
|
||||
secret := make([]byte, secretSize)
|
||||
_, err := io.ReadFull(rand.Reader, secret)
|
||||
assert.Equal(t, nil, err)
|
||||
|
||||
// Test Standard & Base64 wrapped
|
||||
cstd, err := initCipher(secret)
|
||||
assert.Equal(t, nil, err)
|
||||
|
||||
cb64, err := NewBase64Cipher(initCipher, secret)
|
||||
assert.Equal(t, nil, err)
|
||||
|
||||
ciphers := map[string]Cipher{
|
||||
"Standard": cstd,
|
||||
"Base64": cb64,
|
||||
}
|
||||
|
||||
for cName, c := range ciphers {
|
||||
t.Run(cName, func(t *testing.T) {
|
||||
// Test various sizes sessions might be
|
||||
for _, dataSize := range []int{10, 100, 1000, 5000, 10000} {
|
||||
t.Run(fmt.Sprintf("%d", dataSize), func(t *testing.T) {
|
||||
runEncryptAndDecrypt(t, c, dataSize)
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func runEncryptAndDecrypt(t *testing.T, c Cipher, dataSize int) {
|
||||
data := make([]byte, dataSize)
|
||||
_, err := io.ReadFull(rand.Reader, data)
|
||||
assert.Equal(t, nil, err)
|
||||
|
||||
token := "my access token"
|
||||
originalToken := token
|
||||
// Ensure our Encrypt function doesn't encrypt in place
|
||||
immutableData := make([]byte, len(data))
|
||||
copy(immutableData, data)
|
||||
|
||||
assert.Equal(t, nil, c.EncryptInto(&token))
|
||||
assert.NotEqual(t, originalToken, token)
|
||||
encrypted, err := c.Encrypt(data)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.NotEqual(t, encrypted, data)
|
||||
// Encrypt didn't operate in-place on []byte
|
||||
assert.Equal(t, data, immutableData)
|
||||
|
||||
assert.Equal(t, nil, c.DecryptInto(&token))
|
||||
assert.Equal(t, originalToken, token)
|
||||
// Ensure our Decrypt function doesn't decrypt in place
|
||||
immutableEnc := make([]byte, len(encrypted))
|
||||
copy(immutableEnc, encrypted)
|
||||
|
||||
// Check no errors with empty or nil strings
|
||||
empty := ""
|
||||
assert.Equal(t, nil, c.EncryptInto(&empty))
|
||||
assert.Equal(t, nil, c.DecryptInto(&empty))
|
||||
assert.Equal(t, nil, c.EncryptInto(nil))
|
||||
assert.Equal(t, nil, c.DecryptInto(nil))
|
||||
decrypted, err := c.Decrypt(encrypted)
|
||||
assert.Equal(t, nil, err)
|
||||
// Original data back
|
||||
assert.Equal(t, data, decrypted)
|
||||
// Decrypt didn't operate in-place on []byte
|
||||
assert.Equal(t, encrypted, immutableEnc)
|
||||
// Encrypt/Decrypt actually did something
|
||||
assert.NotEqual(t, encrypted, decrypted)
|
||||
}
|
||||
|
||||
func TestDecryptCFBWrongSecret(t *testing.T) {
|
||||
secret1 := []byte("0123456789abcdefghijklmnopqrstuv")
|
||||
secret2 := []byte("9876543210abcdefghijklmnopqrstuv")
|
||||
|
||||
c1, err := NewCFBCipher(secret1)
|
||||
assert.Equal(t, nil, err)
|
||||
|
||||
c2, err := NewCFBCipher(secret2)
|
||||
assert.Equal(t, nil, err)
|
||||
|
||||
data := []byte("f3928pufm982374dj02y485dsl34890u2t9nd4028s94dm58y2394087dhmsyt29h8df")
|
||||
|
||||
ciphertext, err := c1.Encrypt(data)
|
||||
assert.Equal(t, nil, err)
|
||||
|
||||
wrongData, err := c2.Decrypt(ciphertext)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.NotEqual(t, data, wrongData)
|
||||
}
|
||||
|
||||
func TestDecryptGCMWrongSecret(t *testing.T) {
|
||||
secret1 := []byte("0123456789abcdefghijklmnopqrstuv")
|
||||
secret2 := []byte("9876543210abcdefghijklmnopqrstuv")
|
||||
|
||||
c1, err := NewGCMCipher(secret1)
|
||||
assert.Equal(t, nil, err)
|
||||
|
||||
c2, err := NewGCMCipher(secret2)
|
||||
assert.Equal(t, nil, err)
|
||||
|
||||
data := []byte("f3928pufm982374dj02y485dsl34890u2t9nd4028s94dm58y2394087dhmsyt29h8df")
|
||||
|
||||
ciphertext, err := c1.Encrypt(data)
|
||||
assert.Equal(t, nil, err)
|
||||
|
||||
// GCM is authenticated - this should lead to message authentication failed
|
||||
_, err = c2.Decrypt(ciphertext)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
// Encrypt with GCM, Decrypt with CFB: Results in Garbage data
|
||||
func TestGCMtoCFBErrors(t *testing.T) {
|
||||
// Test all 3 valid AES sizes
|
||||
for _, secretSize := range []int{16, 24, 32} {
|
||||
t.Run(fmt.Sprintf("%d", secretSize), func(t *testing.T) {
|
||||
secret := make([]byte, secretSize)
|
||||
_, err := io.ReadFull(rand.Reader, secret)
|
||||
assert.Equal(t, nil, err)
|
||||
|
||||
gcm, err := NewGCMCipher(secret)
|
||||
assert.Equal(t, nil, err)
|
||||
|
||||
cfb, err := NewCFBCipher(secret)
|
||||
assert.Equal(t, nil, err)
|
||||
|
||||
// Test various sizes sessions might be
|
||||
for _, dataSize := range []int{10, 100, 1000, 5000, 10000} {
|
||||
t.Run(fmt.Sprintf("%d", dataSize), func(t *testing.T) {
|
||||
data := make([]byte, dataSize)
|
||||
_, err := io.ReadFull(rand.Reader, data)
|
||||
assert.Equal(t, nil, err)
|
||||
|
||||
encrypted, err := gcm.Encrypt(data)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.NotEqual(t, encrypted, data)
|
||||
|
||||
decrypted, err := cfb.Decrypt(encrypted)
|
||||
assert.Equal(t, nil, err)
|
||||
// Data is mangled
|
||||
assert.NotEqual(t, data, decrypted)
|
||||
assert.NotEqual(t, encrypted, decrypted)
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Encrypt with CFB, Decrypt with GCM: Results in errors
|
||||
func TestCFBtoGCMErrors(t *testing.T) {
|
||||
// Test all 3 valid AES sizes
|
||||
for _, secretSize := range []int{16, 24, 32} {
|
||||
t.Run(fmt.Sprintf("%d", secretSize), func(t *testing.T) {
|
||||
secret := make([]byte, secretSize)
|
||||
_, err := io.ReadFull(rand.Reader, secret)
|
||||
assert.Equal(t, nil, err)
|
||||
|
||||
gcm, err := NewGCMCipher(secret)
|
||||
assert.Equal(t, nil, err)
|
||||
|
||||
cfb, err := NewCFBCipher(secret)
|
||||
assert.Equal(t, nil, err)
|
||||
|
||||
// Test various sizes sessions might be
|
||||
for _, dataSize := range []int{10, 100, 1000, 5000, 10000} {
|
||||
t.Run(fmt.Sprintf("%d", dataSize), func(t *testing.T) {
|
||||
data := make([]byte, dataSize)
|
||||
_, err := io.ReadFull(rand.Reader, data)
|
||||
assert.Equal(t, nil, err)
|
||||
|
||||
encrypted, err := cfb.Encrypt(data)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.NotEqual(t, encrypted, data)
|
||||
|
||||
// GCM is authenticated - this should lead to message authentication failed
|
||||
_, err = gcm.Decrypt(encrypted)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
106
pkg/encryption/utils.go
Normal file
106
pkg/encryption/utils.go
Normal file
@ -0,0 +1,106 @@
|
||||
package encryption
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/sha1"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"hash"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// SecretBytes attempts to base64 decode the secret, if that fails it treats the secret as binary
|
||||
func SecretBytes(secret string) []byte {
|
||||
b, err := base64.RawURLEncoding.DecodeString(strings.TrimRight(secret, "="))
|
||||
if err == nil {
|
||||
// Only return decoded form if a valid AES length
|
||||
// Don't want unintentional decoding resulting in invalid lengths confusing a user
|
||||
// that thought they used a 16, 24, 32 length string
|
||||
for _, i := range []int{16, 24, 32} {
|
||||
if len(b) == i {
|
||||
return b
|
||||
}
|
||||
}
|
||||
}
|
||||
// If decoding didn't work or resulted in non-AES compliant length,
|
||||
// assume the raw string was the intended secret
|
||||
return []byte(secret)
|
||||
}
|
||||
|
||||
// cookies are stored in a 3 part (value + timestamp + signature) to enforce that the values are as originally set.
|
||||
// additionally, the 'value' is encrypted so it's opaque to the browser
|
||||
|
||||
// Validate ensures a cookie is properly signed
|
||||
func Validate(cookie *http.Cookie, seed string, expiration time.Duration) (value []byte, t time.Time, ok bool) {
|
||||
// value, timestamp, sig
|
||||
parts := strings.Split(cookie.Value, "|")
|
||||
if len(parts) != 3 {
|
||||
return
|
||||
}
|
||||
if checkSignature(parts[2], seed, cookie.Name, parts[0], parts[1]) {
|
||||
ts, err := strconv.Atoi(parts[1])
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
// The expiration timestamp set when the cookie was created
|
||||
// isn't sent back by the browser. Hence, we check whether the
|
||||
// creation timestamp stored in the cookie falls within the
|
||||
// window defined by (Now()-expiration, Now()].
|
||||
t = time.Unix(int64(ts), 0)
|
||||
if t.After(time.Now().Add(expiration*-1)) && t.Before(time.Now().Add(time.Minute*5)) {
|
||||
// it's a valid cookie. now get the contents
|
||||
rawValue, err := base64.URLEncoding.DecodeString(parts[0])
|
||||
if err == nil {
|
||||
value = rawValue
|
||||
ok = true
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// SignedValue returns a cookie that is signed and can later be checked with Validate
|
||||
func SignedValue(seed string, key string, value []byte, now time.Time) string {
|
||||
encodedValue := base64.URLEncoding.EncodeToString(value)
|
||||
timeStr := fmt.Sprintf("%d", now.Unix())
|
||||
sig := cookieSignature(sha256.New, seed, key, encodedValue, timeStr)
|
||||
cookieVal := fmt.Sprintf("%s|%s|%s", encodedValue, timeStr, sig)
|
||||
return cookieVal
|
||||
}
|
||||
|
||||
func cookieSignature(signer func() hash.Hash, args ...string) string {
|
||||
h := hmac.New(signer, []byte(args[0]))
|
||||
for _, arg := range args[1:] {
|
||||
h.Write([]byte(arg))
|
||||
}
|
||||
var b []byte
|
||||
b = h.Sum(b)
|
||||
return base64.URLEncoding.EncodeToString(b)
|
||||
}
|
||||
|
||||
func checkSignature(signature string, args ...string) bool {
|
||||
checkSig := cookieSignature(sha256.New, args...)
|
||||
if checkHmac(signature, checkSig) {
|
||||
return true
|
||||
}
|
||||
|
||||
// TODO: After appropriate rollout window, remove support for SHA1
|
||||
legacySig := cookieSignature(sha1.New, args...)
|
||||
return checkHmac(signature, legacySig)
|
||||
}
|
||||
|
||||
func checkHmac(input, expected string) bool {
|
||||
inputMAC, err1 := base64.URLEncoding.DecodeString(input)
|
||||
if err1 == nil {
|
||||
expectedMAC, err2 := base64.URLEncoding.DecodeString(expected)
|
||||
if err2 == nil {
|
||||
return hmac.Equal(inputMAC, expectedMAC)
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
100
pkg/encryption/utils_test.go
Normal file
100
pkg/encryption/utils_test.go
Normal file
@ -0,0 +1,100 @@
|
||||
package encryption
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/sha1"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestSecretBytesEncoded(t *testing.T) {
|
||||
for _, secretSize := range []int{16, 24, 32} {
|
||||
t.Run(fmt.Sprintf("%d", secretSize), func(t *testing.T) {
|
||||
secret := make([]byte, secretSize)
|
||||
_, err := io.ReadFull(rand.Reader, secret)
|
||||
assert.Equal(t, nil, err)
|
||||
|
||||
// We test both padded & raw Base64 to ensure we handle both
|
||||
// potential user input routes for Base64
|
||||
base64Padded := base64.URLEncoding.EncodeToString(secret)
|
||||
sb := SecretBytes(base64Padded)
|
||||
assert.Equal(t, secret, sb)
|
||||
assert.Equal(t, len(sb), secretSize)
|
||||
|
||||
base64Raw := base64.RawURLEncoding.EncodeToString(secret)
|
||||
sb = SecretBytes(base64Raw)
|
||||
assert.Equal(t, secret, sb)
|
||||
assert.Equal(t, len(sb), secretSize)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// A string that isn't intended as Base64 and still decodes (but to unintended length)
|
||||
// will return the original secret as bytes
|
||||
func TestSecretBytesEncodedWrongSize(t *testing.T) {
|
||||
for _, secretSize := range []int{15, 20, 28, 33, 44} {
|
||||
t.Run(fmt.Sprintf("%d", secretSize), func(t *testing.T) {
|
||||
secret := make([]byte, secretSize)
|
||||
_, err := io.ReadFull(rand.Reader, secret)
|
||||
assert.Equal(t, nil, err)
|
||||
|
||||
// We test both padded & raw Base64 to ensure we handle both
|
||||
// potential user input routes for Base64
|
||||
base64Padded := base64.URLEncoding.EncodeToString(secret)
|
||||
sb := SecretBytes(base64Padded)
|
||||
assert.NotEqual(t, secret, sb)
|
||||
assert.NotEqual(t, len(sb), secretSize)
|
||||
// The given secret is returned as []byte
|
||||
assert.Equal(t, base64Padded, string(sb))
|
||||
|
||||
base64Raw := base64.RawURLEncoding.EncodeToString(secret)
|
||||
sb = SecretBytes(base64Raw)
|
||||
assert.NotEqual(t, secret, sb)
|
||||
assert.NotEqual(t, len(sb), secretSize)
|
||||
// The given secret is returned as []byte
|
||||
assert.Equal(t, base64Raw, string(sb))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSecretBytesNonBase64(t *testing.T) {
|
||||
trailer := "equals=========="
|
||||
assert.Equal(t, trailer, string(SecretBytes(trailer)))
|
||||
|
||||
raw16 := "asdflkjhqwer)(*&"
|
||||
sb16 := SecretBytes(raw16)
|
||||
assert.Equal(t, raw16, string(sb16))
|
||||
assert.Equal(t, 16, len(sb16))
|
||||
|
||||
raw24 := "asdflkjhqwer)(*&CJEN#$%^"
|
||||
sb24 := SecretBytes(raw24)
|
||||
assert.Equal(t, raw24, string(sb24))
|
||||
assert.Equal(t, 24, len(sb24))
|
||||
|
||||
raw32 := "asdflkjhqwer)(*&1234lkjhqwer)(*&"
|
||||
sb32 := SecretBytes(raw32)
|
||||
assert.Equal(t, raw32, string(sb32))
|
||||
assert.Equal(t, 32, len(sb32))
|
||||
}
|
||||
|
||||
func TestSignAndValidate(t *testing.T) {
|
||||
seed := "0123456789abcdef"
|
||||
key := "cookie-name"
|
||||
value := base64.URLEncoding.EncodeToString([]byte("I am soooo encoded"))
|
||||
epoch := "123456789"
|
||||
|
||||
sha256sig := cookieSignature(sha256.New, seed, key, value, epoch)
|
||||
sha1sig := cookieSignature(sha1.New, seed, key, value, epoch)
|
||||
|
||||
assert.True(t, checkSignature(sha256sig, seed, key, value, epoch))
|
||||
// This should be switched to False after fully deprecating SHA1
|
||||
assert.True(t, checkSignature(sha1sig, seed, key, value, epoch))
|
||||
|
||||
assert.False(t, checkSignature(sha256sig, seed, key, "tampered", epoch))
|
||||
assert.False(t, checkSignature(sha1sig, seed, key, "tampered", epoch))
|
||||
}
|
@ -28,7 +28,7 @@ var _ sessions.SessionStore = &SessionStore{}
|
||||
// interface that stores sessions in client side cookies
|
||||
type SessionStore struct {
|
||||
CookieOptions *options.CookieOptions
|
||||
CookieCipher *encryption.Cipher
|
||||
CookieCipher encryption.Cipher
|
||||
}
|
||||
|
||||
// Save takes a sessions.SessionState and stores the information from it
|
||||
@ -59,7 +59,7 @@ func (s *SessionStore) Load(req *http.Request) (*sessions.SessionState, error) {
|
||||
return nil, errors.New("cookie signature not valid")
|
||||
}
|
||||
|
||||
session, err := sessionFromCookie(val, s.CookieCipher)
|
||||
session, err := sessionFromCookie(string(val), s.CookieCipher)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -84,12 +84,12 @@ func (s *SessionStore) Clear(rw http.ResponseWriter, req *http.Request) error {
|
||||
}
|
||||
|
||||
// cookieForSession serializes a session state for storage in a cookie
|
||||
func cookieForSession(s *sessions.SessionState, c *encryption.Cipher) (string, error) {
|
||||
func cookieForSession(s *sessions.SessionState, c encryption.Cipher) (string, error) {
|
||||
return s.EncodeSessionState(c)
|
||||
}
|
||||
|
||||
// sessionFromCookie deserializes a session from a cookie value
|
||||
func sessionFromCookie(v string, c *encryption.Cipher) (s *sessions.SessionState, err error) {
|
||||
func sessionFromCookie(v string, c encryption.Cipher) (s *sessions.SessionState, err error) {
|
||||
return sessions.DecodeSessionState(v, c)
|
||||
}
|
||||
|
||||
@ -104,7 +104,7 @@ func (s *SessionStore) setSessionCookie(rw http.ResponseWriter, req *http.Reques
|
||||
// authentication details
|
||||
func (s *SessionStore) makeSessionCookie(req *http.Request, value string, now time.Time) []*http.Cookie {
|
||||
if value != "" {
|
||||
value = encryption.SignedValue(s.CookieOptions.Secret, s.CookieOptions.Name, value, now)
|
||||
value = encryption.SignedValue(s.CookieOptions.Secret, s.CookieOptions.Name, []byte(value), now)
|
||||
}
|
||||
c := s.makeCookie(req, s.CookieOptions.Name, value, s.CookieOptions.Expire, now)
|
||||
if len(c.Value) > 4096-len(s.CookieOptions.Name) {
|
||||
|
@ -32,7 +32,7 @@ type TicketData struct {
|
||||
// SessionStore is an implementation of the sessions.SessionStore
|
||||
// interface that stores sessions in redis
|
||||
type SessionStore struct {
|
||||
CookieCipher *encryption.Cipher
|
||||
CookieCipher encryption.Cipher
|
||||
CookieOptions *options.CookieOptions
|
||||
Client Client
|
||||
}
|
||||
@ -175,7 +175,7 @@ func (store *SessionStore) Load(req *http.Request) (*sessions.SessionState, erro
|
||||
return nil, fmt.Errorf("cookie signature not valid")
|
||||
}
|
||||
ctx := req.Context()
|
||||
session, err := store.loadSessionFromString(ctx, val)
|
||||
session, err := store.loadSessionFromString(ctx, string(val))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error loading session: %s", err)
|
||||
}
|
||||
@ -237,7 +237,7 @@ func (store *SessionStore) Clear(rw http.ResponseWriter, req *http.Request) erro
|
||||
|
||||
// We only return an error if we had an issue with redis
|
||||
// If there's an issue decoding the ticket, ignore it
|
||||
ticket, _ := decodeTicket(store.CookieOptions.Name, val)
|
||||
ticket, _ := decodeTicket(store.CookieOptions.Name, string(val))
|
||||
if ticket != nil {
|
||||
ctx := req.Context()
|
||||
err := store.Client.Del(ctx, ticket.asHandle(store.CookieOptions.Name))
|
||||
@ -251,7 +251,7 @@ func (store *SessionStore) Clear(rw http.ResponseWriter, req *http.Request) erro
|
||||
// makeCookie makes a cookie, signing the value if present
|
||||
func (store *SessionStore) makeCookie(req *http.Request, value string, expires time.Duration, now time.Time) *http.Cookie {
|
||||
if value != "" {
|
||||
value = encryption.SignedValue(store.CookieOptions.Secret, store.CookieOptions.Name, value, now)
|
||||
value = encryption.SignedValue(store.CookieOptions.Secret, store.CookieOptions.Name, []byte(value), now)
|
||||
}
|
||||
return cookies.MakeCookieFromOptions(
|
||||
req,
|
||||
@ -302,7 +302,7 @@ func (store *SessionStore) getTicket(requestCookie *http.Cookie) (*TicketData, e
|
||||
}
|
||||
|
||||
// Valid cookie, decode the ticket
|
||||
ticket, err := decodeTicket(store.CookieOptions.Name, val)
|
||||
ticket, err := decodeTicket(store.CookieOptions.Name, string(val))
|
||||
if err != nil {
|
||||
// If we can't decode the ticket we have to create a new one
|
||||
return newTicket()
|
||||
|
@ -170,7 +170,7 @@ var _ = Describe("NewSessionStore", func() {
|
||||
BeforeEach(func() {
|
||||
By("Using a valid cookie with a different providers session encoding")
|
||||
broken := "BrokenSessionFromADifferentSessionImplementation"
|
||||
value := encryption.SignedValue(cookieOpts.Secret, cookieOpts.Name, broken, time.Now())
|
||||
value := encryption.SignedValue(cookieOpts.Secret, cookieOpts.Name, []byte(broken), time.Now())
|
||||
cookie := cookiesapi.MakeCookieFromOptions(request, cookieOpts.Name, value, cookieOpts, cookieOpts.Expire, time.Now())
|
||||
request.AddCookie(cookie)
|
||||
|
||||
@ -367,7 +367,7 @@ var _ = Describe("NewSessionStore", func() {
|
||||
_, err := rand.Read(secret)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
cookieOpts.Secret = base64.URLEncoding.EncodeToString(secret)
|
||||
cipher, err := encryption.NewCipher(encryption.SecretBytes(cookieOpts.Secret))
|
||||
cipher, err := encryption.NewBase64Cipher(encryption.NewCFBCipher, encryption.SecretBytes(cookieOpts.Secret))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(cipher).ToNot(BeNil())
|
||||
opts.Cipher = cipher
|
||||
|
@ -38,7 +38,7 @@ func Validate(o *options.Options) error {
|
||||
|
||||
msgs := make([]string, 0)
|
||||
|
||||
var cipher *encryption.Cipher
|
||||
var cipher encryption.Cipher
|
||||
if o.Cookie.Secret == "" {
|
||||
msgs = append(msgs, "missing setting: cookie-secret")
|
||||
} else {
|
||||
@ -62,7 +62,7 @@ func Validate(o *options.Options) error {
|
||||
len(encryption.SecretBytes(o.Cookie.Secret)), suffix))
|
||||
} else {
|
||||
var err error
|
||||
cipher, err = encryption.NewCipher(encryption.SecretBytes(o.Cookie.Secret))
|
||||
cipher, err = encryption.NewBase64Cipher(encryption.NewCFBCipher, encryption.SecretBytes(o.Cookie.Secret))
|
||||
if err != nil {
|
||||
msgs = append(msgs, fmt.Sprintf("cookie-secret error: %v", err))
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user