You've already forked oauth2-proxy
							
							
				mirror of
				https://github.com/oauth2-proxy/oauth2-proxy.git
				synced 2025-10-30 23:47:52 +02:00 
			
		
		
		
	Refactor encryption.Cipher to be an Encrypt/Decrypt Interface
All Encrypt/Decrypt Cipher implementations will now take and return []byte to set up usage in future binary compatible encoding schemes to fix issues with bloat encrypting to strings (which requires base64ing adding 33% size)
This commit is contained in:
		| @@ -5,7 +5,7 @@ import "github.com/oauth2-proxy/oauth2-proxy/pkg/encryption" | ||||
| // SessionOptions contains configuration options for the SessionStore providers. | ||||
| type SessionOptions struct { | ||||
| 	Type   string             `flag:"session-store-type" cfg:"session_store_type"` | ||||
| 	Cipher *encryption.Cipher `cfg:",internal"` | ||||
| 	Cipher encryption.Cipher `cfg:",internal"` | ||||
| 	Redis  RedisStoreOptions  `cfg:",squash"` | ||||
| } | ||||
|  | ||||
|   | ||||
| @@ -60,7 +60,7 @@ func (s *SessionState) String() string { | ||||
| } | ||||
|  | ||||
| // EncodeSessionState returns string representation of the current session | ||||
| func (s *SessionState) EncodeSessionState(c *encryption.Cipher) (string, error) { | ||||
| func (s *SessionState) EncodeSessionState(c encryption.Cipher) (string, error) { | ||||
| 	var ss SessionState | ||||
| 	if c == nil { | ||||
| 		// Store only Email and User when cipher is unavailable | ||||
| @@ -89,7 +89,7 @@ func (s *SessionState) EncodeSessionState(c *encryption.Cipher) (string, error) | ||||
| } | ||||
|  | ||||
| // DecodeSessionState decodes the session cookie string into a SessionState | ||||
| func DecodeSessionState(v string, c *encryption.Cipher) (*SessionState, error) { | ||||
| func DecodeSessionState(v string, c encryption.Cipher) (*SessionState, error) { | ||||
| 	var ss SessionState | ||||
| 	err := json.Unmarshal([]byte(v), &ss) | ||||
| 	if err != nil { | ||||
| @@ -106,7 +106,7 @@ func DecodeSessionState(v string, c *encryption.Cipher) (*SessionState, error) { | ||||
| 	} else { | ||||
| 		// Backward compatibility with using unencrypted Email | ||||
| 		if ss.Email != "" { | ||||
| 			decryptedEmail, errEmail := c.Decrypt(ss.Email) | ||||
| 			decryptedEmail, errEmail := stringDecrypt(ss.Email, c) | ||||
| 			if errEmail == nil { | ||||
| 				if !utf8.ValidString(decryptedEmail) { | ||||
| 					return nil, errors.New("invalid value for decrypted email") | ||||
| @@ -116,7 +116,7 @@ func DecodeSessionState(v string, c *encryption.Cipher) (*SessionState, error) { | ||||
| 		} | ||||
| 		// Backward compatibility with using unencrypted User | ||||
| 		if ss.User != "" { | ||||
| 			decryptedUser, errUser := c.Decrypt(ss.User) | ||||
| 			decryptedUser, errUser := stringDecrypt(ss.User, c) | ||||
| 			if errUser == nil { | ||||
| 				if !utf8.ValidString(decryptedUser) { | ||||
| 					return nil, errors.New("invalid value for decrypted user") | ||||
| @@ -139,3 +139,12 @@ func DecodeSessionState(v string, c *encryption.Cipher) (*SessionState, error) { | ||||
| 	} | ||||
| 	return &ss, nil | ||||
| } | ||||
|  | ||||
| // stringDecrypt wraps a Base64Cipher to make it string => string | ||||
| func stringDecrypt(ciphertext string, c encryption.Cipher) (string, error) { | ||||
| 	value, err := c.Decrypt([]byte(ciphertext)) | ||||
| 	if err != nil { | ||||
| 		return "", err | ||||
| 	} | ||||
| 	return string(value), nil | ||||
| } | ||||
|   | ||||
| @@ -145,7 +145,7 @@ func TestExpired(t *testing.T) { | ||||
| type testCase struct { | ||||
| 	sessions.SessionState | ||||
| 	Encoded string | ||||
| 	Cipher  *encryption.Cipher | ||||
| 	Cipher  encryption.Cipher | ||||
| 	Error   bool | ||||
| } | ||||
|  | ||||
|   | ||||
| @@ -109,47 +109,79 @@ func checkHmac(input, expected string) bool { | ||||
| 	return false | ||||
| } | ||||
|  | ||||
| // Cipher provides methods to encrypt and decrypt cookie values | ||||
| type Cipher struct { | ||||
| 	cipher.Block | ||||
| // Cipher provides methods to encrypt and decrypt | ||||
| type Cipher interface { | ||||
| 	Encrypt(value []byte) ([]byte, error) | ||||
| 	Decrypt(ciphertext []byte) ([]byte, error) | ||||
|     EncryptInto(s *string) error | ||||
| 	DecryptInto(s *string) error | ||||
| } | ||||
|  | ||||
| // NewCipher returns a new aes Cipher for encrypting cookie values | ||||
| func NewCipher(secret []byte) (*Cipher, error) { | ||||
| // This defaults to the Base64 Cipher to align with legacy Encrypt/Decrypt functionality | ||||
| func NewCipher(secret []byte) (*Base64Cipher, error) { | ||||
| 	cfb, err := NewCFBCipher(secret) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	return NewBase64Cipher(cfb) | ||||
| } | ||||
|  | ||||
| type Base64Cipher struct { | ||||
| 	Cipher Cipher | ||||
| } | ||||
|  | ||||
| // NewBase64Cipher returns a new AES CFB Cipher for encrypting cookie values | ||||
| // And wrapping them in Base64 -- Supports Legacy encryption scheme | ||||
| func NewBase64Cipher(c Cipher) (*Base64Cipher, error) { | ||||
| 	return &Base64Cipher{Cipher: c}, nil | ||||
| } | ||||
|  | ||||
| // Encrypt encrypts a value with AES CFB & base64 encodes it | ||||
| func (c *Base64Cipher) Encrypt(value []byte) ([]byte, error) { | ||||
| 	encrypted, err := c.Cipher.Encrypt([]byte(value)) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	return []byte(base64.StdEncoding.EncodeToString(encrypted)), nil | ||||
| } | ||||
|  | ||||
| // Decrypt Base64 decodes a value & decrypts it with AES CFB | ||||
| func (c *Base64Cipher) Decrypt(ciphertext []byte) ([]byte, error) { | ||||
| 	encrypted, err := base64.StdEncoding.DecodeString(string(ciphertext)) | ||||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("failed to decrypt cookie value %s", err) | ||||
| 	} | ||||
|  | ||||
| 	return c.Cipher.Decrypt(encrypted) | ||||
| } | ||||
|  | ||||
| // EncryptInto encrypts the value and stores it back in the string pointer | ||||
| func (c *Base64Cipher) EncryptInto(s *string) error { | ||||
| 	return into(c.Encrypt, s) | ||||
| } | ||||
|  | ||||
| // DecryptInto decrypts the value and stores it back in the string pointer | ||||
| func (c *Base64Cipher) DecryptInto(s *string) error { | ||||
| 	return into(c.Decrypt, s) | ||||
| } | ||||
|  | ||||
| type CFBCipher struct { | ||||
| 	cipher.Block | ||||
| } | ||||
|  | ||||
| // NewCFBCipher returns a new AES CFB Cipher | ||||
| func NewCFBCipher(secret []byte) (*CFBCipher, error) { | ||||
| 	c, err := aes.NewCipher(secret) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	return &Cipher{Block: c}, err | ||||
| 	return &CFBCipher{Block: c}, err | ||||
| } | ||||
|  | ||||
| // Encrypt a value for use in a cookie | ||||
| func (c *Cipher) Encrypt(value string) (string, error) { | ||||
| 	encrypted, err := c.EncryptCFB([]byte(value)) | ||||
| 	if err != nil { | ||||
| 		return "", err | ||||
| 	} | ||||
|  | ||||
| 	return base64.StdEncoding.EncodeToString(encrypted), nil | ||||
| } | ||||
|  | ||||
| // Decrypt a value from a cookie to it's original string | ||||
| func (c *Cipher) Decrypt(s string) (string, error) { | ||||
| 	encrypted, err := base64.StdEncoding.DecodeString(s) | ||||
| 	if err != nil { | ||||
| 		return "", fmt.Errorf("failed to decrypt cookie value %s", err) | ||||
| 	} | ||||
|  | ||||
| 	decrypted, err := c.DecryptCFB(encrypted) | ||||
| 	if err != nil { | ||||
| 		return "", err | ||||
| 	} | ||||
|  | ||||
| 	return string(decrypted), nil | ||||
| } | ||||
|  | ||||
| // Encrypt with AES CFB on raw bytes | ||||
| func (c *Cipher) EncryptCFB(value []byte) ([]byte, error) { | ||||
| // Encrypt with AES CFB | ||||
| func (c *CFBCipher) Encrypt(value []byte) ([]byte, error) { | ||||
| 	ciphertext := make([]byte, aes.BlockSize+len(value)) | ||||
| 	iv := ciphertext[:aes.BlockSize] | ||||
| 	if _, err := io.ReadFull(rand.Reader, iv); err != nil { | ||||
| @@ -162,7 +194,7 @@ func (c *Cipher) EncryptCFB(value []byte) ([]byte, error) { | ||||
| } | ||||
|  | ||||
| // Decrypt a AES CFB ciphertext | ||||
| func (c *Cipher) DecryptCFB(ciphertext []byte) ([]byte, error) { | ||||
| func (c *CFBCipher) Decrypt(ciphertext []byte) ([]byte, error) { | ||||
| 	if len(ciphertext) < aes.BlockSize { | ||||
| 		return nil, fmt.Errorf("encrypted value should be "+ | ||||
| 			"at least %d bytes, but is only %d bytes", | ||||
| @@ -178,17 +210,18 @@ func (c *Cipher) DecryptCFB(ciphertext []byte) ([]byte, error) { | ||||
| } | ||||
|  | ||||
| // EncryptInto encrypts the value and stores it back in the string pointer | ||||
| func (c *Cipher) EncryptInto(s *string) error { | ||||
| func (c *CFBCipher) EncryptInto(s *string) error { | ||||
| 	return into(c.Encrypt, s) | ||||
| } | ||||
|  | ||||
| // DecryptInto decrypts the value and stores it back in the string pointer | ||||
| func (c *Cipher) DecryptInto(s *string) error { | ||||
| func (c *CFBCipher) DecryptInto(s *string) error { | ||||
| 	return into(c.Decrypt, s) | ||||
| } | ||||
|  | ||||
| // codecFunc is a function that takes a string and encodes/decodes it | ||||
| type codecFunc func(string) (string, error) | ||||
| type codecFunc func([]byte) ([]byte, error) | ||||
|  | ||||
|  | ||||
| func into(f codecFunc, s *string) error { | ||||
| 	// Do not encrypt/decrypt nil or empty strings | ||||
| @@ -196,10 +229,10 @@ func into(f codecFunc, s *string) error { | ||||
| 		return nil | ||||
| 	} | ||||
|  | ||||
| 	d, err := f(*s) | ||||
| 	d, err := f([]byte(*s)) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	*s = d | ||||
| 	*s = string(d) | ||||
| 	return nil | ||||
| } | ||||
|   | ||||
| @@ -105,14 +105,14 @@ func TestEncodeAndDecodeAccessToken(t *testing.T) { | ||||
| 	c, err := NewCipher([]byte(secret)) | ||||
| 	assert.Equal(t, nil, err) | ||||
|  | ||||
| 	encoded, err := c.Encrypt(token) | ||||
| 	encoded, err := c.Encrypt([]byte(token)) | ||||
| 	assert.Equal(t, nil, err) | ||||
|  | ||||
| 	decoded, err := c.Decrypt(encoded) | ||||
| 	assert.Equal(t, nil, err) | ||||
|  | ||||
| 	assert.NotEqual(t, token, encoded) | ||||
| 	assert.Equal(t, token, decoded) | ||||
| 	assert.NotEqual(t, []byte(token), encoded) | ||||
| 	assert.Equal(t, []byte(token), decoded) | ||||
| } | ||||
|  | ||||
| func TestEncodeAndDecodeAccessTokenB64(t *testing.T) { | ||||
| @@ -124,14 +124,115 @@ func TestEncodeAndDecodeAccessTokenB64(t *testing.T) { | ||||
| 	c, err := NewCipher([]byte(secret)) | ||||
| 	assert.Equal(t, nil, err) | ||||
|  | ||||
| 	encoded, err := c.Encrypt(token) | ||||
| 	encoded, err := c.Encrypt([]byte(token)) | ||||
| 	assert.Equal(t, nil, err) | ||||
|  | ||||
| 	decoded, err := c.Decrypt(encoded) | ||||
| 	assert.Equal(t, nil, err) | ||||
|  | ||||
| 	assert.NotEqual(t, token, encoded) | ||||
| 	assert.Equal(t, token, decoded) | ||||
| 	assert.NotEqual(t, []byte(token), encoded) | ||||
| 	assert.Equal(t, []byte(token), decoded) | ||||
| } | ||||
|  | ||||
| func TestEncryptAndDecryptBase64(t *testing.T) { | ||||
| 	var err error | ||||
|  | ||||
| 	// Test all 3 valid AES sizes | ||||
| 	for _, secretSize := range []int{16, 24, 32} { | ||||
| 		secret := make([]byte, secretSize) | ||||
| 		_, err = io.ReadFull(rand.Reader, secret) | ||||
| 		assert.Equal(t, nil, err) | ||||
|  | ||||
| 		// NewCipher creates a Base64 wrapper of CFBCipher | ||||
| 		c, err := NewCipher(secret) | ||||
| 		assert.Equal(t, nil, err) | ||||
|  | ||||
| 		// Test various sizes sessions might be | ||||
| 		for _, dataSize := range []int{10, 100, 1000, 5000, 10000} { | ||||
| 			data := make([]byte, dataSize) | ||||
| 			_, err := io.ReadFull(rand.Reader, data) | ||||
| 			assert.Equal(t, nil, err) | ||||
|  | ||||
| 			encrypted, err := c.Encrypt(data) | ||||
| 			assert.Equal(t, nil, err) | ||||
|  | ||||
| 			decrypted, err := c.Decrypt(encrypted) | ||||
| 			assert.Equal(t, nil, err) | ||||
| 			assert.Equal(t, data, decrypted) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestDecryptBase64WrongSecret(t *testing.T) { | ||||
| 	var err error | ||||
|  | ||||
| 	secret1 := []byte("0123456789abcdefghijklmnopqrstuv") | ||||
| 	secret2 := []byte("9876543210abcdefghijklmnopqrstuv") | ||||
|  | ||||
| 	c1, err := NewCipher(secret1) | ||||
| 	assert.Equal(t, nil, err) | ||||
|  | ||||
| 	c2, err := NewCipher(secret2) | ||||
| 	assert.Equal(t, nil, err) | ||||
|  | ||||
| 	data := []byte("f3928pufm982374dj02y485dsl34890u2t9nd4028s94dm58y2394087dhmsyt29h8df") | ||||
|  | ||||
| 	ciphertext, err := c1.Encrypt(data) | ||||
| 	assert.Equal(t, nil, err) | ||||
|  | ||||
| 	wrongData, err := c2.Decrypt(ciphertext) | ||||
| 	assert.Equal(t, nil, err) | ||||
| 	assert.NotEqual(t, data, wrongData) | ||||
| } | ||||
|  | ||||
| func TestEncryptAndDecryptCFB(t *testing.T) { | ||||
| 	var err error | ||||
|  | ||||
| 	// Test all 3 valid AES sizes | ||||
| 	for _, secretSize := range []int{16, 24, 32} { | ||||
| 		secret := make([]byte, secretSize) | ||||
| 		_, err = io.ReadFull(rand.Reader, secret) | ||||
| 		assert.Equal(t, nil, err) | ||||
|  | ||||
| 		c, err := NewCFBCipher(secret) | ||||
| 		assert.Equal(t, nil, err) | ||||
|  | ||||
| 		// Test various sizes sessions might be | ||||
| 		for _, dataSize := range []int{10, 100, 1000, 5000, 10000} { | ||||
| 			data := make([]byte, dataSize) | ||||
| 			_, err := io.ReadFull(rand.Reader, data) | ||||
| 			assert.Equal(t, nil, err) | ||||
|  | ||||
| 			encrypted, err := c.Encrypt(data) | ||||
| 			assert.Equal(t, nil, err) | ||||
|  | ||||
| 			decrypted, err := c.Decrypt(encrypted) | ||||
| 			assert.Equal(t, nil, err) | ||||
| 			assert.Equal(t, data, decrypted) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestDecryptCFBWrongSecret(t *testing.T) { | ||||
| 	var err error | ||||
|  | ||||
| 	secret1 := []byte("0123456789abcdefghijklmnopqrstuv") | ||||
| 	secret2 := []byte("9876543210abcdefghijklmnopqrstuv") | ||||
|  | ||||
| 	c1, err := NewCFBCipher(secret1) | ||||
| 	assert.Equal(t, nil, err) | ||||
|  | ||||
| 	c2, err := NewCFBCipher(secret2) | ||||
| 	assert.Equal(t, nil, err) | ||||
|  | ||||
| 	data := []byte("f3928pufm982374dj02y485dsl34890u2t9nd4028s94dm58y2394087dhmsyt29h8df") | ||||
|  | ||||
| 	ciphertext, err := c1.Encrypt(data) | ||||
| 	assert.Equal(t, nil, err) | ||||
|  | ||||
| 	wrongData, err := c2.Decrypt(ciphertext) | ||||
| 	assert.Equal(t, nil, err) | ||||
| 	assert.NotEqual(t, data, wrongData) | ||||
| } | ||||
|  | ||||
| func TestEncodeIntoAndDecodeIntoAccessToken(t *testing.T) { | ||||
|   | ||||
| @@ -28,7 +28,7 @@ var _ sessions.SessionStore = &SessionStore{} | ||||
| // interface that stores sessions in client side cookies | ||||
| type SessionStore struct { | ||||
| 	CookieOptions *options.CookieOptions | ||||
| 	CookieCipher  *encryption.Cipher | ||||
| 	CookieCipher  encryption.Cipher | ||||
| } | ||||
|  | ||||
| // Save takes a sessions.SessionState and stores the information from it | ||||
| @@ -84,12 +84,12 @@ func (s *SessionStore) Clear(rw http.ResponseWriter, req *http.Request) error { | ||||
| } | ||||
|  | ||||
| // cookieForSession serializes a session state for storage in a cookie | ||||
| func cookieForSession(s *sessions.SessionState, c *encryption.Cipher) (string, error) { | ||||
| func cookieForSession(s *sessions.SessionState, c encryption.Cipher) (string, error) { | ||||
| 	return s.EncodeSessionState(c) | ||||
| } | ||||
|  | ||||
| // sessionFromCookie deserializes a session from a cookie value | ||||
| func sessionFromCookie(v string, c *encryption.Cipher) (s *sessions.SessionState, err error) { | ||||
| func sessionFromCookie(v string, c encryption.Cipher) (s *sessions.SessionState, err error) { | ||||
| 	return sessions.DecodeSessionState(v, c) | ||||
| } | ||||
|  | ||||
|   | ||||
| @@ -32,7 +32,7 @@ type TicketData struct { | ||||
| // SessionStore is an implementation of the sessions.SessionStore | ||||
| // interface that stores sessions in redis | ||||
| type SessionStore struct { | ||||
| 	CookieCipher  *encryption.Cipher | ||||
| 	CookieCipher  encryption.Cipher | ||||
| 	CookieOptions *options.CookieOptions | ||||
| 	Client        Client | ||||
| } | ||||
|   | ||||
| @@ -38,7 +38,7 @@ func Validate(o *options.Options) error { | ||||
|  | ||||
| 	msgs := make([]string, 0) | ||||
|  | ||||
| 	var cipher *encryption.Cipher | ||||
| 	var cipher encryption.Cipher | ||||
| 	if o.Cookie.Secret == "" { | ||||
| 		msgs = append(msgs, "missing setting: cookie-secret") | ||||
| 	} else { | ||||
|   | ||||
		Reference in New Issue
	
	Block a user