From e2c7ff6ddd874faef4e0644f0709ce39d7a5440f Mon Sep 17 00:00:00 2001 From: Kevin Kreitner Date: Tue, 19 Jan 2021 17:28:58 +0100 Subject: [PATCH 1/3] Use session to lock to protect concurrent refreshes --- pkg/middleware/stored_session.go | 109 ++++++++- pkg/middleware/stored_session_test.go | 338 ++++++++++++++++++++++++-- 2 files changed, 409 insertions(+), 38 deletions(-) diff --git a/pkg/middleware/stored_session.go b/pkg/middleware/stored_session.go index 4cbc47eb..0fb1e645 100644 --- a/pkg/middleware/stored_session.go +++ b/pkg/middleware/stored_session.go @@ -14,6 +14,11 @@ import ( "github.com/oauth2-proxy/oauth2-proxy/v7/providers" ) +const ( + SessionLockExpireTime = 5 * time.Second + SessionLockPeekDelay = 50 * time.Millisecond +) + // StoredSessionLoaderOptions contains all of the requirements to construct // a stored session loader. // All options must be provided. @@ -91,13 +96,10 @@ func (s *storedSessionLoader) loadSession(next http.Handler) http.Handler { // that is is valid. func (s *storedSessionLoader) getValidatedSession(rw http.ResponseWriter, req *http.Request) (*sessionsapi.SessionState, error) { session, err := s.store.Load(req) - if err != nil { + if err != nil || session == nil { + // No session was found in the storage or error occurred, nothing more to do return nil, err } - if session == nil { - // No session was found in the storage, nothing more to do - return nil, nil - } err = s.refreshSessionIfNeeded(rw, req, session) if err != nil { @@ -116,12 +118,21 @@ func (s *storedSessionLoader) refreshSessionIfNeeded(rw http.ResponseWriter, req return nil } - logger.Printf("Refreshing session - User: %s; SessionAge: %s", session.User, session.Age()) - err := s.refreshSession(rw, req, session) + wasRefreshed, err := s.checkForConcurrentRefresh(session, req) if err != nil { - // If a preemptive refresh fails, we still keep the session - // if validateSession succeeds. - logger.Errorf("Unable to refresh session: %v", err) + return err + } + + // If session was already refreshed via a concurrent request locked skip refreshing, + // because the refreshed session is already loaded from storage + if !wasRefreshed { + logger.Printf("Refreshing session - User: %s; SessionAge: %s", session.User, session.Age()) + err = s.refreshSession(rw, req, session) + if err != nil { + // If a preemptive refresh fails, we still keep the session + // if validateSession succeeds. + logger.Errorf("Unable to refresh session: %v", err) + } } // Validate all sessions after any Redeem/Refresh operation (fail or success) @@ -131,6 +142,18 @@ func (s *storedSessionLoader) refreshSessionIfNeeded(rw http.ResponseWriter, req // refreshSession attempts to refresh the session with the provider // and will save the session if it was updated. func (s *storedSessionLoader) refreshSession(rw http.ResponseWriter, req *http.Request, session *sessionsapi.SessionState) error { + err := session.ObtainLock(req.Context(), SessionLockExpireTime) + if err != nil { + logger.Errorf("Unable to obtain lock: %v", err) + return s.handleObtainLockError(req, session) + } + defer func() { + err = session.ReleaseLock(req.Context()) + if err != nil { + logger.Errorf("unable to release lock: %v", err) + } + }() + refreshed, err := s.sessionRefresher(req.Context(), session) if err != nil && !errors.Is(err, providers.ErrNotImplemented) { return fmt.Errorf("error refreshing tokens: %v", err) @@ -159,11 +182,75 @@ func (s *storedSessionLoader) refreshSession(rw http.ResponseWriter, req *http.R err = s.store.Save(rw, req, session) if err != nil { logger.PrintAuthf(session.Email, req, logger.AuthError, "error saving session: %v", err) - return fmt.Errorf("error saving session: %v", err) + err = fmt.Errorf("error saving session: %v", err) + } + return err +} + +func (s *storedSessionLoader) handleObtainLockError(req *http.Request, session *sessionsapi.SessionState) error { + wasRefreshed, err := s.checkForConcurrentRefresh(session, req) + if err != nil { + logger.Errorf("Unable to wait for obtained lock: %v", err) + return err + } + if !wasRefreshed { + return errors.New("unable to obtain lock and session was also not refreshed via concurrent request") } return nil } +func (s *storedSessionLoader) updateSessionFromStore(req *http.Request, session *sessionsapi.SessionState) error { + sessionStored, err := s.store.Load(req) + if err != nil { + return fmt.Errorf("unable to load updated session from store: %v", err) + } + + if sessionStored == nil { + return fmt.Errorf("no session available to udpate from store") + } + *session = *sessionStored + + return nil +} + +func (s *storedSessionLoader) waitForPossibleSessionLock(session *sessionsapi.SessionState, req *http.Request) (bool, error) { + var wasLocked bool + isLocked, err := session.PeekLock(req.Context()) + for isLocked { + wasLocked = true + // delay next peek lock + time.Sleep(SessionLockPeekDelay) + isLocked, err = session.PeekLock(req.Context()) + } + + if err != nil { + return false, err + } + + return wasLocked, nil +} + +// checkForConcurrentRefresh returns true if the session is already refreshed via a concurrent request. +func (s *storedSessionLoader) checkForConcurrentRefresh(session *sessionsapi.SessionState, req *http.Request) (bool, error) { + wasLocked, err := s.waitForPossibleSessionLock(session, req) + if err != nil { + return false, err + } + + refreshed := false + if wasLocked { + logger.Printf("Update session from store instead of refreshing") + err = s.updateSessionFromStore(req, session) + if err != nil { + logger.Errorf("Unable to update session from store: %v", err) + return false, err + } + refreshed = true + } + + return refreshed, nil +} + // validateSession checks whether the session has expired and performs // provider validation on the session. // An error implies the session is not longer valid. diff --git a/pkg/middleware/stored_session_test.go b/pkg/middleware/stored_session_test.go index 782390b6..b1f2ef70 100644 --- a/pkg/middleware/stored_session_test.go +++ b/pkg/middleware/stored_session_test.go @@ -6,6 +6,7 @@ import ( "fmt" "net/http" "net/http/httptest" + "sync" "time" middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" @@ -17,9 +18,104 @@ import ( . "github.com/onsi/gomega" ) +type TestLock struct { + Locked bool + WasObtained bool + WasRefreshed bool + WasReleased bool + PeekedCount int + LockedOnPeekCount int + ObtainError error + PeekError error + RefreshError error + ReleaseError error +} + +func (l *TestLock) Obtain(_ context.Context, _ time.Duration) error { + if l.ObtainError != nil { + return l.ObtainError + } + l.Locked = true + l.WasObtained = true + return nil +} + +func (l *TestLock) Peek(_ context.Context) (bool, error) { + if l.PeekError != nil { + return false, l.PeekError + } + locked := l.Locked + l.Locked = false + l.PeekedCount++ + // mainly used to test case when peek initially returns false, + // but when trying to obtain lock, it returns true. + if l.LockedOnPeekCount == l.PeekedCount { + return true, nil + } + return locked, nil +} + +func (l *TestLock) Refresh(_ context.Context, _ time.Duration) error { + if l.RefreshError != nil { + return l.ReleaseError + } + l.WasRefreshed = true + return nil +} + +func (l *TestLock) Release(_ context.Context) error { + if l.ReleaseError != nil { + return l.ReleaseError + } + l.Locked = false + l.WasReleased = true + return nil +} + +type LockConc struct { + mu sync.Mutex + lock bool + disablePeek bool +} + +func (l *LockConc) Obtain(_ context.Context, _ time.Duration) error { + l.mu.Lock() + if l.lock { + l.mu.Unlock() + return sessionsapi.ErrLockNotObtained + } + l.lock = true + l.mu.Unlock() + return nil +} + +func (l *LockConc) Peek(_ context.Context) (bool, error) { + var response bool + l.mu.Lock() + if l.disablePeek { + response = false + } else { + response = l.lock + } + l.mu.Unlock() + return response, nil +} + +func (l *LockConc) Refresh(_ context.Context, _ time.Duration) error { + return nil +} + +func (l *LockConc) Release(_ context.Context) error { + l.mu.Lock() + l.lock = false + l.mu.Unlock() + return nil +} + var _ = Describe("Stored Session Suite", func() { const ( refresh = "Refresh" + refreshed = "Refreshed" noRefresh = "NoRefresh" notImplemented = "NotImplemented" ) @@ -34,7 +130,7 @@ var _ = Describe("Stored Session Suite", func() { var defaultRefreshFunc = func(_ context.Context, ss *sessionsapi.SessionState) (bool, error) { switch ss.RefreshToken { case refresh: - ss.RefreshToken = "Refreshed" + ss.RefreshToken = refreshed return true, nil case noRefresh: return false, nil @@ -181,6 +277,7 @@ var _ = Describe("Stored Session Suite", func() { RefreshToken: noRefresh, CreatedAt: &createdPast, ExpiresOn: &createdFuture, + Lock: &sessionsapi.NoOpLock{}, }, store: defaultSessionStore, refreshPeriod: 1 * time.Minute, @@ -222,6 +319,7 @@ var _ = Describe("Stored Session Suite", func() { RefreshToken: "Refreshed", CreatedAt: &now, ExpiresOn: &createdFuture, + Lock: &sessionsapi.NoOpLock{}, }, store: defaultSessionStore, refreshPeriod: 1 * time.Minute, @@ -237,6 +335,7 @@ var _ = Describe("Stored Session Suite", func() { RefreshToken: "RefreshError", CreatedAt: &createdPast, ExpiresOn: &createdFuture, + Lock: &sessionsapi.NoOpLock{}, }, store: defaultSessionStore, refreshPeriod: 1 * time.Minute, @@ -266,15 +365,109 @@ var _ = Describe("Stored Session Suite", func() { validateSession: defaultValidateFunc, }), ) + + type storedSessionLoaderConcurrentTableInput struct { + existingSession *sessionsapi.SessionState + refreshPeriod time.Duration + numConcReqs int + } + + DescribeTable("when serving concurrent requests", + func(in storedSessionLoaderConcurrentTableInput) { + lockConc := &LockConc{} + + refreshedChan := make(chan bool, in.numConcReqs) + for i := 0; i < in.numConcReqs; i++ { + go func(refreshedChan chan bool, lockConc sessionsapi.Lock) { + existingSession := *in.existingSession // deep copy existingSession state + existingSession.Lock = lockConc + store := &fakeSessionStore{ + LoadFunc: func(req *http.Request) (*sessionsapi.SessionState, error) { + return &existingSession, nil + }, + SaveFunc: func(http.ResponseWriter, *http.Request, *sessionsapi.SessionState) error { + return nil + }, + } + + scope := &middlewareapi.RequestScope{ + Session: nil, + } + + // Set up the request with the request header and a request scope + req := httptest.NewRequest("", "/", nil) + req = middlewareapi.AddRequestScope(req, scope) + + rw := httptest.NewRecorder() + + sessionRefreshed := false + opts := &StoredSessionLoaderOptions{ + SessionStore: store, + RefreshPeriod: in.refreshPeriod, + RefreshSession: func(ctx context.Context, s *sessionsapi.SessionState) (bool, error) { + time.Sleep(10 * time.Millisecond) + sessionRefreshed = true + return true, nil + }, + ValidateSession: func(context.Context, *sessionsapi.SessionState) bool { + return true + }, + } + + handler := NewStoredSessionLoader(opts)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) + handler.ServeHTTP(rw, req) + + refreshedChan <- sessionRefreshed + }(refreshedChan, lockConc) + } + var refreshedSlice []bool + for i := 0; i < in.numConcReqs; i++ { + refreshedSlice = append(refreshedSlice, <-refreshedChan) + } + sessionRefreshedCount := 0 + for _, sessionRefreshed := range refreshedSlice { + if sessionRefreshed { + sessionRefreshedCount++ + } + } + Expect(sessionRefreshedCount).To(Equal(1)) + }, + Entry("with two concurrent requests", storedSessionLoaderConcurrentTableInput{ + existingSession: &sessionsapi.SessionState{ + RefreshToken: refresh, + CreatedAt: &createdPast, + }, + numConcReqs: 2, + refreshPeriod: 1 * time.Minute, + }), + Entry("with 5 concurrent requests", storedSessionLoaderConcurrentTableInput{ + existingSession: &sessionsapi.SessionState{ + RefreshToken: refresh, + CreatedAt: &createdPast, + }, + numConcReqs: 5, + refreshPeriod: 1 * time.Minute, + }), + Entry("with one request", storedSessionLoaderConcurrentTableInput{ + existingSession: &sessionsapi.SessionState{ + RefreshToken: refresh, + CreatedAt: &createdPast, + }, + numConcReqs: 1, + refreshPeriod: 1 * time.Minute, + }), + ) }) Context("refreshSessionIfNeeded", func() { type refreshSessionIfNeededTableInput struct { - refreshPeriod time.Duration - session *sessionsapi.SessionState - expectedErr error - expectRefreshed bool - expectValidated bool + refreshPeriod time.Duration + sessionStored bool + session *sessionsapi.SessionState + expectedErr error + expectRefreshed bool + expectValidated bool + expectedLockState TestLock } createdPast := time.Now().Add(-5 * time.Minute) @@ -285,9 +478,18 @@ var _ = Describe("Stored Session Suite", func() { refreshed := false validated := false + store := &fakeSessionStore{} + if in.sessionStored { + store = &fakeSessionStore{ + LoadFunc: func(req *http.Request) (*sessionsapi.SessionState, error) { + return in.session, nil + }, + } + } + s := &storedSessionLoader{ refreshPeriod: in.refreshPeriod, - store: &fakeSessionStore{}, + store: store, sessionRefresher: func(_ context.Context, ss *sessionsapi.SessionState) (bool, error) { refreshed = true switch ss.RefreshToken { @@ -316,46 +518,117 @@ var _ = Describe("Stored Session Suite", func() { } Expect(refreshed).To(Equal(in.expectRefreshed)) Expect(validated).To(Equal(in.expectValidated)) + testLock, ok := in.session.Lock.(*TestLock) + Expect(ok).To(Equal(true)) + + Expect(testLock).To(Equal(&in.expectedLockState)) }, Entry("when the refresh period is 0, and the session does not need refreshing", refreshSessionIfNeededTableInput{ refreshPeriod: time.Duration(0), session: &sessionsapi.SessionState{ RefreshToken: refresh, CreatedAt: &createdFuture, + Lock: &TestLock{}, }, - expectedErr: nil, - expectRefreshed: false, - expectValidated: false, + expectedErr: nil, + expectRefreshed: false, + expectValidated: false, + expectedLockState: TestLock{}, }), Entry("when the refresh period is 0, and the session needs refreshing", refreshSessionIfNeededTableInput{ refreshPeriod: time.Duration(0), session: &sessionsapi.SessionState{ RefreshToken: refresh, CreatedAt: &createdPast, + Lock: &TestLock{}, }, - expectedErr: nil, - expectRefreshed: false, - expectValidated: false, + expectedErr: nil, + expectRefreshed: false, + expectValidated: false, + expectedLockState: TestLock{}, }), Entry("when the session does not need refreshing", refreshSessionIfNeededTableInput{ refreshPeriod: 1 * time.Minute, session: &sessionsapi.SessionState{ RefreshToken: refresh, CreatedAt: &createdFuture, + Lock: &TestLock{}, }, - expectedErr: nil, - expectRefreshed: false, - expectValidated: false, + expectedErr: nil, + expectRefreshed: false, + expectValidated: false, + expectedLockState: TestLock{}, }), Entry("when the session is refreshed by the provider", refreshSessionIfNeededTableInput{ refreshPeriod: 1 * time.Minute, session: &sessionsapi.SessionState{ RefreshToken: refresh, CreatedAt: &createdPast, + Lock: &TestLock{}, }, expectedErr: nil, expectRefreshed: true, expectValidated: true, + expectedLockState: TestLock{ + Locked: false, + WasObtained: true, + WasReleased: true, + PeekedCount: 1, + }, + }), + Entry("when the session is locked and instead loaded from storage", refreshSessionIfNeededTableInput{ + refreshPeriod: 1 * time.Minute, + session: &sessionsapi.SessionState{ + RefreshToken: noRefresh, + CreatedAt: &createdPast, + Lock: &TestLock{ + Locked: true, + }, + }, + sessionStored: true, + expectedErr: nil, + expectRefreshed: false, + expectValidated: true, + expectedLockState: TestLock{ + Locked: false, + PeekedCount: 2, + }, + }), + Entry("when obtaining lock failed, but concurrent request refreshed", refreshSessionIfNeededTableInput{ + refreshPeriod: 1 * time.Minute, + session: &sessionsapi.SessionState{ + RefreshToken: noRefresh, + CreatedAt: &createdPast, + Lock: &TestLock{ + ObtainError: errors.New("not able to obtain lock"), + LockedOnPeekCount: 2, + }, + }, + expectedErr: nil, + expectRefreshed: false, + expectValidated: true, + expectedLockState: TestLock{ + PeekedCount: 3, + LockedOnPeekCount: 2, + ObtainError: errors.New("not able to obtain lock"), + }, + }), + Entry("when obtaining lock failed", refreshSessionIfNeededTableInput{ + refreshPeriod: 1 * time.Minute, + session: &sessionsapi.SessionState{ + RefreshToken: noRefresh, + CreatedAt: &createdPast, + Lock: &TestLock{ + ObtainError: errors.New("not able to obtain lock"), + }, + }, + expectedErr: nil, + expectRefreshed: false, + expectValidated: true, + expectedLockState: TestLock{ + PeekedCount: 2, + ObtainError: errors.New("not able to obtain lock"), + }, }), Entry("when the session is not refreshed by the provider", refreshSessionIfNeededTableInput{ refreshPeriod: 1 * time.Minute, @@ -363,42 +636,53 @@ var _ = Describe("Stored Session Suite", func() { RefreshToken: noRefresh, CreatedAt: &createdPast, ExpiresOn: &createdFuture, + Lock: &TestLock{}, }, expectedErr: nil, expectRefreshed: true, expectValidated: true, + expectedLockState: TestLock{ + Locked: false, + WasObtained: true, + WasReleased: true, + PeekedCount: 1, + }, }), - Entry("when the provider doesn't implement refresh but validation succeeds", refreshSessionIfNeededTableInput{ + Entry("when the provider doesn't implement refresh", refreshSessionIfNeededTableInput{ refreshPeriod: 1 * time.Minute, session: &sessionsapi.SessionState{ RefreshToken: notImplemented, CreatedAt: &createdPast, + Lock: &TestLock{}, }, expectedErr: nil, expectRefreshed: true, expectValidated: true, - }), - Entry("when the provider refresh fails but validation succeeds", refreshSessionIfNeededTableInput{ - refreshPeriod: 1 * time.Minute, - session: &sessionsapi.SessionState{ - RefreshToken: "RefreshError", - CreatedAt: &createdPast, + expectedLockState: TestLock{ + Locked: false, + WasObtained: true, + WasReleased: true, + PeekedCount: 1, }, - expectedErr: nil, - expectRefreshed: true, - expectValidated: true, }), - Entry("when the session is not refreshed by the provider and validation fails", refreshSessionIfNeededTableInput{ + Entry("when the session is not refreshed by the provider", refreshSessionIfNeededTableInput{ refreshPeriod: 1 * time.Minute, session: &sessionsapi.SessionState{ AccessToken: "Invalid", RefreshToken: noRefresh, CreatedAt: &createdPast, ExpiresOn: &createdFuture, + Lock: &TestLock{}, }, expectedErr: errors.New("session is invalid"), expectRefreshed: true, expectValidated: true, + expectedLockState: TestLock{ + Locked: false, + WasObtained: true, + WasReleased: true, + PeekedCount: 1, + }, }), ) }) From 54d42c58298aa2d041b51a67d3a50617bb7e7e7c Mon Sep 17 00:00:00 2001 From: Joel Speed Date: Sun, 5 Dec 2021 23:56:08 +0000 Subject: [PATCH 2/3] Implement refresh relying on obtaining lock --- pkg/middleware/stored_session.go | 171 ++++++------- pkg/middleware/stored_session_test.go | 353 +++++++++++--------------- 2 files changed, 228 insertions(+), 296 deletions(-) diff --git a/pkg/middleware/stored_session.go b/pkg/middleware/stored_session.go index 0fb1e645..1afe6d0c 100644 --- a/pkg/middleware/stored_session.go +++ b/pkg/middleware/stored_session.go @@ -15,8 +15,20 @@ import ( ) const ( - SessionLockExpireTime = 5 * time.Second - SessionLockPeekDelay = 50 * time.Millisecond + // When attempting to obtain the lock, if it's not done before this timeout + // then exit and fail the refresh attempt. + // TODO: This should probably be configurable by the end user. + sessionRefreshObtainTimeout = 5 * time.Second + + // Maximum time allowed for a session refresh attempt. + // If the refresh request isn't finished within this time, the lock will be + // released. + // TODO: This should probably be configurable by the end user. + sessionRefreshLockDuration = 2 * time.Second + + // How long to wait after failing to obtain the lock before trying again. + // TODO: This should probably be configurable by the end user. + sessionRefreshRetryPeriod = 10 * time.Millisecond ) // StoredSessionLoaderOptions contains all of the requirements to construct @@ -113,47 +125,86 @@ func (s *storedSessionLoader) getValidatedSession(rw http.ResponseWriter, req *h // is older than the refresh period. // Success or fail, we will then validate the session. func (s *storedSessionLoader) refreshSessionIfNeeded(rw http.ResponseWriter, req *http.Request, session *sessionsapi.SessionState) error { - if s.refreshPeriod <= time.Duration(0) || session.Age() < s.refreshPeriod { + if !needsRefresh(s.refreshPeriod, session) { // Refresh is disabled or the session is not old enough, do nothing return nil } - wasRefreshed, err := s.checkForConcurrentRefresh(session, req) - if err != nil { - return err + var lockObtained bool + ctx, cancel := context.WithTimeout(context.Background(), sessionRefreshObtainTimeout) + defer cancel() + + for !lockObtained { + select { + case <-ctx.Done(): + return errors.New("timeout obtaining session lock") + default: + err := session.ObtainLock(req.Context(), sessionRefreshLockDuration) + if err != nil && !errors.Is(err, sessionsapi.ErrLockNotObtained) { + return fmt.Errorf("error occurred while trying to obtain lock: %v", err) + } else if errors.Is(err, sessionsapi.ErrLockNotObtained) { + time.Sleep(sessionRefreshRetryPeriod) + continue + } + // No error means we obtained the lock + lockObtained = true + } } - // If session was already refreshed via a concurrent request locked skip refreshing, - // because the refreshed session is already loaded from storage - if !wasRefreshed { - logger.Printf("Refreshing session - User: %s; SessionAge: %s", session.User, session.Age()) - err = s.refreshSession(rw, req, session) - if err != nil { - // If a preemptive refresh fails, we still keep the session - // if validateSession succeeds. - logger.Errorf("Unable to refresh session: %v", err) + // The rest of this function is carried out under lock, but we must release it + // wherever we exit from this function. + defer func() { + if session == nil { + return } + if err := session.ReleaseLock(req.Context()); err != nil { + logger.Errorf("unable to release lock: %v", err) + } + }() + + // Reload the session in case it was changed underneath us. + freshSession, err := s.store.Load(req) + if err != nil { + return fmt.Errorf("could not load session: %v", err) + } + if freshSession == nil { + return errors.New("session no longer exists, it may have been removed by another request") + } + // Restore the state of the fresh session into the original pointer. + // This is important so that changes are passed up the to the parent scope. + lock := session.Lock + *session = *freshSession + + // Ensure we maintain the session lock after we have refreshed the session. + // Loading from the session store creates a new lock in the session. + session.Lock = lock + + if !needsRefresh(s.refreshPeriod, session) { + // The session must have already been refreshed while we were waiting to + // obtain the lock. + return nil + } + + // We are holding the lock and the session needs a refresh + logger.Printf("Refreshing session - User: %s; SessionAge: %s", session.User, session.Age()) + if err := s.refreshSession(rw, req, session); err != nil { + // If a preemptive refresh fails, we still keep the session + // if validateSession succeeds. + logger.Errorf("Unable to refresh session: %v", err) } // Validate all sessions after any Redeem/Refresh operation (fail or success) return s.validateSession(req.Context(), session) } +// needsRefresh determines whether we should attempt to refresh a session or not. +func needsRefresh(refreshPeriod time.Duration, session *sessionsapi.SessionState) bool { + return refreshPeriod > time.Duration(0) && session.Age() > refreshPeriod +} + // refreshSession attempts to refresh the session with the provider // and will save the session if it was updated. func (s *storedSessionLoader) refreshSession(rw http.ResponseWriter, req *http.Request, session *sessionsapi.SessionState) error { - err := session.ObtainLock(req.Context(), SessionLockExpireTime) - if err != nil { - logger.Errorf("Unable to obtain lock: %v", err) - return s.handleObtainLockError(req, session) - } - defer func() { - err = session.ReleaseLock(req.Context()) - if err != nil { - logger.Errorf("unable to release lock: %v", err) - } - }() - refreshed, err := s.sessionRefresher(req.Context(), session) if err != nil && !errors.Is(err, providers.ErrNotImplemented) { return fmt.Errorf("error refreshing tokens: %v", err) @@ -182,75 +233,11 @@ func (s *storedSessionLoader) refreshSession(rw http.ResponseWriter, req *http.R err = s.store.Save(rw, req, session) if err != nil { logger.PrintAuthf(session.Email, req, logger.AuthError, "error saving session: %v", err) - err = fmt.Errorf("error saving session: %v", err) - } - return err -} - -func (s *storedSessionLoader) handleObtainLockError(req *http.Request, session *sessionsapi.SessionState) error { - wasRefreshed, err := s.checkForConcurrentRefresh(session, req) - if err != nil { - logger.Errorf("Unable to wait for obtained lock: %v", err) - return err - } - if !wasRefreshed { - return errors.New("unable to obtain lock and session was also not refreshed via concurrent request") + return fmt.Errorf("error saving session: %v", err) } return nil } -func (s *storedSessionLoader) updateSessionFromStore(req *http.Request, session *sessionsapi.SessionState) error { - sessionStored, err := s.store.Load(req) - if err != nil { - return fmt.Errorf("unable to load updated session from store: %v", err) - } - - if sessionStored == nil { - return fmt.Errorf("no session available to udpate from store") - } - *session = *sessionStored - - return nil -} - -func (s *storedSessionLoader) waitForPossibleSessionLock(session *sessionsapi.SessionState, req *http.Request) (bool, error) { - var wasLocked bool - isLocked, err := session.PeekLock(req.Context()) - for isLocked { - wasLocked = true - // delay next peek lock - time.Sleep(SessionLockPeekDelay) - isLocked, err = session.PeekLock(req.Context()) - } - - if err != nil { - return false, err - } - - return wasLocked, nil -} - -// checkForConcurrentRefresh returns true if the session is already refreshed via a concurrent request. -func (s *storedSessionLoader) checkForConcurrentRefresh(session *sessionsapi.SessionState, req *http.Request) (bool, error) { - wasLocked, err := s.waitForPossibleSessionLock(session, req) - if err != nil { - return false, err - } - - refreshed := false - if wasLocked { - logger.Printf("Update session from store instead of refreshing") - err = s.updateSessionFromStore(req, session) - if err != nil { - logger.Errorf("Unable to update session from store: %v", err) - return false, err - } - refreshed = true - } - - return refreshed, nil -} - // validateSession checks whether the session has expired and performs // provider validation on the session. // An error implies the session is not longer valid. diff --git a/pkg/middleware/stored_session_test.go b/pkg/middleware/stored_session_test.go index b1f2ef70..c462278f 100644 --- a/pkg/middleware/stored_session_test.go +++ b/pkg/middleware/stored_session_test.go @@ -18,97 +18,67 @@ import ( . "github.com/onsi/gomega" ) -type TestLock struct { - Locked bool - WasObtained bool - WasRefreshed bool - WasReleased bool - PeekedCount int - LockedOnPeekCount int - ObtainError error - PeekError error - RefreshError error - ReleaseError error +type testLock struct { + locked bool + obtainOnAttempt int + obtainAttempts int + obtainError error } -func (l *TestLock) Obtain(_ context.Context, _ time.Duration) error { - if l.ObtainError != nil { - return l.ObtainError - } - l.Locked = true - l.WasObtained = true - return nil -} - -func (l *TestLock) Peek(_ context.Context) (bool, error) { - if l.PeekError != nil { - return false, l.PeekError - } - locked := l.Locked - l.Locked = false - l.PeekedCount++ - // mainly used to test case when peek initially returns false, - // but when trying to obtain lock, it returns true. - if l.LockedOnPeekCount == l.PeekedCount { - return true, nil - } - return locked, nil -} - -func (l *TestLock) Refresh(_ context.Context, _ time.Duration) error { - if l.RefreshError != nil { - return l.ReleaseError - } - l.WasRefreshed = true - return nil -} - -func (l *TestLock) Release(_ context.Context) error { - if l.ReleaseError != nil { - return l.ReleaseError - } - l.Locked = false - l.WasReleased = true - return nil -} - -type LockConc struct { - mu sync.Mutex - lock bool - disablePeek bool -} - -func (l *LockConc) Obtain(_ context.Context, _ time.Duration) error { - l.mu.Lock() - if l.lock { - l.mu.Unlock() +func (l *testLock) Obtain(_ context.Context, _ time.Duration) error { + l.obtainAttempts++ + if l.obtainAttempts < l.obtainOnAttempt { return sessionsapi.ErrLockNotObtained } - l.lock = true - l.mu.Unlock() - return nil -} - -func (l *LockConc) Peek(_ context.Context) (bool, error) { - var response bool - l.mu.Lock() - if l.disablePeek { - response = false - } else { - response = l.lock + if l.obtainError != nil { + return l.obtainError } - l.mu.Unlock() - return response, nil -} - -func (l *LockConc) Refresh(_ context.Context, _ time.Duration) error { + l.locked = true return nil } -func (l *LockConc) Release(_ context.Context) error { +func (l *testLock) Peek(_ context.Context) (bool, error) { + return l.locked, nil +} + +func (l *testLock) Refresh(_ context.Context, _ time.Duration) error { + return nil +} + +func (l *testLock) Release(_ context.Context) error { + l.locked = false + return nil +} + +type testLockConcurrent struct { + mu sync.RWMutex + locked bool +} + +func (l *testLockConcurrent) Obtain(_ context.Context, _ time.Duration) error { l.mu.Lock() - l.lock = false - l.mu.Unlock() + defer l.mu.Unlock() + if l.locked { + return sessionsapi.ErrLockNotObtained + } + l.locked = true + return nil +} + +func (l *testLockConcurrent) Peek(_ context.Context) (bool, error) { + l.mu.RLock() + defer l.mu.RUnlock() + return l.locked, nil +} + +func (l *testLockConcurrent) Refresh(_ context.Context, _ time.Duration) error { + return nil +} + +func (l *testLockConcurrent) Release(_ context.Context) error { + l.mu.Lock() + defer l.mu.Unlock() + l.locked = false return nil } @@ -374,22 +344,29 @@ var _ = Describe("Stored Session Suite", func() { DescribeTable("when serving concurrent requests", func(in storedSessionLoaderConcurrentTableInput) { - lockConc := &LockConc{} + lockConc := &testLockConcurrent{} + + lock := &sync.RWMutex{} + existingSession := *in.existingSession // deep copy existingSession state + existingSession.Lock = lockConc + store := &fakeSessionStore{ + LoadFunc: func(req *http.Request) (*sessionsapi.SessionState, error) { + lock.RLock() + defer lock.RUnlock() + session := existingSession + return &session, nil + }, + SaveFunc: func(_ http.ResponseWriter, _ *http.Request, session *sessionsapi.SessionState) error { + lock.Lock() + defer lock.Unlock() + existingSession = *session + return nil + }, + } refreshedChan := make(chan bool, in.numConcReqs) for i := 0; i < in.numConcReqs; i++ { go func(refreshedChan chan bool, lockConc sessionsapi.Lock) { - existingSession := *in.existingSession // deep copy existingSession state - existingSession.Lock = lockConc - store := &fakeSessionStore{ - LoadFunc: func(req *http.Request) (*sessionsapi.SessionState, error) { - return &existingSession, nil - }, - SaveFunc: func(http.ResponseWriter, *http.Request, *sessionsapi.SessionState) error { - return nil - }, - } - scope := &middlewareapi.RequestScope{ Session: nil, } @@ -461,13 +438,13 @@ var _ = Describe("Stored Session Suite", func() { Context("refreshSessionIfNeeded", func() { type refreshSessionIfNeededTableInput struct { - refreshPeriod time.Duration - sessionStored bool - session *sessionsapi.SessionState - expectedErr error - expectRefreshed bool - expectValidated bool - expectedLockState TestLock + refreshPeriod time.Duration + session *sessionsapi.SessionState + concurrentSessionRefresh bool + expectedErr error + expectRefreshed bool + expectValidated bool + expectedLockObtained bool } createdPast := time.Now().Add(-5 * time.Minute) @@ -478,13 +455,23 @@ var _ = Describe("Stored Session Suite", func() { refreshed := false validated := false - store := &fakeSessionStore{} - if in.sessionStored { - store = &fakeSessionStore{ - LoadFunc: func(req *http.Request) (*sessionsapi.SessionState, error) { - return in.session, nil - }, - } + session := &sessionsapi.SessionState{} + *session = *in.session + if in.concurrentSessionRefresh { + // Update the session that Load returns. + // This simulates a concurrent refresh in the background. + session.CreatedAt = &createdFuture + } + store := &fakeSessionStore{ + LoadFunc: func(req *http.Request) (*sessionsapi.SessionState, error) { + // Loading the session from the provider creates a new lock + session.Lock = &testLock{} + return session, nil + }, + SaveFunc: func(_ http.ResponseWriter, _ *http.Request, s *sessionsapi.SessionState) error { + *session = *s + return nil + }, } s := &storedSessionLoader{ @@ -518,117 +505,90 @@ var _ = Describe("Stored Session Suite", func() { } Expect(refreshed).To(Equal(in.expectRefreshed)) Expect(validated).To(Equal(in.expectValidated)) - testLock, ok := in.session.Lock.(*TestLock) + testLock, ok := in.session.Lock.(*testLock) Expect(ok).To(Equal(true)) - Expect(testLock).To(Equal(&in.expectedLockState)) + if in.expectedLockObtained { + Expect(testLock.obtainAttempts).Should(BeNumerically(">", 0), "Expected at least one attempt at obtaining the session lock") + } + Expect(testLock.locked).To(BeFalse(), "Expected lock should always be released") }, Entry("when the refresh period is 0, and the session does not need refreshing", refreshSessionIfNeededTableInput{ refreshPeriod: time.Duration(0), session: &sessionsapi.SessionState{ RefreshToken: refresh, CreatedAt: &createdFuture, - Lock: &TestLock{}, + Lock: &testLock{}, }, - expectedErr: nil, - expectRefreshed: false, - expectValidated: false, - expectedLockState: TestLock{}, + expectedErr: nil, + expectRefreshed: false, + expectValidated: false, + expectedLockObtained: false, }), Entry("when the refresh period is 0, and the session needs refreshing", refreshSessionIfNeededTableInput{ refreshPeriod: time.Duration(0), session: &sessionsapi.SessionState{ RefreshToken: refresh, CreatedAt: &createdPast, - Lock: &TestLock{}, + Lock: &testLock{}, }, - expectedErr: nil, - expectRefreshed: false, - expectValidated: false, - expectedLockState: TestLock{}, + expectedErr: nil, + expectRefreshed: false, + expectValidated: false, + expectedLockObtained: false, }), Entry("when the session does not need refreshing", refreshSessionIfNeededTableInput{ refreshPeriod: 1 * time.Minute, session: &sessionsapi.SessionState{ RefreshToken: refresh, CreatedAt: &createdFuture, - Lock: &TestLock{}, + Lock: &testLock{}, }, - expectedErr: nil, - expectRefreshed: false, - expectValidated: false, - expectedLockState: TestLock{}, + expectedErr: nil, + expectRefreshed: false, + expectValidated: false, + expectedLockObtained: false, }), Entry("when the session is refreshed by the provider", refreshSessionIfNeededTableInput{ refreshPeriod: 1 * time.Minute, session: &sessionsapi.SessionState{ RefreshToken: refresh, CreatedAt: &createdPast, - Lock: &TestLock{}, - }, - expectedErr: nil, - expectRefreshed: true, - expectValidated: true, - expectedLockState: TestLock{ - Locked: false, - WasObtained: true, - WasReleased: true, - PeekedCount: 1, - }, - }), - Entry("when the session is locked and instead loaded from storage", refreshSessionIfNeededTableInput{ - refreshPeriod: 1 * time.Minute, - session: &sessionsapi.SessionState{ - RefreshToken: noRefresh, - CreatedAt: &createdPast, - Lock: &TestLock{ - Locked: true, - }, - }, - sessionStored: true, - expectedErr: nil, - expectRefreshed: false, - expectValidated: true, - expectedLockState: TestLock{ - Locked: false, - PeekedCount: 2, + Lock: &testLock{}, }, + expectedErr: nil, + expectRefreshed: true, + expectValidated: true, + expectedLockObtained: true, }), Entry("when obtaining lock failed, but concurrent request refreshed", refreshSessionIfNeededTableInput{ refreshPeriod: 1 * time.Minute, session: &sessionsapi.SessionState{ RefreshToken: noRefresh, CreatedAt: &createdPast, - Lock: &TestLock{ - ObtainError: errors.New("not able to obtain lock"), - LockedOnPeekCount: 2, + Lock: &testLock{ + obtainOnAttempt: 4, }, }, - expectedErr: nil, - expectRefreshed: false, - expectValidated: true, - expectedLockState: TestLock{ - PeekedCount: 3, - LockedOnPeekCount: 2, - ObtainError: errors.New("not able to obtain lock"), - }, + concurrentSessionRefresh: true, + expectedErr: nil, + expectRefreshed: false, + expectValidated: false, + expectedLockObtained: true, }), - Entry("when obtaining lock failed", refreshSessionIfNeededTableInput{ + Entry("when obtaining lock failed with a valid session", refreshSessionIfNeededTableInput{ refreshPeriod: 1 * time.Minute, session: &sessionsapi.SessionState{ RefreshToken: noRefresh, CreatedAt: &createdPast, - Lock: &TestLock{ - ObtainError: errors.New("not able to obtain lock"), + Lock: &testLock{ + obtainError: sessionsapi.ErrLockNotObtained, }, }, - expectedErr: nil, - expectRefreshed: false, - expectValidated: true, - expectedLockState: TestLock{ - PeekedCount: 2, - ObtainError: errors.New("not able to obtain lock"), - }, + expectedErr: errors.New("timeout obtaining session lock"), + expectRefreshed: false, + expectValidated: false, + expectedLockObtained: true, }), Entry("when the session is not refreshed by the provider", refreshSessionIfNeededTableInput{ refreshPeriod: 1 * time.Minute, @@ -636,34 +596,24 @@ var _ = Describe("Stored Session Suite", func() { RefreshToken: noRefresh, CreatedAt: &createdPast, ExpiresOn: &createdFuture, - Lock: &TestLock{}, - }, - expectedErr: nil, - expectRefreshed: true, - expectValidated: true, - expectedLockState: TestLock{ - Locked: false, - WasObtained: true, - WasReleased: true, - PeekedCount: 1, + Lock: &testLock{}, }, + expectedErr: nil, + expectRefreshed: true, + expectValidated: true, + expectedLockObtained: true, }), Entry("when the provider doesn't implement refresh", refreshSessionIfNeededTableInput{ refreshPeriod: 1 * time.Minute, session: &sessionsapi.SessionState{ RefreshToken: notImplemented, CreatedAt: &createdPast, - Lock: &TestLock{}, - }, - expectedErr: nil, - expectRefreshed: true, - expectValidated: true, - expectedLockState: TestLock{ - Locked: false, - WasObtained: true, - WasReleased: true, - PeekedCount: 1, + Lock: &testLock{}, }, + expectedErr: nil, + expectRefreshed: true, + expectValidated: true, + expectedLockObtained: true, }), Entry("when the session is not refreshed by the provider", refreshSessionIfNeededTableInput{ refreshPeriod: 1 * time.Minute, @@ -672,17 +622,12 @@ var _ = Describe("Stored Session Suite", func() { RefreshToken: noRefresh, CreatedAt: &createdPast, ExpiresOn: &createdFuture, - Lock: &TestLock{}, - }, - expectedErr: errors.New("session is invalid"), - expectRefreshed: true, - expectValidated: true, - expectedLockState: TestLock{ - Locked: false, - WasObtained: true, - WasReleased: true, - PeekedCount: 1, + Lock: &testLock{}, }, + expectedErr: errors.New("session is invalid"), + expectRefreshed: true, + expectValidated: true, + expectedLockObtained: true, }), ) }) From da92648e54cbc78d80cd314918c43c6b79725ad2 Mon Sep 17 00:00:00 2001 From: Joel Speed Date: Mon, 14 Feb 2022 14:29:16 +0000 Subject: [PATCH 3/3] Add changelog entry for session locking --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 05ae8c92..690b4831 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,7 @@ ## Changes since v7.2.1 +- [#1468](https://github.com/oauth2-proxy/oauth2-proxy/pull/1468) Implement session locking with session state lock (@JoelSpeed, @Bibob7) - [#1489](https://github.com/oauth2-proxy/oauth2-proxy/pull/1489) Fix Docker Buildx push to include build version (@JoelSpeed) - [#1477](https://github.com/oauth2-proxy/oauth2-proxy/pull/1477) Remove provider documentation for `Microsoft Azure AD` (@omBratteng) - [#1204](https://github.com/oauth2-proxy/oauth2-proxy/pull/1204) Added configuration for audience claim (`--oidc-extra-audience`) and ability to specify extra audiences (`--oidc-extra-audience`) allowed passing audience verification. This enables support for AWS Cognito and other issuers that have custom audience claims. Also, this adds the ability to allow multiple audiences. (@kschu91)