diff --git a/pkg/middleware/stored_session.go b/pkg/middleware/stored_session.go index 0fb1e645..d244a563 100644 --- a/pkg/middleware/stored_session.go +++ b/pkg/middleware/stored_session.go @@ -15,8 +15,20 @@ import ( ) const ( - SessionLockExpireTime = 5 * time.Second - SessionLockPeekDelay = 50 * time.Millisecond + // When attempting to obtain the lock, if it's not done before this timeout + // then exit and fail the refresh attempt. + // TODO: This should probably be configurable by the end user. + sessionRefreshObtainTimeout = 5 * time.Second + + // Maximum time allowed for a session refresh attempt. + // If the refresh request isn't finished within this time, the lock will be + // released. + // TODO: This should probably be configurable by the end user. + sessionRefreshLockDuration = 2 * time.Second + + // How long to wait after failing to obtain the lock before trying again. + // TODO: This should probably be configurable by the end user. + sessionRefreshRetryPeriod = 10 * time.Millisecond ) // StoredSessionLoaderOptions contains all of the requirements to construct @@ -113,47 +125,81 @@ func (s *storedSessionLoader) getValidatedSession(rw http.ResponseWriter, req *h // is older than the refresh period. // Success or fail, we will then validate the session. func (s *storedSessionLoader) refreshSessionIfNeeded(rw http.ResponseWriter, req *http.Request, session *sessionsapi.SessionState) error { - if s.refreshPeriod <= time.Duration(0) || session.Age() < s.refreshPeriod { + if !needsRefresh(s.refreshPeriod, session) { // Refresh is disabled or the session is not old enough, do nothing return nil } - wasRefreshed, err := s.checkForConcurrentRefresh(session, req) - if err != nil { - return err + var lockObtained bool + ctx, cancel := context.WithTimeout(context.Background(), sessionRefreshObtainTimeout) + defer cancel() + + for !lockObtained { + select { + case <-ctx.Done(): + return errors.New("timeout obtaining session lock") + default: + err := session.ObtainLock(req.Context(), sessionRefreshLockDuration) + if err != nil && !errors.Is(err, sessionsapi.ErrLockNotObtained) { + return fmt.Errorf("error occurred while trying to obtain lock: %v", err) + } else if errors.Is(err, sessionsapi.ErrLockNotObtained) { + time.Sleep(sessionRefreshRetryPeriod) + continue + } + // No error means we obtained the lock + lockObtained = true + } } - // If session was already refreshed via a concurrent request locked skip refreshing, - // because the refreshed session is already loaded from storage - if !wasRefreshed { - logger.Printf("Refreshing session - User: %s; SessionAge: %s", session.User, session.Age()) - err = s.refreshSession(rw, req, session) - if err != nil { - // If a preemptive refresh fails, we still keep the session - // if validateSession succeeds. - logger.Errorf("Unable to refresh session: %v", err) + // The rest of this function is carried out under lock, but we must release it + // wherever we exit from this function. + defer func() { + if session == nil { + return } + if err := session.ReleaseLock(req.Context()); err != nil { + logger.Errorf("unable to release lock: %v", err) + } + }() + + // Reload the session in case it was changed underneath us. + freshSession, err := s.store.Load(req) + if err != nil { + return fmt.Errorf("could not load session: %v", err) + } + if freshSession == nil { + return errors.New("session no longer exists, it may have been removed by another request") + } + // Restore the state of the fresh session into the original pointer. + // This is important so that changes are passed up the to the parent scope. + *session = *freshSession + + if !needsRefresh(s.refreshPeriod, session) { + // The session must have already been refreshed while we were waiting to + // obtain the lock. + return nil + } + + // We are holding the lock and the session needs a refresh + logger.Printf("Refreshing session - User: %s; SessionAge: %s", session.User, session.Age()) + if err := s.refreshSession(rw, req, session); err != nil { + // If a preemptive refresh fails, we still keep the session + // if validateSession succeeds. + logger.Errorf("Unable to refresh session: %v", err) } // Validate all sessions after any Redeem/Refresh operation (fail or success) return s.validateSession(req.Context(), session) } +// needsRefresh determines whether we should attempt to refresh a session or not. +func needsRefresh(refreshPeriod time.Duration, session *sessionsapi.SessionState) bool { + return refreshPeriod > time.Duration(0) && session.Age() > refreshPeriod +} + // refreshSession attempts to refresh the session with the provider // and will save the session if it was updated. func (s *storedSessionLoader) refreshSession(rw http.ResponseWriter, req *http.Request, session *sessionsapi.SessionState) error { - err := session.ObtainLock(req.Context(), SessionLockExpireTime) - if err != nil { - logger.Errorf("Unable to obtain lock: %v", err) - return s.handleObtainLockError(req, session) - } - defer func() { - err = session.ReleaseLock(req.Context()) - if err != nil { - logger.Errorf("unable to release lock: %v", err) - } - }() - refreshed, err := s.sessionRefresher(req.Context(), session) if err != nil && !errors.Is(err, providers.ErrNotImplemented) { return fmt.Errorf("error refreshing tokens: %v", err) @@ -182,75 +228,11 @@ func (s *storedSessionLoader) refreshSession(rw http.ResponseWriter, req *http.R err = s.store.Save(rw, req, session) if err != nil { logger.PrintAuthf(session.Email, req, logger.AuthError, "error saving session: %v", err) - err = fmt.Errorf("error saving session: %v", err) - } - return err -} - -func (s *storedSessionLoader) handleObtainLockError(req *http.Request, session *sessionsapi.SessionState) error { - wasRefreshed, err := s.checkForConcurrentRefresh(session, req) - if err != nil { - logger.Errorf("Unable to wait for obtained lock: %v", err) - return err - } - if !wasRefreshed { - return errors.New("unable to obtain lock and session was also not refreshed via concurrent request") + return fmt.Errorf("error saving session: %v", err) } return nil } -func (s *storedSessionLoader) updateSessionFromStore(req *http.Request, session *sessionsapi.SessionState) error { - sessionStored, err := s.store.Load(req) - if err != nil { - return fmt.Errorf("unable to load updated session from store: %v", err) - } - - if sessionStored == nil { - return fmt.Errorf("no session available to udpate from store") - } - *session = *sessionStored - - return nil -} - -func (s *storedSessionLoader) waitForPossibleSessionLock(session *sessionsapi.SessionState, req *http.Request) (bool, error) { - var wasLocked bool - isLocked, err := session.PeekLock(req.Context()) - for isLocked { - wasLocked = true - // delay next peek lock - time.Sleep(SessionLockPeekDelay) - isLocked, err = session.PeekLock(req.Context()) - } - - if err != nil { - return false, err - } - - return wasLocked, nil -} - -// checkForConcurrentRefresh returns true if the session is already refreshed via a concurrent request. -func (s *storedSessionLoader) checkForConcurrentRefresh(session *sessionsapi.SessionState, req *http.Request) (bool, error) { - wasLocked, err := s.waitForPossibleSessionLock(session, req) - if err != nil { - return false, err - } - - refreshed := false - if wasLocked { - logger.Printf("Update session from store instead of refreshing") - err = s.updateSessionFromStore(req, session) - if err != nil { - logger.Errorf("Unable to update session from store: %v", err) - return false, err - } - refreshed = true - } - - return refreshed, nil -} - // validateSession checks whether the session has expired and performs // provider validation on the session. // An error implies the session is not longer valid. diff --git a/pkg/middleware/stored_session_test.go b/pkg/middleware/stored_session_test.go index 901ffba2..4891e631 100644 --- a/pkg/middleware/stored_session_test.go +++ b/pkg/middleware/stored_session_test.go @@ -18,97 +18,67 @@ import ( . "github.com/onsi/gomega" ) -type TestLock struct { - Locked bool - WasObtained bool - WasRefreshed bool - WasReleased bool - PeekedCount int - LockedOnPeekCount int - ObtainError error - PeekError error - RefreshError error - ReleaseError error +type testLock struct { + locked bool + obtainOnAttempt int + obtainAttempts int + obtainError error } -func (l *TestLock) Obtain(_ context.Context, _ time.Duration) error { - if l.ObtainError != nil { - return l.ObtainError - } - l.Locked = true - l.WasObtained = true - return nil -} - -func (l *TestLock) Peek(_ context.Context) (bool, error) { - if l.PeekError != nil { - return false, l.PeekError - } - locked := l.Locked - // l.Locked = false - l.PeekedCount++ - // mainly used to test case when peek initially returns false, - // but when trying to obtain lock, it returns true. - if l.LockedOnPeekCount == l.PeekedCount { - return true, nil - } - return locked, nil -} - -func (l *TestLock) Refresh(_ context.Context, _ time.Duration) error { - if l.RefreshError != nil { - return l.ReleaseError - } - l.WasRefreshed = true - return nil -} - -func (l *TestLock) Release(_ context.Context) error { - if l.ReleaseError != nil { - return l.ReleaseError - } - l.Locked = false - l.WasReleased = true - return nil -} - -type LockConc struct { - mu sync.Mutex - lock bool - disablePeek bool -} - -func (l *LockConc) Obtain(_ context.Context, _ time.Duration) error { - l.mu.Lock() - if l.lock { - l.mu.Unlock() +func (l *testLock) Obtain(_ context.Context, _ time.Duration) error { + l.obtainAttempts++ + if l.obtainAttempts < l.obtainOnAttempt { return sessionsapi.ErrLockNotObtained } - l.lock = true - l.mu.Unlock() - return nil -} - -func (l *LockConc) Peek(_ context.Context) (bool, error) { - var response bool - l.mu.Lock() - if l.disablePeek { - response = false - } else { - response = l.lock + if l.obtainError != nil { + return l.obtainError } - l.mu.Unlock() - return response, nil -} - -func (l *LockConc) Refresh(_ context.Context, _ time.Duration) error { + l.locked = true return nil } -func (l *LockConc) Release(_ context.Context) error { +func (l *testLock) Peek(_ context.Context) (bool, error) { + return l.locked, nil +} + +func (l *testLock) Refresh(_ context.Context, _ time.Duration) error { + return nil +} + +func (l *testLock) Release(_ context.Context) error { + l.locked = false + return nil +} + +type testLockConcurrent struct { + mu sync.RWMutex + locked bool +} + +func (l *testLockConcurrent) Obtain(_ context.Context, _ time.Duration) error { l.mu.Lock() - l.lock = false - l.mu.Unlock() + defer l.mu.Unlock() + if l.locked { + return sessionsapi.ErrLockNotObtained + } + l.locked = true + return nil +} + +func (l *testLockConcurrent) Peek(_ context.Context) (bool, error) { + l.mu.RLock() + defer l.mu.RUnlock() + return l.locked, nil +} + +func (l *testLockConcurrent) Refresh(_ context.Context, _ time.Duration) error { + return nil +} + +func (l *testLockConcurrent) Release(_ context.Context) error { + l.mu.Lock() + defer l.mu.Unlock() + l.locked = false return nil } @@ -374,22 +344,29 @@ var _ = Describe("Stored Session Suite", func() { DescribeTable("when serving concurrent requests", func(in storedSessionLoaderConcurrentTableInput) { - lockConc := &LockConc{} + lockConc := &testLockConcurrent{} + + lock := &sync.RWMutex{} + existingSession := *in.existingSession // deep copy existingSession state + existingSession.Lock = lockConc + store := &fakeSessionStore{ + LoadFunc: func(req *http.Request) (*sessionsapi.SessionState, error) { + lock.RLock() + defer lock.RUnlock() + session := existingSession + return &session, nil + }, + SaveFunc: func(_ http.ResponseWriter, _ *http.Request, session *sessionsapi.SessionState) error { + lock.Lock() + defer lock.Unlock() + existingSession = *session + return nil + }, + } refreshedChan := make(chan bool, in.numConcReqs) for i := 0; i < in.numConcReqs; i++ { go func(refreshedChan chan bool, lockConc sessionsapi.Lock) { - existingSession := *in.existingSession // deep copy existingSession state - existingSession.Lock = lockConc - store := &fakeSessionStore{ - LoadFunc: func(req *http.Request) (*sessionsapi.SessionState, error) { - return &existingSession, nil - }, - SaveFunc: func(http.ResponseWriter, *http.Request, *sessionsapi.SessionState) error { - return nil - }, - } - scope := &middlewareapi.RequestScope{ Session: nil, } @@ -461,13 +438,13 @@ var _ = Describe("Stored Session Suite", func() { Context("refreshSessionIfNeeded", func() { type refreshSessionIfNeededTableInput struct { - refreshPeriod time.Duration - sessionStored bool - session *sessionsapi.SessionState - expectedErr error - expectRefreshed bool - expectValidated bool - expectedLockState TestLock + refreshPeriod time.Duration + session *sessionsapi.SessionState + concurrentSessionRefresh bool + expectedErr error + expectRefreshed bool + expectValidated bool + expectedLockObtained bool } createdPast := time.Now().Add(-5 * time.Minute) @@ -478,13 +455,21 @@ var _ = Describe("Stored Session Suite", func() { refreshed := false validated := false - store := &fakeSessionStore{} - if in.sessionStored { - store = &fakeSessionStore{ - LoadFunc: func(req *http.Request) (*sessionsapi.SessionState, error) { - return in.session, nil - }, - } + session := &sessionsapi.SessionState{} + *session = *in.session + if in.concurrentSessionRefresh { + // Update the session that Load returns. + // This simulates a concurrent refresh in the background. + session.CreatedAt = &createdFuture + } + store := &fakeSessionStore{ + LoadFunc: func(req *http.Request) (*sessionsapi.SessionState, error) { + return session, nil + }, + SaveFunc: func(_ http.ResponseWriter, _ *http.Request, s *sessionsapi.SessionState) error { + *session = *s + return nil + }, } s := &storedSessionLoader{ @@ -518,135 +503,91 @@ var _ = Describe("Stored Session Suite", func() { } Expect(refreshed).To(Equal(in.expectRefreshed)) Expect(validated).To(Equal(in.expectValidated)) - testLock, ok := in.session.Lock.(*TestLock) + testLock, ok := in.session.Lock.(*testLock) Expect(ok).To(Equal(true)) - Expect(testLock).To(Equal(&in.expectedLockState)) + if in.expectedLockObtained { + Expect(testLock.obtainAttempts).Should(BeNumerically(">", 0), "Expected at least one attempt at obtaining the session lock") + } + Expect(testLock.locked).To(BeFalse(), "Expected lock should always be released") + // Expect(testLock).To(Equal(&in.expectedLockState)) }, Entry("when the refresh period is 0, and the session does not need refreshing", refreshSessionIfNeededTableInput{ refreshPeriod: time.Duration(0), session: &sessionsapi.SessionState{ RefreshToken: refresh, CreatedAt: &createdFuture, - Lock: &TestLock{}, + Lock: &testLock{}, }, - expectedErr: nil, - expectRefreshed: false, - expectValidated: false, - expectedLockState: TestLock{}, + expectedErr: nil, + expectRefreshed: false, + expectValidated: false, + expectedLockObtained: false, }), Entry("when the refresh period is 0, and the session needs refreshing", refreshSessionIfNeededTableInput{ refreshPeriod: time.Duration(0), session: &sessionsapi.SessionState{ RefreshToken: refresh, CreatedAt: &createdPast, - Lock: &TestLock{}, + Lock: &testLock{}, }, - expectedErr: nil, - expectRefreshed: false, - expectValidated: false, - expectedLockState: TestLock{}, + expectedErr: nil, + expectRefreshed: false, + expectValidated: false, + expectedLockObtained: false, }), Entry("when the session does not need refreshing", refreshSessionIfNeededTableInput{ refreshPeriod: 1 * time.Minute, session: &sessionsapi.SessionState{ RefreshToken: refresh, CreatedAt: &createdFuture, - Lock: &TestLock{}, + Lock: &testLock{}, }, - expectedErr: nil, - expectRefreshed: false, - expectValidated: false, - expectedLockState: TestLock{}, + expectedErr: nil, + expectRefreshed: false, + expectValidated: false, + expectedLockObtained: false, }), Entry("when the session is refreshed by the provider", refreshSessionIfNeededTableInput{ refreshPeriod: 1 * time.Minute, session: &sessionsapi.SessionState{ RefreshToken: refresh, CreatedAt: &createdPast, - Lock: &TestLock{}, - }, - expectedErr: nil, - expectRefreshed: true, - expectValidated: true, - expectedLockState: TestLock{ - Locked: false, - WasObtained: true, - WasReleased: true, - PeekedCount: 1, - }, - }), - PEntry("when the session is locked and instead loaded from storage", refreshSessionIfNeededTableInput{ - refreshPeriod: 1 * time.Minute, - session: &sessionsapi.SessionState{ - RefreshToken: noRefresh, - CreatedAt: &createdPast, - Lock: &TestLock{ - Locked: true, - }, - }, - sessionStored: true, - expectedErr: nil, - expectRefreshed: false, - expectValidated: true, - expectedLockState: TestLock{ - Locked: false, - PeekedCount: 2, + Lock: &testLock{}, }, + expectedErr: nil, + expectRefreshed: true, + expectValidated: true, + expectedLockObtained: true, }), Entry("when obtaining lock failed, but concurrent request refreshed", refreshSessionIfNeededTableInput{ refreshPeriod: 1 * time.Minute, session: &sessionsapi.SessionState{ RefreshToken: noRefresh, CreatedAt: &createdPast, - Lock: &TestLock{ - ObtainError: errors.New("not able to obtain lock"), - LockedOnPeekCount: 2, + Lock: &testLock{ + obtainOnAttempt: 4, }, }, - expectedErr: nil, - expectRefreshed: false, - expectValidated: true, - expectedLockState: TestLock{ - PeekedCount: 3, - LockedOnPeekCount: 2, - ObtainError: errors.New("not able to obtain lock"), - }, + concurrentSessionRefresh: true, + expectedErr: nil, + expectRefreshed: false, + expectValidated: false, + expectedLockObtained: true, }), - Entry("when obtaining lock failed with a valid session", refreshSessionIfNeededTableInput{ + Entry("when obtaining lock failed", refreshSessionIfNeededTableInput{ refreshPeriod: 1 * time.Minute, session: &sessionsapi.SessionState{ RefreshToken: noRefresh, CreatedAt: &createdPast, - Lock: &TestLock{ - ObtainError: errors.New("not able to obtain lock"), + Lock: &testLock{ + obtainError: sessionsapi.ErrLockNotObtained, }, }, - expectedErr: nil, - expectRefreshed: false, - expectValidated: true, - expectedLockState: TestLock{ - PeekedCount: 2, - ObtainError: errors.New("not able to obtain lock"), - }, - }), - Entry("when obtaining lock failed with an invalid session", refreshSessionIfNeededTableInput{ - refreshPeriod: 1 * time.Minute, - session: &sessionsapi.SessionState{ - RefreshToken: noRefresh, - CreatedAt: &createdPast, - ExpiresOn: &createdPast, - Lock: &TestLock{ - ObtainError: errors.New("not able to obtain lock"), - }, - }, - expectedErr: nil, - expectRefreshed: true, - expectValidated: false, - expectedLockState: TestLock{ - PeekedCount: 2, - ObtainError: errors.New("not able to obtain lock"), - }, + expectedErr: errors.New("timeout obtaining session lock"), + expectRefreshed: false, + expectValidated: false, + expectedLockObtained: true, }), Entry("when the session is not refreshed by the provider", refreshSessionIfNeededTableInput{ refreshPeriod: 1 * time.Minute, @@ -654,34 +595,24 @@ var _ = Describe("Stored Session Suite", func() { RefreshToken: noRefresh, CreatedAt: &createdPast, ExpiresOn: &createdFuture, - Lock: &TestLock{}, - }, - expectedErr: nil, - expectRefreshed: true, - expectValidated: true, - expectedLockState: TestLock{ - Locked: false, - WasObtained: true, - WasReleased: true, - PeekedCount: 1, + Lock: &testLock{}, }, + expectedErr: nil, + expectRefreshed: true, + expectValidated: true, + expectedLockObtained: true, }), Entry("when the provider doesn't implement refresh", refreshSessionIfNeededTableInput{ refreshPeriod: 1 * time.Minute, session: &sessionsapi.SessionState{ RefreshToken: notImplemented, CreatedAt: &createdPast, - Lock: &TestLock{}, - }, - expectedErr: nil, - expectRefreshed: true, - expectValidated: true, - expectedLockState: TestLock{ - Locked: false, - WasObtained: true, - WasReleased: true, - PeekedCount: 1, + Lock: &testLock{}, }, + expectedErr: nil, + expectRefreshed: true, + expectValidated: true, + expectedLockObtained: true, }), Entry("when the session is not refreshed by the provider", refreshSessionIfNeededTableInput{ refreshPeriod: 1 * time.Minute, @@ -690,17 +621,12 @@ var _ = Describe("Stored Session Suite", func() { RefreshToken: noRefresh, CreatedAt: &createdPast, ExpiresOn: &createdFuture, - Lock: &TestLock{}, - }, - expectedErr: errors.New("session is invalid"), - expectRefreshed: true, - expectValidated: true, - expectedLockState: TestLock{ - Locked: false, - WasObtained: true, - WasReleased: true, - PeekedCount: 1, + Lock: &testLock{}, }, + expectedErr: errors.New("session is invalid"), + expectRefreshed: true, + expectValidated: true, + expectedLockObtained: true, }), ) })