1
0
mirror of https://github.com/oauth2-proxy/oauth2-proxy.git synced 2025-03-25 22:00:56 +02:00
oauth2-proxy/pkg/apis/sessions/session_state.go
2021-06-21 21:54:52 -07:00

293 lines
7.3 KiB
Go

package sessions
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"io/ioutil"
"reflect"
"time"
"unicode/utf8"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/clock"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/encryption"
"github.com/pierrec/lz4"
"github.com/vmihailenco/msgpack/v4"
)
// SessionState is used to store information about the currently authenticated user session
type SessionState struct {
CreatedAt *time.Time `msgpack:"ca,omitempty"`
ExpiresOn *time.Time `msgpack:"eo,omitempty"`
AccessToken string `msgpack:"at,omitempty"`
IDToken string `msgpack:"it,omitempty"`
RefreshToken string `msgpack:"rt,omitempty"`
Nonce []byte `msgpack:"n,omitempty"`
Email string `msgpack:"e,omitempty"`
User string `msgpack:"u,omitempty"`
Groups []string `msgpack:"g,omitempty"`
PreferredUsername string `msgpack:"pu,omitempty"`
Clock clock.Clock `msgpack:"-"`
Lock Lock `msgpack:"-"`
}
func (s *SessionState) ObtainLock(ctx context.Context, expiration time.Duration) error {
if s.Lock == nil {
s.Lock = &NoOpLock{}
}
return s.Lock.Obtain(ctx, expiration)
}
func (s *SessionState) RefreshLock(ctx context.Context, expiration time.Duration) error {
if s.Lock == nil {
s.Lock = &NoOpLock{}
}
return s.Lock.Refresh(ctx, expiration)
}
func (s *SessionState) ReleaseLock(ctx context.Context) error {
if s.Lock == nil {
s.Lock = &NoOpLock{}
}
return s.Lock.Release(ctx)
}
func (s *SessionState) PeekLock(ctx context.Context) (bool, error) {
if s.Lock == nil {
s.Lock = &NoOpLock{}
}
return s.Lock.Peek(ctx)
}
// CreatedAtNow sets a SessionState's CreatedAt to now
func (s *SessionState) CreatedAtNow() {
now := s.Clock.Now()
s.CreatedAt = &now
}
// SetExpiresOn sets an expiration
func (s *SessionState) SetExpiresOn(exp time.Time) {
s.ExpiresOn = &exp
}
// ExpiresIn sets an expiration a certain duration from CreatedAt.
// CreatedAt will be set to time.Now if it is unset.
func (s *SessionState) ExpiresIn(d time.Duration) {
if s.CreatedAt == nil {
s.CreatedAtNow()
}
exp := s.CreatedAt.Add(d)
s.ExpiresOn = &exp
}
// IsExpired checks whether the session has expired
func (s *SessionState) IsExpired() bool {
if s.ExpiresOn != nil && !s.ExpiresOn.IsZero() && s.ExpiresOn.Before(s.Clock.Now()) {
return true
}
return false
}
// Age returns the age of a session
func (s *SessionState) Age() time.Duration {
if s.CreatedAt != nil && !s.CreatedAt.IsZero() {
return s.Clock.Now().Truncate(time.Second).Sub(*s.CreatedAt)
}
return 0
}
// String constructs a summary of the session state
func (s *SessionState) String() string {
o := fmt.Sprintf("Session{email:%s user:%s PreferredUsername:%s", s.Email, s.User, s.PreferredUsername)
if s.AccessToken != "" {
o += " token:true"
}
if s.IDToken != "" {
o += " id_token:true"
}
if s.CreatedAt != nil && !s.CreatedAt.IsZero() {
o += fmt.Sprintf(" created:%s", s.CreatedAt)
}
if s.ExpiresOn != nil && !s.ExpiresOn.IsZero() {
o += fmt.Sprintf(" expires:%s", s.ExpiresOn)
}
if s.RefreshToken != "" {
o += " refresh_token:true"
}
if len(s.Groups) > 0 {
o += fmt.Sprintf(" groups:%v", s.Groups)
}
return o + "}"
}
func (s *SessionState) GetClaim(claim string) []string {
if s == nil {
return []string{}
}
switch claim {
case "access_token":
return []string{s.AccessToken}
case "id_token":
return []string{s.IDToken}
case "created_at":
return []string{s.CreatedAt.String()}
case "expires_on":
return []string{s.ExpiresOn.String()}
case "refresh_token":
return []string{s.RefreshToken}
case "email":
return []string{s.Email}
case "user":
return []string{s.User}
case "groups":
groups := make([]string, len(s.Groups))
copy(groups, s.Groups)
return groups
case "preferred_username":
return []string{s.PreferredUsername}
default:
return []string{}
}
}
// CheckNonce compares the Nonce against a potential hash of it
func (s *SessionState) CheckNonce(hashed string) bool {
return encryption.CheckNonce(s.Nonce, hashed)
}
// EncodeSessionState returns an encrypted, lz4 compressed, MessagePack encoded session
func (s *SessionState) EncodeSessionState(c encryption.Cipher, compress bool) ([]byte, error) {
packed, err := msgpack.Marshal(s)
if err != nil {
return nil, fmt.Errorf("error marshalling session state to msgpack: %w", err)
}
if !compress {
return c.Encrypt(packed)
}
compressed, err := lz4Compress(packed)
if err != nil {
return nil, err
}
return c.Encrypt(compressed)
}
// DecodeSessionState decodes a LZ4 compressed MessagePack into a Session State
func DecodeSessionState(data []byte, c encryption.Cipher, compressed bool) (*SessionState, error) {
decrypted, err := c.Decrypt(data)
if err != nil {
return nil, fmt.Errorf("error decrypting the session state: %w", err)
}
packed := decrypted
if compressed {
packed, err = lz4Decompress(decrypted)
if err != nil {
return nil, err
}
}
var ss SessionState
err = msgpack.Unmarshal(packed, &ss)
if err != nil {
return nil, fmt.Errorf("error unmarshalling data to session state: %w", err)
}
err = ss.validate()
if err != nil {
return nil, err
}
return &ss, nil
}
// lz4Compress compresses with LZ4
//
// The Compress:Decompress ratio is 1:Many. LZ4 gives fastest decompress speeds
// at the expense of greater compression compared to other compression
// algorithms.
func lz4Compress(payload []byte) ([]byte, error) {
buf := new(bytes.Buffer)
zw := lz4.NewWriter(nil)
zw.Header = lz4.Header{
BlockMaxSize: 65536,
CompressionLevel: 0,
}
zw.Reset(buf)
reader := bytes.NewReader(payload)
_, err := io.Copy(zw, reader)
if err != nil {
return nil, fmt.Errorf("error copying lz4 stream to buffer: %w", err)
}
err = zw.Close()
if err != nil {
return nil, fmt.Errorf("error closing lz4 writer: %w", err)
}
compressed, err := ioutil.ReadAll(buf)
if err != nil {
return nil, fmt.Errorf("error reading lz4 buffer: %w", err)
}
return compressed, nil
}
// lz4Decompress decompresses with LZ4
func lz4Decompress(compressed []byte) ([]byte, error) {
reader := bytes.NewReader(compressed)
buf := new(bytes.Buffer)
zr := lz4.NewReader(nil)
zr.Reset(reader)
_, err := io.Copy(buf, zr)
if err != nil {
return nil, fmt.Errorf("error copying lz4 stream to buffer: %w", err)
}
payload, err := ioutil.ReadAll(buf)
if err != nil {
return nil, fmt.Errorf("error reading lz4 buffer: %w", err)
}
return payload, nil
}
// validate ensures the decoded session is non-empty and contains valid data
//
// Non-empty check is needed due to ensure the non-authenticated AES-CFB
// decryption doesn't result in garbage data that collides with a valid
// MessagePack header bytes (which MessagePack will unmarshal to an empty
// default SessionState). <1% chance, but observed with random test data.
//
// UTF-8 check ensures the strings are valid and not raw bytes overloaded
// into Latin-1 encoding. The occurs when legacy unencrypted fields are
// decrypted with AES-CFB which results in random bytes.
func (s *SessionState) validate() error {
for _, field := range []string{
s.User,
s.Email,
s.PreferredUsername,
s.AccessToken,
s.IDToken,
s.RefreshToken,
} {
if !utf8.ValidString(field) {
return errors.New("invalid non-UTF8 field in session")
}
}
empty := new(SessionState)
if reflect.DeepEqual(*s, *empty) {
return errors.New("invalid empty session unmarshalled")
}
return nil
}