mirror of
https://github.com/oauth2-proxy/oauth2-proxy.git
synced 2025-05-23 22:40:31 +02:00
Merge pull request #1086 from oauth2-proxy/early-refresh
Convert RefreshSessionIfNeeded into RefreshSession
This commit is contained in:
commit
16a9893a19
@ -4,10 +4,16 @@
|
|||||||
|
|
||||||
## Important Notes
|
## Important Notes
|
||||||
|
|
||||||
|
- [#1086](https://github.com/oauth2-proxy/oauth2-proxy/pull/1086) The extra validation to protect invalid session
|
||||||
|
deserialization from v6.0.0 (only) has been removed to improve performance. If you are on v6.0.0, either upgrade
|
||||||
|
to a version before this first and allow legacy sessions to expire gracefully or change your `cookie-secret`
|
||||||
|
value and force all sessions to reauthenticate.
|
||||||
|
|
||||||
## Breaking Changes
|
## Breaking Changes
|
||||||
|
|
||||||
## Changes since v7.1.3
|
## Changes since v7.1.3
|
||||||
|
|
||||||
|
- [#1086](https://github.com/oauth2-proxy/oauth2-proxy/pull/1086) Refresh sessions before token expiration if configured (@NickMeves)
|
||||||
- [#1226](https://github.com/oauth2-proxy/oauth2-proxy/pull/1226) Move app redirection logic to its own package (@JoelSpeed)
|
- [#1226](https://github.com/oauth2-proxy/oauth2-proxy/pull/1226) Move app redirection logic to its own package (@JoelSpeed)
|
||||||
- [#1128](https://github.com/oauth2-proxy/oauth2-proxy/pull/1128) Use gorilla mux for OAuth Proxy routing (@JoelSpeed)
|
- [#1128](https://github.com/oauth2-proxy/oauth2-proxy/pull/1128) Use gorilla mux for OAuth Proxy routing (@JoelSpeed)
|
||||||
- [#1238](https://github.com/oauth2-proxy/oauth2-proxy/pull/1238) Added ADFS provider (@samirachoadi)
|
- [#1238](https://github.com/oauth2-proxy/oauth2-proxy/pull/1238) Added ADFS provider (@samirachoadi)
|
||||||
|
@ -361,10 +361,10 @@ func buildSessionChain(opts *options.Options, sessionStore sessionsapi.SessionSt
|
|||||||
}
|
}
|
||||||
|
|
||||||
chain = chain.Append(middleware.NewStoredSessionLoader(&middleware.StoredSessionLoaderOptions{
|
chain = chain.Append(middleware.NewStoredSessionLoader(&middleware.StoredSessionLoaderOptions{
|
||||||
SessionStore: sessionStore,
|
SessionStore: sessionStore,
|
||||||
RefreshPeriod: opts.Cookie.Refresh,
|
RefreshPeriod: opts.Cookie.Refresh,
|
||||||
RefreshSessionIfNeeded: opts.GetProvider().RefreshSessionIfNeeded,
|
RefreshSession: opts.GetProvider().RefreshSession,
|
||||||
ValidateSessionState: opts.GetProvider().ValidateSession,
|
ValidateSession: opts.GetProvider().ValidateSession,
|
||||||
}))
|
}))
|
||||||
|
|
||||||
return chain
|
return chain
|
||||||
@ -786,6 +786,15 @@ func (p *OAuthProxy) redeemCode(req *http.Request) (*sessionsapi.SessionState, e
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Force setting these in case the Provider didn't
|
||||||
|
if s.CreatedAt == nil {
|
||||||
|
s.CreatedAtNow()
|
||||||
|
}
|
||||||
|
if s.ExpiresOn == nil {
|
||||||
|
s.ExpiresIn(p.CookieOptions.Expire)
|
||||||
|
}
|
||||||
|
|
||||||
return s, nil
|
return s, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -3,14 +3,12 @@ package sessions
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"reflect"
|
|
||||||
"time"
|
"time"
|
||||||
"unicode/utf8"
|
|
||||||
|
|
||||||
|
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/clock"
|
||||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/encryption"
|
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/encryption"
|
||||||
"github.com/pierrec/lz4"
|
"github.com/pierrec/lz4"
|
||||||
"github.com/vmihailenco/msgpack/v4"
|
"github.com/vmihailenco/msgpack/v4"
|
||||||
@ -32,7 +30,9 @@ type SessionState struct {
|
|||||||
Groups []string `msgpack:"g,omitempty"`
|
Groups []string `msgpack:"g,omitempty"`
|
||||||
PreferredUsername string `msgpack:"pu,omitempty"`
|
PreferredUsername string `msgpack:"pu,omitempty"`
|
||||||
|
|
||||||
Lock Lock `msgpack:"-"`
|
// Internal helpers, not serialized
|
||||||
|
Clock clock.Clock `msgpack:"-"`
|
||||||
|
Lock Lock `msgpack:"-"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SessionState) ObtainLock(ctx context.Context, expiration time.Duration) error {
|
func (s *SessionState) ObtainLock(ctx context.Context, expiration time.Duration) error {
|
||||||
@ -63,9 +63,30 @@ func (s *SessionState) PeekLock(ctx context.Context) (bool, error) {
|
|||||||
return s.Lock.Peek(ctx)
|
return s.Lock.Peek(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CreatedAtNow sets a SessionState's CreatedAt to now
|
||||||
|
func (s *SessionState) CreatedAtNow() {
|
||||||
|
now := s.Clock.Now()
|
||||||
|
s.CreatedAt = &now
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetExpiresOn sets an expiration
|
||||||
|
func (s *SessionState) SetExpiresOn(exp time.Time) {
|
||||||
|
s.ExpiresOn = &exp
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExpiresIn sets an expiration a certain duration from CreatedAt.
|
||||||
|
// CreatedAt will be set to time.Now if it is unset.
|
||||||
|
func (s *SessionState) ExpiresIn(d time.Duration) {
|
||||||
|
if s.CreatedAt == nil {
|
||||||
|
s.CreatedAtNow()
|
||||||
|
}
|
||||||
|
exp := s.CreatedAt.Add(d)
|
||||||
|
s.ExpiresOn = &exp
|
||||||
|
}
|
||||||
|
|
||||||
// IsExpired checks whether the session has expired
|
// IsExpired checks whether the session has expired
|
||||||
func (s *SessionState) IsExpired() bool {
|
func (s *SessionState) IsExpired() bool {
|
||||||
if s.ExpiresOn != nil && !s.ExpiresOn.IsZero() && s.ExpiresOn.Before(time.Now()) {
|
if s.ExpiresOn != nil && !s.ExpiresOn.IsZero() && s.ExpiresOn.Before(s.Clock.Now()) {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
return false
|
return false
|
||||||
@ -74,7 +95,7 @@ func (s *SessionState) IsExpired() bool {
|
|||||||
// Age returns the age of a session
|
// Age returns the age of a session
|
||||||
func (s *SessionState) Age() time.Duration {
|
func (s *SessionState) Age() time.Duration {
|
||||||
if s.CreatedAt != nil && !s.CreatedAt.IsZero() {
|
if s.CreatedAt != nil && !s.CreatedAt.IsZero() {
|
||||||
return time.Now().Truncate(time.Second).Sub(*s.CreatedAt)
|
return s.Clock.Now().Truncate(time.Second).Sub(*s.CreatedAt)
|
||||||
}
|
}
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
@ -177,11 +198,6 @@ func DecodeSessionState(data []byte, c encryption.Cipher, compressed bool) (*Ses
|
|||||||
return nil, fmt.Errorf("error unmarshalling data to session state: %w", err)
|
return nil, fmt.Errorf("error unmarshalling data to session state: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = ss.validate()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return &ss, nil
|
return &ss, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -235,35 +251,3 @@ func lz4Decompress(compressed []byte) ([]byte, error) {
|
|||||||
|
|
||||||
return payload, nil
|
return payload, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// validate ensures the decoded session is non-empty and contains valid data
|
|
||||||
//
|
|
||||||
// Non-empty check is needed due to ensure the non-authenticated AES-CFB
|
|
||||||
// decryption doesn't result in garbage data that collides with a valid
|
|
||||||
// MessagePack header bytes (which MessagePack will unmarshal to an empty
|
|
||||||
// default SessionState). <1% chance, but observed with random test data.
|
|
||||||
//
|
|
||||||
// UTF-8 check ensures the strings are valid and not raw bytes overloaded
|
|
||||||
// into Latin-1 encoding. The occurs when legacy unencrypted fields are
|
|
||||||
// decrypted with AES-CFB which results in random bytes.
|
|
||||||
func (s *SessionState) validate() error {
|
|
||||||
for _, field := range []string{
|
|
||||||
s.User,
|
|
||||||
s.Email,
|
|
||||||
s.PreferredUsername,
|
|
||||||
s.AccessToken,
|
|
||||||
s.IDToken,
|
|
||||||
s.RefreshToken,
|
|
||||||
} {
|
|
||||||
if !utf8.ValidString(field) {
|
|
||||||
return errors.New("invalid non-UTF8 field in session")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
empty := new(SessionState)
|
|
||||||
if reflect.DeepEqual(*s, *empty) {
|
|
||||||
return errors.New("invalid empty session unmarshalled")
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
@ -16,6 +16,30 @@ func timePtr(t time.Time) *time.Time {
|
|||||||
return &t
|
return &t
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCreatedAtNow(t *testing.T) {
|
||||||
|
g := NewWithT(t)
|
||||||
|
ss := &SessionState{}
|
||||||
|
|
||||||
|
now := time.Unix(1234567890, 0)
|
||||||
|
ss.Clock.Set(now)
|
||||||
|
|
||||||
|
ss.CreatedAtNow()
|
||||||
|
g.Expect(*ss.CreatedAt).To(Equal(now))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExpiresIn(t *testing.T) {
|
||||||
|
g := NewWithT(t)
|
||||||
|
ss := &SessionState{}
|
||||||
|
|
||||||
|
now := time.Unix(1234567890, 0)
|
||||||
|
ss.Clock.Set(now)
|
||||||
|
|
||||||
|
ttl := time.Duration(743) * time.Second
|
||||||
|
ss.ExpiresIn(ttl)
|
||||||
|
|
||||||
|
g.Expect(*ss.ExpiresOn).To(Equal(ss.CreatedAt.Add(ttl)))
|
||||||
|
}
|
||||||
|
|
||||||
func TestString(t *testing.T) {
|
func TestString(t *testing.T) {
|
||||||
g := NewWithT(t)
|
g := NewWithT(t)
|
||||||
created, err := time.Parse(time.RFC3339, "2000-01-01T00:00:00Z")
|
created, err := time.Parse(time.RFC3339, "2000-01-01T00:00:00Z")
|
||||||
|
@ -63,13 +63,10 @@ func Reset() *clockapi.Mock {
|
|||||||
// package.
|
// package.
|
||||||
type Clock struct {
|
type Clock struct {
|
||||||
mock *clockapi.Mock
|
mock *clockapi.Mock
|
||||||
sync.Mutex
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set sets the Clock to a clock.Mock at the given time.Time
|
// Set sets the Clock to a clock.Mock at the given time.Time
|
||||||
func (c *Clock) Set(t time.Time) {
|
func (c *Clock) Set(t time.Time) {
|
||||||
c.Lock()
|
|
||||||
defer c.Unlock()
|
|
||||||
if c.mock == nil {
|
if c.mock == nil {
|
||||||
c.mock = clockapi.NewMock()
|
c.mock = clockapi.NewMock()
|
||||||
}
|
}
|
||||||
@ -79,8 +76,6 @@ func (c *Clock) Set(t time.Time) {
|
|||||||
// Add moves clock forward time.Duration if it is mocked. It will error
|
// Add moves clock forward time.Duration if it is mocked. It will error
|
||||||
// if the clock is not mocked.
|
// if the clock is not mocked.
|
||||||
func (c *Clock) Add(d time.Duration) error {
|
func (c *Clock) Add(d time.Duration) error {
|
||||||
c.Lock()
|
|
||||||
defer c.Unlock()
|
|
||||||
if c.mock == nil {
|
if c.mock == nil {
|
||||||
return errors.New("clock not mocked")
|
return errors.New("clock not mocked")
|
||||||
}
|
}
|
||||||
@ -91,8 +86,6 @@ func (c *Clock) Add(d time.Duration) error {
|
|||||||
// Reset removes local clock.Mock. Returns any existing Mock if set in case
|
// Reset removes local clock.Mock. Returns any existing Mock if set in case
|
||||||
// lingering time operations are attached to it.
|
// lingering time operations are attached to it.
|
||||||
func (c *Clock) Reset() *clockapi.Mock {
|
func (c *Clock) Reset() *clockapi.Mock {
|
||||||
c.Lock()
|
|
||||||
defer c.Unlock()
|
|
||||||
existing := c.mock
|
existing := c.mock
|
||||||
c.mock = nil
|
c.mock = nil
|
||||||
return existing
|
return existing
|
||||||
|
@ -11,25 +11,26 @@ import (
|
|||||||
middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware"
|
middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware"
|
||||||
sessionsapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
|
sessionsapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
|
||||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
|
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
|
||||||
|
"github.com/oauth2-proxy/oauth2-proxy/v7/providers"
|
||||||
)
|
)
|
||||||
|
|
||||||
// StoredSessionLoaderOptions cotnains all of the requirements to construct
|
// StoredSessionLoaderOptions contains all of the requirements to construct
|
||||||
// a stored session loader.
|
// a stored session loader.
|
||||||
// All options must be provided.
|
// All options must be provided.
|
||||||
type StoredSessionLoaderOptions struct {
|
type StoredSessionLoaderOptions struct {
|
||||||
// Session storage basckend
|
// Session storage backend
|
||||||
SessionStore sessionsapi.SessionStore
|
SessionStore sessionsapi.SessionStore
|
||||||
|
|
||||||
// How often should sessions be refreshed
|
// How often should sessions be refreshed
|
||||||
RefreshPeriod time.Duration
|
RefreshPeriod time.Duration
|
||||||
|
|
||||||
// Provider based sesssion refreshing
|
// Provider based session refreshing
|
||||||
RefreshSessionIfNeeded func(context.Context, *sessionsapi.SessionState) (bool, error)
|
RefreshSession func(context.Context, *sessionsapi.SessionState) (bool, error)
|
||||||
|
|
||||||
// Provider based session validation.
|
// Provider based session validation.
|
||||||
// If the sesssion is older than `RefreshPeriod` but the provider doesn't
|
// If the sesssion is older than `RefreshPeriod` but the provider doesn't
|
||||||
// refresh it, we must re-validate using this validation.
|
// refresh it, we must re-validate using this validation.
|
||||||
ValidateSessionState func(context.Context, *sessionsapi.SessionState) bool
|
ValidateSession func(context.Context, *sessionsapi.SessionState) bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewStoredSessionLoader creates a new storedSessionLoader which loads
|
// NewStoredSessionLoader creates a new storedSessionLoader which loads
|
||||||
@ -38,10 +39,10 @@ type StoredSessionLoaderOptions struct {
|
|||||||
// If a session was loader by a previous handler, it will not be replaced.
|
// If a session was loader by a previous handler, it will not be replaced.
|
||||||
func NewStoredSessionLoader(opts *StoredSessionLoaderOptions) alice.Constructor {
|
func NewStoredSessionLoader(opts *StoredSessionLoaderOptions) alice.Constructor {
|
||||||
ss := &storedSessionLoader{
|
ss := &storedSessionLoader{
|
||||||
store: opts.SessionStore,
|
store: opts.SessionStore,
|
||||||
refreshPeriod: opts.RefreshPeriod,
|
refreshPeriod: opts.RefreshPeriod,
|
||||||
refreshSessionWithProviderIfNeeded: opts.RefreshSessionIfNeeded,
|
sessionRefresher: opts.RefreshSession,
|
||||||
validateSessionState: opts.ValidateSessionState,
|
sessionValidator: opts.ValidateSession,
|
||||||
}
|
}
|
||||||
return ss.loadSession
|
return ss.loadSession
|
||||||
}
|
}
|
||||||
@ -49,10 +50,10 @@ func NewStoredSessionLoader(opts *StoredSessionLoaderOptions) alice.Constructor
|
|||||||
// storedSessionLoader is responsible for loading sessions from cookie
|
// storedSessionLoader is responsible for loading sessions from cookie
|
||||||
// identified sessions in the session store.
|
// identified sessions in the session store.
|
||||||
type storedSessionLoader struct {
|
type storedSessionLoader struct {
|
||||||
store sessionsapi.SessionStore
|
store sessionsapi.SessionStore
|
||||||
refreshPeriod time.Duration
|
refreshPeriod time.Duration
|
||||||
refreshSessionWithProviderIfNeeded func(context.Context, *sessionsapi.SessionState) (bool, error)
|
sessionRefresher func(context.Context, *sessionsapi.SessionState) (bool, error)
|
||||||
validateSessionState func(context.Context, *sessionsapi.SessionState) bool
|
sessionValidator func(context.Context, *sessionsapi.SessionState) bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// loadSession attempts to load a session as identified by the request cookies.
|
// loadSession attempts to load a session as identified by the request cookies.
|
||||||
@ -108,49 +109,59 @@ func (s *storedSessionLoader) getValidatedSession(rw http.ResponseWriter, req *h
|
|||||||
|
|
||||||
// refreshSessionIfNeeded will attempt to refresh a session if the session
|
// refreshSessionIfNeeded will attempt to refresh a session if the session
|
||||||
// is older than the refresh period.
|
// is older than the refresh period.
|
||||||
// It is assumed that if the provider refreshes the session, the session is now
|
// Success or fail, we will then validate the session.
|
||||||
// valid.
|
|
||||||
// If the session requires refreshing but the provider does not refresh it,
|
|
||||||
// we must validate the session to ensure that the returned session is still
|
|
||||||
// valid.
|
|
||||||
func (s *storedSessionLoader) refreshSessionIfNeeded(rw http.ResponseWriter, req *http.Request, session *sessionsapi.SessionState) error {
|
func (s *storedSessionLoader) refreshSessionIfNeeded(rw http.ResponseWriter, req *http.Request, session *sessionsapi.SessionState) error {
|
||||||
if s.refreshPeriod <= time.Duration(0) || session.Age() < s.refreshPeriod {
|
if s.refreshPeriod <= time.Duration(0) || session.Age() < s.refreshPeriod {
|
||||||
// Refresh is disabled or the session is not old enough, do nothing
|
// Refresh is disabled or the session is not old enough, do nothing
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.Printf("Refreshing %s old session cookie for %s (refresh after %s)", session.Age(), session, s.refreshPeriod)
|
logger.Printf("Refreshing session - User: %s; SessionAge: %s", session.User, session.Age())
|
||||||
refreshed, err := s.refreshSessionWithProvider(rw, req, session)
|
err := s.refreshSession(rw, req, session)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
// If a preemptive refresh fails, we still keep the session
|
||||||
|
// if validateSession succeeds.
|
||||||
|
logger.Errorf("Unable to refresh session: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !refreshed {
|
// Validate all sessions after any Redeem/Refresh operation (fail or success)
|
||||||
// Session wasn't refreshed, so make sure it's still valid
|
return s.validateSession(req.Context(), session)
|
||||||
return s.validateSession(req.Context(), session)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// refreshSessionWithProvider attempts to refresh the sessinon with the provider
|
// refreshSession attempts to refresh the session with the provider
|
||||||
// and will save the session if it was updated.
|
// and will save the session if it was updated.
|
||||||
func (s *storedSessionLoader) refreshSessionWithProvider(rw http.ResponseWriter, req *http.Request, session *sessionsapi.SessionState) (bool, error) {
|
func (s *storedSessionLoader) refreshSession(rw http.ResponseWriter, req *http.Request, session *sessionsapi.SessionState) error {
|
||||||
refreshed, err := s.refreshSessionWithProviderIfNeeded(req.Context(), session)
|
refreshed, err := s.sessionRefresher(req.Context(), session)
|
||||||
if err != nil {
|
if err != nil && !errors.Is(err, providers.ErrNotImplemented) {
|
||||||
return false, fmt.Errorf("error refreshing access token: %v", err)
|
return fmt.Errorf("error refreshing tokens: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !refreshed {
|
// HACK:
|
||||||
return false, nil
|
// Providers that don't implement `RefreshSession` use the default
|
||||||
|
// implementation which returns `ErrNotImplemented`.
|
||||||
|
// Pretend it refreshed to reset the refresh timer so that `ValidateSession`
|
||||||
|
// isn't triggered every subsequent request and is only called once during
|
||||||
|
// this request.
|
||||||
|
if errors.Is(err, providers.ErrNotImplemented) {
|
||||||
|
refreshed = true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Session not refreshed, nothing to persist.
|
||||||
|
if !refreshed {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// If we refreshed, update the `CreatedAt` time to reset the refresh timer
|
||||||
|
// (In case underlying provider implementations forget)
|
||||||
|
session.CreatedAtNow()
|
||||||
|
|
||||||
// Because the session was refreshed, make sure to save it
|
// Because the session was refreshed, make sure to save it
|
||||||
err = s.store.Save(rw, req, session)
|
err = s.store.Save(rw, req, session)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.PrintAuthf(session.Email, req, logger.AuthError, "error saving session: %v", err)
|
logger.PrintAuthf(session.Email, req, logger.AuthError, "error saving session: %v", err)
|
||||||
return false, fmt.Errorf("error saving session: %v", err)
|
return fmt.Errorf("error saving session: %v", err)
|
||||||
}
|
}
|
||||||
return true, nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// validateSession checks whether the session has expired and performs
|
// validateSession checks whether the session has expired and performs
|
||||||
@ -161,7 +172,7 @@ func (s *storedSessionLoader) validateSession(ctx context.Context, session *sess
|
|||||||
return errors.New("session is expired")
|
return errors.New("session is expired")
|
||||||
}
|
}
|
||||||
|
|
||||||
if !s.validateSessionState(ctx, session) {
|
if !s.sessionValidator(ctx, session) {
|
||||||
return errors.New("session is invalid")
|
return errors.New("session is invalid")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -10,6 +10,8 @@ import (
|
|||||||
|
|
||||||
middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware"
|
middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware"
|
||||||
sessionsapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
|
sessionsapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
|
||||||
|
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/clock"
|
||||||
|
"github.com/oauth2-proxy/oauth2-proxy/v7/providers"
|
||||||
. "github.com/onsi/ginkgo"
|
. "github.com/onsi/ginkgo"
|
||||||
. "github.com/onsi/ginkgo/extensions/table"
|
. "github.com/onsi/ginkgo/extensions/table"
|
||||||
. "github.com/onsi/gomega"
|
. "github.com/onsi/gomega"
|
||||||
@ -17,15 +19,17 @@ import (
|
|||||||
|
|
||||||
var _ = Describe("Stored Session Suite", func() {
|
var _ = Describe("Stored Session Suite", func() {
|
||||||
const (
|
const (
|
||||||
refresh = "Refresh"
|
refresh = "Refresh"
|
||||||
noRefresh = "NoRefresh"
|
noRefresh = "NoRefresh"
|
||||||
|
notImplemented = "NotImplemented"
|
||||||
)
|
)
|
||||||
|
|
||||||
var ctx = context.Background()
|
var ctx = context.Background()
|
||||||
|
|
||||||
Context("StoredSessionLoader", func() {
|
Context("StoredSessionLoader", func() {
|
||||||
createdPast := time.Now().Add(-5 * time.Minute)
|
now := time.Now()
|
||||||
createdFuture := time.Now().Add(5 * time.Minute)
|
createdPast := now.Add(-5 * time.Minute)
|
||||||
|
createdFuture := now.Add(5 * time.Minute)
|
||||||
|
|
||||||
var defaultRefreshFunc = func(_ context.Context, ss *sessionsapi.SessionState) (bool, error) {
|
var defaultRefreshFunc = func(_ context.Context, ss *sessionsapi.SessionState) (bool, error) {
|
||||||
switch ss.RefreshToken {
|
switch ss.RefreshToken {
|
||||||
@ -85,6 +89,14 @@ var _ = Describe("Stored Session Suite", func() {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
BeforeEach(func() {
|
||||||
|
clock.Set(now)
|
||||||
|
})
|
||||||
|
|
||||||
|
AfterEach(func() {
|
||||||
|
clock.Reset()
|
||||||
|
})
|
||||||
|
|
||||||
type storedSessionLoaderTableInput struct {
|
type storedSessionLoaderTableInput struct {
|
||||||
requestHeaders http.Header
|
requestHeaders http.Header
|
||||||
existingSession *sessionsapi.SessionState
|
existingSession *sessionsapi.SessionState
|
||||||
@ -109,10 +121,10 @@ var _ = Describe("Stored Session Suite", func() {
|
|||||||
rw := httptest.NewRecorder()
|
rw := httptest.NewRecorder()
|
||||||
|
|
||||||
opts := &StoredSessionLoaderOptions{
|
opts := &StoredSessionLoaderOptions{
|
||||||
SessionStore: in.store,
|
SessionStore: in.store,
|
||||||
RefreshPeriod: in.refreshPeriod,
|
RefreshPeriod: in.refreshPeriod,
|
||||||
RefreshSessionIfNeeded: in.refreshSession,
|
RefreshSession: in.refreshSession,
|
||||||
ValidateSessionState: in.validateSession,
|
ValidateSession: in.validateSession,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create the handler with a next handler that will capture the session
|
// Create the handler with a next handler that will capture the session
|
||||||
@ -208,6 +220,21 @@ var _ = Describe("Stored Session Suite", func() {
|
|||||||
existingSession: nil,
|
existingSession: nil,
|
||||||
expectedSession: &sessionsapi.SessionState{
|
expectedSession: &sessionsapi.SessionState{
|
||||||
RefreshToken: "Refreshed",
|
RefreshToken: "Refreshed",
|
||||||
|
CreatedAt: &now,
|
||||||
|
ExpiresOn: &createdFuture,
|
||||||
|
},
|
||||||
|
store: defaultSessionStore,
|
||||||
|
refreshPeriod: 1 * time.Minute,
|
||||||
|
refreshSession: defaultRefreshFunc,
|
||||||
|
validateSession: defaultValidateFunc,
|
||||||
|
}),
|
||||||
|
Entry("when the provider refresh fails but validation succeeds", storedSessionLoaderTableInput{
|
||||||
|
requestHeaders: http.Header{
|
||||||
|
"Cookie": []string{"_oauth2_proxy=RefreshError"},
|
||||||
|
},
|
||||||
|
existingSession: nil,
|
||||||
|
expectedSession: &sessionsapi.SessionState{
|
||||||
|
RefreshToken: "RefreshError",
|
||||||
CreatedAt: &createdPast,
|
CreatedAt: &createdPast,
|
||||||
ExpiresOn: &createdFuture,
|
ExpiresOn: &createdFuture,
|
||||||
},
|
},
|
||||||
@ -216,7 +243,7 @@ var _ = Describe("Stored Session Suite", func() {
|
|||||||
refreshSession: defaultRefreshFunc,
|
refreshSession: defaultRefreshFunc,
|
||||||
validateSession: defaultValidateFunc,
|
validateSession: defaultValidateFunc,
|
||||||
}),
|
}),
|
||||||
Entry("when the provider refresh fails", storedSessionLoaderTableInput{
|
Entry("when the provider refresh fails and validation fails", storedSessionLoaderTableInput{
|
||||||
requestHeaders: http.Header{
|
requestHeaders: http.Header{
|
||||||
"Cookie": []string{"_oauth2_proxy=RefreshError"},
|
"Cookie": []string{"_oauth2_proxy=RefreshError"},
|
||||||
},
|
},
|
||||||
@ -225,7 +252,7 @@ var _ = Describe("Stored Session Suite", func() {
|
|||||||
store: defaultSessionStore,
|
store: defaultSessionStore,
|
||||||
refreshPeriod: 1 * time.Minute,
|
refreshPeriod: 1 * time.Minute,
|
||||||
refreshSession: defaultRefreshFunc,
|
refreshSession: defaultRefreshFunc,
|
||||||
validateSession: defaultValidateFunc,
|
validateSession: func(context.Context, *sessionsapi.SessionState) bool { return false },
|
||||||
}),
|
}),
|
||||||
Entry("when the session is not refreshed and is no longer valid", storedSessionLoaderTableInput{
|
Entry("when the session is not refreshed and is no longer valid", storedSessionLoaderTableInput{
|
||||||
requestHeaders: http.Header{
|
requestHeaders: http.Header{
|
||||||
@ -261,18 +288,20 @@ var _ = Describe("Stored Session Suite", func() {
|
|||||||
s := &storedSessionLoader{
|
s := &storedSessionLoader{
|
||||||
refreshPeriod: in.refreshPeriod,
|
refreshPeriod: in.refreshPeriod,
|
||||||
store: &fakeSessionStore{},
|
store: &fakeSessionStore{},
|
||||||
refreshSessionWithProviderIfNeeded: func(_ context.Context, ss *sessionsapi.SessionState) (bool, error) {
|
sessionRefresher: func(_ context.Context, ss *sessionsapi.SessionState) (bool, error) {
|
||||||
refreshed = true
|
refreshed = true
|
||||||
switch ss.RefreshToken {
|
switch ss.RefreshToken {
|
||||||
case refresh:
|
case refresh:
|
||||||
return true, nil
|
return true, nil
|
||||||
case noRefresh:
|
case noRefresh:
|
||||||
return false, nil
|
return false, nil
|
||||||
|
case notImplemented:
|
||||||
|
return false, providers.ErrNotImplemented
|
||||||
default:
|
default:
|
||||||
return false, errors.New("error refreshing session")
|
return false, errors.New("error refreshing session")
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
validateSessionState: func(_ context.Context, ss *sessionsapi.SessionState) bool {
|
sessionValidator: func(_ context.Context, ss *sessionsapi.SessionState) bool {
|
||||||
validated = true
|
validated = true
|
||||||
return ss.AccessToken != "Invalid"
|
return ss.AccessToken != "Invalid"
|
||||||
},
|
},
|
||||||
@ -326,7 +355,7 @@ var _ = Describe("Stored Session Suite", func() {
|
|||||||
},
|
},
|
||||||
expectedErr: nil,
|
expectedErr: nil,
|
||||||
expectRefreshed: true,
|
expectRefreshed: true,
|
||||||
expectValidated: false,
|
expectValidated: true,
|
||||||
}),
|
}),
|
||||||
Entry("when the session is not refreshed by the provider", refreshSessionIfNeededTableInput{
|
Entry("when the session is not refreshed by the provider", refreshSessionIfNeededTableInput{
|
||||||
refreshPeriod: 1 * time.Minute,
|
refreshPeriod: 1 * time.Minute,
|
||||||
@ -339,15 +368,25 @@ var _ = Describe("Stored Session Suite", func() {
|
|||||||
expectRefreshed: true,
|
expectRefreshed: true,
|
||||||
expectValidated: true,
|
expectValidated: true,
|
||||||
}),
|
}),
|
||||||
Entry("when the provider refresh fails", refreshSessionIfNeededTableInput{
|
Entry("when the provider doesn't implement refresh but validation succeeds", refreshSessionIfNeededTableInput{
|
||||||
|
refreshPeriod: 1 * time.Minute,
|
||||||
|
session: &sessionsapi.SessionState{
|
||||||
|
RefreshToken: notImplemented,
|
||||||
|
CreatedAt: &createdPast,
|
||||||
|
},
|
||||||
|
expectedErr: nil,
|
||||||
|
expectRefreshed: true,
|
||||||
|
expectValidated: true,
|
||||||
|
}),
|
||||||
|
Entry("when the provider refresh fails but validation succeeds", refreshSessionIfNeededTableInput{
|
||||||
refreshPeriod: 1 * time.Minute,
|
refreshPeriod: 1 * time.Minute,
|
||||||
session: &sessionsapi.SessionState{
|
session: &sessionsapi.SessionState{
|
||||||
RefreshToken: "RefreshError",
|
RefreshToken: "RefreshError",
|
||||||
CreatedAt: &createdPast,
|
CreatedAt: &createdPast,
|
||||||
},
|
},
|
||||||
expectedErr: errors.New("error refreshing access token: error refreshing session"),
|
expectedErr: nil,
|
||||||
expectRefreshed: true,
|
expectRefreshed: true,
|
||||||
expectValidated: false,
|
expectValidated: true,
|
||||||
}),
|
}),
|
||||||
Entry("when the session is not refreshed by the provider and validation fails", refreshSessionIfNeededTableInput{
|
Entry("when the session is not refreshed by the provider and validation fails", refreshSessionIfNeededTableInput{
|
||||||
refreshPeriod: 1 * time.Minute,
|
refreshPeriod: 1 * time.Minute,
|
||||||
@ -364,12 +403,11 @@ var _ = Describe("Stored Session Suite", func() {
|
|||||||
)
|
)
|
||||||
})
|
})
|
||||||
|
|
||||||
Context("refreshSessionWithProvider", func() {
|
Context("refreshSession", func() {
|
||||||
type refreshSessionWithProviderTableInput struct {
|
type refreshSessionWithProviderTableInput struct {
|
||||||
session *sessionsapi.SessionState
|
session *sessionsapi.SessionState
|
||||||
expectedErr error
|
expectedErr error
|
||||||
expectRefreshed bool
|
expectSaved bool
|
||||||
expectSaved bool
|
|
||||||
}
|
}
|
||||||
|
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
@ -388,12 +426,14 @@ var _ = Describe("Stored Session Suite", func() {
|
|||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
refreshSessionWithProviderIfNeeded: func(_ context.Context, ss *sessionsapi.SessionState) (bool, error) {
|
sessionRefresher: func(_ context.Context, ss *sessionsapi.SessionState) (bool, error) {
|
||||||
switch ss.RefreshToken {
|
switch ss.RefreshToken {
|
||||||
case refresh:
|
case refresh:
|
||||||
return true, nil
|
return true, nil
|
||||||
case noRefresh:
|
case noRefresh:
|
||||||
return false, nil
|
return false, nil
|
||||||
|
case notImplemented:
|
||||||
|
return false, providers.ErrNotImplemented
|
||||||
default:
|
default:
|
||||||
return false, errors.New("error refreshing session")
|
return false, errors.New("error refreshing session")
|
||||||
}
|
}
|
||||||
@ -402,30 +442,34 @@ var _ = Describe("Stored Session Suite", func() {
|
|||||||
|
|
||||||
req := httptest.NewRequest("", "/", nil)
|
req := httptest.NewRequest("", "/", nil)
|
||||||
req = middlewareapi.AddRequestScope(req, &middlewareapi.RequestScope{})
|
req = middlewareapi.AddRequestScope(req, &middlewareapi.RequestScope{})
|
||||||
refreshed, err := s.refreshSessionWithProvider(nil, req, in.session)
|
err := s.refreshSession(nil, req, in.session)
|
||||||
if in.expectedErr != nil {
|
if in.expectedErr != nil {
|
||||||
Expect(err).To(MatchError(in.expectedErr))
|
Expect(err).To(MatchError(in.expectedErr))
|
||||||
} else {
|
} else {
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
}
|
}
|
||||||
Expect(refreshed).To(Equal(in.expectRefreshed))
|
|
||||||
Expect(saved).To(Equal(in.expectSaved))
|
Expect(saved).To(Equal(in.expectSaved))
|
||||||
},
|
},
|
||||||
Entry("when the provider does not refresh the session", refreshSessionWithProviderTableInput{
|
Entry("when the provider does not refresh the session", refreshSessionWithProviderTableInput{
|
||||||
session: &sessionsapi.SessionState{
|
session: &sessionsapi.SessionState{
|
||||||
RefreshToken: noRefresh,
|
RefreshToken: noRefresh,
|
||||||
},
|
},
|
||||||
expectedErr: nil,
|
expectedErr: nil,
|
||||||
expectRefreshed: false,
|
expectSaved: false,
|
||||||
expectSaved: false,
|
|
||||||
}),
|
}),
|
||||||
Entry("when the provider refreshes the session", refreshSessionWithProviderTableInput{
|
Entry("when the provider refreshes the session", refreshSessionWithProviderTableInput{
|
||||||
session: &sessionsapi.SessionState{
|
session: &sessionsapi.SessionState{
|
||||||
RefreshToken: refresh,
|
RefreshToken: refresh,
|
||||||
},
|
},
|
||||||
expectedErr: nil,
|
expectedErr: nil,
|
||||||
expectRefreshed: true,
|
expectSaved: true,
|
||||||
expectSaved: true,
|
}),
|
||||||
|
Entry("when the provider doesn't implement refresh", refreshSessionWithProviderTableInput{
|
||||||
|
session: &sessionsapi.SessionState{
|
||||||
|
RefreshToken: notImplemented,
|
||||||
|
},
|
||||||
|
expectedErr: nil,
|
||||||
|
expectSaved: true,
|
||||||
}),
|
}),
|
||||||
Entry("when the provider returns an error", refreshSessionWithProviderTableInput{
|
Entry("when the provider returns an error", refreshSessionWithProviderTableInput{
|
||||||
session: &sessionsapi.SessionState{
|
session: &sessionsapi.SessionState{
|
||||||
@ -433,18 +477,16 @@ var _ = Describe("Stored Session Suite", func() {
|
|||||||
CreatedAt: &now,
|
CreatedAt: &now,
|
||||||
ExpiresOn: &now,
|
ExpiresOn: &now,
|
||||||
},
|
},
|
||||||
expectedErr: errors.New("error refreshing access token: error refreshing session"),
|
expectedErr: errors.New("error refreshing tokens: error refreshing session"),
|
||||||
expectRefreshed: false,
|
expectSaved: false,
|
||||||
expectSaved: false,
|
|
||||||
}),
|
}),
|
||||||
Entry("when the saving the session returns an error", refreshSessionWithProviderTableInput{
|
Entry("when the saving the session returns an error", refreshSessionWithProviderTableInput{
|
||||||
session: &sessionsapi.SessionState{
|
session: &sessionsapi.SessionState{
|
||||||
RefreshToken: refresh,
|
RefreshToken: refresh,
|
||||||
AccessToken: "NoSave",
|
AccessToken: "NoSave",
|
||||||
},
|
},
|
||||||
expectedErr: errors.New("error saving session: unable to save session"),
|
expectedErr: errors.New("error saving session: unable to save session"),
|
||||||
expectRefreshed: false,
|
expectSaved: true,
|
||||||
expectSaved: true,
|
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
})
|
})
|
||||||
@ -454,7 +496,7 @@ var _ = Describe("Stored Session Suite", func() {
|
|||||||
|
|
||||||
BeforeEach(func() {
|
BeforeEach(func() {
|
||||||
s = &storedSessionLoader{
|
s = &storedSessionLoader{
|
||||||
validateSessionState: func(_ context.Context, ss *sessionsapi.SessionState) bool {
|
sessionValidator: func(_ context.Context, ss *sessionsapi.SessionState) bool {
|
||||||
return ss.AccessToken == "Valid"
|
return ss.AccessToken == "Valid"
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
@ -36,8 +36,7 @@ type SessionStore struct {
|
|||||||
// within Cookies set on the HTTP response writer
|
// within Cookies set on the HTTP response writer
|
||||||
func (s *SessionStore) Save(rw http.ResponseWriter, req *http.Request, ss *sessions.SessionState) error {
|
func (s *SessionStore) Save(rw http.ResponseWriter, req *http.Request, ss *sessions.SessionState) error {
|
||||||
if ss.CreatedAt == nil || ss.CreatedAt.IsZero() {
|
if ss.CreatedAt == nil || ss.CreatedAt.IsZero() {
|
||||||
now := time.Now()
|
ss.CreatedAtNow()
|
||||||
ss.CreatedAt = &now
|
|
||||||
}
|
}
|
||||||
value, err := s.cookieForSession(ss)
|
value, err := s.cookieForSession(ss)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -30,8 +30,7 @@ func NewManager(store Store, cookieOpts *options.Cookie) *Manager {
|
|||||||
// from the persistent data store.
|
// from the persistent data store.
|
||||||
func (m *Manager) Save(rw http.ResponseWriter, req *http.Request, s *sessions.SessionState) error {
|
func (m *Manager) Save(rw http.ResponseWriter, req *http.Request, s *sessions.SessionState) error {
|
||||||
if s.CreatedAt == nil || s.CreatedAt.IsZero() {
|
if s.CreatedAt == nil || s.CreatedAt.IsZero() {
|
||||||
now := time.Now()
|
s.CreatedAtNow()
|
||||||
s.CreatedAt = &now
|
|
||||||
}
|
}
|
||||||
|
|
||||||
tckt, err := decodeTicketFromRequest(req, m.Options)
|
tckt, err := decodeTicketFromRequest(req, m.Options)
|
||||||
|
@ -142,16 +142,13 @@ func (p *AzureProvider) Redeem(ctx context.Context, redirectURL, code string) (*
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
created := time.Now()
|
|
||||||
expires := time.Unix(jsonResponse.ExpiresOn, 0)
|
|
||||||
|
|
||||||
session := &sessions.SessionState{
|
session := &sessions.SessionState{
|
||||||
AccessToken: jsonResponse.AccessToken,
|
AccessToken: jsonResponse.AccessToken,
|
||||||
IDToken: jsonResponse.IDToken,
|
IDToken: jsonResponse.IDToken,
|
||||||
CreatedAt: &created,
|
|
||||||
ExpiresOn: &expires,
|
|
||||||
RefreshToken: jsonResponse.RefreshToken,
|
RefreshToken: jsonResponse.RefreshToken,
|
||||||
}
|
}
|
||||||
|
session.CreatedAtNow()
|
||||||
|
session.SetExpiresOn(time.Unix(jsonResponse.ExpiresOn, 0))
|
||||||
|
|
||||||
email, err := p.verifyTokenAndExtractEmail(ctx, session.IDToken)
|
email, err := p.verifyTokenAndExtractEmail(ctx, session.IDToken)
|
||||||
|
|
||||||
@ -239,28 +236,29 @@ func (p *AzureProvider) verifyTokenAndExtractEmail(ctx context.Context, token st
|
|||||||
return email, nil
|
return email, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// RefreshSessionIfNeeded checks if the session has expired and uses the
|
// RefreshSession uses the RefreshToken to fetch new Access and ID Tokens
|
||||||
// RefreshToken to fetch a new ID token if required
|
func (p *AzureProvider) RefreshSession(ctx context.Context, s *sessions.SessionState) (bool, error) {
|
||||||
func (p *AzureProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error) {
|
if s == nil || s.RefreshToken == "" {
|
||||||
if s == nil || s.ExpiresOn.After(time.Now()) || s.RefreshToken == "" {
|
|
||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
origExpiration := s.ExpiresOn
|
|
||||||
|
|
||||||
err := p.redeemRefreshToken(ctx, s)
|
err := p.redeemRefreshToken(ctx, s)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, fmt.Errorf("unable to redeem refresh token: %v", err)
|
return false, fmt.Errorf("unable to redeem refresh token: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.Printf("refreshed id token %s (expired on %s)\n", s, origExpiration)
|
|
||||||
return true, nil
|
return true, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *AzureProvider) redeemRefreshToken(ctx context.Context, s *sessions.SessionState) (err error) {
|
func (p *AzureProvider) redeemRefreshToken(ctx context.Context, s *sessions.SessionState) error {
|
||||||
|
clientSecret, err := p.GetClientSecret()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
params := url.Values{}
|
params := url.Values{}
|
||||||
params.Add("client_id", p.ClientID)
|
params.Add("client_id", p.ClientID)
|
||||||
params.Add("client_secret", p.ClientSecret)
|
params.Add("client_secret", clientSecret)
|
||||||
params.Add("refresh_token", s.RefreshToken)
|
params.Add("refresh_token", s.RefreshToken)
|
||||||
params.Add("grant_type", "refresh_token")
|
params.Add("grant_type", "refresh_token")
|
||||||
|
|
||||||
@ -278,18 +276,16 @@ func (p *AzureProvider) redeemRefreshToken(ctx context.Context, s *sessions.Sess
|
|||||||
SetHeader("Content-Type", "application/x-www-form-urlencoded").
|
SetHeader("Content-Type", "application/x-www-form-urlencoded").
|
||||||
Do().
|
Do().
|
||||||
UnmarshalInto(&jsonResponse)
|
UnmarshalInto(&jsonResponse)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
now := time.Now()
|
|
||||||
expires := time.Unix(jsonResponse.ExpiresOn, 0)
|
|
||||||
s.AccessToken = jsonResponse.AccessToken
|
s.AccessToken = jsonResponse.AccessToken
|
||||||
s.IDToken = jsonResponse.IDToken
|
s.IDToken = jsonResponse.IDToken
|
||||||
s.RefreshToken = jsonResponse.RefreshToken
|
s.RefreshToken = jsonResponse.RefreshToken
|
||||||
s.CreatedAt = &now
|
|
||||||
s.ExpiresOn = &expires
|
s.CreatedAtNow()
|
||||||
|
s.SetExpiresOn(time.Unix(jsonResponse.ExpiresOn, 0))
|
||||||
|
|
||||||
email, err := p.verifyTokenAndExtractEmail(ctx, s.IDToken)
|
email, err := p.verifyTokenAndExtractEmail(ctx, s.IDToken)
|
||||||
|
|
||||||
@ -312,7 +308,7 @@ func (p *AzureProvider) redeemRefreshToken(ctx context.Context, s *sessions.Sess
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func makeAzureHeader(accessToken string) http.Header {
|
func makeAzureHeader(accessToken string) http.Header {
|
||||||
|
@ -340,17 +340,7 @@ func TestAzureProviderProtectedResourceConfigured(t *testing.T) {
|
|||||||
assert.Contains(t, result, "resource="+url.QueryEscape("http://my.resource.test"))
|
assert.Contains(t, result, "resource="+url.QueryEscape("http://my.resource.test"))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAzureProviderNotRefreshWhenNotExpired(t *testing.T) {
|
func TestAzureProviderRefresh(t *testing.T) {
|
||||||
p := testAzureProvider("")
|
|
||||||
|
|
||||||
expires := time.Now().Add(time.Duration(1) * time.Hour)
|
|
||||||
session := &sessions.SessionState{AccessToken: "some_access_token", RefreshToken: "some_refresh_token", IDToken: "some_id_token", ExpiresOn: &expires}
|
|
||||||
refreshNeeded, err := p.RefreshSessionIfNeeded(context.Background(), session)
|
|
||||||
assert.Equal(t, nil, err)
|
|
||||||
assert.False(t, refreshNeeded)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAzureProviderRefreshWhenExpired(t *testing.T) {
|
|
||||||
email := "foo@example.com"
|
email := "foo@example.com"
|
||||||
idToken := idTokenClaims{Email: email}
|
idToken := idTokenClaims{Email: email}
|
||||||
idTokenString, err := newSignedTestIDToken(idToken)
|
idTokenString, err := newSignedTestIDToken(idToken)
|
||||||
@ -373,9 +363,10 @@ func TestAzureProviderRefreshWhenExpired(t *testing.T) {
|
|||||||
|
|
||||||
expires := time.Now().Add(time.Duration(-1) * time.Hour)
|
expires := time.Now().Add(time.Duration(-1) * time.Hour)
|
||||||
session := &sessions.SessionState{AccessToken: "some_access_token", RefreshToken: "some_refresh_token", IDToken: "some_id_token", ExpiresOn: &expires}
|
session := &sessions.SessionState{AccessToken: "some_access_token", RefreshToken: "some_refresh_token", IDToken: "some_id_token", ExpiresOn: &expires}
|
||||||
refreshNeeded, err := p.RefreshSessionIfNeeded(context.Background(), session)
|
|
||||||
|
refreshed, err := p.RefreshSession(context.Background(), session)
|
||||||
assert.Equal(t, nil, err)
|
assert.Equal(t, nil, err)
|
||||||
assert.True(t, refreshNeeded)
|
assert.True(t, refreshed)
|
||||||
assert.NotEqual(t, session, nil)
|
assert.NotEqual(t, session, nil)
|
||||||
assert.Equal(t, "new_some_access_token", session.AccessToken)
|
assert.Equal(t, "new_some_access_token", session.AccessToken)
|
||||||
assert.Equal(t, "new_some_refresh_token", session.RefreshToken)
|
assert.Equal(t, "new_some_refresh_token", session.RefreshToken)
|
||||||
|
@ -88,7 +88,7 @@ func (p *FacebookProvider) GetEmailAddress(ctx context.Context, s *sessions.Sess
|
|||||||
return r.Email, nil
|
return r.Email, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ValidateSessionState validates the AccessToken
|
// ValidateSession validates the AccessToken
|
||||||
func (p *FacebookProvider) ValidateSession(ctx context.Context, s *sessions.SessionState) bool {
|
func (p *FacebookProvider) ValidateSession(ctx context.Context, s *sessions.SessionState) bool {
|
||||||
return validateToken(ctx, p, s.AccessToken, makeOIDCHeader(s.AccessToken))
|
return validateToken(ctx, p, s.AccessToken, makeOIDCHeader(s.AccessToken))
|
||||||
}
|
}
|
||||||
|
@ -121,10 +121,9 @@ func (p *GitLabProvider) SetProjectScope() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// RefreshSessionIfNeeded checks if the session has expired and uses the
|
// RefreshSession uses the RefreshToken to fetch new Access and ID Tokens
|
||||||
// RefreshToken to fetch a new ID token if required
|
func (p *GitLabProvider) RefreshSession(ctx context.Context, s *sessions.SessionState) (bool, error) {
|
||||||
func (p *GitLabProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error) {
|
if s == nil || s.RefreshToken == "" {
|
||||||
if s == nil || (s.ExpiresOn != nil && s.ExpiresOn.After(time.Now())) || s.RefreshToken == "" {
|
|
||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -139,10 +138,10 @@ func (p *GitLabProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions
|
|||||||
return true, nil
|
return true, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *GitLabProvider) redeemRefreshToken(ctx context.Context, s *sessions.SessionState) (err error) {
|
func (p *GitLabProvider) redeemRefreshToken(ctx context.Context, s *sessions.SessionState) error {
|
||||||
clientSecret, err := p.GetClientSecret()
|
clientSecret, err := p.GetClientSecret()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
c := oauth2.Config{
|
c := oauth2.Config{
|
||||||
@ -164,13 +163,9 @@ func (p *GitLabProvider) redeemRefreshToken(ctx context.Context, s *sessions.Ses
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("unable to update session: %v", err)
|
return fmt.Errorf("unable to update session: %v", err)
|
||||||
}
|
}
|
||||||
s.AccessToken = newSession.AccessToken
|
*s = *newSession
|
||||||
s.IDToken = newSession.IDToken
|
|
||||||
s.RefreshToken = newSession.RefreshToken
|
return nil
|
||||||
s.CreatedAt = newSession.CreatedAt
|
|
||||||
s.ExpiresOn = newSession.ExpiresOn
|
|
||||||
s.Email = newSession.Email
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type gitlabUserInfo struct {
|
type gitlabUserInfo struct {
|
||||||
@ -264,14 +259,16 @@ func (p *GitLabProvider) createSession(ctx context.Context, token *oauth2.Token)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
created := time.Now()
|
ss := &sessions.SessionState{
|
||||||
return &sessions.SessionState{
|
|
||||||
AccessToken: token.AccessToken,
|
AccessToken: token.AccessToken,
|
||||||
IDToken: getIDToken(token),
|
IDToken: getIDToken(token),
|
||||||
RefreshToken: token.RefreshToken,
|
RefreshToken: token.RefreshToken,
|
||||||
CreatedAt: &created,
|
}
|
||||||
ExpiresOn: &idToken.Expiry,
|
|
||||||
}, nil
|
ss.CreatedAtNow()
|
||||||
|
ss.SetExpiresOn(idToken.Expiry)
|
||||||
|
|
||||||
|
return ss, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ValidateSession checks that the session's IDToken is still valid
|
// ValidateSession checks that the session's IDToken is still valid
|
||||||
|
@ -163,23 +163,22 @@ func (p *GoogleProvider) Redeem(ctx context.Context, redirectURL, code string) (
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
created := time.Now()
|
ss := &sessions.SessionState{
|
||||||
expires := time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second)
|
|
||||||
|
|
||||||
return &sessions.SessionState{
|
|
||||||
AccessToken: jsonResponse.AccessToken,
|
AccessToken: jsonResponse.AccessToken,
|
||||||
IDToken: jsonResponse.IDToken,
|
IDToken: jsonResponse.IDToken,
|
||||||
CreatedAt: &created,
|
|
||||||
ExpiresOn: &expires,
|
|
||||||
RefreshToken: jsonResponse.RefreshToken,
|
RefreshToken: jsonResponse.RefreshToken,
|
||||||
Email: c.Email,
|
Email: c.Email,
|
||||||
User: c.Subject,
|
User: c.Subject,
|
||||||
}, nil
|
}
|
||||||
|
ss.CreatedAtNow()
|
||||||
|
ss.ExpiresIn(time.Duration(jsonResponse.ExpiresIn) * time.Second)
|
||||||
|
|
||||||
|
return ss, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// EnrichSession checks the listed Google Groups configured and adds any
|
// EnrichSession checks the listed Google Groups configured and adds any
|
||||||
// that the user is a member of to session.Groups.
|
// that the user is a member of to session.Groups.
|
||||||
func (p *GoogleProvider) EnrichSession(ctx context.Context, s *sessions.SessionState) error {
|
func (p *GoogleProvider) EnrichSession(_ context.Context, s *sessions.SessionState) error {
|
||||||
// TODO (@NickMeves) - Move to pure EnrichSession logic and stop
|
// TODO (@NickMeves) - Move to pure EnrichSession logic and stop
|
||||||
// reusing legacy `groupValidator`.
|
// reusing legacy `groupValidator`.
|
||||||
//
|
//
|
||||||
@ -266,14 +265,13 @@ func userInGroup(service *admin.Service, group string, email string) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// RefreshSessionIfNeeded checks if the session has expired and uses the
|
// RefreshSession uses the RefreshToken to fetch new Access and ID Tokens
|
||||||
// RefreshToken to fetch a new ID token if required
|
func (p *GoogleProvider) RefreshSession(ctx context.Context, s *sessions.SessionState) (bool, error) {
|
||||||
func (p *GoogleProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error) {
|
if s == nil || s.RefreshToken == "" {
|
||||||
if s == nil || (s.ExpiresOn != nil && s.ExpiresOn.After(time.Now())) || s.RefreshToken == "" {
|
|
||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
newToken, newIDToken, duration, err := p.redeemRefreshToken(ctx, s.RefreshToken)
|
err := p.redeemRefreshToken(ctx, s)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
@ -286,26 +284,20 @@ func (p *GoogleProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions
|
|||||||
return false, fmt.Errorf("%s is no longer in the group(s)", s.Email)
|
return false, fmt.Errorf("%s is no longer in the group(s)", s.Email)
|
||||||
}
|
}
|
||||||
|
|
||||||
origExpiration := s.ExpiresOn
|
|
||||||
expires := time.Now().Add(duration).Truncate(time.Second)
|
|
||||||
s.AccessToken = newToken
|
|
||||||
s.IDToken = newIDToken
|
|
||||||
s.ExpiresOn = &expires
|
|
||||||
logger.Printf("refreshed access token %s (expired on %s)", s, origExpiration)
|
|
||||||
return true, nil
|
return true, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *GoogleProvider) redeemRefreshToken(ctx context.Context, refreshToken string) (token string, idToken string, expires time.Duration, err error) {
|
func (p *GoogleProvider) redeemRefreshToken(ctx context.Context, s *sessions.SessionState) error {
|
||||||
// https://developers.google.com/identity/protocols/OAuth2WebServer#refresh
|
// https://developers.google.com/identity/protocols/OAuth2WebServer#refresh
|
||||||
clientSecret, err := p.GetClientSecret()
|
clientSecret, err := p.GetClientSecret()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
params := url.Values{}
|
params := url.Values{}
|
||||||
params.Add("client_id", p.ClientID)
|
params.Add("client_id", p.ClientID)
|
||||||
params.Add("client_secret", clientSecret)
|
params.Add("client_secret", clientSecret)
|
||||||
params.Add("refresh_token", refreshToken)
|
params.Add("refresh_token", s.RefreshToken)
|
||||||
params.Add("grant_type", "refresh_token")
|
params.Add("grant_type", "refresh_token")
|
||||||
|
|
||||||
var data struct {
|
var data struct {
|
||||||
@ -322,11 +314,14 @@ func (p *GoogleProvider) redeemRefreshToken(ctx context.Context, refreshToken st
|
|||||||
Do().
|
Do().
|
||||||
UnmarshalInto(&data)
|
UnmarshalInto(&data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", "", 0, err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
token = data.AccessToken
|
s.AccessToken = data.AccessToken
|
||||||
idToken = data.IDToken
|
s.IDToken = data.IDToken
|
||||||
expires = time.Duration(data.ExpiresIn) * time.Second
|
|
||||||
return
|
s.CreatedAtNow()
|
||||||
|
s.ExpiresIn(time.Duration(data.ExpiresIn) * time.Second)
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -93,7 +93,7 @@ func (p *LinkedInProvider) GetEmailAddress(ctx context.Context, s *sessions.Sess
|
|||||||
return email, nil
|
return email, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ValidateSessionState validates the AccessToken
|
// ValidateSession validates the AccessToken
|
||||||
func (p *LinkedInProvider) ValidateSession(ctx context.Context, s *sessions.SessionState) bool {
|
func (p *LinkedInProvider) ValidateSession(ctx context.Context, s *sessions.SessionState) bool {
|
||||||
return validateToken(ctx, p, s.AccessToken, makeLinkedInHeader(s.AccessToken))
|
return validateToken(ctx, p, s.AccessToken, makeLinkedInHeader(s.AccessToken))
|
||||||
}
|
}
|
||||||
|
@ -159,7 +159,7 @@ func emailFromUserInfo(ctx context.Context, accessToken string, userInfoEndpoint
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Redeem exchanges the OAuth2 authentication token for an ID token
|
// Redeem exchanges the OAuth2 authentication token for an ID token
|
||||||
func (p *LoginGovProvider) Redeem(ctx context.Context, redirectURL, code string) (*sessions.SessionState, error) {
|
func (p *LoginGovProvider) Redeem(ctx context.Context, _, code string) (*sessions.SessionState, error) {
|
||||||
if code == "" {
|
if code == "" {
|
||||||
return nil, ErrMissingCode
|
return nil, ErrMissingCode
|
||||||
}
|
}
|
||||||
@ -214,17 +214,16 @@ func (p *LoginGovProvider) Redeem(ctx context.Context, redirectURL, code string)
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
created := time.Now()
|
session := &sessions.SessionState{
|
||||||
expires := time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second)
|
|
||||||
|
|
||||||
// Store the data that we found in the session state
|
|
||||||
return &sessions.SessionState{
|
|
||||||
AccessToken: jsonResponse.AccessToken,
|
AccessToken: jsonResponse.AccessToken,
|
||||||
IDToken: jsonResponse.IDToken,
|
IDToken: jsonResponse.IDToken,
|
||||||
CreatedAt: &created,
|
|
||||||
ExpiresOn: &expires,
|
|
||||||
Email: email,
|
Email: email,
|
||||||
}, nil
|
}
|
||||||
|
|
||||||
|
session.CreatedAtNow()
|
||||||
|
session.ExpiresIn(time.Duration(jsonResponse.ExpiresIn) * time.Second)
|
||||||
|
|
||||||
|
return session, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetLoginURL overrides GetLoginURL to add login.gov parameters
|
// GetLoginURL overrides GetLoginURL to add login.gov parameters
|
||||||
|
@ -143,10 +143,9 @@ func (p *OIDCProvider) ValidateSession(ctx context.Context, s *sessions.SessionS
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
// RefreshSessionIfNeeded checks if the session has expired and uses the
|
// RefreshSession uses the RefreshToken to fetch new Access and ID Tokens
|
||||||
// RefreshToken to fetch a new Access Token (and optional ID token) if required
|
func (p *OIDCProvider) RefreshSession(ctx context.Context, s *sessions.SessionState) (bool, error) {
|
||||||
func (p *OIDCProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error) {
|
if s == nil || s.RefreshToken == "" {
|
||||||
if s == nil || (s.ExpiresOn != nil && s.ExpiresOn.After(time.Now())) || s.RefreshToken == "" {
|
|
||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -155,7 +154,6 @@ func (p *OIDCProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions.S
|
|||||||
return false, fmt.Errorf("unable to redeem refresh token: %v", err)
|
return false, fmt.Errorf("unable to redeem refresh token: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.Printf("refreshed session: %s", s)
|
|
||||||
return true, nil
|
return true, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -227,7 +225,9 @@ func (p *OIDCProvider) CreateSessionFromToken(ctx context.Context, token string)
|
|||||||
ss.AccessToken = token
|
ss.AccessToken = token
|
||||||
ss.IDToken = token
|
ss.IDToken = token
|
||||||
ss.RefreshToken = ""
|
ss.RefreshToken = ""
|
||||||
ss.ExpiresOn = &idToken.Expiry
|
|
||||||
|
ss.CreatedAtNow()
|
||||||
|
ss.SetExpiresOn(idToken.Expiry)
|
||||||
|
|
||||||
return ss, nil
|
return ss, nil
|
||||||
}
|
}
|
||||||
@ -257,9 +257,8 @@ func (p *OIDCProvider) createSession(ctx context.Context, token *oauth2.Token, r
|
|||||||
ss.RefreshToken = token.RefreshToken
|
ss.RefreshToken = token.RefreshToken
|
||||||
ss.IDToken = getIDToken(token)
|
ss.IDToken = getIDToken(token)
|
||||||
|
|
||||||
created := time.Now()
|
ss.CreatedAtNow()
|
||||||
ss.CreatedAt = &created
|
ss.SetExpiresOn(token.Expiry)
|
||||||
ss.ExpiresOn = &token.Expiry
|
|
||||||
|
|
||||||
return ss, nil
|
return ss, nil
|
||||||
}
|
}
|
||||||
|
@ -487,7 +487,7 @@ func TestOIDCProviderRefreshSessionIfNeededWithoutIdToken(t *testing.T) {
|
|||||||
User: "11223344",
|
User: "11223344",
|
||||||
}
|
}
|
||||||
|
|
||||||
refreshed, err := provider.RefreshSessionIfNeeded(context.Background(), existingSession)
|
refreshed, err := provider.RefreshSession(context.Background(), existingSession)
|
||||||
assert.Equal(t, nil, err)
|
assert.Equal(t, nil, err)
|
||||||
assert.Equal(t, refreshed, true)
|
assert.Equal(t, refreshed, true)
|
||||||
assert.Equal(t, "janedoe@example.com", existingSession.Email)
|
assert.Equal(t, "janedoe@example.com", existingSession.Email)
|
||||||
@ -520,7 +520,7 @@ func TestOIDCProviderRefreshSessionIfNeededWithIdToken(t *testing.T) {
|
|||||||
Email: "changeit",
|
Email: "changeit",
|
||||||
User: "changeit",
|
User: "changeit",
|
||||||
}
|
}
|
||||||
refreshed, err := provider.RefreshSessionIfNeeded(context.Background(), existingSession)
|
refreshed, err := provider.RefreshSession(context.Background(), existingSession)
|
||||||
assert.Equal(t, nil, err)
|
assert.Equal(t, nil, err)
|
||||||
assert.Equal(t, refreshed, true)
|
assert.Equal(t, refreshed, true)
|
||||||
assert.Equal(t, defaultIDToken.Email, existingSession.Email)
|
assert.Equal(t, defaultIDToken.Email, existingSession.Email)
|
||||||
|
@ -6,7 +6,6 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/url"
|
"net/url"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware"
|
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware"
|
||||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
|
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
|
||||||
@ -85,9 +84,13 @@ func (p *ProviderData) Redeem(ctx context.Context, redirectURL, code string) (*s
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
// TODO (@NickMeves): Uses OAuth `expires_in` to set an expiration
|
||||||
if token := values.Get("access_token"); token != "" {
|
if token := values.Get("access_token"); token != "" {
|
||||||
created := time.Now()
|
ss := &sessions.SessionState{
|
||||||
return &sessions.SessionState{AccessToken: token, CreatedAt: &created}, nil
|
AccessToken: token,
|
||||||
|
}
|
||||||
|
ss.CreatedAtNow()
|
||||||
|
return ss, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, fmt.Errorf("no access token found %s", result.Body())
|
return nil, fmt.Errorf("no access token found %s", result.Body())
|
||||||
@ -126,10 +129,9 @@ func (p *ProviderData) ValidateSession(ctx context.Context, s *sessions.SessionS
|
|||||||
return validateToken(ctx, p, s.AccessToken, nil)
|
return validateToken(ctx, p, s.AccessToken, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
// RefreshSessionIfNeeded should refresh the user's session if required and
|
// RefreshSession refreshes the user's session
|
||||||
// do nothing if a refresh is not required
|
func (p *ProviderData) RefreshSession(_ context.Context, _ *sessions.SessionState) (bool, error) {
|
||||||
func (p *ProviderData) RefreshSessionIfNeeded(_ context.Context, _ *sessions.SessionState) (bool, error) {
|
return false, ErrNotImplemented
|
||||||
return false, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateSessionFromToken converts Bearer IDTokens into sessions
|
// CreateSessionFromToken converts Bearer IDTokens into sessions
|
||||||
|
@ -14,12 +14,20 @@ import (
|
|||||||
func TestRefresh(t *testing.T) {
|
func TestRefresh(t *testing.T) {
|
||||||
p := &ProviderData{}
|
p := &ProviderData{}
|
||||||
|
|
||||||
expires := time.Now().Add(time.Duration(-11) * time.Minute)
|
now := time.Unix(1234567890, 10)
|
||||||
refreshed, err := p.RefreshSessionIfNeeded(context.Background(), &sessions.SessionState{
|
expires := time.Unix(1234567890, 0)
|
||||||
ExpiresOn: &expires,
|
|
||||||
})
|
ss := &sessions.SessionState{}
|
||||||
assert.Equal(t, false, refreshed)
|
ss.Clock.Set(now)
|
||||||
assert.Equal(t, nil, err)
|
ss.SetExpiresOn(expires)
|
||||||
|
|
||||||
|
refreshed, err := p.RefreshSession(context.Background(), ss)
|
||||||
|
assert.False(t, refreshed)
|
||||||
|
assert.Equal(t, ErrNotImplemented, err)
|
||||||
|
|
||||||
|
refreshed, err = p.RefreshSession(context.Background(), nil)
|
||||||
|
assert.False(t, refreshed)
|
||||||
|
assert.Equal(t, ErrNotImplemented, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAcrValuesNotConfigured(t *testing.T) {
|
func TestAcrValuesNotConfigured(t *testing.T) {
|
||||||
|
@ -9,14 +9,14 @@ import (
|
|||||||
// Provider represents an upstream identity provider implementation
|
// Provider represents an upstream identity provider implementation
|
||||||
type Provider interface {
|
type Provider interface {
|
||||||
Data() *ProviderData
|
Data() *ProviderData
|
||||||
|
GetLoginURL(redirectURI, finalRedirect string, nonce string) string
|
||||||
|
Redeem(ctx context.Context, redirectURI, code string) (*sessions.SessionState, error)
|
||||||
// Deprecated: Migrate to EnrichSession
|
// Deprecated: Migrate to EnrichSession
|
||||||
GetEmailAddress(ctx context.Context, s *sessions.SessionState) (string, error)
|
GetEmailAddress(ctx context.Context, s *sessions.SessionState) (string, error)
|
||||||
GetLoginURL(redirectURI, state, nonce string) string
|
|
||||||
Redeem(ctx context.Context, redirectURI, code string) (*sessions.SessionState, error)
|
|
||||||
EnrichSession(ctx context.Context, s *sessions.SessionState) error
|
EnrichSession(ctx context.Context, s *sessions.SessionState) error
|
||||||
Authorize(ctx context.Context, s *sessions.SessionState) (bool, error)
|
Authorize(ctx context.Context, s *sessions.SessionState) (bool, error)
|
||||||
ValidateSession(ctx context.Context, s *sessions.SessionState) bool
|
ValidateSession(ctx context.Context, s *sessions.SessionState) bool
|
||||||
RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error)
|
RefreshSession(ctx context.Context, s *sessions.SessionState) (bool, error)
|
||||||
CreateSessionFromToken(ctx context.Context, token string) (*sessions.SessionState, error)
|
CreateSessionFromToken(ctx context.Context, token string) (*sessions.SessionState, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user