1
0
mirror of https://github.com/oauth2-proxy/oauth2-proxy.git synced 2025-07-17 01:52:30 +02:00

Refactor StoredSessionHandler

This commit is contained in:
Kevin Kreitner
2021-09-29 13:31:18 +02:00
committed by Joel Speed
parent 518e619289
commit 86ba2f41ce

View File

@ -109,6 +109,12 @@ func (s *storedSessionLoader) getValidatedSession(rw http.ResponseWriter, req *h
return nil, fmt.Errorf("error refreshing access token for session (%s): %v", session, err)
}
// Validate all sessions after any Redeem/Refresh operation (fail or success)
err = s.validateSession(req.Context(), session)
if err != nil {
return nil, err
}
return session, nil
}
@ -121,36 +127,35 @@ func (s *storedSessionLoader) refreshSessionIfNeeded(rw http.ResponseWriter, req
return nil
}
var wasLocked bool
var err error
var isLocked bool
for isLocked, err = session.PeekLock(req.Context()); isLocked; isLocked, err = session.PeekLock(req.Context()) {
wasLocked = true
// delay next peek lock
time.Sleep(SessionLockPeekDelay)
}
wasLocked, err := s.waitForPossibleSessionLock(session, req)
if err != nil {
return err
}
// If session was locked fetch current state
// If session was locked, fetch current state, because
// it should be updated after lock is released.
if wasLocked {
var sessionStored *sessionsapi.SessionState
sessionStored, err = s.store.Load(req)
err = s.updateSessionFromStore(req, session)
if err != nil {
logger.Errorf("Unable to load updated session from store: %v", err)
}
return err
}
if session == nil || sessionStored == nil {
return nil
logger.Printf("Refreshing session - User: %s; SessionAge: %s", session.User, session.Age())
err = s.refreshSession(rw, req, session)
if err != nil {
// If a preemptive refresh fails, we still keep the session
// if validateSession succeeds.
logger.Errorf("Unable to refresh session: %v", err)
}
*session = *sessionStored
return nil
return err
}
err = session.ObtainLock(req.Context(), SessionLockExpireTime)
// refreshSession attempts to refresh the session with the provider
// and will save the session if it was updated.
func (s *storedSessionLoader) refreshSession(rw http.ResponseWriter, req *http.Request, session *sessionsapi.SessionState) error {
err := session.ObtainLock(req.Context(), SessionLockExpireTime)
if err != nil {
logger.Errorf("unable to obtain lock (skipping refresh): %v", err)
return nil
@ -162,21 +167,6 @@ func (s *storedSessionLoader) refreshSessionIfNeeded(rw http.ResponseWriter, req
}
}()
logger.Printf("Refreshing session - User: %s; SessionAge: %s", session.User, session.Age())
err = s.refreshSession(rw, req, session)
if err != nil {
// If a preemptive refresh fails, we still keep the session
// if validateSession succeeds.
logger.Errorf("Unable to refresh session: %v", err)
}
// Validate all sessions after any Redeem/Refresh operation (fail or success)
return s.validateSession(req.Context(), session)
}
// refreshSession attempts to refresh the session with the provider
// and will save the session if it was updated.
func (s *storedSessionLoader) refreshSession(rw http.ResponseWriter, req *http.Request, session *sessionsapi.SessionState) error {
refreshed, err := s.sessionRefresher(req.Context(), session)
if err != nil && !errors.Is(err, providers.ErrNotImplemented) {
return fmt.Errorf("error refreshing tokens: %v", err)
@ -205,10 +195,41 @@ func (s *storedSessionLoader) refreshSession(rw http.ResponseWriter, req *http.R
err = s.store.Save(rw, req, session)
if err != nil {
logger.PrintAuthf(session.Email, req, logger.AuthError, "error saving session: %v", err)
return fmt.Errorf("error saving session: %v", err)
err = fmt.Errorf("error saving session: %v", err)
}
return err
}
func (s *storedSessionLoader) updateSessionFromStore(req *http.Request, session *sessionsapi.SessionState) error {
sessionStored, err := s.store.Load(req)
if err != nil {
return err
}
if session == nil || sessionStored == nil {
return nil
}
*session = *sessionStored
return nil
}
func (s *storedSessionLoader) waitForPossibleSessionLock(session *sessionsapi.SessionState, req *http.Request) (bool, error) {
var wasLocked bool
var err error
var isLocked bool
for isLocked, err = session.PeekLock(req.Context()); isLocked; isLocked, err = session.PeekLock(req.Context()) {
wasLocked = true
// delay next peek lock
time.Sleep(SessionLockPeekDelay)
}
if err != nil {
return false, err
}
return wasLocked, nil
}
// validateSession checks whether the session has expired and performs
// provider validation on the session.