You've already forked oauth2-proxy
							
							
				mirror of
				https://github.com/oauth2-proxy/oauth2-proxy.git
				synced 2025-10-30 23:47:52 +02:00 
			
		
		
		
	Add EncryptInto/DecryptInto Unit Tests
This commit is contained in:
		| @@ -55,6 +55,7 @@ | |||||||
|  |  | ||||||
| ## Changes since v5.1.1 | ## Changes since v5.1.1 | ||||||
|  |  | ||||||
|  | - [#539](https://github.com/oauth2-proxy/oauth2-proxy/pull/539) Refactor encryption ciphers and add AES-GCM support (@NickMeves) | ||||||
| - [#601](https://github.com/oauth2-proxy/oauth2-proxy/pull/601) Ensure decrypted user/email are valid UTF8 (@JoelSpeed) | - [#601](https://github.com/oauth2-proxy/oauth2-proxy/pull/601) Ensure decrypted user/email are valid UTF8 (@JoelSpeed) | ||||||
| - [#560](https://github.com/oauth2-proxy/oauth2-proxy/pull/560) Fallback to UserInfo is User ID claim not present (@JoelSpeed) | - [#560](https://github.com/oauth2-proxy/oauth2-proxy/pull/560) Fallback to UserInfo is User ID claim not present (@JoelSpeed) | ||||||
| - [#598](https://github.com/oauth2-proxy/oauth2-proxy/pull/598) acr_values no longer sent to IdP when empty (@ScottGuymer) | - [#598](https://github.com/oauth2-proxy/oauth2-proxy/pull/598) acr_values no longer sent to IdP when empty (@ScottGuymer) | ||||||
|   | |||||||
| @@ -104,24 +104,18 @@ func DecodeSessionState(v string, c encryption.Cipher) (*SessionState, error) { | |||||||
| 			PreferredUsername: ss.PreferredUsername, | 			PreferredUsername: ss.PreferredUsername, | ||||||
| 		} | 		} | ||||||
| 	} else { | 	} else { | ||||||
| 		// Backward compatibility with using unencrypted Email | 		// Backward compatibility with using unencrypted Email or User | ||||||
| 		if ss.Email != "" { | 		// Decryption errors will leave original string | ||||||
| 			decryptedEmail, errEmail := stringDecrypt(ss.Email, c) | 		err = c.DecryptInto(&ss.Email) | ||||||
| 			if errEmail == nil { | 		if err == nil { | ||||||
| 				if !utf8.ValidString(decryptedEmail) { | 			if !utf8.ValidString(ss.Email) { | ||||||
| 					return nil, errors.New("invalid value for decrypted email") | 				return nil, errors.New("invalid value for decrypted email") | ||||||
| 				} |  | ||||||
| 				ss.Email = decryptedEmail |  | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
| 		// Backward compatibility with using unencrypted User | 		err = c.DecryptInto(&ss.User) | ||||||
| 		if ss.User != "" { | 		if err == nil { | ||||||
| 			decryptedUser, errUser := stringDecrypt(ss.User, c) | 			if !utf8.ValidString(ss.User) { | ||||||
| 			if errUser == nil { | 				return nil, errors.New("invalid value for decrypted user") | ||||||
| 				if !utf8.ValidString(decryptedUser) { |  | ||||||
| 					return nil, errors.New("invalid value for decrypted user") |  | ||||||
| 				} |  | ||||||
| 				ss.User = decryptedUser |  | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| @@ -139,12 +133,3 @@ func DecodeSessionState(v string, c encryption.Cipher) (*SessionState, error) { | |||||||
| 	} | 	} | ||||||
| 	return &ss, nil | 	return &ss, nil | ||||||
| } | } | ||||||
|  |  | ||||||
| // stringDecrypt wraps a Base64Cipher to make it string => string |  | ||||||
| func stringDecrypt(ciphertext string, c encryption.Cipher) (string, error) { |  | ||||||
| 	value, err := c.Decrypt([]byte(ciphertext)) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return "", err |  | ||||||
| 	} |  | ||||||
| 	return string(value), nil |  | ||||||
| } |  | ||||||
|   | |||||||
| @@ -17,14 +17,14 @@ func timePtr(t time.Time) *time.Time { | |||||||
| 	return &t | 	return &t | ||||||
| } | } | ||||||
|  |  | ||||||
| func NewCipher(secret []byte) (encryption.Cipher, error) { | func newTestCipher(secret []byte) (encryption.Cipher, error) { | ||||||
| 	return encryption.NewBase64Cipher(encryption.NewCFBCipher, secret) | 	return encryption.NewBase64Cipher(encryption.NewCFBCipher, secret) | ||||||
| } | } | ||||||
|  |  | ||||||
| func TestSessionStateSerialization(t *testing.T) { | func TestSessionStateSerialization(t *testing.T) { | ||||||
| 	c, err := NewCipher([]byte(secret)) | 	c, err := newTestCipher([]byte(secret)) | ||||||
| 	assert.Equal(t, nil, err) | 	assert.Equal(t, nil, err) | ||||||
| 	c2, err := NewCipher([]byte(altSecret)) | 	c2, err := newTestCipher([]byte(altSecret)) | ||||||
| 	assert.Equal(t, nil, err) | 	assert.Equal(t, nil, err) | ||||||
| 	s := &sessions.SessionState{ | 	s := &sessions.SessionState{ | ||||||
| 		Email:             "user@domain.com", | 		Email:             "user@domain.com", | ||||||
| @@ -57,9 +57,9 @@ func TestSessionStateSerialization(t *testing.T) { | |||||||
| } | } | ||||||
|  |  | ||||||
| func TestSessionStateSerializationWithUser(t *testing.T) { | func TestSessionStateSerializationWithUser(t *testing.T) { | ||||||
| 	c, err := NewCipher([]byte(secret)) | 	c, err := newTestCipher([]byte(secret)) | ||||||
| 	assert.Equal(t, nil, err) | 	assert.Equal(t, nil, err) | ||||||
| 	c2, err := NewCipher([]byte(altSecret)) | 	c2, err := newTestCipher([]byte(altSecret)) | ||||||
| 	assert.Equal(t, nil, err) | 	assert.Equal(t, nil, err) | ||||||
| 	s := &sessions.SessionState{ | 	s := &sessions.SessionState{ | ||||||
| 		User:              "just-user", | 		User:              "just-user", | ||||||
| @@ -205,7 +205,7 @@ func TestDecodeSessionState(t *testing.T) { | |||||||
| 	eJSON, _ := e.MarshalJSON() | 	eJSON, _ := e.MarshalJSON() | ||||||
| 	eString := string(eJSON) | 	eString := string(eJSON) | ||||||
|  |  | ||||||
| 	c, err := NewCipher([]byte(secret)) | 	c, err := newTestCipher([]byte(secret)) | ||||||
| 	assert.NoError(t, err) | 	assert.NoError(t, err) | ||||||
|  |  | ||||||
| 	testCases := []testCase{ | 	testCases := []testCase{ | ||||||
|   | |||||||
| @@ -13,30 +13,11 @@ import ( | |||||||
| type Cipher interface { | type Cipher interface { | ||||||
| 	Encrypt(value []byte) ([]byte, error) | 	Encrypt(value []byte) ([]byte, error) | ||||||
| 	Decrypt(ciphertext []byte) ([]byte, error) | 	Decrypt(ciphertext []byte) ([]byte, error) | ||||||
|     EncryptInto(s *string) error | 	EncryptInto(s *string) error | ||||||
| 	DecryptInto(s *string) error | 	DecryptInto(s *string) error | ||||||
| } | } | ||||||
|  |  | ||||||
| type DefaultCipher struct {} |  | ||||||
|  |  | ||||||
| // Encrypt is a dummy method for CommonCipher.EncryptInto support |  | ||||||
| func (c *DefaultCipher) Encrypt(value []byte) ([]byte, error) { return value, nil } |  | ||||||
|  |  | ||||||
| // Decrypt is a dummy method for CommonCipher.DecryptInto support |  | ||||||
| func (c *DefaultCipher) Decrypt(ciphertext []byte) ([]byte, error) { return ciphertext, nil } |  | ||||||
|  |  | ||||||
| // EncryptInto encrypts the value and stores it back in the string pointer |  | ||||||
| func (c *DefaultCipher) EncryptInto(s *string) error { |  | ||||||
| 	return into(c.Encrypt, s) |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // DecryptInto decrypts the value and stores it back in the string pointer |  | ||||||
| func (c *DefaultCipher) DecryptInto(s *string) error { |  | ||||||
| 	return into(c.Decrypt, s) |  | ||||||
| } |  | ||||||
|  |  | ||||||
| type Base64Cipher struct { | type Base64Cipher struct { | ||||||
| 	DefaultCipher |  | ||||||
| 	Cipher Cipher | 	Cipher Cipher | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -52,7 +33,7 @@ func NewBase64Cipher(initCipher func([]byte) (Cipher, error), secret []byte) (Ci | |||||||
|  |  | ||||||
| // Encrypt encrypts a value with the embedded Cipher & Base64 encodes it | // Encrypt encrypts a value with the embedded Cipher & Base64 encodes it | ||||||
| func (c *Base64Cipher) Encrypt(value []byte) ([]byte, error) { | func (c *Base64Cipher) Encrypt(value []byte) ([]byte, error) { | ||||||
| 	encrypted, err := c.Cipher.Encrypt([]byte(value)) | 	encrypted, err := c.Cipher.Encrypt(value) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| @@ -70,8 +51,17 @@ func (c *Base64Cipher) Decrypt(ciphertext []byte) ([]byte, error) { | |||||||
| 	return c.Cipher.Decrypt(encrypted) | 	return c.Cipher.Decrypt(encrypted) | ||||||
| } | } | ||||||
|  |  | ||||||
|  | // EncryptInto encrypts the value and stores it back in the string pointer | ||||||
|  | func (c *Base64Cipher) EncryptInto(s *string) error { | ||||||
|  | 	return into(c.Encrypt, s) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // DecryptInto decrypts the value and stores it back in the string pointer | ||||||
|  | func (c *Base64Cipher) DecryptInto(s *string) error { | ||||||
|  | 	return into(c.Decrypt, s) | ||||||
|  | } | ||||||
|  |  | ||||||
| type CFBCipher struct { | type CFBCipher struct { | ||||||
| 	DefaultCipher |  | ||||||
| 	cipher.Block | 	cipher.Block | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -111,8 +101,17 @@ func (c *CFBCipher) Decrypt(ciphertext []byte) ([]byte, error) { | |||||||
| 	return plaintext, nil | 	return plaintext, nil | ||||||
| } | } | ||||||
|  |  | ||||||
|  | // EncryptInto returns an error since the encrypted data is a []byte that isn't string cast-able | ||||||
|  | func (c *CFBCipher) EncryptInto(s *string) error { | ||||||
|  | 	return fmt.Errorf("CFBCipher is not a string->string compatible cipher") | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // EncryptInto returns an error since the encrypted data needs to be a []byte | ||||||
|  | func (c *CFBCipher) DecryptInto(s *string) error { | ||||||
|  | 	return fmt.Errorf("CFBCipher is not a string->string compatible cipher") | ||||||
|  | } | ||||||
|  |  | ||||||
| type GCMCipher struct { | type GCMCipher struct { | ||||||
| 	DefaultCipher |  | ||||||
| 	cipher.Block | 	cipher.Block | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -158,6 +157,16 @@ func (c *GCMCipher) Decrypt(ciphertext []byte) ([]byte, error) { | |||||||
| 	return plaintext, nil | 	return plaintext, nil | ||||||
| } | } | ||||||
|  |  | ||||||
|  | // EncryptInto returns an error since the encrypted data is a []byte that isn't string cast-able | ||||||
|  | func (c *GCMCipher) EncryptInto(s *string) error { | ||||||
|  | 	return fmt.Errorf("CFBCipher is not a string->string compatible cipher") | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // EncryptInto returns an error since the encrypted data needs to be a []byte | ||||||
|  | func (c *GCMCipher) DecryptInto(s *string) error { | ||||||
|  | 	return fmt.Errorf("CFBCipher is not a string->string compatible cipher") | ||||||
|  | } | ||||||
|  |  | ||||||
| // codecFunc is a function that takes a string and encodes/decodes it | // codecFunc is a function that takes a string and encodes/decodes it | ||||||
| type codecFunc func([]byte) ([]byte, error) | type codecFunc func([]byte) ([]byte, error) | ||||||
|  |  | ||||||
|   | |||||||
| @@ -5,6 +5,7 @@ import ( | |||||||
| 	"encoding/base64" | 	"encoding/base64" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"io" | 	"io" | ||||||
|  | 	mathrand "math/rand" | ||||||
| 	"testing" | 	"testing" | ||||||
|  |  | ||||||
| 	"github.com/stretchr/testify/assert" | 	"github.com/stretchr/testify/assert" | ||||||
| @@ -117,6 +118,80 @@ func runEncryptAndDecrypt(t *testing.T, c Cipher, dataSize int) { | |||||||
| 	assert.NotEqual(t, encrypted, decrypted) | 	assert.NotEqual(t, encrypted, decrypted) | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func TestEncryptIntoAndDecryptInto(t *testing.T) { | ||||||
|  | 	// Test our 2 cipher types | ||||||
|  | 	cipherInits := map[string]func([]byte) (Cipher, error){ | ||||||
|  | 		"CFB": NewCFBCipher, | ||||||
|  | 		"GCM": NewGCMCipher, | ||||||
|  | 	} | ||||||
|  | 	for name, initCipher := range cipherInits { | ||||||
|  | 		t.Run(name, func(t *testing.T) { | ||||||
|  | 			// Test all 3 valid AES sizes | ||||||
|  | 			for _, secretSize := range []int{16, 24, 32} { | ||||||
|  | 				t.Run(fmt.Sprintf("%d", secretSize), func(t *testing.T) { | ||||||
|  | 					secret := make([]byte, secretSize) | ||||||
|  | 					_, err := io.ReadFull(rand.Reader, secret) | ||||||
|  | 					assert.Equal(t, nil, err) | ||||||
|  |  | ||||||
|  | 					// Test Standard & Base64 wrapped | ||||||
|  | 					cstd, err := initCipher(secret) | ||||||
|  | 					assert.Equal(t, nil, err) | ||||||
|  |  | ||||||
|  | 					cb64, err := NewBase64Cipher(initCipher, secret) | ||||||
|  | 					assert.Equal(t, nil, err) | ||||||
|  |  | ||||||
|  | 					ciphers := map[string]Cipher{ | ||||||
|  | 						"Standard": cstd, | ||||||
|  | 						"Base64":   cb64, | ||||||
|  | 					} | ||||||
|  |  | ||||||
|  | 					for cName, c := range ciphers { | ||||||
|  | 						// Check no errors with empty or nil strings | ||||||
|  | 						if cName == "Base64" { | ||||||
|  | 							empty := "" | ||||||
|  | 							assert.Equal(t, nil, c.EncryptInto(&empty)) | ||||||
|  | 							assert.Equal(t, nil, c.DecryptInto(&empty)) | ||||||
|  | 							assert.Equal(t, nil, c.EncryptInto(nil)) | ||||||
|  | 							assert.Equal(t, nil, c.DecryptInto(nil)) | ||||||
|  | 						} | ||||||
|  |  | ||||||
|  | 						t.Run(cName, func(t *testing.T) { | ||||||
|  | 							// Test various sizes sessions might be | ||||||
|  | 							for _, dataSize := range []int{10, 100, 1000, 5000, 10000} { | ||||||
|  | 								t.Run(fmt.Sprintf("%d", dataSize), func(t *testing.T) { | ||||||
|  | 									runEncryptIntoAndDecryptInto(t, c, cName, dataSize) | ||||||
|  | 								}) | ||||||
|  | 							} | ||||||
|  | 						}) | ||||||
|  | 					} | ||||||
|  | 				}) | ||||||
|  | 			} | ||||||
|  | 		}) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func runEncryptIntoAndDecryptInto(t *testing.T, c Cipher, cipherType string, dataSize int) { | ||||||
|  | 	const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" | ||||||
|  | 	b := make([]byte, dataSize) | ||||||
|  | 	for i := range b { | ||||||
|  | 		b[i] = charset[mathrand.Intn(len(charset))] | ||||||
|  | 	} | ||||||
|  | 	data := string(b) | ||||||
|  | 	originalData := data | ||||||
|  |  | ||||||
|  | 	// Base64 is the only cipher that supports string->string Encrypt/Decrypt Into methods | ||||||
|  | 	if cipherType == "Base64" { | ||||||
|  | 		assert.Equal(t, nil, c.EncryptInto(&data)) | ||||||
|  | 		assert.NotEqual(t, originalData, data) | ||||||
|  |  | ||||||
|  | 		assert.Equal(t, nil, c.DecryptInto(&data)) | ||||||
|  | 		assert.Equal(t, originalData, data) | ||||||
|  | 	} else { | ||||||
|  | 		assert.NotEqual(t, nil, c.EncryptInto(&data)) | ||||||
|  | 		assert.NotEqual(t, nil, c.DecryptInto(&data)) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
| func TestDecryptCFBWrongSecret(t *testing.T) { | func TestDecryptCFBWrongSecret(t *testing.T) { | ||||||
| 	secret1 := []byte("0123456789abcdefghijklmnopqrstuv") | 	secret1 := []byte("0123456789abcdefghijklmnopqrstuv") | ||||||
| 	secret2 := []byte("9876543210abcdefghijklmnopqrstuv") | 	secret2 := []byte("9876543210abcdefghijklmnopqrstuv") | ||||||
| @@ -228,25 +303,3 @@ func TestCFBtoGCMErrors(t *testing.T) { | |||||||
| 		}) | 		}) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| func TestEncodeIntoAndDecodeIntoAccessToken(t *testing.T) { |  | ||||||
| 	const secret = "0123456789abcdefghijklmnopqrstuv" |  | ||||||
| 	c, err := NewCipher([]byte(secret)) |  | ||||||
| 	assert.Equal(t, nil, err) |  | ||||||
|  |  | ||||||
| 	token := "my access token" |  | ||||||
| 	originalToken := token |  | ||||||
|  |  | ||||||
| 	assert.Equal(t, nil, c.EncryptInto(&token)) |  | ||||||
| 	assert.NotEqual(t, originalToken, token) |  | ||||||
|  |  | ||||||
| 	assert.Equal(t, nil, c.DecryptInto(&token)) |  | ||||||
| 	assert.Equal(t, originalToken, token) |  | ||||||
|  |  | ||||||
| 	// Check no errors with empty or nil strings |  | ||||||
| 	empty := "" |  | ||||||
| 	assert.Equal(t, nil, c.EncryptInto(&empty)) |  | ||||||
| 	assert.Equal(t, nil, c.DecryptInto(&empty)) |  | ||||||
| 	assert.Equal(t, nil, c.EncryptInto(nil)) |  | ||||||
| 	assert.Equal(t, nil, c.DecryptInto(nil)) |  | ||||||
| } |  | ||||||
|   | |||||||
| @@ -104,4 +104,3 @@ func checkHmac(input, expected string) bool { | |||||||
| 	} | 	} | ||||||
| 	return false | 	return false | ||||||
| } | } | ||||||
|  |  | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user