mirror of
https://github.com/oauth2-proxy/oauth2-proxy.git
synced 2025-05-21 22:33:38 +02:00
* Add sensible logging flag to default setup for logger * Add Redis lock * Fix default value flag for sensitive logging * Split RefreshSessionIfNeeded in two methods and use Redis lock * Small adjustments to doc and code * Remove sensible logging * Fix method names in ticket.go * Revert "Fix method names in ticket.go" This reverts commit 408ba1a1a5c55a3cad507a0be8634af1977769cb. * Fix methods name in ticket.go * Remove block in Redis client get * Increase lock time to 1 second * Perform retries, if session store is locked * Reverse if condition, because it should return if session does not have to be refreshed * Update go.sum * Update MockStore * Return error if loading session fails * Fix and update tests * Change validSession to session in docs and strings * Change validSession to session in docs and strings * Fix docs * Fix wrong field name * Fix linting * Fix imports for linting * Revert changes except from locking functionality * Add lock feature on session state * Update from master * Remove errors package, because it is not used * Only pass context instead of request to lock * Use lock key * By default use NoOpLock * Remove debug output * Update ticket_test.go * Map internal error to sessions error * Add ErrLockNotObtained * Enable lock peek for all redis clients * Use lock key prefix consistent * Fix imports * Use exists method for peek lock * Fix imports * Fix imports * Fix imports * Remove own Dockerfile * Fix imports * Fix tests for ticket and session store * Fix session store test * Update pkg/apis/sessions/interfaces.go Co-authored-by: Joel Speed <Joel.speed@hotmail.co.uk> * Do not wrap lock method Co-authored-by: Joel Speed <Joel.speed@hotmail.co.uk> * Use errors package for lock constants * Use better naming for initLock function * Add comments * Add session store lock test * Fix tests * Fix tests * Fix tests * Fix tests * Add cookies after saving session * Add mock lock * Fix imports for mock_lock.go * Store mock lock for key * Apply elapsed time on mock lock * Check if lock is initially applied * Reuse existing lock * Test all lock methods * Update CHANGELOG.md * Use redis client methods in redis.lock for release an refresh * Use lock key suffix instead of prefix for lock key * Add comments for Lock interface * Update comment for Lock interface * Update CHANGELOG.md * Change LockSuffix to const * Check lock on already loaded session * Use global var for loadedSession in lock tests * Use lock instance for refreshing and releasing of lock * Update possible error type for Refresh Co-authored-by: Joel Speed <Joel.speed@hotmail.co.uk>
270 lines
6.7 KiB
Go
270 lines
6.7 KiB
Go
package sessions
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"io/ioutil"
|
|
"reflect"
|
|
"time"
|
|
"unicode/utf8"
|
|
|
|
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/encryption"
|
|
"github.com/pierrec/lz4"
|
|
"github.com/vmihailenco/msgpack/v4"
|
|
)
|
|
|
|
// SessionState is used to store information about the currently authenticated user session
|
|
type SessionState struct {
|
|
CreatedAt *time.Time `msgpack:"ca,omitempty"`
|
|
ExpiresOn *time.Time `msgpack:"eo,omitempty"`
|
|
|
|
AccessToken string `msgpack:"at,omitempty"`
|
|
IDToken string `msgpack:"it,omitempty"`
|
|
RefreshToken string `msgpack:"rt,omitempty"`
|
|
|
|
Nonce []byte `msgpack:"n,omitempty"`
|
|
|
|
Email string `msgpack:"e,omitempty"`
|
|
User string `msgpack:"u,omitempty"`
|
|
Groups []string `msgpack:"g,omitempty"`
|
|
PreferredUsername string `msgpack:"pu,omitempty"`
|
|
|
|
Lock Lock `msgpack:"-"`
|
|
}
|
|
|
|
func (s *SessionState) ObtainLock(ctx context.Context, expiration time.Duration) error {
|
|
if s.Lock == nil {
|
|
s.Lock = &NoOpLock{}
|
|
}
|
|
return s.Lock.Obtain(ctx, expiration)
|
|
}
|
|
|
|
func (s *SessionState) RefreshLock(ctx context.Context, expiration time.Duration) error {
|
|
if s.Lock == nil {
|
|
s.Lock = &NoOpLock{}
|
|
}
|
|
return s.Lock.Refresh(ctx, expiration)
|
|
}
|
|
|
|
func (s *SessionState) ReleaseLock(ctx context.Context) error {
|
|
if s.Lock == nil {
|
|
s.Lock = &NoOpLock{}
|
|
}
|
|
return s.Lock.Release(ctx)
|
|
}
|
|
|
|
func (s *SessionState) PeekLock(ctx context.Context) (bool, error) {
|
|
if s.Lock == nil {
|
|
s.Lock = &NoOpLock{}
|
|
}
|
|
return s.Lock.Peek(ctx)
|
|
}
|
|
|
|
// IsExpired checks whether the session has expired
|
|
func (s *SessionState) IsExpired() bool {
|
|
if s.ExpiresOn != nil && !s.ExpiresOn.IsZero() && s.ExpiresOn.Before(time.Now()) {
|
|
return true
|
|
}
|
|
return false
|
|
}
|
|
|
|
// Age returns the age of a session
|
|
func (s *SessionState) Age() time.Duration {
|
|
if s.CreatedAt != nil && !s.CreatedAt.IsZero() {
|
|
return time.Now().Truncate(time.Second).Sub(*s.CreatedAt)
|
|
}
|
|
return 0
|
|
}
|
|
|
|
// String constructs a summary of the session state
|
|
func (s *SessionState) String() string {
|
|
o := fmt.Sprintf("Session{email:%s user:%s PreferredUsername:%s", s.Email, s.User, s.PreferredUsername)
|
|
if s.AccessToken != "" {
|
|
o += " token:true"
|
|
}
|
|
if s.IDToken != "" {
|
|
o += " id_token:true"
|
|
}
|
|
if s.CreatedAt != nil && !s.CreatedAt.IsZero() {
|
|
o += fmt.Sprintf(" created:%s", s.CreatedAt)
|
|
}
|
|
if s.ExpiresOn != nil && !s.ExpiresOn.IsZero() {
|
|
o += fmt.Sprintf(" expires:%s", s.ExpiresOn)
|
|
}
|
|
if s.RefreshToken != "" {
|
|
o += " refresh_token:true"
|
|
}
|
|
if len(s.Groups) > 0 {
|
|
o += fmt.Sprintf(" groups:%v", s.Groups)
|
|
}
|
|
return o + "}"
|
|
}
|
|
|
|
func (s *SessionState) GetClaim(claim string) []string {
|
|
if s == nil {
|
|
return []string{}
|
|
}
|
|
switch claim {
|
|
case "access_token":
|
|
return []string{s.AccessToken}
|
|
case "id_token":
|
|
return []string{s.IDToken}
|
|
case "created_at":
|
|
return []string{s.CreatedAt.String()}
|
|
case "expires_on":
|
|
return []string{s.ExpiresOn.String()}
|
|
case "refresh_token":
|
|
return []string{s.RefreshToken}
|
|
case "email":
|
|
return []string{s.Email}
|
|
case "user":
|
|
return []string{s.User}
|
|
case "groups":
|
|
groups := make([]string, len(s.Groups))
|
|
copy(groups, s.Groups)
|
|
return groups
|
|
case "preferred_username":
|
|
return []string{s.PreferredUsername}
|
|
default:
|
|
return []string{}
|
|
}
|
|
}
|
|
|
|
// CheckNonce compares the Nonce against a potential hash of it
|
|
func (s *SessionState) CheckNonce(hashed string) bool {
|
|
return encryption.CheckNonce(s.Nonce, hashed)
|
|
}
|
|
|
|
// EncodeSessionState returns an encrypted, lz4 compressed, MessagePack encoded session
|
|
func (s *SessionState) EncodeSessionState(c encryption.Cipher, compress bool) ([]byte, error) {
|
|
packed, err := msgpack.Marshal(s)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("error marshalling session state to msgpack: %w", err)
|
|
}
|
|
|
|
if !compress {
|
|
return c.Encrypt(packed)
|
|
}
|
|
|
|
compressed, err := lz4Compress(packed)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return c.Encrypt(compressed)
|
|
}
|
|
|
|
// DecodeSessionState decodes a LZ4 compressed MessagePack into a Session State
|
|
func DecodeSessionState(data []byte, c encryption.Cipher, compressed bool) (*SessionState, error) {
|
|
decrypted, err := c.Decrypt(data)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("error decrypting the session state: %w", err)
|
|
}
|
|
|
|
packed := decrypted
|
|
if compressed {
|
|
packed, err = lz4Decompress(decrypted)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
var ss SessionState
|
|
err = msgpack.Unmarshal(packed, &ss)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("error unmarshalling data to session state: %w", err)
|
|
}
|
|
|
|
err = ss.validate()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &ss, nil
|
|
}
|
|
|
|
// lz4Compress compresses with LZ4
|
|
//
|
|
// The Compress:Decompress ratio is 1:Many. LZ4 gives fastest decompress speeds
|
|
// at the expense of greater compression compared to other compression
|
|
// algorithms.
|
|
func lz4Compress(payload []byte) ([]byte, error) {
|
|
buf := new(bytes.Buffer)
|
|
zw := lz4.NewWriter(nil)
|
|
zw.Header = lz4.Header{
|
|
BlockMaxSize: 65536,
|
|
CompressionLevel: 0,
|
|
}
|
|
zw.Reset(buf)
|
|
|
|
reader := bytes.NewReader(payload)
|
|
_, err := io.Copy(zw, reader)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("error copying lz4 stream to buffer: %w", err)
|
|
}
|
|
err = zw.Close()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("error closing lz4 writer: %w", err)
|
|
}
|
|
|
|
compressed, err := ioutil.ReadAll(buf)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("error reading lz4 buffer: %w", err)
|
|
}
|
|
|
|
return compressed, nil
|
|
}
|
|
|
|
// lz4Decompress decompresses with LZ4
|
|
func lz4Decompress(compressed []byte) ([]byte, error) {
|
|
reader := bytes.NewReader(compressed)
|
|
buf := new(bytes.Buffer)
|
|
zr := lz4.NewReader(nil)
|
|
zr.Reset(reader)
|
|
_, err := io.Copy(buf, zr)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("error copying lz4 stream to buffer: %w", err)
|
|
}
|
|
|
|
payload, err := ioutil.ReadAll(buf)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("error reading lz4 buffer: %w", err)
|
|
}
|
|
|
|
return payload, nil
|
|
}
|
|
|
|
// validate ensures the decoded session is non-empty and contains valid data
|
|
//
|
|
// Non-empty check is needed due to ensure the non-authenticated AES-CFB
|
|
// decryption doesn't result in garbage data that collides with a valid
|
|
// MessagePack header bytes (which MessagePack will unmarshal to an empty
|
|
// default SessionState). <1% chance, but observed with random test data.
|
|
//
|
|
// UTF-8 check ensures the strings are valid and not raw bytes overloaded
|
|
// into Latin-1 encoding. The occurs when legacy unencrypted fields are
|
|
// decrypted with AES-CFB which results in random bytes.
|
|
func (s *SessionState) validate() error {
|
|
for _, field := range []string{
|
|
s.User,
|
|
s.Email,
|
|
s.PreferredUsername,
|
|
s.AccessToken,
|
|
s.IDToken,
|
|
s.RefreshToken,
|
|
} {
|
|
if !utf8.ValidString(field) {
|
|
return errors.New("invalid non-UTF8 field in session")
|
|
}
|
|
}
|
|
|
|
empty := new(SessionState)
|
|
if reflect.DeepEqual(*s, *empty) {
|
|
return errors.New("invalid empty session unmarshalled")
|
|
}
|
|
|
|
return nil
|
|
}
|