From 19796275347540d561a1585d5fadab89f8220fa8 Mon Sep 17 00:00:00 2001 From: Nick Meves Date: Thu, 4 Jun 2020 14:39:31 -0700 Subject: [PATCH] Move Encrypt/Decrypt Into helper to session_state.go This helper method is only applicable for Base64 wrapped encryption since it operated on string -> string primarily. It wouldn't be used for pure CFB/GCM ciphers. After a messagePack session refactor, this method would further only be used for legacy session compatibility - making its placement in cipher.go not ideal. --- pkg/apis/sessions/session_state.go | 25 +++++- pkg/apis/sessions/session_state_test.go | 101 +++++++++++++++++------- pkg/encryption/cipher.go | 73 +++-------------- pkg/encryption/cipher_test.go | 75 ------------------ 4 files changed, 105 insertions(+), 169 deletions(-) diff --git a/pkg/apis/sessions/session_state.go b/pkg/apis/sessions/session_state.go index 24377c4a..44b91bd2 100644 --- a/pkg/apis/sessions/session_state.go +++ b/pkg/apis/sessions/session_state.go @@ -77,7 +77,7 @@ func (s *SessionState) EncodeSessionState(c encryption.Cipher) (string, error) { &ss.IDToken, &ss.RefreshToken, } { - err := c.EncryptInto(s) + err := into(s, c.Encrypt) if err != nil { return "", err } @@ -106,13 +106,13 @@ func DecodeSessionState(v string, c encryption.Cipher) (*SessionState, error) { } else { // Backward compatibility with using unencrypted Email or User // Decryption errors will leave original string - err = c.DecryptInto(&ss.Email) + err = into(&ss.Email, c.Decrypt) if err == nil { if !utf8.ValidString(ss.Email) { return nil, errors.New("invalid value for decrypted email") } } - err = c.DecryptInto(&ss.User) + err = into(&ss.User, c.Decrypt) if err == nil { if !utf8.ValidString(ss.User) { return nil, errors.New("invalid value for decrypted user") @@ -125,7 +125,7 @@ func DecodeSessionState(v string, c encryption.Cipher) (*SessionState, error) { &ss.IDToken, &ss.RefreshToken, } { - err := c.DecryptInto(s) + err := into(s, c.Decrypt) if err != nil { return nil, err } @@ -133,3 +133,20 @@ func DecodeSessionState(v string, c encryption.Cipher) (*SessionState, error) { } return &ss, nil } + +// codecFunc is a function that takes a []byte and encodes/decodes it +type codecFunc func([]byte) ([]byte, error) + +func into(s *string, f codecFunc) error { + // Do not encrypt/decrypt nil or empty strings + if s == nil || *s == "" { + return nil + } + + d, err := f([]byte(*s)) + if err != nil { + return err + } + *s = string(d) + return nil +} diff --git a/pkg/apis/sessions/session_state_test.go b/pkg/apis/sessions/session_state_test.go index d48ec502..3e9554c5 100644 --- a/pkg/apis/sessions/session_state_test.go +++ b/pkg/apis/sessions/session_state_test.go @@ -1,11 +1,13 @@ -package sessions_test +package sessions import ( + "crypto/rand" "fmt" + "io" + mathrand "math/rand" "testing" "time" - "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions" "github.com/oauth2-proxy/oauth2-proxy/pkg/encryption" "github.com/stretchr/testify/assert" ) @@ -26,7 +28,7 @@ func TestSessionStateSerialization(t *testing.T) { assert.Equal(t, nil, err) c2, err := newTestCipher([]byte(altSecret)) assert.Equal(t, nil, err) - s := &sessions.SessionState{ + s := &SessionState{ Email: "user@domain.com", PreferredUsername: "user", AccessToken: "token1234", @@ -38,7 +40,7 @@ func TestSessionStateSerialization(t *testing.T) { encoded, err := s.EncodeSessionState(c) assert.Equal(t, nil, err) - ss, err := sessions.DecodeSessionState(encoded, c) + ss, err := DecodeSessionState(encoded, c) t.Logf("%#v", ss) assert.Equal(t, nil, err) assert.Equal(t, "", ss.User) @@ -51,7 +53,7 @@ func TestSessionStateSerialization(t *testing.T) { assert.Equal(t, s.RefreshToken, ss.RefreshToken) // ensure a different cipher can't decode properly (ie: it gets gibberish) - ss, err = sessions.DecodeSessionState(encoded, c2) + ss, err = DecodeSessionState(encoded, c2) t.Logf("%#v", ss) assert.NotEqual(t, nil, err) } @@ -61,7 +63,7 @@ func TestSessionStateSerializationWithUser(t *testing.T) { assert.Equal(t, nil, err) c2, err := newTestCipher([]byte(altSecret)) assert.Equal(t, nil, err) - s := &sessions.SessionState{ + s := &SessionState{ User: "just-user", PreferredUsername: "ju", Email: "user@domain.com", @@ -73,7 +75,7 @@ func TestSessionStateSerializationWithUser(t *testing.T) { encoded, err := s.EncodeSessionState(c) assert.Equal(t, nil, err) - ss, err := sessions.DecodeSessionState(encoded, c) + ss, err := DecodeSessionState(encoded, c) t.Logf("%#v", ss) assert.Equal(t, nil, err) assert.Equal(t, s.User, ss.User) @@ -85,13 +87,13 @@ func TestSessionStateSerializationWithUser(t *testing.T) { assert.Equal(t, s.RefreshToken, ss.RefreshToken) // ensure a different cipher can't decode properly (ie: it gets gibberish) - ss, err = sessions.DecodeSessionState(encoded, c2) + ss, err = DecodeSessionState(encoded, c2) t.Logf("%#v", ss) assert.NotEqual(t, nil, err) } func TestSessionStateSerializationNoCipher(t *testing.T) { - s := &sessions.SessionState{ + s := &SessionState{ Email: "user@domain.com", PreferredUsername: "user", AccessToken: "token1234", @@ -103,7 +105,7 @@ func TestSessionStateSerializationNoCipher(t *testing.T) { assert.Equal(t, nil, err) // only email should have been serialized - ss, err := sessions.DecodeSessionState(encoded, nil) + ss, err := DecodeSessionState(encoded, nil) assert.Equal(t, nil, err) assert.Equal(t, "", ss.User) assert.Equal(t, s.Email, ss.Email) @@ -113,7 +115,7 @@ func TestSessionStateSerializationNoCipher(t *testing.T) { } func TestSessionStateSerializationNoCipherWithUser(t *testing.T) { - s := &sessions.SessionState{ + s := &SessionState{ User: "just-user", Email: "user@domain.com", PreferredUsername: "user", @@ -126,7 +128,7 @@ func TestSessionStateSerializationNoCipherWithUser(t *testing.T) { assert.Equal(t, nil, err) // only email should have been serialized - ss, err := sessions.DecodeSessionState(encoded, nil) + ss, err := DecodeSessionState(encoded, nil) assert.Equal(t, nil, err) assert.Equal(t, s.User, ss.User) assert.Equal(t, s.Email, ss.Email) @@ -136,18 +138,18 @@ func TestSessionStateSerializationNoCipherWithUser(t *testing.T) { } func TestExpired(t *testing.T) { - s := &sessions.SessionState{ExpiresOn: timePtr(time.Now().Add(time.Duration(-1) * time.Minute))} + s := &SessionState{ExpiresOn: timePtr(time.Now().Add(time.Duration(-1) * time.Minute))} assert.Equal(t, true, s.IsExpired()) - s = &sessions.SessionState{ExpiresOn: timePtr(time.Now().Add(time.Duration(1) * time.Minute))} + s = &SessionState{ExpiresOn: timePtr(time.Now().Add(time.Duration(1) * time.Minute))} assert.Equal(t, false, s.IsExpired()) - s = &sessions.SessionState{} + s = &SessionState{} assert.Equal(t, false, s.IsExpired()) } type testCase struct { - sessions.SessionState + SessionState Encoded string Cipher encryption.Cipher Error bool @@ -163,14 +165,14 @@ func TestEncodeSessionState(t *testing.T) { testCases := []testCase{ { - SessionState: sessions.SessionState{ + SessionState: SessionState{ Email: "user@domain.com", User: "just-user", }, Encoded: `{"Email":"user@domain.com","User":"just-user"}`, }, { - SessionState: sessions.SessionState{ + SessionState: SessionState{ Email: "user@domain.com", User: "just-user", AccessToken: "token1234", @@ -185,7 +187,7 @@ func TestEncodeSessionState(t *testing.T) { for i, tc := range testCases { encoded, err := tc.EncodeSessionState(tc.Cipher) - t.Logf("i:%d Encoded:%#vsessions.SessionState:%#v Error:%#v", i, encoded, tc.SessionState, err) + t.Logf("i:%d Encoded:%#vSessionState:%#v Error:%#v", i, encoded, tc.SessionState, err) if tc.Error { assert.Error(t, err) assert.Empty(t, encoded) @@ -210,34 +212,34 @@ func TestDecodeSessionState(t *testing.T) { testCases := []testCase{ { - SessionState: sessions.SessionState{ + SessionState: SessionState{ Email: "user@domain.com", User: "just-user", }, Encoded: `{"Email":"user@domain.com","User":"just-user"}`, }, { - SessionState: sessions.SessionState{ + SessionState: SessionState{ Email: "user@domain.com", User: "", }, Encoded: `{"Email":"user@domain.com"}`, }, { - SessionState: sessions.SessionState{ + SessionState: SessionState{ User: "just-user", }, Encoded: `{"User":"just-user"}`, }, { - SessionState: sessions.SessionState{ + SessionState: SessionState{ Email: "user@domain.com", User: "just-user", }, Encoded: fmt.Sprintf(`{"Email":"user@domain.com","User":"just-user","AccessToken":"I6s+ml+/MldBMgHIiC35BTKTh57skGX24w==","IDToken":"xojNdyyjB1HgYWh6XMtXY/Ph5eCVxa1cNsklJw==","RefreshToken":"qEX0x6RmASxo4dhlBG6YuRs9Syn/e9sHu/+K","CreatedAt":%s,"ExpiresOn":%s}`, createdString, eString), }, { - SessionState: sessions.SessionState{ + SessionState: SessionState{ Email: "user@domain.com", User: "just-user", AccessToken: "token1234", @@ -250,7 +252,7 @@ func TestDecodeSessionState(t *testing.T) { Cipher: c, }, { - SessionState: sessions.SessionState{ + SessionState: SessionState{ Email: "user@domain.com", User: "just-user", }, @@ -268,7 +270,7 @@ func TestDecodeSessionState(t *testing.T) { Error: true, }, { - SessionState: sessions.SessionState{ + SessionState: SessionState{ Email: "user@domain.com", User: "YmFzZTY0LWVuY29kZWQtdXNlcgo=", // Base64 encoding of base64-encoded-user }, @@ -278,8 +280,8 @@ func TestDecodeSessionState(t *testing.T) { } for i, tc := range testCases { - ss, err := sessions.DecodeSessionState(tc.Encoded, tc.Cipher) - t.Logf("i:%d Encoded:%#vsessions.SessionState:%#v Error:%#v", i, tc.Encoded, ss, err) + ss, err := DecodeSessionState(tc.Encoded, tc.Cipher) + t.Logf("i:%d Encoded:%#vSessionState:%#v Error:%#v", i, tc.Encoded, ss, err) if tc.Error { assert.Error(t, err) assert.Nil(t, ss) @@ -301,7 +303,7 @@ func TestDecodeSessionState(t *testing.T) { } func TestSessionStateAge(t *testing.T) { - ss := &sessions.SessionState{} + ss := &SessionState{} // Created at unset so should be 0 assert.Equal(t, time.Duration(0), ss.Age()) @@ -310,3 +312,44 @@ func TestSessionStateAge(t *testing.T) { ss.CreatedAt = timePtr(time.Now().Add(-1 * time.Hour)) assert.Equal(t, time.Hour, ss.Age().Round(time.Minute)) } + +func TestIntoEncryptAndIntoDecrypt(t *testing.T) { + const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" + + // Test all 3 valid AES sizes + for _, secretSize := range []int{16, 24, 32} { + t.Run(fmt.Sprintf("%d", secretSize), func(t *testing.T) { + secret := make([]byte, secretSize) + _, err := io.ReadFull(rand.Reader, secret) + assert.Equal(t, nil, err) + + c, err := newTestCipher(secret) + assert.NoError(t, err) + + // Check no errors with empty or nil strings + empty := "" + assert.Equal(t, nil, into(&empty, c.Encrypt)) + assert.Equal(t, nil, into(&empty, c.Decrypt)) + assert.Equal(t, nil, into(nil, c.Encrypt)) + assert.Equal(t, nil, into(nil, c.Decrypt)) + + // Test various sizes tokens might be + for _, dataSize := range []int{10, 100, 1000, 5000, 10000} { + t.Run(fmt.Sprintf("%d", dataSize), func(t *testing.T) { + b := make([]byte, dataSize) + for i := range b { + b[i] = charset[mathrand.Intn(len(charset))] + } + data := string(b) + originalData := data + + assert.Equal(t, nil, into(&data, c.Encrypt)) + assert.NotEqual(t, originalData, data) + + assert.Equal(t, nil, into(&data, c.Decrypt)) + assert.Equal(t, originalData, data) + }) + } + }) + } +} diff --git a/pkg/encryption/cipher.go b/pkg/encryption/cipher.go index 34499ba6..c1158b5c 100644 --- a/pkg/encryption/cipher.go +++ b/pkg/encryption/cipher.go @@ -13,11 +13,9 @@ import ( type Cipher interface { Encrypt(value []byte) ([]byte, error) Decrypt(ciphertext []byte) ([]byte, error) - EncryptInto(s *string) error - DecryptInto(s *string) error } -type Base64Cipher struct { +type base64Cipher struct { Cipher Cipher } @@ -28,11 +26,11 @@ func NewBase64Cipher(initCipher func([]byte) (Cipher, error), secret []byte) (Ci if err != nil { return nil, err } - return &Base64Cipher{Cipher: c}, nil + return &base64Cipher{Cipher: c}, nil } // Encrypt encrypts a value with the embedded Cipher & Base64 encodes it -func (c *Base64Cipher) Encrypt(value []byte) ([]byte, error) { +func (c *base64Cipher) Encrypt(value []byte) ([]byte, error) { encrypted, err := c.Cipher.Encrypt(value) if err != nil { return nil, err @@ -42,7 +40,7 @@ func (c *Base64Cipher) Encrypt(value []byte) ([]byte, error) { } // Decrypt Base64 decodes a value & decrypts it with the embedded Cipher -func (c *Base64Cipher) Decrypt(ciphertext []byte) ([]byte, error) { +func (c *base64Cipher) Decrypt(ciphertext []byte) ([]byte, error) { encrypted, err := base64.StdEncoding.DecodeString(string(ciphertext)) if err != nil { return nil, fmt.Errorf("failed to base64 decode value %s", err) @@ -51,17 +49,7 @@ func (c *Base64Cipher) Decrypt(ciphertext []byte) ([]byte, error) { return c.Cipher.Decrypt(encrypted) } -// EncryptInto encrypts the value and stores it back in the string pointer -func (c *Base64Cipher) EncryptInto(s *string) error { - return into(c.Encrypt, s) -} - -// DecryptInto decrypts the value and stores it back in the string pointer -func (c *Base64Cipher) DecryptInto(s *string) error { - return into(c.Decrypt, s) -} - -type CFBCipher struct { +type cfbCipher struct { cipher.Block } @@ -71,11 +59,11 @@ func NewCFBCipher(secret []byte) (Cipher, error) { if err != nil { return nil, err } - return &CFBCipher{Block: c}, err + return &cfbCipher{Block: c}, err } // Encrypt with AES CFB -func (c *CFBCipher) Encrypt(value []byte) ([]byte, error) { +func (c *cfbCipher) Encrypt(value []byte) ([]byte, error) { ciphertext := make([]byte, aes.BlockSize+len(value)) iv := ciphertext[:aes.BlockSize] if _, err := io.ReadFull(rand.Reader, iv); err != nil { @@ -88,7 +76,7 @@ func (c *CFBCipher) Encrypt(value []byte) ([]byte, error) { } // Decrypt an AES CFB ciphertext -func (c *CFBCipher) Decrypt(ciphertext []byte) ([]byte, error) { +func (c *cfbCipher) Decrypt(ciphertext []byte) ([]byte, error) { if len(ciphertext) < aes.BlockSize { return nil, fmt.Errorf("encrypted value should be at least %d bytes, but is only %d bytes", aes.BlockSize, len(ciphertext)) } @@ -101,17 +89,7 @@ func (c *CFBCipher) Decrypt(ciphertext []byte) ([]byte, error) { return plaintext, nil } -// EncryptInto returns an error since the encrypted data is a []byte that isn't string cast-able -func (c *CFBCipher) EncryptInto(s *string) error { - return fmt.Errorf("CFBCipher is not a string->string compatible cipher") -} - -// EncryptInto returns an error since the encrypted data needs to be a []byte -func (c *CFBCipher) DecryptInto(s *string) error { - return fmt.Errorf("CFBCipher is not a string->string compatible cipher") -} - -type GCMCipher struct { +type gcmCipher struct { cipher.Block } @@ -121,11 +99,11 @@ func NewGCMCipher(secret []byte) (Cipher, error) { if err != nil { return nil, err } - return &GCMCipher{Block: c}, err + return &gcmCipher{Block: c}, err } // Encrypt with AES GCM on raw bytes -func (c *GCMCipher) Encrypt(value []byte) ([]byte, error) { +func (c *gcmCipher) Encrypt(value []byte) ([]byte, error) { gcm, err := cipher.NewGCM(c.Block) if err != nil { return nil, err @@ -141,7 +119,7 @@ func (c *GCMCipher) Encrypt(value []byte) ([]byte, error) { } // Decrypt an AES GCM ciphertext -func (c *GCMCipher) Decrypt(ciphertext []byte) ([]byte, error) { +func (c *gcmCipher) Decrypt(ciphertext []byte) ([]byte, error) { gcm, err := cipher.NewGCM(c.Block) if err != nil { return nil, err @@ -156,30 +134,3 @@ func (c *GCMCipher) Decrypt(ciphertext []byte) ([]byte, error) { } return plaintext, nil } - -// EncryptInto returns an error since the encrypted data is a []byte that isn't string cast-able -func (c *GCMCipher) EncryptInto(s *string) error { - return fmt.Errorf("CFBCipher is not a string->string compatible cipher") -} - -// EncryptInto returns an error since the encrypted data needs to be a []byte -func (c *GCMCipher) DecryptInto(s *string) error { - return fmt.Errorf("CFBCipher is not a string->string compatible cipher") -} - -// codecFunc is a function that takes a string and encodes/decodes it -type codecFunc func([]byte) ([]byte, error) - -func into(f codecFunc, s *string) error { - // Do not encrypt/decrypt nil or empty strings - if s == nil || *s == "" { - return nil - } - - d, err := f([]byte(*s)) - if err != nil { - return err - } - *s = string(d) - return nil -} diff --git a/pkg/encryption/cipher_test.go b/pkg/encryption/cipher_test.go index e80986d5..b552e70c 100644 --- a/pkg/encryption/cipher_test.go +++ b/pkg/encryption/cipher_test.go @@ -5,7 +5,6 @@ import ( "encoding/base64" "fmt" "io" - mathrand "math/rand" "testing" "github.com/stretchr/testify/assert" @@ -118,80 +117,6 @@ func runEncryptAndDecrypt(t *testing.T, c Cipher, dataSize int) { assert.NotEqual(t, encrypted, decrypted) } -func TestEncryptIntoAndDecryptInto(t *testing.T) { - // Test our 2 cipher types - cipherInits := map[string]func([]byte) (Cipher, error){ - "CFB": NewCFBCipher, - "GCM": NewGCMCipher, - } - for name, initCipher := range cipherInits { - t.Run(name, func(t *testing.T) { - // Test all 3 valid AES sizes - for _, secretSize := range []int{16, 24, 32} { - t.Run(fmt.Sprintf("%d", secretSize), func(t *testing.T) { - secret := make([]byte, secretSize) - _, err := io.ReadFull(rand.Reader, secret) - assert.Equal(t, nil, err) - - // Test Standard & Base64 wrapped - cstd, err := initCipher(secret) - assert.Equal(t, nil, err) - - cb64, err := NewBase64Cipher(initCipher, secret) - assert.Equal(t, nil, err) - - ciphers := map[string]Cipher{ - "Standard": cstd, - "Base64": cb64, - } - - for cName, c := range ciphers { - // Check no errors with empty or nil strings - if cName == "Base64" { - empty := "" - assert.Equal(t, nil, c.EncryptInto(&empty)) - assert.Equal(t, nil, c.DecryptInto(&empty)) - assert.Equal(t, nil, c.EncryptInto(nil)) - assert.Equal(t, nil, c.DecryptInto(nil)) - } - - t.Run(cName, func(t *testing.T) { - // Test various sizes sessions might be - for _, dataSize := range []int{10, 100, 1000, 5000, 10000} { - t.Run(fmt.Sprintf("%d", dataSize), func(t *testing.T) { - runEncryptIntoAndDecryptInto(t, c, cName, dataSize) - }) - } - }) - } - }) - } - }) - } -} - -func runEncryptIntoAndDecryptInto(t *testing.T, c Cipher, cipherType string, dataSize int) { - const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" - b := make([]byte, dataSize) - for i := range b { - b[i] = charset[mathrand.Intn(len(charset))] - } - data := string(b) - originalData := data - - // Base64 is the only cipher that supports string->string Encrypt/Decrypt Into methods - if cipherType == "Base64" { - assert.Equal(t, nil, c.EncryptInto(&data)) - assert.NotEqual(t, originalData, data) - - assert.Equal(t, nil, c.DecryptInto(&data)) - assert.Equal(t, originalData, data) - } else { - assert.NotEqual(t, nil, c.EncryptInto(&data)) - assert.NotEqual(t, nil, c.DecryptInto(&data)) - } -} - func TestDecryptCFBWrongSecret(t *testing.T) { secret1 := []byte("0123456789abcdefghijklmnopqrstuv") secret2 := []byte("9876543210abcdefghijklmnopqrstuv")