You've already forked oauth2-proxy
mirror of
https://github.com/oauth2-proxy/oauth2-proxy.git
synced 2025-07-17 01:52:30 +02:00
Refactor StoredSessionHandler
This commit is contained in:
committed by
Joel Speed
parent
518e619289
commit
86ba2f41ce
@ -109,6 +109,12 @@ func (s *storedSessionLoader) getValidatedSession(rw http.ResponseWriter, req *h
|
||||
return nil, fmt.Errorf("error refreshing access token for session (%s): %v", session, err)
|
||||
}
|
||||
|
||||
// Validate all sessions after any Redeem/Refresh operation (fail or success)
|
||||
err = s.validateSession(req.Context(), session)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return session, nil
|
||||
}
|
||||
|
||||
@ -121,36 +127,35 @@ func (s *storedSessionLoader) refreshSessionIfNeeded(rw http.ResponseWriter, req
|
||||
return nil
|
||||
}
|
||||
|
||||
var wasLocked bool
|
||||
var err error
|
||||
var isLocked bool
|
||||
for isLocked, err = session.PeekLock(req.Context()); isLocked; isLocked, err = session.PeekLock(req.Context()) {
|
||||
wasLocked = true
|
||||
// delay next peek lock
|
||||
time.Sleep(SessionLockPeekDelay)
|
||||
}
|
||||
|
||||
wasLocked, err := s.waitForPossibleSessionLock(session, req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// If session was locked fetch current state
|
||||
// If session was locked, fetch current state, because
|
||||
// it should be updated after lock is released.
|
||||
if wasLocked {
|
||||
var sessionStored *sessionsapi.SessionState
|
||||
sessionStored, err = s.store.Load(req)
|
||||
err = s.updateSessionFromStore(req, session)
|
||||
if err != nil {
|
||||
return err
|
||||
logger.Errorf("Unable to load updated session from store: %v", err)
|
||||
}
|
||||
|
||||
if session == nil || sessionStored == nil {
|
||||
return nil
|
||||
}
|
||||
*session = *sessionStored
|
||||
|
||||
return nil
|
||||
return err
|
||||
}
|
||||
|
||||
err = session.ObtainLock(req.Context(), SessionLockExpireTime)
|
||||
logger.Printf("Refreshing session - User: %s; SessionAge: %s", session.User, session.Age())
|
||||
err = s.refreshSession(rw, req, session)
|
||||
if err != nil {
|
||||
// If a preemptive refresh fails, we still keep the session
|
||||
// if validateSession succeeds.
|
||||
logger.Errorf("Unable to refresh session: %v", err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// refreshSession attempts to refresh the session with the provider
|
||||
// and will save the session if it was updated.
|
||||
func (s *storedSessionLoader) refreshSession(rw http.ResponseWriter, req *http.Request, session *sessionsapi.SessionState) error {
|
||||
err := session.ObtainLock(req.Context(), SessionLockExpireTime)
|
||||
if err != nil {
|
||||
logger.Errorf("unable to obtain lock (skipping refresh): %v", err)
|
||||
return nil
|
||||
@ -162,21 +167,6 @@ func (s *storedSessionLoader) refreshSessionIfNeeded(rw http.ResponseWriter, req
|
||||
}
|
||||
}()
|
||||
|
||||
logger.Printf("Refreshing session - User: %s; SessionAge: %s", session.User, session.Age())
|
||||
err = s.refreshSession(rw, req, session)
|
||||
if err != nil {
|
||||
// If a preemptive refresh fails, we still keep the session
|
||||
// if validateSession succeeds.
|
||||
logger.Errorf("Unable to refresh session: %v", err)
|
||||
}
|
||||
|
||||
// Validate all sessions after any Redeem/Refresh operation (fail or success)
|
||||
return s.validateSession(req.Context(), session)
|
||||
}
|
||||
|
||||
// refreshSession attempts to refresh the session with the provider
|
||||
// and will save the session if it was updated.
|
||||
func (s *storedSessionLoader) refreshSession(rw http.ResponseWriter, req *http.Request, session *sessionsapi.SessionState) error {
|
||||
refreshed, err := s.sessionRefresher(req.Context(), session)
|
||||
if err != nil && !errors.Is(err, providers.ErrNotImplemented) {
|
||||
return fmt.Errorf("error refreshing tokens: %v", err)
|
||||
@ -205,11 +195,42 @@ func (s *storedSessionLoader) refreshSession(rw http.ResponseWriter, req *http.R
|
||||
err = s.store.Save(rw, req, session)
|
||||
if err != nil {
|
||||
logger.PrintAuthf(session.Email, req, logger.AuthError, "error saving session: %v", err)
|
||||
return fmt.Errorf("error saving session: %v", err)
|
||||
err = fmt.Errorf("error saving session: %v", err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *storedSessionLoader) updateSessionFromStore(req *http.Request, session *sessionsapi.SessionState) error {
|
||||
sessionStored, err := s.store.Load(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if session == nil || sessionStored == nil {
|
||||
return nil
|
||||
}
|
||||
*session = *sessionStored
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *storedSessionLoader) waitForPossibleSessionLock(session *sessionsapi.SessionState, req *http.Request) (bool, error) {
|
||||
var wasLocked bool
|
||||
var err error
|
||||
var isLocked bool
|
||||
for isLocked, err = session.PeekLock(req.Context()); isLocked; isLocked, err = session.PeekLock(req.Context()) {
|
||||
wasLocked = true
|
||||
// delay next peek lock
|
||||
time.Sleep(SessionLockPeekDelay)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return wasLocked, nil
|
||||
}
|
||||
|
||||
// validateSession checks whether the session has expired and performs
|
||||
// provider validation on the session.
|
||||
// An error implies the session is not longer valid.
|
||||
|
Reference in New Issue
Block a user