1
0
mirror of https://github.com/oauth2-proxy/oauth2-proxy.git synced 2025-06-15 00:15:00 +02:00

RefreshSessions immediately when called

This commit is contained in:
Nick Meves
2021-03-06 15:33:13 -08:00
parent 5f4ac25b1e
commit 7e80e5596b
13 changed files with 74 additions and 79 deletions

View File

@ -24,12 +24,12 @@ type StoredSessionLoaderOptions struct {
RefreshPeriod time.Duration
// Provider based sesssion refreshing
RefreshSessionIfNeeded func(context.Context, *sessionsapi.SessionState) (bool, error)
RefreshSession func(context.Context, *sessionsapi.SessionState) (bool, error)
// Provider based session validation.
// If the sesssion is older than `RefreshPeriod` but the provider doesn't
// refresh it, we must re-validate using this validation.
ValidateSessionState func(context.Context, *sessionsapi.SessionState) bool
ValidateSession func(context.Context, *sessionsapi.SessionState) bool
}
// NewStoredSessionLoader creates a new storedSessionLoader which loads
@ -38,10 +38,10 @@ type StoredSessionLoaderOptions struct {
// If a session was loader by a previous handler, it will not be replaced.
func NewStoredSessionLoader(opts *StoredSessionLoaderOptions) alice.Constructor {
ss := &storedSessionLoader{
store: opts.SessionStore,
refreshPeriod: opts.RefreshPeriod,
refreshSessionWithProviderIfNeeded: opts.RefreshSessionIfNeeded,
validateSessionState: opts.ValidateSessionState,
store: opts.SessionStore,
refreshPeriod: opts.RefreshPeriod,
sessionRefresher: opts.RefreshSession,
sessionValidator: opts.ValidateSession,
}
return ss.loadSession
}
@ -49,10 +49,10 @@ func NewStoredSessionLoader(opts *StoredSessionLoaderOptions) alice.Constructor
// storedSessionLoader is responsible for loading sessions from cookie
// identified sessions in the session store.
type storedSessionLoader struct {
store sessionsapi.SessionStore
refreshPeriod time.Duration
refreshSessionWithProviderIfNeeded func(context.Context, *sessionsapi.SessionState) (bool, error)
validateSessionState func(context.Context, *sessionsapi.SessionState) bool
store sessionsapi.SessionStore
refreshPeriod time.Duration
sessionRefresher func(context.Context, *sessionsapi.SessionState) (bool, error)
sessionValidator func(context.Context, *sessionsapi.SessionState) bool
}
// loadSession attempts to load a session as identified by the request cookies.
@ -120,37 +120,38 @@ func (s *storedSessionLoader) refreshSessionIfNeeded(rw http.ResponseWriter, req
}
logger.Printf("Refreshing %s old session cookie for %s (refresh after %s)", session.Age(), session, s.refreshPeriod)
refreshed, err := s.refreshSessionWithProvider(rw, req, session)
err := s.refreshSession(rw, req, session)
if err != nil {
return err
}
if !refreshed {
// Session wasn't refreshed, so make sure it's still valid
return s.validateSession(req.Context(), session)
}
return nil
// Validate all sessions after any Redeem/Refresh operation
return s.validateSession(req.Context(), session)
}
// refreshSessionWithProvider attempts to refresh the sessinon with the provider
// refreshSession attempts to refresh the session with the provider
// and will save the session if it was updated.
func (s *storedSessionLoader) refreshSessionWithProvider(rw http.ResponseWriter, req *http.Request, session *sessionsapi.SessionState) (bool, error) {
refreshed, err := s.refreshSessionWithProviderIfNeeded(req.Context(), session)
func (s *storedSessionLoader) refreshSession(rw http.ResponseWriter, req *http.Request, session *sessionsapi.SessionState) error {
refreshed, err := s.sessionRefresher(req.Context(), session)
if err != nil {
return false, fmt.Errorf("error refreshing access token: %v", err)
return fmt.Errorf("error refreshing access token: %v", err)
}
if !refreshed {
return false, nil
return nil
}
// If we refreshed, update the `CreatedAt` time to reset the refresh timer
// TODO: Implement
// session.CreatedAtNow()
// Because the session was refreshed, make sure to save it
err = s.store.Save(rw, req, session)
if err != nil {
logger.PrintAuthf(session.Email, req, logger.AuthError, "error saving session: %v", err)
return false, fmt.Errorf("error saving session: %v", err)
return fmt.Errorf("error saving session: %v", err)
}
return true, nil
return nil
}
// validateSession checks whether the session has expired and performs
@ -161,7 +162,7 @@ func (s *storedSessionLoader) validateSession(ctx context.Context, session *sess
return errors.New("session is expired")
}
if !s.validateSessionState(ctx, session) {
if !s.sessionValidator(ctx, session) {
return errors.New("session is invalid")
}