package sessions import ( "bytes" "context" "fmt" "io" "io/ioutil" "time" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/clock" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/encryption" "github.com/pierrec/lz4" "github.com/vmihailenco/msgpack/v4" ) // SessionState is used to store information about the currently authenticated user session type SessionState struct { CreatedAt *time.Time `msgpack:"ca,omitempty"` ExpiresOn *time.Time `msgpack:"eo,omitempty"` AccessToken string `msgpack:"at,omitempty"` IDToken string `msgpack:"it,omitempty"` RefreshToken string `msgpack:"rt,omitempty"` Nonce []byte `msgpack:"n,omitempty"` Email string `msgpack:"e,omitempty"` User string `msgpack:"u,omitempty"` Groups []string `msgpack:"g,omitempty"` PreferredUsername string `msgpack:"pu,omitempty"` // Internal helpers, not serialized Clock clock.Clock `msgpack:"-"` Lock Lock `msgpack:"-"` } func (s *SessionState) ObtainLock(ctx context.Context, expiration time.Duration) error { if s.Lock == nil { s.Lock = &NoOpLock{} } return s.Lock.Obtain(ctx, expiration) } func (s *SessionState) RefreshLock(ctx context.Context, expiration time.Duration) error { if s.Lock == nil { s.Lock = &NoOpLock{} } return s.Lock.Refresh(ctx, expiration) } func (s *SessionState) ReleaseLock(ctx context.Context) error { if s.Lock == nil { s.Lock = &NoOpLock{} } return s.Lock.Release(ctx) } func (s *SessionState) PeekLock(ctx context.Context) (bool, error) { if s.Lock == nil { s.Lock = &NoOpLock{} } return s.Lock.Peek(ctx) } // CreatedAtNow sets a SessionState's CreatedAt to now func (s *SessionState) CreatedAtNow() { now := s.Clock.Now() s.CreatedAt = &now } // SetExpiresOn sets an expiration func (s *SessionState) SetExpiresOn(exp time.Time) { s.ExpiresOn = &exp } // ExpiresIn sets an expiration a certain duration from CreatedAt. // CreatedAt will be set to time.Now if it is unset. func (s *SessionState) ExpiresIn(d time.Duration) { if s.CreatedAt == nil { s.CreatedAtNow() } exp := s.CreatedAt.Add(d) s.ExpiresOn = &exp } // IsExpired checks whether the session has expired func (s *SessionState) IsExpired() bool { if s.ExpiresOn != nil && !s.ExpiresOn.IsZero() && s.ExpiresOn.Before(s.Clock.Now()) { return true } return false } // Age returns the age of a session func (s *SessionState) Age() time.Duration { if s.CreatedAt != nil && !s.CreatedAt.IsZero() { return s.Clock.Now().Truncate(time.Second).Sub(*s.CreatedAt) } return 0 } // String constructs a summary of the session state func (s *SessionState) String() string { o := fmt.Sprintf("Session{email:%s user:%s PreferredUsername:%s", s.Email, s.User, s.PreferredUsername) if s.AccessToken != "" { o += " token:true" } if s.IDToken != "" { o += " id_token:true" } if s.CreatedAt != nil && !s.CreatedAt.IsZero() { o += fmt.Sprintf(" created:%s", s.CreatedAt) } if s.ExpiresOn != nil && !s.ExpiresOn.IsZero() { o += fmt.Sprintf(" expires:%s", s.ExpiresOn) } if s.RefreshToken != "" { o += " refresh_token:true" } if len(s.Groups) > 0 { o += fmt.Sprintf(" groups:%v", s.Groups) } return o + "}" } func (s *SessionState) GetClaim(claim string) []string { if s == nil { return []string{} } switch claim { case "access_token": return []string{s.AccessToken} case "id_token": return []string{s.IDToken} case "created_at": return []string{s.CreatedAt.String()} case "expires_on": return []string{s.ExpiresOn.String()} case "refresh_token": return []string{s.RefreshToken} case "email": return []string{s.Email} case "user": return []string{s.User} case "groups": groups := make([]string, len(s.Groups)) copy(groups, s.Groups) return groups case "preferred_username": return []string{s.PreferredUsername} default: return []string{} } } // CheckNonce compares the Nonce against a potential hash of it func (s *SessionState) CheckNonce(hashed string) bool { return encryption.CheckNonce(s.Nonce, hashed) } // EncodeSessionState returns an encrypted, lz4 compressed, MessagePack encoded session func (s *SessionState) EncodeSessionState(c encryption.Cipher, compress bool) ([]byte, error) { packed, err := msgpack.Marshal(s) if err != nil { return nil, fmt.Errorf("error marshalling session state to msgpack: %w", err) } if !compress { return c.Encrypt(packed) } compressed, err := lz4Compress(packed) if err != nil { return nil, err } return c.Encrypt(compressed) } // DecodeSessionState decodes a LZ4 compressed MessagePack into a Session State func DecodeSessionState(data []byte, c encryption.Cipher, compressed bool) (*SessionState, error) { decrypted, err := c.Decrypt(data) if err != nil { return nil, fmt.Errorf("error decrypting the session state: %w", err) } packed := decrypted if compressed { packed, err = lz4Decompress(decrypted) if err != nil { return nil, err } } var ss SessionState err = msgpack.Unmarshal(packed, &ss) if err != nil { return nil, fmt.Errorf("error unmarshalling data to session state: %w", err) } return &ss, nil } // lz4Compress compresses with LZ4 // // The Compress:Decompress ratio is 1:Many. LZ4 gives fastest decompress speeds // at the expense of greater compression compared to other compression // algorithms. func lz4Compress(payload []byte) ([]byte, error) { buf := new(bytes.Buffer) zw := lz4.NewWriter(nil) zw.Header = lz4.Header{ BlockMaxSize: 65536, CompressionLevel: 0, } zw.Reset(buf) reader := bytes.NewReader(payload) _, err := io.Copy(zw, reader) if err != nil { return nil, fmt.Errorf("error copying lz4 stream to buffer: %w", err) } err = zw.Close() if err != nil { return nil, fmt.Errorf("error closing lz4 writer: %w", err) } compressed, err := ioutil.ReadAll(buf) if err != nil { return nil, fmt.Errorf("error reading lz4 buffer: %w", err) } return compressed, nil } // lz4Decompress decompresses with LZ4 func lz4Decompress(compressed []byte) ([]byte, error) { reader := bytes.NewReader(compressed) buf := new(bytes.Buffer) zr := lz4.NewReader(nil) zr.Reset(reader) _, err := io.Copy(buf, zr) if err != nil { return nil, fmt.Errorf("error copying lz4 stream to buffer: %w", err) } payload, err := ioutil.ReadAll(buf) if err != nil { return nil, fmt.Errorf("error reading lz4 buffer: %w", err) } return payload, nil }