1
0
mirror of https://github.com/oauth2-proxy/oauth2-proxy.git synced 2025-01-24 05:26:55 +02:00
oauth2-proxy/pkg/apis/sessions/session_state.go
Nick Meves a09eecc6a2
Reduce SessionState size better with MessagePack + LZ4 (#632)
* Encode sessions with MsgPack + LZ4

Assumes ciphers are now mandatory per #414. Cookie & Redis sessions
can fallback to V5 style JSON in error cases. TODO: session_state.go
unit tests & new unit tests for Legacy fallback scenarios.

* Only compress encoded sessions with Cookie Store

* Cleanup msgpack + lz4 error handling

* Change NewBase64Cipher to take in an existing Cipher

* Add msgpack & lz4 session state tests

* Add required options for oauthproxy tests

More aggressively assert.NoError on all
validation.Validate(opts) calls to enforce legal
options in all our tests.
Add additional NoError checks wherever error return
values were ignored.

* Remove support for uncompressed session state fields

* Improve error verbosity & add session state tests

* Ensure all marshalled sessions are valid

Invalid CFB decryptions can result in garbage data
that 1/100 times might cause message pack unmarshal
to not fail and instead return an empty session.
This adds more rigor to make sure legacy sessions
cause appropriate errors.

* Add tests for legacy V5 session decoding

Refactor common legacy JSON test cases to a
legacy helpers area under session store tests.

* Make ValidateSession a struct method & add CHANGELOG entry

* Improve SessionState error & comments verbosity

* Move legacy session test helpers to sessions pkg

Placing these helpers under the sessions pkg removed
all the circular import uses in housing it under the
session store area.

* Improve SignatureAuthenticator test helper formatting

* Make redis.legacyV5DecodeSession internal

* Make LegacyV5TestCase test table public for linter
2020-07-13 20:56:05 +01:00

242 lines
6.1 KiB
Go

package sessions
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"io"
"io/ioutil"
"time"
"unicode/utf8"
"github.com/oauth2-proxy/oauth2-proxy/pkg/encryption"
"github.com/pierrec/lz4"
"github.com/vmihailenco/msgpack/v4"
)
// SessionState is used to store information about the currently authenticated user session
type SessionState struct {
AccessToken string `json:",omitempty" msgpack:"at,omitempty"`
IDToken string `json:",omitempty" msgpack:"it,omitempty"`
CreatedAt *time.Time `json:",omitempty" msgpack:"ca,omitempty"`
ExpiresOn *time.Time `json:",omitempty" msgpack:"eo,omitempty"`
RefreshToken string `json:",omitempty" msgpack:"rt,omitempty"`
Email string `json:",omitempty" msgpack:"e,omitempty"`
User string `json:",omitempty" msgpack:"u,omitempty"`
PreferredUsername string `json:",omitempty" msgpack:"pu,omitempty"`
}
// IsExpired checks whether the session has expired
func (s *SessionState) IsExpired() bool {
if s.ExpiresOn != nil && !s.ExpiresOn.IsZero() && s.ExpiresOn.Before(time.Now()) {
return true
}
return false
}
// Age returns the age of a session
func (s *SessionState) Age() time.Duration {
if s.CreatedAt != nil && !s.CreatedAt.IsZero() {
return time.Now().Truncate(time.Second).Sub(*s.CreatedAt)
}
return 0
}
// String constructs a summary of the session state
func (s *SessionState) String() string {
o := fmt.Sprintf("Session{email:%s user:%s PreferredUsername:%s", s.Email, s.User, s.PreferredUsername)
if s.AccessToken != "" {
o += " token:true"
}
if s.IDToken != "" {
o += " id_token:true"
}
if !s.CreatedAt.IsZero() {
o += fmt.Sprintf(" created:%s", s.CreatedAt)
}
if !s.ExpiresOn.IsZero() {
o += fmt.Sprintf(" expires:%s", s.ExpiresOn)
}
if s.RefreshToken != "" {
o += " refresh_token:true"
}
return o + "}"
}
// EncodeSessionState returns an encrypted, lz4 compressed, MessagePack encoded session
func (s *SessionState) EncodeSessionState(c encryption.Cipher, compress bool) ([]byte, error) {
packed, err := msgpack.Marshal(s)
if err != nil {
return nil, fmt.Errorf("error marshalling session state to msgpack: %w", err)
}
if !compress {
return c.Encrypt(packed)
}
compressed, err := lz4Compress(packed)
if err != nil {
return nil, err
}
return c.Encrypt(compressed)
}
// DecodeSessionState decodes a LZ4 compressed MessagePack into a Session State
func DecodeSessionState(data []byte, c encryption.Cipher, compressed bool) (*SessionState, error) {
decrypted, err := c.Decrypt(data)
if err != nil {
return nil, fmt.Errorf("error decrypting the session state: %w", err)
}
packed := decrypted
if compressed {
packed, err = lz4Decompress(decrypted)
if err != nil {
return nil, err
}
}
var ss SessionState
err = msgpack.Unmarshal(packed, &ss)
if err != nil {
return nil, fmt.Errorf("error unmarshalling data to session state: %w", err)
}
err = ss.validate()
if err != nil {
return nil, err
}
return &ss, nil
}
// LegacyV5DecodeSessionState decodes a legacy JSON session cookie string into a SessionState
func LegacyV5DecodeSessionState(v string, c encryption.Cipher) (*SessionState, error) {
var ss SessionState
err := json.Unmarshal([]byte(v), &ss)
if err != nil {
return nil, fmt.Errorf("error unmarshalling session: %w", err)
}
for _, s := range []*string{
&ss.User,
&ss.Email,
&ss.PreferredUsername,
&ss.AccessToken,
&ss.IDToken,
&ss.RefreshToken,
} {
err := into(s, c.Decrypt)
if err != nil {
return nil, err
}
}
err = ss.validate()
if err != nil {
return nil, err
}
return &ss, nil
}
// codecFunc is a function that takes a []byte and encodes/decodes it
type codecFunc func([]byte) ([]byte, error)
func into(s *string, f codecFunc) error {
// Do not encrypt/decrypt nil or empty strings
if s == nil || *s == "" {
return nil
}
d, err := f([]byte(*s))
if err != nil {
return err
}
*s = string(d)
return nil
}
// lz4Compress compresses with LZ4
//
// The Compress:Decompress ratio is 1:Many. LZ4 gives fastest decompress speeds
// at the expense of greater compression compared to other compression
// algorithms.
func lz4Compress(payload []byte) ([]byte, error) {
buf := new(bytes.Buffer)
zw := lz4.NewWriter(nil)
zw.Header = lz4.Header{
BlockMaxSize: 65536,
CompressionLevel: 0,
}
zw.Reset(buf)
reader := bytes.NewReader(payload)
_, err := io.Copy(zw, reader)
if err != nil {
return nil, fmt.Errorf("error copying lz4 stream to buffer: %w", err)
}
err = zw.Close()
if err != nil {
return nil, fmt.Errorf("error closing lz4 writer: %w", err)
}
compressed, err := ioutil.ReadAll(buf)
if err != nil {
return nil, fmt.Errorf("error reading lz4 buffer: %w", err)
}
return compressed, nil
}
// lz4Decompress decompresses with LZ4
func lz4Decompress(compressed []byte) ([]byte, error) {
reader := bytes.NewReader(compressed)
buf := new(bytes.Buffer)
zr := lz4.NewReader(nil)
zr.Reset(reader)
_, err := io.Copy(buf, zr)
if err != nil {
return nil, fmt.Errorf("error copying lz4 stream to buffer: %w", err)
}
payload, err := ioutil.ReadAll(buf)
if err != nil {
return nil, fmt.Errorf("error reading lz4 buffer: %w", err)
}
return payload, nil
}
// validate ensures the decoded session is non-empty and contains valid data
//
// Non-empty check is needed due to ensure the non-authenticated AES-CFB
// decryption doesn't result in garbage data that collides with a valid
// MessagePack header bytes (which MessagePack will unmarshal to an empty
// default SessionState). <1% chance, but observed with random test data.
//
// UTF-8 check ensures the strings are valid and not raw bytes overloaded
// into Latin-1 encoding. The occurs when legacy unencrypted fields are
// decrypted with AES-CFB which results in random bytes.
func (s *SessionState) validate() error {
for _, field := range []string{
s.User,
s.Email,
s.PreferredUsername,
s.AccessToken,
s.IDToken,
s.RefreshToken,
} {
if !utf8.ValidString(field) {
return errors.New("invalid non-UTF8 field in session")
}
}
empty := new(SessionState)
if *s == *empty {
return errors.New("invalid empty session unmarshalled")
}
return nil
}