You've already forked oauth2-proxy
mirror of
https://github.com/oauth2-proxy/oauth2-proxy.git
synced 2025-12-01 22:51:45 +02:00
Reduce SessionState size better with MessagePack + LZ4 (#632)
* Encode sessions with MsgPack + LZ4 Assumes ciphers are now mandatory per #414. Cookie & Redis sessions can fallback to V5 style JSON in error cases. TODO: session_state.go unit tests & new unit tests for Legacy fallback scenarios. * Only compress encoded sessions with Cookie Store * Cleanup msgpack + lz4 error handling * Change NewBase64Cipher to take in an existing Cipher * Add msgpack & lz4 session state tests * Add required options for oauthproxy tests More aggressively assert.NoError on all validation.Validate(opts) calls to enforce legal options in all our tests. Add additional NoError checks wherever error return values were ignored. * Remove support for uncompressed session state fields * Improve error verbosity & add session state tests * Ensure all marshalled sessions are valid Invalid CFB decryptions can result in garbage data that 1/100 times might cause message pack unmarshal to not fail and instead return an empty session. This adds more rigor to make sure legacy sessions cause appropriate errors. * Add tests for legacy V5 session decoding Refactor common legacy JSON test cases to a legacy helpers area under session store tests. * Make ValidateSession a struct method & add CHANGELOG entry * Improve SessionState error & comments verbosity * Move legacy session test helpers to sessions pkg Placing these helpers under the sessions pkg removed all the circular import uses in housing it under the session store area. * Improve SignatureAuthenticator test helper formatting * Make redis.legacyV5DecodeSession internal * Make LegacyV5TestCase test table public for linter
This commit is contained in:
87
pkg/apis/sessions/legacy_v5_tester.go
Normal file
87
pkg/apis/sessions/legacy_v5_tester.go
Normal file
@@ -0,0 +1,87 @@
|
||||
package sessions
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/oauth2-proxy/oauth2-proxy/pkg/encryption"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// LegacyV5TestCase provides V5 JSON based test cases for legacy fallback code
|
||||
type LegacyV5TestCase struct {
|
||||
Input string
|
||||
Error bool
|
||||
Output *SessionState
|
||||
}
|
||||
|
||||
// CreateLegacyV5TestCases makes various V5 JSON sessions as test cases
|
||||
//
|
||||
// Used for `apis/sessions/session_state_test.go` & `sessions/redis/redis_store_test.go`
|
||||
//
|
||||
// TODO: Remove when this is deprecated (likely V7)
|
||||
func CreateLegacyV5TestCases(t *testing.T) (map[string]LegacyV5TestCase, encryption.Cipher, encryption.Cipher) {
|
||||
const secret = "0123456789abcdefghijklmnopqrstuv"
|
||||
|
||||
created := time.Now()
|
||||
createdJSON, err := created.MarshalJSON()
|
||||
assert.NoError(t, err)
|
||||
createdString := string(createdJSON)
|
||||
e := time.Now().Add(time.Duration(1) * time.Hour)
|
||||
eJSON, err := e.MarshalJSON()
|
||||
assert.NoError(t, err)
|
||||
eString := string(eJSON)
|
||||
|
||||
cfbCipher, err := encryption.NewCFBCipher([]byte(secret))
|
||||
assert.NoError(t, err)
|
||||
legacyCipher := encryption.NewBase64Cipher(cfbCipher)
|
||||
|
||||
testCases := map[string]LegacyV5TestCase{
|
||||
"User & email unencrypted": {
|
||||
Input: `{"Email":"user@domain.com","User":"just-user"}`,
|
||||
Error: true,
|
||||
},
|
||||
"Only email unencrypted": {
|
||||
Input: `{"Email":"user@domain.com"}`,
|
||||
Error: true,
|
||||
},
|
||||
"Just user unencrypted": {
|
||||
Input: `{"User":"just-user"}`,
|
||||
Error: true,
|
||||
},
|
||||
"User and Email unencrypted while rest is encrypted": {
|
||||
Input: fmt.Sprintf(`{"Email":"user@domain.com","User":"just-user","AccessToken":"I6s+ml+/MldBMgHIiC35BTKTh57skGX24w==","IDToken":"xojNdyyjB1HgYWh6XMtXY/Ph5eCVxa1cNsklJw==","RefreshToken":"qEX0x6RmASxo4dhlBG6YuRs9Syn/e9sHu/+K","CreatedAt":%s,"ExpiresOn":%s}`, createdString, eString),
|
||||
Error: true,
|
||||
},
|
||||
"Full session with cipher": {
|
||||
Input: fmt.Sprintf(`{"Email":"FsKKYrTWZWrxSOAqA/fTNAUZS5QWCqOBjuAbBlbVOw==","User":"rT6JP3dxQhxUhkWrrd7yt6c1mDVyQCVVxw==","AccessToken":"I6s+ml+/MldBMgHIiC35BTKTh57skGX24w==","IDToken":"xojNdyyjB1HgYWh6XMtXY/Ph5eCVxa1cNsklJw==","RefreshToken":"qEX0x6RmASxo4dhlBG6YuRs9Syn/e9sHu/+K","CreatedAt":%s,"ExpiresOn":%s}`, createdString, eString),
|
||||
Output: &SessionState{
|
||||
Email: "user@domain.com",
|
||||
User: "just-user",
|
||||
AccessToken: "token1234",
|
||||
IDToken: "rawtoken1234",
|
||||
CreatedAt: &created,
|
||||
ExpiresOn: &e,
|
||||
RefreshToken: "refresh4321",
|
||||
},
|
||||
},
|
||||
"Minimal session encrypted with cipher": {
|
||||
Input: `{"Email":"EGTllJcOFC16b7LBYzLekaHAC5SMMSPdyUrg8hd25g==","User":"rT6JP3dxQhxUhkWrrd7yt6c1mDVyQCVVxw=="}`,
|
||||
Output: &SessionState{
|
||||
Email: "user@domain.com",
|
||||
User: "just-user",
|
||||
},
|
||||
},
|
||||
"Unencrypted User, Email and AccessToken": {
|
||||
Input: `{"Email":"user@domain.com","User":"just-user","AccessToken":"X"}`,
|
||||
Error: true,
|
||||
},
|
||||
"Unencrypted User, Email and IDToken": {
|
||||
Input: `{"Email":"user@domain.com","User":"just-user","IDToken":"XXXX"}`,
|
||||
Error: true,
|
||||
},
|
||||
}
|
||||
|
||||
return testCases, cfbCipher, legacyCipher
|
||||
}
|
||||
@@ -1,25 +1,30 @@
|
||||
package sessions
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/oauth2-proxy/oauth2-proxy/pkg/encryption"
|
||||
"github.com/pierrec/lz4"
|
||||
"github.com/vmihailenco/msgpack/v4"
|
||||
)
|
||||
|
||||
// SessionState is used to store information about the currently authenticated user session
|
||||
type SessionState struct {
|
||||
AccessToken string `json:",omitempty"`
|
||||
IDToken string `json:",omitempty"`
|
||||
CreatedAt *time.Time `json:",omitempty"`
|
||||
ExpiresOn *time.Time `json:",omitempty"`
|
||||
RefreshToken string `json:",omitempty"`
|
||||
Email string `json:",omitempty"`
|
||||
User string `json:",omitempty"`
|
||||
PreferredUsername string `json:",omitempty"`
|
||||
AccessToken string `json:",omitempty" msgpack:"at,omitempty"`
|
||||
IDToken string `json:",omitempty" msgpack:"it,omitempty"`
|
||||
CreatedAt *time.Time `json:",omitempty" msgpack:"ca,omitempty"`
|
||||
ExpiresOn *time.Time `json:",omitempty" msgpack:"eo,omitempty"`
|
||||
RefreshToken string `json:",omitempty" msgpack:"rt,omitempty"`
|
||||
Email string `json:",omitempty" msgpack:"e,omitempty"`
|
||||
User string `json:",omitempty" msgpack:"u,omitempty"`
|
||||
PreferredUsername string `json:",omitempty" msgpack:"pu,omitempty"`
|
||||
}
|
||||
|
||||
// IsExpired checks whether the session has expired
|
||||
@@ -59,78 +64,79 @@ func (s *SessionState) String() string {
|
||||
return o + "}"
|
||||
}
|
||||
|
||||
// EncodeSessionState returns string representation of the current session
|
||||
func (s *SessionState) EncodeSessionState(c encryption.Cipher) (string, error) {
|
||||
var ss SessionState
|
||||
if c == nil {
|
||||
// Store only Email and User when cipher is unavailable
|
||||
ss.Email = s.Email
|
||||
ss.User = s.User
|
||||
ss.PreferredUsername = s.PreferredUsername
|
||||
} else {
|
||||
ss = *s
|
||||
for _, s := range []*string{
|
||||
&ss.Email,
|
||||
&ss.User,
|
||||
&ss.PreferredUsername,
|
||||
&ss.AccessToken,
|
||||
&ss.IDToken,
|
||||
&ss.RefreshToken,
|
||||
} {
|
||||
err := into(s, c.Encrypt)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
// EncodeSessionState returns an encrypted, lz4 compressed, MessagePack encoded session
|
||||
func (s *SessionState) EncodeSessionState(c encryption.Cipher, compress bool) ([]byte, error) {
|
||||
packed, err := msgpack.Marshal(s)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error marshalling session state to msgpack: %w", err)
|
||||
}
|
||||
|
||||
if !compress {
|
||||
return c.Encrypt(packed)
|
||||
}
|
||||
|
||||
compressed, err := lz4Compress(packed)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return c.Encrypt(compressed)
|
||||
}
|
||||
|
||||
// DecodeSessionState decodes a LZ4 compressed MessagePack into a Session State
|
||||
func DecodeSessionState(data []byte, c encryption.Cipher, compressed bool) (*SessionState, error) {
|
||||
decrypted, err := c.Decrypt(data)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error decrypting the session state: %w", err)
|
||||
}
|
||||
|
||||
packed := decrypted
|
||||
if compressed {
|
||||
packed, err = lz4Decompress(decrypted)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
b, err := json.Marshal(ss)
|
||||
return string(b), err
|
||||
var ss SessionState
|
||||
err = msgpack.Unmarshal(packed, &ss)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error unmarshalling data to session state: %w", err)
|
||||
}
|
||||
|
||||
err = ss.validate()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &ss, nil
|
||||
}
|
||||
|
||||
// DecodeSessionState decodes the session cookie string into a SessionState
|
||||
func DecodeSessionState(v string, c encryption.Cipher) (*SessionState, error) {
|
||||
// LegacyV5DecodeSessionState decodes a legacy JSON session cookie string into a SessionState
|
||||
func LegacyV5DecodeSessionState(v string, c encryption.Cipher) (*SessionState, error) {
|
||||
var ss SessionState
|
||||
err := json.Unmarshal([]byte(v), &ss)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error unmarshalling session: %w", err)
|
||||
}
|
||||
|
||||
if c == nil {
|
||||
// Load only Email and User when cipher is unavailable
|
||||
ss = SessionState{
|
||||
Email: ss.Email,
|
||||
User: ss.User,
|
||||
PreferredUsername: ss.PreferredUsername,
|
||||
}
|
||||
} else {
|
||||
// Backward compatibility with using unencrypted Email or User
|
||||
// Decryption errors will leave original string
|
||||
err = into(&ss.Email, c.Decrypt)
|
||||
if err == nil {
|
||||
if !utf8.ValidString(ss.Email) {
|
||||
return nil, errors.New("invalid value for decrypted email")
|
||||
}
|
||||
}
|
||||
err = into(&ss.User, c.Decrypt)
|
||||
if err == nil {
|
||||
if !utf8.ValidString(ss.User) {
|
||||
return nil, errors.New("invalid value for decrypted user")
|
||||
}
|
||||
}
|
||||
|
||||
for _, s := range []*string{
|
||||
&ss.PreferredUsername,
|
||||
&ss.AccessToken,
|
||||
&ss.IDToken,
|
||||
&ss.RefreshToken,
|
||||
} {
|
||||
err := into(s, c.Decrypt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, s := range []*string{
|
||||
&ss.User,
|
||||
&ss.Email,
|
||||
&ss.PreferredUsername,
|
||||
&ss.AccessToken,
|
||||
&ss.IDToken,
|
||||
&ss.RefreshToken,
|
||||
} {
|
||||
err := into(s, c.Decrypt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
err = ss.validate()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &ss, nil
|
||||
}
|
||||
|
||||
@@ -150,3 +156,86 @@ func into(s *string, f codecFunc) error {
|
||||
*s = string(d)
|
||||
return nil
|
||||
}
|
||||
|
||||
// lz4Compress compresses with LZ4
|
||||
//
|
||||
// The Compress:Decompress ratio is 1:Many. LZ4 gives fastest decompress speeds
|
||||
// at the expense of greater compression compared to other compression
|
||||
// algorithms.
|
||||
func lz4Compress(payload []byte) ([]byte, error) {
|
||||
buf := new(bytes.Buffer)
|
||||
zw := lz4.NewWriter(nil)
|
||||
zw.Header = lz4.Header{
|
||||
BlockMaxSize: 65536,
|
||||
CompressionLevel: 0,
|
||||
}
|
||||
zw.Reset(buf)
|
||||
|
||||
reader := bytes.NewReader(payload)
|
||||
_, err := io.Copy(zw, reader)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error copying lz4 stream to buffer: %w", err)
|
||||
}
|
||||
err = zw.Close()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error closing lz4 writer: %w", err)
|
||||
}
|
||||
|
||||
compressed, err := ioutil.ReadAll(buf)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error reading lz4 buffer: %w", err)
|
||||
}
|
||||
|
||||
return compressed, nil
|
||||
}
|
||||
|
||||
// lz4Decompress decompresses with LZ4
|
||||
func lz4Decompress(compressed []byte) ([]byte, error) {
|
||||
reader := bytes.NewReader(compressed)
|
||||
buf := new(bytes.Buffer)
|
||||
zr := lz4.NewReader(nil)
|
||||
zr.Reset(reader)
|
||||
_, err := io.Copy(buf, zr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error copying lz4 stream to buffer: %w", err)
|
||||
}
|
||||
|
||||
payload, err := ioutil.ReadAll(buf)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error reading lz4 buffer: %w", err)
|
||||
}
|
||||
|
||||
return payload, nil
|
||||
}
|
||||
|
||||
// validate ensures the decoded session is non-empty and contains valid data
|
||||
//
|
||||
// Non-empty check is needed due to ensure the non-authenticated AES-CFB
|
||||
// decryption doesn't result in garbage data that collides with a valid
|
||||
// MessagePack header bytes (which MessagePack will unmarshal to an empty
|
||||
// default SessionState). <1% chance, but observed with random test data.
|
||||
//
|
||||
// UTF-8 check ensures the strings are valid and not raw bytes overloaded
|
||||
// into Latin-1 encoding. The occurs when legacy unencrypted fields are
|
||||
// decrypted with AES-CFB which results in random bytes.
|
||||
func (s *SessionState) validate() error {
|
||||
for _, field := range []string{
|
||||
s.User,
|
||||
s.Email,
|
||||
s.PreferredUsername,
|
||||
s.AccessToken,
|
||||
s.IDToken,
|
||||
s.RefreshToken,
|
||||
} {
|
||||
if !utf8.ValidString(field) {
|
||||
return errors.New("invalid non-UTF8 field in session")
|
||||
}
|
||||
}
|
||||
|
||||
empty := new(SessionState)
|
||||
if *s == *empty {
|
||||
return errors.New("invalid empty session unmarshalled")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -12,132 +12,11 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
const secret = "0123456789abcdefghijklmnopqrstuv"
|
||||
const altSecret = "0000000000abcdefghijklmnopqrstuv"
|
||||
|
||||
func timePtr(t time.Time) *time.Time {
|
||||
return &t
|
||||
}
|
||||
|
||||
func newTestCipher(secret []byte) (encryption.Cipher, error) {
|
||||
return encryption.NewBase64Cipher(encryption.NewCFBCipher, secret)
|
||||
}
|
||||
|
||||
func TestSessionStateSerialization(t *testing.T) {
|
||||
c, err := newTestCipher([]byte(secret))
|
||||
assert.Equal(t, nil, err)
|
||||
c2, err := newTestCipher([]byte(altSecret))
|
||||
assert.Equal(t, nil, err)
|
||||
s := &SessionState{
|
||||
Email: "user@domain.com",
|
||||
PreferredUsername: "user",
|
||||
AccessToken: "token1234",
|
||||
IDToken: "rawtoken1234",
|
||||
CreatedAt: timePtr(time.Now()),
|
||||
ExpiresOn: timePtr(time.Now().Add(time.Duration(1) * time.Hour)),
|
||||
RefreshToken: "refresh4321",
|
||||
}
|
||||
encoded, err := s.EncodeSessionState(c)
|
||||
assert.Equal(t, nil, err)
|
||||
|
||||
ss, err := DecodeSessionState(encoded, c)
|
||||
t.Logf("%#v", ss)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.Equal(t, "", ss.User)
|
||||
assert.Equal(t, s.Email, ss.Email)
|
||||
assert.Equal(t, s.PreferredUsername, ss.PreferredUsername)
|
||||
assert.Equal(t, s.AccessToken, ss.AccessToken)
|
||||
assert.Equal(t, s.IDToken, ss.IDToken)
|
||||
assert.Equal(t, s.CreatedAt.Unix(), ss.CreatedAt.Unix())
|
||||
assert.Equal(t, s.ExpiresOn.Unix(), ss.ExpiresOn.Unix())
|
||||
assert.Equal(t, s.RefreshToken, ss.RefreshToken)
|
||||
|
||||
// ensure a different cipher can't decode properly (ie: it gets gibberish)
|
||||
ss, err = DecodeSessionState(encoded, c2)
|
||||
t.Logf("%#v", ss)
|
||||
assert.NotEqual(t, nil, err)
|
||||
}
|
||||
|
||||
func TestSessionStateSerializationWithUser(t *testing.T) {
|
||||
c, err := newTestCipher([]byte(secret))
|
||||
assert.Equal(t, nil, err)
|
||||
c2, err := newTestCipher([]byte(altSecret))
|
||||
assert.Equal(t, nil, err)
|
||||
s := &SessionState{
|
||||
User: "just-user",
|
||||
PreferredUsername: "ju",
|
||||
Email: "user@domain.com",
|
||||
AccessToken: "token1234",
|
||||
CreatedAt: timePtr(time.Now()),
|
||||
ExpiresOn: timePtr(time.Now().Add(time.Duration(1) * time.Hour)),
|
||||
RefreshToken: "refresh4321",
|
||||
}
|
||||
encoded, err := s.EncodeSessionState(c)
|
||||
assert.Equal(t, nil, err)
|
||||
|
||||
ss, err := DecodeSessionState(encoded, c)
|
||||
t.Logf("%#v", ss)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.Equal(t, s.User, ss.User)
|
||||
assert.Equal(t, s.Email, ss.Email)
|
||||
assert.Equal(t, s.PreferredUsername, ss.PreferredUsername)
|
||||
assert.Equal(t, s.AccessToken, ss.AccessToken)
|
||||
assert.Equal(t, s.CreatedAt.Unix(), ss.CreatedAt.Unix())
|
||||
assert.Equal(t, s.ExpiresOn.Unix(), ss.ExpiresOn.Unix())
|
||||
assert.Equal(t, s.RefreshToken, ss.RefreshToken)
|
||||
|
||||
// ensure a different cipher can't decode properly (ie: it gets gibberish)
|
||||
ss, err = DecodeSessionState(encoded, c2)
|
||||
t.Logf("%#v", ss)
|
||||
assert.NotEqual(t, nil, err)
|
||||
}
|
||||
|
||||
func TestSessionStateSerializationNoCipher(t *testing.T) {
|
||||
s := &SessionState{
|
||||
Email: "user@domain.com",
|
||||
PreferredUsername: "user",
|
||||
AccessToken: "token1234",
|
||||
CreatedAt: timePtr(time.Now()),
|
||||
ExpiresOn: timePtr(time.Now().Add(time.Duration(1) * time.Hour)),
|
||||
RefreshToken: "refresh4321",
|
||||
}
|
||||
encoded, err := s.EncodeSessionState(nil)
|
||||
assert.Equal(t, nil, err)
|
||||
|
||||
// only email should have been serialized
|
||||
ss, err := DecodeSessionState(encoded, nil)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.Equal(t, "", ss.User)
|
||||
assert.Equal(t, s.Email, ss.Email)
|
||||
assert.Equal(t, s.PreferredUsername, ss.PreferredUsername)
|
||||
assert.Equal(t, "", ss.AccessToken)
|
||||
assert.Equal(t, "", ss.RefreshToken)
|
||||
}
|
||||
|
||||
func TestSessionStateSerializationNoCipherWithUser(t *testing.T) {
|
||||
s := &SessionState{
|
||||
User: "just-user",
|
||||
Email: "user@domain.com",
|
||||
PreferredUsername: "user",
|
||||
AccessToken: "token1234",
|
||||
CreatedAt: timePtr(time.Now()),
|
||||
ExpiresOn: timePtr(time.Now().Add(time.Duration(1) * time.Hour)),
|
||||
RefreshToken: "refresh4321",
|
||||
}
|
||||
encoded, err := s.EncodeSessionState(nil)
|
||||
assert.Equal(t, nil, err)
|
||||
|
||||
// only email should have been serialized
|
||||
ss, err := DecodeSessionState(encoded, nil)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.Equal(t, s.User, ss.User)
|
||||
assert.Equal(t, s.Email, ss.Email)
|
||||
assert.Equal(t, s.PreferredUsername, ss.PreferredUsername)
|
||||
assert.Equal(t, "", ss.AccessToken)
|
||||
assert.Equal(t, "", ss.RefreshToken)
|
||||
}
|
||||
|
||||
func TestExpired(t *testing.T) {
|
||||
func TestIsExpired(t *testing.T) {
|
||||
s := &SessionState{ExpiresOn: timePtr(time.Now().Add(time.Duration(-1) * time.Minute))}
|
||||
assert.Equal(t, true, s.IsExpired())
|
||||
|
||||
@@ -148,161 +27,7 @@ func TestExpired(t *testing.T) {
|
||||
assert.Equal(t, false, s.IsExpired())
|
||||
}
|
||||
|
||||
type testCase struct {
|
||||
SessionState
|
||||
Encoded string
|
||||
Cipher encryption.Cipher
|
||||
Error bool
|
||||
}
|
||||
|
||||
// TestEncodeSessionState tests EncodeSessionState with the test vector
|
||||
//
|
||||
// Currently only tests without cipher here because we have no way to mock
|
||||
// the random generator used in EncodeSessionState.
|
||||
func TestEncodeSessionState(t *testing.T) {
|
||||
c := time.Now()
|
||||
e := time.Now().Add(time.Duration(1) * time.Hour)
|
||||
|
||||
testCases := []testCase{
|
||||
{
|
||||
SessionState: SessionState{
|
||||
Email: "user@domain.com",
|
||||
User: "just-user",
|
||||
},
|
||||
Encoded: `{"Email":"user@domain.com","User":"just-user"}`,
|
||||
},
|
||||
{
|
||||
SessionState: SessionState{
|
||||
Email: "user@domain.com",
|
||||
User: "just-user",
|
||||
AccessToken: "token1234",
|
||||
IDToken: "rawtoken1234",
|
||||
CreatedAt: &c,
|
||||
ExpiresOn: &e,
|
||||
RefreshToken: "refresh4321",
|
||||
},
|
||||
Encoded: `{"Email":"user@domain.com","User":"just-user"}`,
|
||||
},
|
||||
}
|
||||
|
||||
for i, tc := range testCases {
|
||||
encoded, err := tc.EncodeSessionState(tc.Cipher)
|
||||
t.Logf("i:%d Encoded:%#vSessionState:%#v Error:%#v", i, encoded, tc.SessionState, err)
|
||||
if tc.Error {
|
||||
assert.Error(t, err)
|
||||
assert.Empty(t, encoded)
|
||||
continue
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
assert.JSONEq(t, tc.Encoded, encoded)
|
||||
}
|
||||
}
|
||||
|
||||
// TestDecodeSessionState testssessions.DecodeSessionState with the test vector
|
||||
func TestDecodeSessionState(t *testing.T) {
|
||||
created := time.Now()
|
||||
createdJSON, _ := created.MarshalJSON()
|
||||
createdString := string(createdJSON)
|
||||
e := time.Now().Add(time.Duration(1) * time.Hour)
|
||||
eJSON, _ := e.MarshalJSON()
|
||||
eString := string(eJSON)
|
||||
|
||||
c, err := newTestCipher([]byte(secret))
|
||||
assert.NoError(t, err)
|
||||
|
||||
testCases := []testCase{
|
||||
{
|
||||
SessionState: SessionState{
|
||||
Email: "user@domain.com",
|
||||
User: "just-user",
|
||||
},
|
||||
Encoded: `{"Email":"user@domain.com","User":"just-user"}`,
|
||||
},
|
||||
{
|
||||
SessionState: SessionState{
|
||||
Email: "user@domain.com",
|
||||
User: "",
|
||||
},
|
||||
Encoded: `{"Email":"user@domain.com"}`,
|
||||
},
|
||||
{
|
||||
SessionState: SessionState{
|
||||
User: "just-user",
|
||||
},
|
||||
Encoded: `{"User":"just-user"}`,
|
||||
},
|
||||
{
|
||||
SessionState: SessionState{
|
||||
Email: "user@domain.com",
|
||||
User: "just-user",
|
||||
},
|
||||
Encoded: fmt.Sprintf(`{"Email":"user@domain.com","User":"just-user","AccessToken":"I6s+ml+/MldBMgHIiC35BTKTh57skGX24w==","IDToken":"xojNdyyjB1HgYWh6XMtXY/Ph5eCVxa1cNsklJw==","RefreshToken":"qEX0x6RmASxo4dhlBG6YuRs9Syn/e9sHu/+K","CreatedAt":%s,"ExpiresOn":%s}`, createdString, eString),
|
||||
},
|
||||
{
|
||||
SessionState: SessionState{
|
||||
Email: "user@domain.com",
|
||||
User: "just-user",
|
||||
AccessToken: "token1234",
|
||||
IDToken: "rawtoken1234",
|
||||
CreatedAt: &created,
|
||||
ExpiresOn: &e,
|
||||
RefreshToken: "refresh4321",
|
||||
},
|
||||
Encoded: fmt.Sprintf(`{"Email":"FsKKYrTWZWrxSOAqA/fTNAUZS5QWCqOBjuAbBlbVOw==","User":"rT6JP3dxQhxUhkWrrd7yt6c1mDVyQCVVxw==","AccessToken":"I6s+ml+/MldBMgHIiC35BTKTh57skGX24w==","IDToken":"xojNdyyjB1HgYWh6XMtXY/Ph5eCVxa1cNsklJw==","RefreshToken":"qEX0x6RmASxo4dhlBG6YuRs9Syn/e9sHu/+K","CreatedAt":%s,"ExpiresOn":%s}`, createdString, eString),
|
||||
Cipher: c,
|
||||
},
|
||||
{
|
||||
SessionState: SessionState{
|
||||
Email: "user@domain.com",
|
||||
User: "just-user",
|
||||
},
|
||||
Encoded: `{"Email":"EGTllJcOFC16b7LBYzLekaHAC5SMMSPdyUrg8hd25g==","User":"rT6JP3dxQhxUhkWrrd7yt6c1mDVyQCVVxw=="}`,
|
||||
Cipher: c,
|
||||
},
|
||||
{
|
||||
Encoded: `{"Email":"user@domain.com","User":"just-user","AccessToken":"X"}`,
|
||||
Cipher: c,
|
||||
Error: true,
|
||||
},
|
||||
{
|
||||
Encoded: `{"Email":"user@domain.com","User":"just-user","IDToken":"XXXX"}`,
|
||||
Cipher: c,
|
||||
Error: true,
|
||||
},
|
||||
{
|
||||
SessionState: SessionState{
|
||||
Email: "user@domain.com",
|
||||
User: "YmFzZTY0LWVuY29kZWQtdXNlcgo=", // Base64 encoding of base64-encoded-user
|
||||
},
|
||||
Error: true,
|
||||
Cipher: c,
|
||||
},
|
||||
}
|
||||
|
||||
for i, tc := range testCases {
|
||||
ss, err := DecodeSessionState(tc.Encoded, tc.Cipher)
|
||||
t.Logf("i:%d Encoded:%#vSessionState:%#v Error:%#v", i, tc.Encoded, ss, err)
|
||||
if tc.Error {
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, ss)
|
||||
continue
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
if assert.NotNil(t, ss) {
|
||||
assert.Equal(t, tc.User, ss.User)
|
||||
assert.Equal(t, tc.Email, ss.Email)
|
||||
assert.Equal(t, tc.AccessToken, ss.AccessToken)
|
||||
assert.Equal(t, tc.RefreshToken, ss.RefreshToken)
|
||||
assert.Equal(t, tc.IDToken, ss.IDToken)
|
||||
if tc.ExpiresOn != nil {
|
||||
assert.NotEqual(t, nil, ss.ExpiresOn)
|
||||
assert.Equal(t, tc.ExpiresOn.Unix(), ss.ExpiresOn.Unix())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionStateAge(t *testing.T) {
|
||||
func TestAge(t *testing.T) {
|
||||
ss := &SessionState{}
|
||||
|
||||
// Created at unset so should be 0
|
||||
@@ -313,7 +38,149 @@ func TestSessionStateAge(t *testing.T) {
|
||||
assert.Equal(t, time.Hour, ss.Age().Round(time.Minute))
|
||||
}
|
||||
|
||||
func TestIntoEncryptAndIntoDecrypt(t *testing.T) {
|
||||
// TestEncodeAndDecodeSessionState encodes & decodes various session states
|
||||
// and confirms the operation is 1:1
|
||||
func TestEncodeAndDecodeSessionState(t *testing.T) {
|
||||
created := time.Now()
|
||||
expires := time.Now().Add(time.Duration(1) * time.Hour)
|
||||
|
||||
// Tokens in the test table are purposefully redundant
|
||||
// Otherwise compressing small payloads could result in a compressed value
|
||||
// that is larger (compression dictionary + limited like strings to compress)
|
||||
// which breaks the len(compressed) < len(uncompressed) assertion.
|
||||
testCases := map[string]SessionState{
|
||||
"Full session": {
|
||||
Email: "username@example.com",
|
||||
User: "username",
|
||||
PreferredUsername: "preferred.username",
|
||||
AccessToken: "AccessToken.12349871293847fdsaihf9238h4f91h8fr.1349f831y98fd7",
|
||||
IDToken: "IDToken.12349871293847fdsaihf9238h4f91h8fr.1349f831y98fd7",
|
||||
CreatedAt: &created,
|
||||
ExpiresOn: &expires,
|
||||
RefreshToken: "RefreshToken.12349871293847fdsaihf9238h4f91h8fr.1349f831y98fd7",
|
||||
},
|
||||
"No ExpiresOn": {
|
||||
Email: "username@example.com",
|
||||
User: "username",
|
||||
PreferredUsername: "preferred.username",
|
||||
AccessToken: "AccessToken.12349871293847fdsaihf9238h4f91h8fr.1349f831y98fd7",
|
||||
IDToken: "IDToken.12349871293847fdsaihf9238h4f91h8fr.1349f831y98fd7",
|
||||
CreatedAt: &created,
|
||||
RefreshToken: "RefreshToken.12349871293847fdsaihf9238h4f91h8fr.1349f831y98fd7",
|
||||
},
|
||||
"No PreferredUsername": {
|
||||
Email: "username@example.com",
|
||||
User: "username",
|
||||
AccessToken: "AccessToken.12349871293847fdsaihf9238h4f91h8fr.1349f831y98fd7",
|
||||
IDToken: "IDToken.12349871293847fdsaihf9238h4f91h8fr.1349f831y98fd7",
|
||||
CreatedAt: &created,
|
||||
ExpiresOn: &expires,
|
||||
RefreshToken: "RefreshToken.12349871293847fdsaihf9238h4f91h8fr.1349f831y98fd7",
|
||||
},
|
||||
"Minimal session": {
|
||||
User: "username",
|
||||
IDToken: "IDToken.12349871293847fdsaihf9238h4f91h8fr.1349f831y98fd7",
|
||||
CreatedAt: &created,
|
||||
RefreshToken: "RefreshToken.12349871293847fdsaihf9238h4f91h8fr.1349f831y98fd7",
|
||||
},
|
||||
"Bearer authorization header created session": {
|
||||
Email: "username",
|
||||
User: "username",
|
||||
AccessToken: "IDToken.12349871293847fdsaihf9238h4f91h8fr.1349f831y98fd7",
|
||||
IDToken: "IDToken.12349871293847fdsaihf9238h4f91h8fr.1349f831y98fd7",
|
||||
ExpiresOn: &expires,
|
||||
},
|
||||
}
|
||||
|
||||
for _, secretSize := range []int{16, 24, 32} {
|
||||
t.Run(fmt.Sprintf("%d byte secret", secretSize), func(t *testing.T) {
|
||||
secret := make([]byte, secretSize)
|
||||
_, err := io.ReadFull(rand.Reader, secret)
|
||||
assert.NoError(t, err)
|
||||
|
||||
cfb, err := encryption.NewCFBCipher([]byte(secret))
|
||||
assert.NoError(t, err)
|
||||
gcm, err := encryption.NewGCMCipher([]byte(secret))
|
||||
assert.NoError(t, err)
|
||||
|
||||
ciphers := map[string]encryption.Cipher{
|
||||
"CFB cipher": cfb,
|
||||
"GCM cipher": gcm,
|
||||
}
|
||||
|
||||
for cipherName, c := range ciphers {
|
||||
t.Run(cipherName, func(t *testing.T) {
|
||||
for testName, ss := range testCases {
|
||||
t.Run(testName, func(t *testing.T) {
|
||||
encoded, err := ss.EncodeSessionState(c, false)
|
||||
assert.NoError(t, err)
|
||||
encodedCompressed, err := ss.EncodeSessionState(c, true)
|
||||
assert.NoError(t, err)
|
||||
// Make sure compressed version is smaller than if not compressed
|
||||
assert.Greater(t, len(encoded), len(encodedCompressed))
|
||||
|
||||
decoded, err := DecodeSessionState(encoded, c, false)
|
||||
assert.NoError(t, err)
|
||||
decodedCompressed, err := DecodeSessionState(encodedCompressed, c, true)
|
||||
assert.NoError(t, err)
|
||||
|
||||
compareSessionStates(t, decoded, decodedCompressed)
|
||||
compareSessionStates(t, decoded, &ss)
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
t.Run("Mixed cipher types cause errors", func(t *testing.T) {
|
||||
for testName, ss := range testCases {
|
||||
t.Run(testName, func(t *testing.T) {
|
||||
cfbEncoded, err := ss.EncodeSessionState(cfb, false)
|
||||
assert.NoError(t, err)
|
||||
_, err = DecodeSessionState(cfbEncoded, gcm, false)
|
||||
assert.Error(t, err)
|
||||
|
||||
gcmEncoded, err := ss.EncodeSessionState(gcm, false)
|
||||
assert.NoError(t, err)
|
||||
_, err = DecodeSessionState(gcmEncoded, cfb, false)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestLegacyV5DecodeSessionState confirms V5 JSON sessions decode
|
||||
//
|
||||
// TODO: Remove when this is deprecated (likely V7)
|
||||
func TestLegacyV5DecodeSessionState(t *testing.T) {
|
||||
testCases, cipher, legacyCipher := CreateLegacyV5TestCases(t)
|
||||
|
||||
for testName, tc := range testCases {
|
||||
t.Run(testName, func(t *testing.T) {
|
||||
// Legacy sessions fail in DecodeSessionState which results in
|
||||
// the fallback to LegacyV5DecodeSessionState
|
||||
_, err := DecodeSessionState([]byte(tc.Input), cipher, false)
|
||||
assert.Error(t, err)
|
||||
_, err = DecodeSessionState([]byte(tc.Input), cipher, true)
|
||||
assert.Error(t, err)
|
||||
|
||||
ss, err := LegacyV5DecodeSessionState(tc.Input, legacyCipher)
|
||||
if tc.Error {
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, ss)
|
||||
return
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
compareSessionStates(t, tc.Output, ss)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Test_into tests the into helper function used in LegacyV5DecodeSessionState
|
||||
//
|
||||
// TODO: Remove when this is deprecated (likely V7)
|
||||
func Test_into(t *testing.T) {
|
||||
const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
|
||||
|
||||
// Test all 3 valid AES sizes
|
||||
@@ -323,8 +190,9 @@ func TestIntoEncryptAndIntoDecrypt(t *testing.T) {
|
||||
_, err := io.ReadFull(rand.Reader, secret)
|
||||
assert.Equal(t, nil, err)
|
||||
|
||||
c, err := newTestCipher(secret)
|
||||
cfb, err := encryption.NewCFBCipher(secret)
|
||||
assert.NoError(t, err)
|
||||
c := encryption.NewBase64Cipher(cfb)
|
||||
|
||||
// Check no errors with empty or nil strings
|
||||
empty := ""
|
||||
@@ -353,3 +221,27 @@ func TestIntoEncryptAndIntoDecrypt(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func compareSessionStates(t *testing.T, expected *SessionState, actual *SessionState) {
|
||||
if expected.CreatedAt != nil {
|
||||
assert.NotNil(t, actual.CreatedAt)
|
||||
assert.Equal(t, true, expected.CreatedAt.Equal(*actual.CreatedAt))
|
||||
} else {
|
||||
assert.Nil(t, actual.CreatedAt)
|
||||
}
|
||||
if expected.ExpiresOn != nil {
|
||||
assert.NotNil(t, actual.ExpiresOn)
|
||||
assert.Equal(t, true, expected.ExpiresOn.Equal(*actual.ExpiresOn))
|
||||
} else {
|
||||
assert.Nil(t, actual.ExpiresOn)
|
||||
}
|
||||
|
||||
// Compare sessions without *time.Time fields
|
||||
exp := *expected
|
||||
exp.CreatedAt = nil
|
||||
exp.ExpiresOn = nil
|
||||
act := *actual
|
||||
act.CreatedAt = nil
|
||||
act.ExpiresOn = nil
|
||||
assert.Equal(t, exp, act)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user