You've already forked oauth2-proxy
							
							
				mirror of
				https://github.com/oauth2-proxy/oauth2-proxy.git
				synced 2025-10-30 23:47:52 +02:00 
			
		
		
		
	Add EncryptInto/DecryptInto Unit Tests
This commit is contained in:
		| @@ -55,6 +55,7 @@ | ||||
|  | ||||
| ## Changes since v5.1.1 | ||||
|  | ||||
| - [#539](https://github.com/oauth2-proxy/oauth2-proxy/pull/539) Refactor encryption ciphers and add AES-GCM support (@NickMeves) | ||||
| - [#601](https://github.com/oauth2-proxy/oauth2-proxy/pull/601) Ensure decrypted user/email are valid UTF8 (@JoelSpeed) | ||||
| - [#560](https://github.com/oauth2-proxy/oauth2-proxy/pull/560) Fallback to UserInfo is User ID claim not present (@JoelSpeed) | ||||
| - [#598](https://github.com/oauth2-proxy/oauth2-proxy/pull/598) acr_values no longer sent to IdP when empty (@ScottGuymer) | ||||
|   | ||||
| @@ -104,24 +104,18 @@ func DecodeSessionState(v string, c encryption.Cipher) (*SessionState, error) { | ||||
| 			PreferredUsername: ss.PreferredUsername, | ||||
| 		} | ||||
| 	} else { | ||||
| 		// Backward compatibility with using unencrypted Email | ||||
| 		if ss.Email != "" { | ||||
| 			decryptedEmail, errEmail := stringDecrypt(ss.Email, c) | ||||
| 			if errEmail == nil { | ||||
| 				if !utf8.ValidString(decryptedEmail) { | ||||
| 					return nil, errors.New("invalid value for decrypted email") | ||||
| 				} | ||||
| 				ss.Email = decryptedEmail | ||||
| 		// Backward compatibility with using unencrypted Email or User | ||||
| 		// Decryption errors will leave original string | ||||
| 		err = c.DecryptInto(&ss.Email) | ||||
| 		if err == nil { | ||||
| 			if !utf8.ValidString(ss.Email) { | ||||
| 				return nil, errors.New("invalid value for decrypted email") | ||||
| 			} | ||||
| 		} | ||||
| 		// Backward compatibility with using unencrypted User | ||||
| 		if ss.User != "" { | ||||
| 			decryptedUser, errUser := stringDecrypt(ss.User, c) | ||||
| 			if errUser == nil { | ||||
| 				if !utf8.ValidString(decryptedUser) { | ||||
| 					return nil, errors.New("invalid value for decrypted user") | ||||
| 				} | ||||
| 				ss.User = decryptedUser | ||||
| 		err = c.DecryptInto(&ss.User) | ||||
| 		if err == nil { | ||||
| 			if !utf8.ValidString(ss.User) { | ||||
| 				return nil, errors.New("invalid value for decrypted user") | ||||
| 			} | ||||
| 		} | ||||
|  | ||||
| @@ -139,12 +133,3 @@ func DecodeSessionState(v string, c encryption.Cipher) (*SessionState, error) { | ||||
| 	} | ||||
| 	return &ss, nil | ||||
| } | ||||
|  | ||||
| // stringDecrypt wraps a Base64Cipher to make it string => string | ||||
| func stringDecrypt(ciphertext string, c encryption.Cipher) (string, error) { | ||||
| 	value, err := c.Decrypt([]byte(ciphertext)) | ||||
| 	if err != nil { | ||||
| 		return "", err | ||||
| 	} | ||||
| 	return string(value), nil | ||||
| } | ||||
|   | ||||
| @@ -17,14 +17,14 @@ func timePtr(t time.Time) *time.Time { | ||||
| 	return &t | ||||
| } | ||||
|  | ||||
| func NewCipher(secret []byte) (encryption.Cipher, error) { | ||||
| func newTestCipher(secret []byte) (encryption.Cipher, error) { | ||||
| 	return encryption.NewBase64Cipher(encryption.NewCFBCipher, secret) | ||||
| } | ||||
|  | ||||
| func TestSessionStateSerialization(t *testing.T) { | ||||
| 	c, err := NewCipher([]byte(secret)) | ||||
| 	c, err := newTestCipher([]byte(secret)) | ||||
| 	assert.Equal(t, nil, err) | ||||
| 	c2, err := NewCipher([]byte(altSecret)) | ||||
| 	c2, err := newTestCipher([]byte(altSecret)) | ||||
| 	assert.Equal(t, nil, err) | ||||
| 	s := &sessions.SessionState{ | ||||
| 		Email:             "user@domain.com", | ||||
| @@ -57,9 +57,9 @@ func TestSessionStateSerialization(t *testing.T) { | ||||
| } | ||||
|  | ||||
| func TestSessionStateSerializationWithUser(t *testing.T) { | ||||
| 	c, err := NewCipher([]byte(secret)) | ||||
| 	c, err := newTestCipher([]byte(secret)) | ||||
| 	assert.Equal(t, nil, err) | ||||
| 	c2, err := NewCipher([]byte(altSecret)) | ||||
| 	c2, err := newTestCipher([]byte(altSecret)) | ||||
| 	assert.Equal(t, nil, err) | ||||
| 	s := &sessions.SessionState{ | ||||
| 		User:              "just-user", | ||||
| @@ -205,7 +205,7 @@ func TestDecodeSessionState(t *testing.T) { | ||||
| 	eJSON, _ := e.MarshalJSON() | ||||
| 	eString := string(eJSON) | ||||
|  | ||||
| 	c, err := NewCipher([]byte(secret)) | ||||
| 	c, err := newTestCipher([]byte(secret)) | ||||
| 	assert.NoError(t, err) | ||||
|  | ||||
| 	testCases := []testCase{ | ||||
|   | ||||
| @@ -13,30 +13,11 @@ import ( | ||||
| type Cipher interface { | ||||
| 	Encrypt(value []byte) ([]byte, error) | ||||
| 	Decrypt(ciphertext []byte) ([]byte, error) | ||||
|     EncryptInto(s *string) error | ||||
| 	EncryptInto(s *string) error | ||||
| 	DecryptInto(s *string) error | ||||
| } | ||||
|  | ||||
| type DefaultCipher struct {} | ||||
|  | ||||
| // Encrypt is a dummy method for CommonCipher.EncryptInto support | ||||
| func (c *DefaultCipher) Encrypt(value []byte) ([]byte, error) { return value, nil } | ||||
|  | ||||
| // Decrypt is a dummy method for CommonCipher.DecryptInto support | ||||
| func (c *DefaultCipher) Decrypt(ciphertext []byte) ([]byte, error) { return ciphertext, nil } | ||||
|  | ||||
| // EncryptInto encrypts the value and stores it back in the string pointer | ||||
| func (c *DefaultCipher) EncryptInto(s *string) error { | ||||
| 	return into(c.Encrypt, s) | ||||
| } | ||||
|  | ||||
| // DecryptInto decrypts the value and stores it back in the string pointer | ||||
| func (c *DefaultCipher) DecryptInto(s *string) error { | ||||
| 	return into(c.Decrypt, s) | ||||
| } | ||||
|  | ||||
| type Base64Cipher struct { | ||||
| 	DefaultCipher | ||||
| 	Cipher Cipher | ||||
| } | ||||
|  | ||||
| @@ -52,7 +33,7 @@ func NewBase64Cipher(initCipher func([]byte) (Cipher, error), secret []byte) (Ci | ||||
|  | ||||
| // Encrypt encrypts a value with the embedded Cipher & Base64 encodes it | ||||
| func (c *Base64Cipher) Encrypt(value []byte) ([]byte, error) { | ||||
| 	encrypted, err := c.Cipher.Encrypt([]byte(value)) | ||||
| 	encrypted, err := c.Cipher.Encrypt(value) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| @@ -70,8 +51,17 @@ func (c *Base64Cipher) Decrypt(ciphertext []byte) ([]byte, error) { | ||||
| 	return c.Cipher.Decrypt(encrypted) | ||||
| } | ||||
|  | ||||
| // EncryptInto encrypts the value and stores it back in the string pointer | ||||
| func (c *Base64Cipher) EncryptInto(s *string) error { | ||||
| 	return into(c.Encrypt, s) | ||||
| } | ||||
|  | ||||
| // DecryptInto decrypts the value and stores it back in the string pointer | ||||
| func (c *Base64Cipher) DecryptInto(s *string) error { | ||||
| 	return into(c.Decrypt, s) | ||||
| } | ||||
|  | ||||
| type CFBCipher struct { | ||||
| 	DefaultCipher | ||||
| 	cipher.Block | ||||
| } | ||||
|  | ||||
| @@ -111,8 +101,17 @@ func (c *CFBCipher) Decrypt(ciphertext []byte) ([]byte, error) { | ||||
| 	return plaintext, nil | ||||
| } | ||||
|  | ||||
| // EncryptInto returns an error since the encrypted data is a []byte that isn't string cast-able | ||||
| func (c *CFBCipher) EncryptInto(s *string) error { | ||||
| 	return fmt.Errorf("CFBCipher is not a string->string compatible cipher") | ||||
| } | ||||
|  | ||||
| // EncryptInto returns an error since the encrypted data needs to be a []byte | ||||
| func (c *CFBCipher) DecryptInto(s *string) error { | ||||
| 	return fmt.Errorf("CFBCipher is not a string->string compatible cipher") | ||||
| } | ||||
|  | ||||
| type GCMCipher struct { | ||||
| 	DefaultCipher | ||||
| 	cipher.Block | ||||
| } | ||||
|  | ||||
| @@ -158,6 +157,16 @@ func (c *GCMCipher) Decrypt(ciphertext []byte) ([]byte, error) { | ||||
| 	return plaintext, nil | ||||
| } | ||||
|  | ||||
| // EncryptInto returns an error since the encrypted data is a []byte that isn't string cast-able | ||||
| func (c *GCMCipher) EncryptInto(s *string) error { | ||||
| 	return fmt.Errorf("CFBCipher is not a string->string compatible cipher") | ||||
| } | ||||
|  | ||||
| // EncryptInto returns an error since the encrypted data needs to be a []byte | ||||
| func (c *GCMCipher) DecryptInto(s *string) error { | ||||
| 	return fmt.Errorf("CFBCipher is not a string->string compatible cipher") | ||||
| } | ||||
|  | ||||
| // codecFunc is a function that takes a string and encodes/decodes it | ||||
| type codecFunc func([]byte) ([]byte, error) | ||||
|  | ||||
|   | ||||
| @@ -5,6 +5,7 @@ import ( | ||||
| 	"encoding/base64" | ||||
| 	"fmt" | ||||
| 	"io" | ||||
| 	mathrand "math/rand" | ||||
| 	"testing" | ||||
|  | ||||
| 	"github.com/stretchr/testify/assert" | ||||
| @@ -117,6 +118,80 @@ func runEncryptAndDecrypt(t *testing.T, c Cipher, dataSize int) { | ||||
| 	assert.NotEqual(t, encrypted, decrypted) | ||||
| } | ||||
|  | ||||
| func TestEncryptIntoAndDecryptInto(t *testing.T) { | ||||
| 	// Test our 2 cipher types | ||||
| 	cipherInits := map[string]func([]byte) (Cipher, error){ | ||||
| 		"CFB": NewCFBCipher, | ||||
| 		"GCM": NewGCMCipher, | ||||
| 	} | ||||
| 	for name, initCipher := range cipherInits { | ||||
| 		t.Run(name, func(t *testing.T) { | ||||
| 			// Test all 3 valid AES sizes | ||||
| 			for _, secretSize := range []int{16, 24, 32} { | ||||
| 				t.Run(fmt.Sprintf("%d", secretSize), func(t *testing.T) { | ||||
| 					secret := make([]byte, secretSize) | ||||
| 					_, err := io.ReadFull(rand.Reader, secret) | ||||
| 					assert.Equal(t, nil, err) | ||||
|  | ||||
| 					// Test Standard & Base64 wrapped | ||||
| 					cstd, err := initCipher(secret) | ||||
| 					assert.Equal(t, nil, err) | ||||
|  | ||||
| 					cb64, err := NewBase64Cipher(initCipher, secret) | ||||
| 					assert.Equal(t, nil, err) | ||||
|  | ||||
| 					ciphers := map[string]Cipher{ | ||||
| 						"Standard": cstd, | ||||
| 						"Base64":   cb64, | ||||
| 					} | ||||
|  | ||||
| 					for cName, c := range ciphers { | ||||
| 						// Check no errors with empty or nil strings | ||||
| 						if cName == "Base64" { | ||||
| 							empty := "" | ||||
| 							assert.Equal(t, nil, c.EncryptInto(&empty)) | ||||
| 							assert.Equal(t, nil, c.DecryptInto(&empty)) | ||||
| 							assert.Equal(t, nil, c.EncryptInto(nil)) | ||||
| 							assert.Equal(t, nil, c.DecryptInto(nil)) | ||||
| 						} | ||||
|  | ||||
| 						t.Run(cName, func(t *testing.T) { | ||||
| 							// Test various sizes sessions might be | ||||
| 							for _, dataSize := range []int{10, 100, 1000, 5000, 10000} { | ||||
| 								t.Run(fmt.Sprintf("%d", dataSize), func(t *testing.T) { | ||||
| 									runEncryptIntoAndDecryptInto(t, c, cName, dataSize) | ||||
| 								}) | ||||
| 							} | ||||
| 						}) | ||||
| 					} | ||||
| 				}) | ||||
| 			} | ||||
| 		}) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func runEncryptIntoAndDecryptInto(t *testing.T, c Cipher, cipherType string, dataSize int) { | ||||
| 	const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" | ||||
| 	b := make([]byte, dataSize) | ||||
| 	for i := range b { | ||||
| 		b[i] = charset[mathrand.Intn(len(charset))] | ||||
| 	} | ||||
| 	data := string(b) | ||||
| 	originalData := data | ||||
|  | ||||
| 	// Base64 is the only cipher that supports string->string Encrypt/Decrypt Into methods | ||||
| 	if cipherType == "Base64" { | ||||
| 		assert.Equal(t, nil, c.EncryptInto(&data)) | ||||
| 		assert.NotEqual(t, originalData, data) | ||||
|  | ||||
| 		assert.Equal(t, nil, c.DecryptInto(&data)) | ||||
| 		assert.Equal(t, originalData, data) | ||||
| 	} else { | ||||
| 		assert.NotEqual(t, nil, c.EncryptInto(&data)) | ||||
| 		assert.NotEqual(t, nil, c.DecryptInto(&data)) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestDecryptCFBWrongSecret(t *testing.T) { | ||||
| 	secret1 := []byte("0123456789abcdefghijklmnopqrstuv") | ||||
| 	secret2 := []byte("9876543210abcdefghijklmnopqrstuv") | ||||
| @@ -228,25 +303,3 @@ func TestCFBtoGCMErrors(t *testing.T) { | ||||
| 		}) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestEncodeIntoAndDecodeIntoAccessToken(t *testing.T) { | ||||
| 	const secret = "0123456789abcdefghijklmnopqrstuv" | ||||
| 	c, err := NewCipher([]byte(secret)) | ||||
| 	assert.Equal(t, nil, err) | ||||
|  | ||||
| 	token := "my access token" | ||||
| 	originalToken := token | ||||
|  | ||||
| 	assert.Equal(t, nil, c.EncryptInto(&token)) | ||||
| 	assert.NotEqual(t, originalToken, token) | ||||
|  | ||||
| 	assert.Equal(t, nil, c.DecryptInto(&token)) | ||||
| 	assert.Equal(t, originalToken, token) | ||||
|  | ||||
| 	// Check no errors with empty or nil strings | ||||
| 	empty := "" | ||||
| 	assert.Equal(t, nil, c.EncryptInto(&empty)) | ||||
| 	assert.Equal(t, nil, c.DecryptInto(&empty)) | ||||
| 	assert.Equal(t, nil, c.EncryptInto(nil)) | ||||
| 	assert.Equal(t, nil, c.DecryptInto(nil)) | ||||
| } | ||||
|   | ||||
| @@ -104,4 +104,3 @@ func checkHmac(input, expected string) bool { | ||||
| 	} | ||||
| 	return false | ||||
| } | ||||
|  | ||||
|   | ||||
		Reference in New Issue
	
	Block a user