You've already forked oauth2-proxy
mirror of
https://github.com/oauth2-proxy/oauth2-proxy.git
synced 2025-07-17 01:52:30 +02:00
Refactor StoredSessionHandler
This commit is contained in:
committed by
Joel Speed
parent
518e619289
commit
86ba2f41ce
@ -109,6 +109,12 @@ func (s *storedSessionLoader) getValidatedSession(rw http.ResponseWriter, req *h
|
|||||||
return nil, fmt.Errorf("error refreshing access token for session (%s): %v", session, err)
|
return nil, fmt.Errorf("error refreshing access token for session (%s): %v", session, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Validate all sessions after any Redeem/Refresh operation (fail or success)
|
||||||
|
err = s.validateSession(req.Context(), session)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
return session, nil
|
return session, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -121,36 +127,35 @@ func (s *storedSessionLoader) refreshSessionIfNeeded(rw http.ResponseWriter, req
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var wasLocked bool
|
wasLocked, err := s.waitForPossibleSessionLock(session, req)
|
||||||
var err error
|
|
||||||
var isLocked bool
|
|
||||||
for isLocked, err = session.PeekLock(req.Context()); isLocked; isLocked, err = session.PeekLock(req.Context()) {
|
|
||||||
wasLocked = true
|
|
||||||
// delay next peek lock
|
|
||||||
time.Sleep(SessionLockPeekDelay)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// If session was locked fetch current state
|
// If session was locked, fetch current state, because
|
||||||
|
// it should be updated after lock is released.
|
||||||
if wasLocked {
|
if wasLocked {
|
||||||
var sessionStored *sessionsapi.SessionState
|
err = s.updateSessionFromStore(req, session)
|
||||||
sessionStored, err = s.store.Load(req)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
logger.Errorf("Unable to load updated session from store: %v", err)
|
||||||
}
|
}
|
||||||
|
return err
|
||||||
if session == nil || sessionStored == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
*session = *sessionStored
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
err = session.ObtainLock(req.Context(), SessionLockExpireTime)
|
logger.Printf("Refreshing session - User: %s; SessionAge: %s", session.User, session.Age())
|
||||||
|
err = s.refreshSession(rw, req, session)
|
||||||
|
if err != nil {
|
||||||
|
// If a preemptive refresh fails, we still keep the session
|
||||||
|
// if validateSession succeeds.
|
||||||
|
logger.Errorf("Unable to refresh session: %v", err)
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// refreshSession attempts to refresh the session with the provider
|
||||||
|
// and will save the session if it was updated.
|
||||||
|
func (s *storedSessionLoader) refreshSession(rw http.ResponseWriter, req *http.Request, session *sessionsapi.SessionState) error {
|
||||||
|
err := session.ObtainLock(req.Context(), SessionLockExpireTime)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Errorf("unable to obtain lock (skipping refresh): %v", err)
|
logger.Errorf("unable to obtain lock (skipping refresh): %v", err)
|
||||||
return nil
|
return nil
|
||||||
@ -162,21 +167,6 @@ func (s *storedSessionLoader) refreshSessionIfNeeded(rw http.ResponseWriter, req
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
logger.Printf("Refreshing session - User: %s; SessionAge: %s", session.User, session.Age())
|
|
||||||
err = s.refreshSession(rw, req, session)
|
|
||||||
if err != nil {
|
|
||||||
// If a preemptive refresh fails, we still keep the session
|
|
||||||
// if validateSession succeeds.
|
|
||||||
logger.Errorf("Unable to refresh session: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Validate all sessions after any Redeem/Refresh operation (fail or success)
|
|
||||||
return s.validateSession(req.Context(), session)
|
|
||||||
}
|
|
||||||
|
|
||||||
// refreshSession attempts to refresh the session with the provider
|
|
||||||
// and will save the session if it was updated.
|
|
||||||
func (s *storedSessionLoader) refreshSession(rw http.ResponseWriter, req *http.Request, session *sessionsapi.SessionState) error {
|
|
||||||
refreshed, err := s.sessionRefresher(req.Context(), session)
|
refreshed, err := s.sessionRefresher(req.Context(), session)
|
||||||
if err != nil && !errors.Is(err, providers.ErrNotImplemented) {
|
if err != nil && !errors.Is(err, providers.ErrNotImplemented) {
|
||||||
return fmt.Errorf("error refreshing tokens: %v", err)
|
return fmt.Errorf("error refreshing tokens: %v", err)
|
||||||
@ -205,11 +195,42 @@ func (s *storedSessionLoader) refreshSession(rw http.ResponseWriter, req *http.R
|
|||||||
err = s.store.Save(rw, req, session)
|
err = s.store.Save(rw, req, session)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.PrintAuthf(session.Email, req, logger.AuthError, "error saving session: %v", err)
|
logger.PrintAuthf(session.Email, req, logger.AuthError, "error saving session: %v", err)
|
||||||
return fmt.Errorf("error saving session: %v", err)
|
err = fmt.Errorf("error saving session: %v", err)
|
||||||
}
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *storedSessionLoader) updateSessionFromStore(req *http.Request, session *sessionsapi.SessionState) error {
|
||||||
|
sessionStored, err := s.store.Load(req)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if session == nil || sessionStored == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
*session = *sessionStored
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *storedSessionLoader) waitForPossibleSessionLock(session *sessionsapi.SessionState, req *http.Request) (bool, error) {
|
||||||
|
var wasLocked bool
|
||||||
|
var err error
|
||||||
|
var isLocked bool
|
||||||
|
for isLocked, err = session.PeekLock(req.Context()); isLocked; isLocked, err = session.PeekLock(req.Context()) {
|
||||||
|
wasLocked = true
|
||||||
|
// delay next peek lock
|
||||||
|
time.Sleep(SessionLockPeekDelay)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return wasLocked, nil
|
||||||
|
}
|
||||||
|
|
||||||
// validateSession checks whether the session has expired and performs
|
// validateSession checks whether the session has expired and performs
|
||||||
// provider validation on the session.
|
// provider validation on the session.
|
||||||
// An error implies the session is not longer valid.
|
// An error implies the session is not longer valid.
|
||||||
|
Reference in New Issue
Block a user