You've already forked oauth2-proxy
mirror of
https://github.com/oauth2-proxy/oauth2-proxy.git
synced 2025-06-15 00:15:00 +02:00
RefreshSessions immediately when called
This commit is contained in:
@ -24,12 +24,12 @@ type StoredSessionLoaderOptions struct {
|
||||
RefreshPeriod time.Duration
|
||||
|
||||
// Provider based sesssion refreshing
|
||||
RefreshSessionIfNeeded func(context.Context, *sessionsapi.SessionState) (bool, error)
|
||||
RefreshSession func(context.Context, *sessionsapi.SessionState) (bool, error)
|
||||
|
||||
// Provider based session validation.
|
||||
// If the sesssion is older than `RefreshPeriod` but the provider doesn't
|
||||
// refresh it, we must re-validate using this validation.
|
||||
ValidateSessionState func(context.Context, *sessionsapi.SessionState) bool
|
||||
ValidateSession func(context.Context, *sessionsapi.SessionState) bool
|
||||
}
|
||||
|
||||
// NewStoredSessionLoader creates a new storedSessionLoader which loads
|
||||
@ -38,10 +38,10 @@ type StoredSessionLoaderOptions struct {
|
||||
// If a session was loader by a previous handler, it will not be replaced.
|
||||
func NewStoredSessionLoader(opts *StoredSessionLoaderOptions) alice.Constructor {
|
||||
ss := &storedSessionLoader{
|
||||
store: opts.SessionStore,
|
||||
refreshPeriod: opts.RefreshPeriod,
|
||||
refreshSessionWithProviderIfNeeded: opts.RefreshSessionIfNeeded,
|
||||
validateSessionState: opts.ValidateSessionState,
|
||||
store: opts.SessionStore,
|
||||
refreshPeriod: opts.RefreshPeriod,
|
||||
sessionRefresher: opts.RefreshSession,
|
||||
sessionValidator: opts.ValidateSession,
|
||||
}
|
||||
return ss.loadSession
|
||||
}
|
||||
@ -49,10 +49,10 @@ func NewStoredSessionLoader(opts *StoredSessionLoaderOptions) alice.Constructor
|
||||
// storedSessionLoader is responsible for loading sessions from cookie
|
||||
// identified sessions in the session store.
|
||||
type storedSessionLoader struct {
|
||||
store sessionsapi.SessionStore
|
||||
refreshPeriod time.Duration
|
||||
refreshSessionWithProviderIfNeeded func(context.Context, *sessionsapi.SessionState) (bool, error)
|
||||
validateSessionState func(context.Context, *sessionsapi.SessionState) bool
|
||||
store sessionsapi.SessionStore
|
||||
refreshPeriod time.Duration
|
||||
sessionRefresher func(context.Context, *sessionsapi.SessionState) (bool, error)
|
||||
sessionValidator func(context.Context, *sessionsapi.SessionState) bool
|
||||
}
|
||||
|
||||
// loadSession attempts to load a session as identified by the request cookies.
|
||||
@ -120,37 +120,38 @@ func (s *storedSessionLoader) refreshSessionIfNeeded(rw http.ResponseWriter, req
|
||||
}
|
||||
|
||||
logger.Printf("Refreshing %s old session cookie for %s (refresh after %s)", session.Age(), session, s.refreshPeriod)
|
||||
refreshed, err := s.refreshSessionWithProvider(rw, req, session)
|
||||
err := s.refreshSession(rw, req, session)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !refreshed {
|
||||
// Session wasn't refreshed, so make sure it's still valid
|
||||
return s.validateSession(req.Context(), session)
|
||||
}
|
||||
return nil
|
||||
// Validate all sessions after any Redeem/Refresh operation
|
||||
return s.validateSession(req.Context(), session)
|
||||
}
|
||||
|
||||
// refreshSessionWithProvider attempts to refresh the sessinon with the provider
|
||||
// refreshSession attempts to refresh the session with the provider
|
||||
// and will save the session if it was updated.
|
||||
func (s *storedSessionLoader) refreshSessionWithProvider(rw http.ResponseWriter, req *http.Request, session *sessionsapi.SessionState) (bool, error) {
|
||||
refreshed, err := s.refreshSessionWithProviderIfNeeded(req.Context(), session)
|
||||
func (s *storedSessionLoader) refreshSession(rw http.ResponseWriter, req *http.Request, session *sessionsapi.SessionState) error {
|
||||
refreshed, err := s.sessionRefresher(req.Context(), session)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("error refreshing access token: %v", err)
|
||||
return fmt.Errorf("error refreshing access token: %v", err)
|
||||
}
|
||||
|
||||
if !refreshed {
|
||||
return false, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
// If we refreshed, update the `CreatedAt` time to reset the refresh timer
|
||||
// TODO: Implement
|
||||
// session.CreatedAtNow()
|
||||
|
||||
// Because the session was refreshed, make sure to save it
|
||||
err = s.store.Save(rw, req, session)
|
||||
if err != nil {
|
||||
logger.PrintAuthf(session.Email, req, logger.AuthError, "error saving session: %v", err)
|
||||
return false, fmt.Errorf("error saving session: %v", err)
|
||||
return fmt.Errorf("error saving session: %v", err)
|
||||
}
|
||||
return true, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateSession checks whether the session has expired and performs
|
||||
@ -161,7 +162,7 @@ func (s *storedSessionLoader) validateSession(ctx context.Context, session *sess
|
||||
return errors.New("session is expired")
|
||||
}
|
||||
|
||||
if !s.validateSessionState(ctx, session) {
|
||||
if !s.sessionValidator(ctx, session) {
|
||||
return errors.New("session is invalid")
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user