diff --git a/pkg/middleware/stored_session.go b/pkg/middleware/stored_session.go index d018cae3..bdc4b4cf 100644 --- a/pkg/middleware/stored_session.go +++ b/pkg/middleware/stored_session.go @@ -109,6 +109,12 @@ func (s *storedSessionLoader) getValidatedSession(rw http.ResponseWriter, req *h return nil, fmt.Errorf("error refreshing access token for session (%s): %v", session, err) } + // Validate all sessions after any Redeem/Refresh operation (fail or success) + err = s.validateSession(req.Context(), session) + if err != nil { + return nil, err + } + return session, nil } @@ -121,36 +127,35 @@ func (s *storedSessionLoader) refreshSessionIfNeeded(rw http.ResponseWriter, req return nil } - var wasLocked bool - var err error - var isLocked bool - for isLocked, err = session.PeekLock(req.Context()); isLocked; isLocked, err = session.PeekLock(req.Context()) { - wasLocked = true - // delay next peek lock - time.Sleep(SessionLockPeekDelay) - } - + wasLocked, err := s.waitForPossibleSessionLock(session, req) if err != nil { return err } - // If session was locked fetch current state + // If session was locked, fetch current state, because + // it should be updated after lock is released. if wasLocked { - var sessionStored *sessionsapi.SessionState - sessionStored, err = s.store.Load(req) + err = s.updateSessionFromStore(req, session) if err != nil { - return err + logger.Errorf("Unable to load updated session from store: %v", err) } - - if session == nil || sessionStored == nil { - return nil - } - *session = *sessionStored - - return nil + return err } - err = session.ObtainLock(req.Context(), SessionLockExpireTime) + logger.Printf("Refreshing session - User: %s; SessionAge: %s", session.User, session.Age()) + err = s.refreshSession(rw, req, session) + if err != nil { + // If a preemptive refresh fails, we still keep the session + // if validateSession succeeds. + logger.Errorf("Unable to refresh session: %v", err) + } + return err +} + +// refreshSession attempts to refresh the session with the provider +// and will save the session if it was updated. +func (s *storedSessionLoader) refreshSession(rw http.ResponseWriter, req *http.Request, session *sessionsapi.SessionState) error { + err := session.ObtainLock(req.Context(), SessionLockExpireTime) if err != nil { logger.Errorf("unable to obtain lock (skipping refresh): %v", err) return nil @@ -162,21 +167,6 @@ func (s *storedSessionLoader) refreshSessionIfNeeded(rw http.ResponseWriter, req } }() - logger.Printf("Refreshing session - User: %s; SessionAge: %s", session.User, session.Age()) - err = s.refreshSession(rw, req, session) - if err != nil { - // If a preemptive refresh fails, we still keep the session - // if validateSession succeeds. - logger.Errorf("Unable to refresh session: %v", err) - } - - // Validate all sessions after any Redeem/Refresh operation (fail or success) - return s.validateSession(req.Context(), session) -} - -// refreshSession attempts to refresh the session with the provider -// and will save the session if it was updated. -func (s *storedSessionLoader) refreshSession(rw http.ResponseWriter, req *http.Request, session *sessionsapi.SessionState) error { refreshed, err := s.sessionRefresher(req.Context(), session) if err != nil && !errors.Is(err, providers.ErrNotImplemented) { return fmt.Errorf("error refreshing tokens: %v", err) @@ -205,11 +195,42 @@ func (s *storedSessionLoader) refreshSession(rw http.ResponseWriter, req *http.R err = s.store.Save(rw, req, session) if err != nil { logger.PrintAuthf(session.Email, req, logger.AuthError, "error saving session: %v", err) - return fmt.Errorf("error saving session: %v", err) + err = fmt.Errorf("error saving session: %v", err) } + return err +} + +func (s *storedSessionLoader) updateSessionFromStore(req *http.Request, session *sessionsapi.SessionState) error { + sessionStored, err := s.store.Load(req) + if err != nil { + return err + } + + if session == nil || sessionStored == nil { + return nil + } + *session = *sessionStored + return nil } +func (s *storedSessionLoader) waitForPossibleSessionLock(session *sessionsapi.SessionState, req *http.Request) (bool, error) { + var wasLocked bool + var err error + var isLocked bool + for isLocked, err = session.PeekLock(req.Context()); isLocked; isLocked, err = session.PeekLock(req.Context()) { + wasLocked = true + // delay next peek lock + time.Sleep(SessionLockPeekDelay) + } + + if err != nil { + return false, err + } + + return wasLocked, nil +} + // validateSession checks whether the session has expired and performs // provider validation on the session. // An error implies the session is not longer valid.