mirror of
https://github.com/oauth2-proxy/oauth2-proxy.git
synced 2025-01-26 05:27:28 +02:00
baf6cf3816
They will only be used in tests, but it doesn't play nice with copy operations many tests use. The linter was not happy. While the global clock needs mutexes for parallelism, local Clocks only used it for Set/Add and didn't even use the mutex for actual time functions.
254 lines
6.3 KiB
Go
254 lines
6.3 KiB
Go
package sessions
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"fmt"
|
|
"io"
|
|
"io/ioutil"
|
|
"time"
|
|
|
|
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/clock"
|
|
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/encryption"
|
|
"github.com/pierrec/lz4"
|
|
"github.com/vmihailenco/msgpack/v4"
|
|
)
|
|
|
|
// SessionState is used to store information about the currently authenticated user session
|
|
type SessionState struct {
|
|
CreatedAt *time.Time `msgpack:"ca,omitempty"`
|
|
ExpiresOn *time.Time `msgpack:"eo,omitempty"`
|
|
|
|
AccessToken string `msgpack:"at,omitempty"`
|
|
IDToken string `msgpack:"it,omitempty"`
|
|
RefreshToken string `msgpack:"rt,omitempty"`
|
|
|
|
Nonce []byte `msgpack:"n,omitempty"`
|
|
|
|
Email string `msgpack:"e,omitempty"`
|
|
User string `msgpack:"u,omitempty"`
|
|
Groups []string `msgpack:"g,omitempty"`
|
|
PreferredUsername string `msgpack:"pu,omitempty"`
|
|
|
|
// Internal helpers, not serialized
|
|
Clock clock.Clock `msgpack:"-"`
|
|
Lock Lock `msgpack:"-"`
|
|
}
|
|
|
|
func (s *SessionState) ObtainLock(ctx context.Context, expiration time.Duration) error {
|
|
if s.Lock == nil {
|
|
s.Lock = &NoOpLock{}
|
|
}
|
|
return s.Lock.Obtain(ctx, expiration)
|
|
}
|
|
|
|
func (s *SessionState) RefreshLock(ctx context.Context, expiration time.Duration) error {
|
|
if s.Lock == nil {
|
|
s.Lock = &NoOpLock{}
|
|
}
|
|
return s.Lock.Refresh(ctx, expiration)
|
|
}
|
|
|
|
func (s *SessionState) ReleaseLock(ctx context.Context) error {
|
|
if s.Lock == nil {
|
|
s.Lock = &NoOpLock{}
|
|
}
|
|
return s.Lock.Release(ctx)
|
|
}
|
|
|
|
func (s *SessionState) PeekLock(ctx context.Context) (bool, error) {
|
|
if s.Lock == nil {
|
|
s.Lock = &NoOpLock{}
|
|
}
|
|
return s.Lock.Peek(ctx)
|
|
}
|
|
|
|
// CreatedAtNow sets a SessionState's CreatedAt to now
|
|
func (s *SessionState) CreatedAtNow() {
|
|
now := s.Clock.Now()
|
|
s.CreatedAt = &now
|
|
}
|
|
|
|
// SetExpiresOn sets an expiration
|
|
func (s *SessionState) SetExpiresOn(exp time.Time) {
|
|
s.ExpiresOn = &exp
|
|
}
|
|
|
|
// ExpiresIn sets an expiration a certain duration from CreatedAt.
|
|
// CreatedAt will be set to time.Now if it is unset.
|
|
func (s *SessionState) ExpiresIn(d time.Duration) {
|
|
if s.CreatedAt == nil {
|
|
s.CreatedAtNow()
|
|
}
|
|
exp := s.CreatedAt.Add(d)
|
|
s.ExpiresOn = &exp
|
|
}
|
|
|
|
// IsExpired checks whether the session has expired
|
|
func (s *SessionState) IsExpired() bool {
|
|
if s.ExpiresOn != nil && !s.ExpiresOn.IsZero() && s.ExpiresOn.Before(s.Clock.Now()) {
|
|
return true
|
|
}
|
|
return false
|
|
}
|
|
|
|
// Age returns the age of a session
|
|
func (s *SessionState) Age() time.Duration {
|
|
if s.CreatedAt != nil && !s.CreatedAt.IsZero() {
|
|
return s.Clock.Now().Truncate(time.Second).Sub(*s.CreatedAt)
|
|
}
|
|
return 0
|
|
}
|
|
|
|
// String constructs a summary of the session state
|
|
func (s *SessionState) String() string {
|
|
o := fmt.Sprintf("Session{email:%s user:%s PreferredUsername:%s", s.Email, s.User, s.PreferredUsername)
|
|
if s.AccessToken != "" {
|
|
o += " token:true"
|
|
}
|
|
if s.IDToken != "" {
|
|
o += " id_token:true"
|
|
}
|
|
if s.CreatedAt != nil && !s.CreatedAt.IsZero() {
|
|
o += fmt.Sprintf(" created:%s", s.CreatedAt)
|
|
}
|
|
if s.ExpiresOn != nil && !s.ExpiresOn.IsZero() {
|
|
o += fmt.Sprintf(" expires:%s", s.ExpiresOn)
|
|
}
|
|
if s.RefreshToken != "" {
|
|
o += " refresh_token:true"
|
|
}
|
|
if len(s.Groups) > 0 {
|
|
o += fmt.Sprintf(" groups:%v", s.Groups)
|
|
}
|
|
return o + "}"
|
|
}
|
|
|
|
func (s *SessionState) GetClaim(claim string) []string {
|
|
if s == nil {
|
|
return []string{}
|
|
}
|
|
switch claim {
|
|
case "access_token":
|
|
return []string{s.AccessToken}
|
|
case "id_token":
|
|
return []string{s.IDToken}
|
|
case "created_at":
|
|
return []string{s.CreatedAt.String()}
|
|
case "expires_on":
|
|
return []string{s.ExpiresOn.String()}
|
|
case "refresh_token":
|
|
return []string{s.RefreshToken}
|
|
case "email":
|
|
return []string{s.Email}
|
|
case "user":
|
|
return []string{s.User}
|
|
case "groups":
|
|
groups := make([]string, len(s.Groups))
|
|
copy(groups, s.Groups)
|
|
return groups
|
|
case "preferred_username":
|
|
return []string{s.PreferredUsername}
|
|
default:
|
|
return []string{}
|
|
}
|
|
}
|
|
|
|
// CheckNonce compares the Nonce against a potential hash of it
|
|
func (s *SessionState) CheckNonce(hashed string) bool {
|
|
return encryption.CheckNonce(s.Nonce, hashed)
|
|
}
|
|
|
|
// EncodeSessionState returns an encrypted, lz4 compressed, MessagePack encoded session
|
|
func (s *SessionState) EncodeSessionState(c encryption.Cipher, compress bool) ([]byte, error) {
|
|
packed, err := msgpack.Marshal(s)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("error marshalling session state to msgpack: %w", err)
|
|
}
|
|
|
|
if !compress {
|
|
return c.Encrypt(packed)
|
|
}
|
|
|
|
compressed, err := lz4Compress(packed)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return c.Encrypt(compressed)
|
|
}
|
|
|
|
// DecodeSessionState decodes a LZ4 compressed MessagePack into a Session State
|
|
func DecodeSessionState(data []byte, c encryption.Cipher, compressed bool) (*SessionState, error) {
|
|
decrypted, err := c.Decrypt(data)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("error decrypting the session state: %w", err)
|
|
}
|
|
|
|
packed := decrypted
|
|
if compressed {
|
|
packed, err = lz4Decompress(decrypted)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
var ss SessionState
|
|
err = msgpack.Unmarshal(packed, &ss)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("error unmarshalling data to session state: %w", err)
|
|
}
|
|
|
|
return &ss, nil
|
|
}
|
|
|
|
// lz4Compress compresses with LZ4
|
|
//
|
|
// The Compress:Decompress ratio is 1:Many. LZ4 gives fastest decompress speeds
|
|
// at the expense of greater compression compared to other compression
|
|
// algorithms.
|
|
func lz4Compress(payload []byte) ([]byte, error) {
|
|
buf := new(bytes.Buffer)
|
|
zw := lz4.NewWriter(nil)
|
|
zw.Header = lz4.Header{
|
|
BlockMaxSize: 65536,
|
|
CompressionLevel: 0,
|
|
}
|
|
zw.Reset(buf)
|
|
|
|
reader := bytes.NewReader(payload)
|
|
_, err := io.Copy(zw, reader)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("error copying lz4 stream to buffer: %w", err)
|
|
}
|
|
err = zw.Close()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("error closing lz4 writer: %w", err)
|
|
}
|
|
|
|
compressed, err := ioutil.ReadAll(buf)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("error reading lz4 buffer: %w", err)
|
|
}
|
|
|
|
return compressed, nil
|
|
}
|
|
|
|
// lz4Decompress decompresses with LZ4
|
|
func lz4Decompress(compressed []byte) ([]byte, error) {
|
|
reader := bytes.NewReader(compressed)
|
|
buf := new(bytes.Buffer)
|
|
zr := lz4.NewReader(nil)
|
|
zr.Reset(reader)
|
|
_, err := io.Copy(buf, zr)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("error copying lz4 stream to buffer: %w", err)
|
|
}
|
|
|
|
payload, err := ioutil.ReadAll(buf)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("error reading lz4 buffer: %w", err)
|
|
}
|
|
|
|
return payload, nil
|
|
}
|