You've already forked oauth2-proxy
							
							
				mirror of
				https://github.com/oauth2-proxy/oauth2-proxy.git
				synced 2025-10-30 23:47:52 +02:00 
			
		
		
		
	Move Encrypt/Decrypt Into helper to session_state.go
This helper method is only applicable for Base64 wrapped encryption since it operated on string -> string primarily. It wouldn't be used for pure CFB/GCM ciphers. After a messagePack session refactor, this method would further only be used for legacy session compatibility - making its placement in cipher.go not ideal.
This commit is contained in:
		| @@ -77,7 +77,7 @@ func (s *SessionState) EncodeSessionState(c encryption.Cipher) (string, error) { | ||||
| 			&ss.IDToken, | ||||
| 			&ss.RefreshToken, | ||||
| 		} { | ||||
| 			err := c.EncryptInto(s) | ||||
| 			err := into(s, c.Encrypt) | ||||
| 			if err != nil { | ||||
| 				return "", err | ||||
| 			} | ||||
| @@ -106,13 +106,13 @@ func DecodeSessionState(v string, c encryption.Cipher) (*SessionState, error) { | ||||
| 	} else { | ||||
| 		// Backward compatibility with using unencrypted Email or User | ||||
| 		// Decryption errors will leave original string | ||||
| 		err = c.DecryptInto(&ss.Email) | ||||
| 		err = into(&ss.Email, c.Decrypt) | ||||
| 		if err == nil { | ||||
| 			if !utf8.ValidString(ss.Email) { | ||||
| 				return nil, errors.New("invalid value for decrypted email") | ||||
| 			} | ||||
| 		} | ||||
| 		err = c.DecryptInto(&ss.User) | ||||
| 		err = into(&ss.User, c.Decrypt) | ||||
| 		if err == nil { | ||||
| 			if !utf8.ValidString(ss.User) { | ||||
| 				return nil, errors.New("invalid value for decrypted user") | ||||
| @@ -125,7 +125,7 @@ func DecodeSessionState(v string, c encryption.Cipher) (*SessionState, error) { | ||||
| 			&ss.IDToken, | ||||
| 			&ss.RefreshToken, | ||||
| 		} { | ||||
| 			err := c.DecryptInto(s) | ||||
| 			err := into(s, c.Decrypt) | ||||
| 			if err != nil { | ||||
| 				return nil, err | ||||
| 			} | ||||
| @@ -133,3 +133,20 @@ func DecodeSessionState(v string, c encryption.Cipher) (*SessionState, error) { | ||||
| 	} | ||||
| 	return &ss, nil | ||||
| } | ||||
|  | ||||
| // codecFunc is a function that takes a []byte and encodes/decodes it | ||||
| type codecFunc func([]byte) ([]byte, error) | ||||
|  | ||||
| func into(s *string, f codecFunc) error { | ||||
| 	// Do not encrypt/decrypt nil or empty strings | ||||
| 	if s == nil || *s == "" { | ||||
| 		return nil | ||||
| 	} | ||||
|  | ||||
| 	d, err := f([]byte(*s)) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	*s = string(d) | ||||
| 	return nil | ||||
| } | ||||
|   | ||||
| @@ -1,11 +1,13 @@ | ||||
| package sessions_test | ||||
| package sessions | ||||
|  | ||||
| import ( | ||||
| 	"crypto/rand" | ||||
| 	"fmt" | ||||
| 	"io" | ||||
| 	mathrand "math/rand" | ||||
| 	"testing" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions" | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/pkg/encryption" | ||||
| 	"github.com/stretchr/testify/assert" | ||||
| ) | ||||
| @@ -26,7 +28,7 @@ func TestSessionStateSerialization(t *testing.T) { | ||||
| 	assert.Equal(t, nil, err) | ||||
| 	c2, err := newTestCipher([]byte(altSecret)) | ||||
| 	assert.Equal(t, nil, err) | ||||
| 	s := &sessions.SessionState{ | ||||
| 	s := &SessionState{ | ||||
| 		Email:             "user@domain.com", | ||||
| 		PreferredUsername: "user", | ||||
| 		AccessToken:       "token1234", | ||||
| @@ -38,7 +40,7 @@ func TestSessionStateSerialization(t *testing.T) { | ||||
| 	encoded, err := s.EncodeSessionState(c) | ||||
| 	assert.Equal(t, nil, err) | ||||
|  | ||||
| 	ss, err := sessions.DecodeSessionState(encoded, c) | ||||
| 	ss, err := DecodeSessionState(encoded, c) | ||||
| 	t.Logf("%#v", ss) | ||||
| 	assert.Equal(t, nil, err) | ||||
| 	assert.Equal(t, "", ss.User) | ||||
| @@ -51,7 +53,7 @@ func TestSessionStateSerialization(t *testing.T) { | ||||
| 	assert.Equal(t, s.RefreshToken, ss.RefreshToken) | ||||
|  | ||||
| 	// ensure a different cipher can't decode properly (ie: it gets gibberish) | ||||
| 	ss, err = sessions.DecodeSessionState(encoded, c2) | ||||
| 	ss, err = DecodeSessionState(encoded, c2) | ||||
| 	t.Logf("%#v", ss) | ||||
| 	assert.NotEqual(t, nil, err) | ||||
| } | ||||
| @@ -61,7 +63,7 @@ func TestSessionStateSerializationWithUser(t *testing.T) { | ||||
| 	assert.Equal(t, nil, err) | ||||
| 	c2, err := newTestCipher([]byte(altSecret)) | ||||
| 	assert.Equal(t, nil, err) | ||||
| 	s := &sessions.SessionState{ | ||||
| 	s := &SessionState{ | ||||
| 		User:              "just-user", | ||||
| 		PreferredUsername: "ju", | ||||
| 		Email:             "user@domain.com", | ||||
| @@ -73,7 +75,7 @@ func TestSessionStateSerializationWithUser(t *testing.T) { | ||||
| 	encoded, err := s.EncodeSessionState(c) | ||||
| 	assert.Equal(t, nil, err) | ||||
|  | ||||
| 	ss, err := sessions.DecodeSessionState(encoded, c) | ||||
| 	ss, err := DecodeSessionState(encoded, c) | ||||
| 	t.Logf("%#v", ss) | ||||
| 	assert.Equal(t, nil, err) | ||||
| 	assert.Equal(t, s.User, ss.User) | ||||
| @@ -85,13 +87,13 @@ func TestSessionStateSerializationWithUser(t *testing.T) { | ||||
| 	assert.Equal(t, s.RefreshToken, ss.RefreshToken) | ||||
|  | ||||
| 	// ensure a different cipher can't decode properly (ie: it gets gibberish) | ||||
| 	ss, err = sessions.DecodeSessionState(encoded, c2) | ||||
| 	ss, err = DecodeSessionState(encoded, c2) | ||||
| 	t.Logf("%#v", ss) | ||||
| 	assert.NotEqual(t, nil, err) | ||||
| } | ||||
|  | ||||
| func TestSessionStateSerializationNoCipher(t *testing.T) { | ||||
| 	s := &sessions.SessionState{ | ||||
| 	s := &SessionState{ | ||||
| 		Email:             "user@domain.com", | ||||
| 		PreferredUsername: "user", | ||||
| 		AccessToken:       "token1234", | ||||
| @@ -103,7 +105,7 @@ func TestSessionStateSerializationNoCipher(t *testing.T) { | ||||
| 	assert.Equal(t, nil, err) | ||||
|  | ||||
| 	// only email should have been serialized | ||||
| 	ss, err := sessions.DecodeSessionState(encoded, nil) | ||||
| 	ss, err := DecodeSessionState(encoded, nil) | ||||
| 	assert.Equal(t, nil, err) | ||||
| 	assert.Equal(t, "", ss.User) | ||||
| 	assert.Equal(t, s.Email, ss.Email) | ||||
| @@ -113,7 +115,7 @@ func TestSessionStateSerializationNoCipher(t *testing.T) { | ||||
| } | ||||
|  | ||||
| func TestSessionStateSerializationNoCipherWithUser(t *testing.T) { | ||||
| 	s := &sessions.SessionState{ | ||||
| 	s := &SessionState{ | ||||
| 		User:              "just-user", | ||||
| 		Email:             "user@domain.com", | ||||
| 		PreferredUsername: "user", | ||||
| @@ -126,7 +128,7 @@ func TestSessionStateSerializationNoCipherWithUser(t *testing.T) { | ||||
| 	assert.Equal(t, nil, err) | ||||
|  | ||||
| 	// only email should have been serialized | ||||
| 	ss, err := sessions.DecodeSessionState(encoded, nil) | ||||
| 	ss, err := DecodeSessionState(encoded, nil) | ||||
| 	assert.Equal(t, nil, err) | ||||
| 	assert.Equal(t, s.User, ss.User) | ||||
| 	assert.Equal(t, s.Email, ss.Email) | ||||
| @@ -136,18 +138,18 @@ func TestSessionStateSerializationNoCipherWithUser(t *testing.T) { | ||||
| } | ||||
|  | ||||
| func TestExpired(t *testing.T) { | ||||
| 	s := &sessions.SessionState{ExpiresOn: timePtr(time.Now().Add(time.Duration(-1) * time.Minute))} | ||||
| 	s := &SessionState{ExpiresOn: timePtr(time.Now().Add(time.Duration(-1) * time.Minute))} | ||||
| 	assert.Equal(t, true, s.IsExpired()) | ||||
|  | ||||
| 	s = &sessions.SessionState{ExpiresOn: timePtr(time.Now().Add(time.Duration(1) * time.Minute))} | ||||
| 	s = &SessionState{ExpiresOn: timePtr(time.Now().Add(time.Duration(1) * time.Minute))} | ||||
| 	assert.Equal(t, false, s.IsExpired()) | ||||
|  | ||||
| 	s = &sessions.SessionState{} | ||||
| 	s = &SessionState{} | ||||
| 	assert.Equal(t, false, s.IsExpired()) | ||||
| } | ||||
|  | ||||
| type testCase struct { | ||||
| 	sessions.SessionState | ||||
| 	SessionState | ||||
| 	Encoded string | ||||
| 	Cipher  encryption.Cipher | ||||
| 	Error   bool | ||||
| @@ -163,14 +165,14 @@ func TestEncodeSessionState(t *testing.T) { | ||||
|  | ||||
| 	testCases := []testCase{ | ||||
| 		{ | ||||
| 			SessionState: sessions.SessionState{ | ||||
| 			SessionState: SessionState{ | ||||
| 				Email: "user@domain.com", | ||||
| 				User:  "just-user", | ||||
| 			}, | ||||
| 			Encoded: `{"Email":"user@domain.com","User":"just-user"}`, | ||||
| 		}, | ||||
| 		{ | ||||
| 			SessionState: sessions.SessionState{ | ||||
| 			SessionState: SessionState{ | ||||
| 				Email:        "user@domain.com", | ||||
| 				User:         "just-user", | ||||
| 				AccessToken:  "token1234", | ||||
| @@ -185,7 +187,7 @@ func TestEncodeSessionState(t *testing.T) { | ||||
|  | ||||
| 	for i, tc := range testCases { | ||||
| 		encoded, err := tc.EncodeSessionState(tc.Cipher) | ||||
| 		t.Logf("i:%d Encoded:%#vsessions.SessionState:%#v Error:%#v", i, encoded, tc.SessionState, err) | ||||
| 		t.Logf("i:%d Encoded:%#vSessionState:%#v Error:%#v", i, encoded, tc.SessionState, err) | ||||
| 		if tc.Error { | ||||
| 			assert.Error(t, err) | ||||
| 			assert.Empty(t, encoded) | ||||
| @@ -210,34 +212,34 @@ func TestDecodeSessionState(t *testing.T) { | ||||
|  | ||||
| 	testCases := []testCase{ | ||||
| 		{ | ||||
| 			SessionState: sessions.SessionState{ | ||||
| 			SessionState: SessionState{ | ||||
| 				Email: "user@domain.com", | ||||
| 				User:  "just-user", | ||||
| 			}, | ||||
| 			Encoded: `{"Email":"user@domain.com","User":"just-user"}`, | ||||
| 		}, | ||||
| 		{ | ||||
| 			SessionState: sessions.SessionState{ | ||||
| 			SessionState: SessionState{ | ||||
| 				Email: "user@domain.com", | ||||
| 				User:  "", | ||||
| 			}, | ||||
| 			Encoded: `{"Email":"user@domain.com"}`, | ||||
| 		}, | ||||
| 		{ | ||||
| 			SessionState: sessions.SessionState{ | ||||
| 			SessionState: SessionState{ | ||||
| 				User: "just-user", | ||||
| 			}, | ||||
| 			Encoded: `{"User":"just-user"}`, | ||||
| 		}, | ||||
| 		{ | ||||
| 			SessionState: sessions.SessionState{ | ||||
| 			SessionState: SessionState{ | ||||
| 				Email: "user@domain.com", | ||||
| 				User:  "just-user", | ||||
| 			}, | ||||
| 			Encoded: fmt.Sprintf(`{"Email":"user@domain.com","User":"just-user","AccessToken":"I6s+ml+/MldBMgHIiC35BTKTh57skGX24w==","IDToken":"xojNdyyjB1HgYWh6XMtXY/Ph5eCVxa1cNsklJw==","RefreshToken":"qEX0x6RmASxo4dhlBG6YuRs9Syn/e9sHu/+K","CreatedAt":%s,"ExpiresOn":%s}`, createdString, eString), | ||||
| 		}, | ||||
| 		{ | ||||
| 			SessionState: sessions.SessionState{ | ||||
| 			SessionState: SessionState{ | ||||
| 				Email:        "user@domain.com", | ||||
| 				User:         "just-user", | ||||
| 				AccessToken:  "token1234", | ||||
| @@ -250,7 +252,7 @@ func TestDecodeSessionState(t *testing.T) { | ||||
| 			Cipher:  c, | ||||
| 		}, | ||||
| 		{ | ||||
| 			SessionState: sessions.SessionState{ | ||||
| 			SessionState: SessionState{ | ||||
| 				Email: "user@domain.com", | ||||
| 				User:  "just-user", | ||||
| 			}, | ||||
| @@ -268,7 +270,7 @@ func TestDecodeSessionState(t *testing.T) { | ||||
| 			Error:   true, | ||||
| 		}, | ||||
| 		{ | ||||
| 			SessionState: sessions.SessionState{ | ||||
| 			SessionState: SessionState{ | ||||
| 				Email: "user@domain.com", | ||||
| 				User:  "YmFzZTY0LWVuY29kZWQtdXNlcgo=", // Base64 encoding of base64-encoded-user | ||||
| 			}, | ||||
| @@ -278,8 +280,8 @@ func TestDecodeSessionState(t *testing.T) { | ||||
| 	} | ||||
|  | ||||
| 	for i, tc := range testCases { | ||||
| 		ss, err := sessions.DecodeSessionState(tc.Encoded, tc.Cipher) | ||||
| 		t.Logf("i:%d Encoded:%#vsessions.SessionState:%#v Error:%#v", i, tc.Encoded, ss, err) | ||||
| 		ss, err := DecodeSessionState(tc.Encoded, tc.Cipher) | ||||
| 		t.Logf("i:%d Encoded:%#vSessionState:%#v Error:%#v", i, tc.Encoded, ss, err) | ||||
| 		if tc.Error { | ||||
| 			assert.Error(t, err) | ||||
| 			assert.Nil(t, ss) | ||||
| @@ -301,7 +303,7 @@ func TestDecodeSessionState(t *testing.T) { | ||||
| } | ||||
|  | ||||
| func TestSessionStateAge(t *testing.T) { | ||||
| 	ss := &sessions.SessionState{} | ||||
| 	ss := &SessionState{} | ||||
|  | ||||
| 	// Created at unset so should be 0 | ||||
| 	assert.Equal(t, time.Duration(0), ss.Age()) | ||||
| @@ -310,3 +312,44 @@ func TestSessionStateAge(t *testing.T) { | ||||
| 	ss.CreatedAt = timePtr(time.Now().Add(-1 * time.Hour)) | ||||
| 	assert.Equal(t, time.Hour, ss.Age().Round(time.Minute)) | ||||
| } | ||||
|  | ||||
| func TestIntoEncryptAndIntoDecrypt(t *testing.T) { | ||||
| 	const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" | ||||
|  | ||||
| 	// Test all 3 valid AES sizes | ||||
| 	for _, secretSize := range []int{16, 24, 32} { | ||||
| 		t.Run(fmt.Sprintf("%d", secretSize), func(t *testing.T) { | ||||
| 			secret := make([]byte, secretSize) | ||||
| 			_, err := io.ReadFull(rand.Reader, secret) | ||||
| 			assert.Equal(t, nil, err) | ||||
|  | ||||
| 			c, err := newTestCipher(secret) | ||||
| 			assert.NoError(t, err) | ||||
|  | ||||
| 			// Check no errors with empty or nil strings | ||||
| 			empty := "" | ||||
| 			assert.Equal(t, nil, into(&empty, c.Encrypt)) | ||||
| 			assert.Equal(t, nil, into(&empty, c.Decrypt)) | ||||
| 			assert.Equal(t, nil, into(nil, c.Encrypt)) | ||||
| 			assert.Equal(t, nil, into(nil, c.Decrypt)) | ||||
|  | ||||
| 			// Test various sizes tokens might be | ||||
| 			for _, dataSize := range []int{10, 100, 1000, 5000, 10000} { | ||||
| 				t.Run(fmt.Sprintf("%d", dataSize), func(t *testing.T) { | ||||
| 					b := make([]byte, dataSize) | ||||
| 					for i := range b { | ||||
| 						b[i] = charset[mathrand.Intn(len(charset))] | ||||
| 					} | ||||
| 					data := string(b) | ||||
| 					originalData := data | ||||
|  | ||||
| 					assert.Equal(t, nil, into(&data, c.Encrypt)) | ||||
| 					assert.NotEqual(t, originalData, data) | ||||
|  | ||||
| 					assert.Equal(t, nil, into(&data, c.Decrypt)) | ||||
| 					assert.Equal(t, originalData, data) | ||||
| 				}) | ||||
| 			} | ||||
| 		}) | ||||
| 	} | ||||
| } | ||||
|   | ||||
| @@ -13,11 +13,9 @@ import ( | ||||
| type Cipher interface { | ||||
| 	Encrypt(value []byte) ([]byte, error) | ||||
| 	Decrypt(ciphertext []byte) ([]byte, error) | ||||
| 	EncryptInto(s *string) error | ||||
| 	DecryptInto(s *string) error | ||||
| } | ||||
|  | ||||
| type Base64Cipher struct { | ||||
| type base64Cipher struct { | ||||
| 	Cipher Cipher | ||||
| } | ||||
|  | ||||
| @@ -28,11 +26,11 @@ func NewBase64Cipher(initCipher func([]byte) (Cipher, error), secret []byte) (Ci | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	return &Base64Cipher{Cipher: c}, nil | ||||
| 	return &base64Cipher{Cipher: c}, nil | ||||
| } | ||||
|  | ||||
| // Encrypt encrypts a value with the embedded Cipher & Base64 encodes it | ||||
| func (c *Base64Cipher) Encrypt(value []byte) ([]byte, error) { | ||||
| func (c *base64Cipher) Encrypt(value []byte) ([]byte, error) { | ||||
| 	encrypted, err := c.Cipher.Encrypt(value) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| @@ -42,7 +40,7 @@ func (c *Base64Cipher) Encrypt(value []byte) ([]byte, error) { | ||||
| } | ||||
|  | ||||
| // Decrypt Base64 decodes a value & decrypts it with the embedded Cipher | ||||
| func (c *Base64Cipher) Decrypt(ciphertext []byte) ([]byte, error) { | ||||
| func (c *base64Cipher) Decrypt(ciphertext []byte) ([]byte, error) { | ||||
| 	encrypted, err := base64.StdEncoding.DecodeString(string(ciphertext)) | ||||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("failed to base64 decode value %s", err) | ||||
| @@ -51,17 +49,7 @@ func (c *Base64Cipher) Decrypt(ciphertext []byte) ([]byte, error) { | ||||
| 	return c.Cipher.Decrypt(encrypted) | ||||
| } | ||||
|  | ||||
| // EncryptInto encrypts the value and stores it back in the string pointer | ||||
| func (c *Base64Cipher) EncryptInto(s *string) error { | ||||
| 	return into(c.Encrypt, s) | ||||
| } | ||||
|  | ||||
| // DecryptInto decrypts the value and stores it back in the string pointer | ||||
| func (c *Base64Cipher) DecryptInto(s *string) error { | ||||
| 	return into(c.Decrypt, s) | ||||
| } | ||||
|  | ||||
| type CFBCipher struct { | ||||
| type cfbCipher struct { | ||||
| 	cipher.Block | ||||
| } | ||||
|  | ||||
| @@ -71,11 +59,11 @@ func NewCFBCipher(secret []byte) (Cipher, error) { | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	return &CFBCipher{Block: c}, err | ||||
| 	return &cfbCipher{Block: c}, err | ||||
| } | ||||
|  | ||||
| // Encrypt with AES CFB | ||||
| func (c *CFBCipher) Encrypt(value []byte) ([]byte, error) { | ||||
| func (c *cfbCipher) Encrypt(value []byte) ([]byte, error) { | ||||
| 	ciphertext := make([]byte, aes.BlockSize+len(value)) | ||||
| 	iv := ciphertext[:aes.BlockSize] | ||||
| 	if _, err := io.ReadFull(rand.Reader, iv); err != nil { | ||||
| @@ -88,7 +76,7 @@ func (c *CFBCipher) Encrypt(value []byte) ([]byte, error) { | ||||
| } | ||||
|  | ||||
| // Decrypt an AES CFB ciphertext | ||||
| func (c *CFBCipher) Decrypt(ciphertext []byte) ([]byte, error) { | ||||
| func (c *cfbCipher) Decrypt(ciphertext []byte) ([]byte, error) { | ||||
| 	if len(ciphertext) < aes.BlockSize { | ||||
| 		return nil, fmt.Errorf("encrypted value should be at least %d bytes, but is only %d bytes", aes.BlockSize, len(ciphertext)) | ||||
| 	} | ||||
| @@ -101,17 +89,7 @@ func (c *CFBCipher) Decrypt(ciphertext []byte) ([]byte, error) { | ||||
| 	return plaintext, nil | ||||
| } | ||||
|  | ||||
| // EncryptInto returns an error since the encrypted data is a []byte that isn't string cast-able | ||||
| func (c *CFBCipher) EncryptInto(s *string) error { | ||||
| 	return fmt.Errorf("CFBCipher is not a string->string compatible cipher") | ||||
| } | ||||
|  | ||||
| // EncryptInto returns an error since the encrypted data needs to be a []byte | ||||
| func (c *CFBCipher) DecryptInto(s *string) error { | ||||
| 	return fmt.Errorf("CFBCipher is not a string->string compatible cipher") | ||||
| } | ||||
|  | ||||
| type GCMCipher struct { | ||||
| type gcmCipher struct { | ||||
| 	cipher.Block | ||||
| } | ||||
|  | ||||
| @@ -121,11 +99,11 @@ func NewGCMCipher(secret []byte) (Cipher, error) { | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	return &GCMCipher{Block: c}, err | ||||
| 	return &gcmCipher{Block: c}, err | ||||
| } | ||||
|  | ||||
| // Encrypt with AES GCM on raw bytes | ||||
| func (c *GCMCipher) Encrypt(value []byte) ([]byte, error) { | ||||
| func (c *gcmCipher) Encrypt(value []byte) ([]byte, error) { | ||||
| 	gcm, err := cipher.NewGCM(c.Block) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| @@ -141,7 +119,7 @@ func (c *GCMCipher) Encrypt(value []byte) ([]byte, error) { | ||||
| } | ||||
|  | ||||
| // Decrypt an AES GCM ciphertext | ||||
| func (c *GCMCipher) Decrypt(ciphertext []byte) ([]byte, error) { | ||||
| func (c *gcmCipher) Decrypt(ciphertext []byte) ([]byte, error) { | ||||
| 	gcm, err := cipher.NewGCM(c.Block) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| @@ -156,30 +134,3 @@ func (c *GCMCipher) Decrypt(ciphertext []byte) ([]byte, error) { | ||||
| 	} | ||||
| 	return plaintext, nil | ||||
| } | ||||
|  | ||||
| // EncryptInto returns an error since the encrypted data is a []byte that isn't string cast-able | ||||
| func (c *GCMCipher) EncryptInto(s *string) error { | ||||
| 	return fmt.Errorf("CFBCipher is not a string->string compatible cipher") | ||||
| } | ||||
|  | ||||
| // EncryptInto returns an error since the encrypted data needs to be a []byte | ||||
| func (c *GCMCipher) DecryptInto(s *string) error { | ||||
| 	return fmt.Errorf("CFBCipher is not a string->string compatible cipher") | ||||
| } | ||||
|  | ||||
| // codecFunc is a function that takes a string and encodes/decodes it | ||||
| type codecFunc func([]byte) ([]byte, error) | ||||
|  | ||||
| func into(f codecFunc, s *string) error { | ||||
| 	// Do not encrypt/decrypt nil or empty strings | ||||
| 	if s == nil || *s == "" { | ||||
| 		return nil | ||||
| 	} | ||||
|  | ||||
| 	d, err := f([]byte(*s)) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	*s = string(d) | ||||
| 	return nil | ||||
| } | ||||
|   | ||||
| @@ -5,7 +5,6 @@ import ( | ||||
| 	"encoding/base64" | ||||
| 	"fmt" | ||||
| 	"io" | ||||
| 	mathrand "math/rand" | ||||
| 	"testing" | ||||
|  | ||||
| 	"github.com/stretchr/testify/assert" | ||||
| @@ -118,80 +117,6 @@ func runEncryptAndDecrypt(t *testing.T, c Cipher, dataSize int) { | ||||
| 	assert.NotEqual(t, encrypted, decrypted) | ||||
| } | ||||
|  | ||||
| func TestEncryptIntoAndDecryptInto(t *testing.T) { | ||||
| 	// Test our 2 cipher types | ||||
| 	cipherInits := map[string]func([]byte) (Cipher, error){ | ||||
| 		"CFB": NewCFBCipher, | ||||
| 		"GCM": NewGCMCipher, | ||||
| 	} | ||||
| 	for name, initCipher := range cipherInits { | ||||
| 		t.Run(name, func(t *testing.T) { | ||||
| 			// Test all 3 valid AES sizes | ||||
| 			for _, secretSize := range []int{16, 24, 32} { | ||||
| 				t.Run(fmt.Sprintf("%d", secretSize), func(t *testing.T) { | ||||
| 					secret := make([]byte, secretSize) | ||||
| 					_, err := io.ReadFull(rand.Reader, secret) | ||||
| 					assert.Equal(t, nil, err) | ||||
|  | ||||
| 					// Test Standard & Base64 wrapped | ||||
| 					cstd, err := initCipher(secret) | ||||
| 					assert.Equal(t, nil, err) | ||||
|  | ||||
| 					cb64, err := NewBase64Cipher(initCipher, secret) | ||||
| 					assert.Equal(t, nil, err) | ||||
|  | ||||
| 					ciphers := map[string]Cipher{ | ||||
| 						"Standard": cstd, | ||||
| 						"Base64":   cb64, | ||||
| 					} | ||||
|  | ||||
| 					for cName, c := range ciphers { | ||||
| 						// Check no errors with empty or nil strings | ||||
| 						if cName == "Base64" { | ||||
| 							empty := "" | ||||
| 							assert.Equal(t, nil, c.EncryptInto(&empty)) | ||||
| 							assert.Equal(t, nil, c.DecryptInto(&empty)) | ||||
| 							assert.Equal(t, nil, c.EncryptInto(nil)) | ||||
| 							assert.Equal(t, nil, c.DecryptInto(nil)) | ||||
| 						} | ||||
|  | ||||
| 						t.Run(cName, func(t *testing.T) { | ||||
| 							// Test various sizes sessions might be | ||||
| 							for _, dataSize := range []int{10, 100, 1000, 5000, 10000} { | ||||
| 								t.Run(fmt.Sprintf("%d", dataSize), func(t *testing.T) { | ||||
| 									runEncryptIntoAndDecryptInto(t, c, cName, dataSize) | ||||
| 								}) | ||||
| 							} | ||||
| 						}) | ||||
| 					} | ||||
| 				}) | ||||
| 			} | ||||
| 		}) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func runEncryptIntoAndDecryptInto(t *testing.T, c Cipher, cipherType string, dataSize int) { | ||||
| 	const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" | ||||
| 	b := make([]byte, dataSize) | ||||
| 	for i := range b { | ||||
| 		b[i] = charset[mathrand.Intn(len(charset))] | ||||
| 	} | ||||
| 	data := string(b) | ||||
| 	originalData := data | ||||
|  | ||||
| 	// Base64 is the only cipher that supports string->string Encrypt/Decrypt Into methods | ||||
| 	if cipherType == "Base64" { | ||||
| 		assert.Equal(t, nil, c.EncryptInto(&data)) | ||||
| 		assert.NotEqual(t, originalData, data) | ||||
|  | ||||
| 		assert.Equal(t, nil, c.DecryptInto(&data)) | ||||
| 		assert.Equal(t, originalData, data) | ||||
| 	} else { | ||||
| 		assert.NotEqual(t, nil, c.EncryptInto(&data)) | ||||
| 		assert.NotEqual(t, nil, c.DecryptInto(&data)) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestDecryptCFBWrongSecret(t *testing.T) { | ||||
| 	secret1 := []byte("0123456789abcdefghijklmnopqrstuv") | ||||
| 	secret2 := []byte("9876543210abcdefghijklmnopqrstuv") | ||||
|   | ||||
		Reference in New Issue
	
	Block a user