From e2c7ff6ddd874faef4e0644f0709ce39d7a5440f Mon Sep 17 00:00:00 2001 From: Kevin Kreitner Date: Tue, 19 Jan 2021 17:28:58 +0100 Subject: [PATCH] Use session to lock to protect concurrent refreshes --- pkg/middleware/stored_session.go | 109 ++++++++- pkg/middleware/stored_session_test.go | 338 ++++++++++++++++++++++++-- 2 files changed, 409 insertions(+), 38 deletions(-) diff --git a/pkg/middleware/stored_session.go b/pkg/middleware/stored_session.go index 4cbc47eb..0fb1e645 100644 --- a/pkg/middleware/stored_session.go +++ b/pkg/middleware/stored_session.go @@ -14,6 +14,11 @@ import ( "github.com/oauth2-proxy/oauth2-proxy/v7/providers" ) +const ( + SessionLockExpireTime = 5 * time.Second + SessionLockPeekDelay = 50 * time.Millisecond +) + // StoredSessionLoaderOptions contains all of the requirements to construct // a stored session loader. // All options must be provided. @@ -91,13 +96,10 @@ func (s *storedSessionLoader) loadSession(next http.Handler) http.Handler { // that is is valid. func (s *storedSessionLoader) getValidatedSession(rw http.ResponseWriter, req *http.Request) (*sessionsapi.SessionState, error) { session, err := s.store.Load(req) - if err != nil { + if err != nil || session == nil { + // No session was found in the storage or error occurred, nothing more to do return nil, err } - if session == nil { - // No session was found in the storage, nothing more to do - return nil, nil - } err = s.refreshSessionIfNeeded(rw, req, session) if err != nil { @@ -116,12 +118,21 @@ func (s *storedSessionLoader) refreshSessionIfNeeded(rw http.ResponseWriter, req return nil } - logger.Printf("Refreshing session - User: %s; SessionAge: %s", session.User, session.Age()) - err := s.refreshSession(rw, req, session) + wasRefreshed, err := s.checkForConcurrentRefresh(session, req) if err != nil { - // If a preemptive refresh fails, we still keep the session - // if validateSession succeeds. - logger.Errorf("Unable to refresh session: %v", err) + return err + } + + // If session was already refreshed via a concurrent request locked skip refreshing, + // because the refreshed session is already loaded from storage + if !wasRefreshed { + logger.Printf("Refreshing session - User: %s; SessionAge: %s", session.User, session.Age()) + err = s.refreshSession(rw, req, session) + if err != nil { + // If a preemptive refresh fails, we still keep the session + // if validateSession succeeds. + logger.Errorf("Unable to refresh session: %v", err) + } } // Validate all sessions after any Redeem/Refresh operation (fail or success) @@ -131,6 +142,18 @@ func (s *storedSessionLoader) refreshSessionIfNeeded(rw http.ResponseWriter, req // refreshSession attempts to refresh the session with the provider // and will save the session if it was updated. func (s *storedSessionLoader) refreshSession(rw http.ResponseWriter, req *http.Request, session *sessionsapi.SessionState) error { + err := session.ObtainLock(req.Context(), SessionLockExpireTime) + if err != nil { + logger.Errorf("Unable to obtain lock: %v", err) + return s.handleObtainLockError(req, session) + } + defer func() { + err = session.ReleaseLock(req.Context()) + if err != nil { + logger.Errorf("unable to release lock: %v", err) + } + }() + refreshed, err := s.sessionRefresher(req.Context(), session) if err != nil && !errors.Is(err, providers.ErrNotImplemented) { return fmt.Errorf("error refreshing tokens: %v", err) @@ -159,11 +182,75 @@ func (s *storedSessionLoader) refreshSession(rw http.ResponseWriter, req *http.R err = s.store.Save(rw, req, session) if err != nil { logger.PrintAuthf(session.Email, req, logger.AuthError, "error saving session: %v", err) - return fmt.Errorf("error saving session: %v", err) + err = fmt.Errorf("error saving session: %v", err) + } + return err +} + +func (s *storedSessionLoader) handleObtainLockError(req *http.Request, session *sessionsapi.SessionState) error { + wasRefreshed, err := s.checkForConcurrentRefresh(session, req) + if err != nil { + logger.Errorf("Unable to wait for obtained lock: %v", err) + return err + } + if !wasRefreshed { + return errors.New("unable to obtain lock and session was also not refreshed via concurrent request") } return nil } +func (s *storedSessionLoader) updateSessionFromStore(req *http.Request, session *sessionsapi.SessionState) error { + sessionStored, err := s.store.Load(req) + if err != nil { + return fmt.Errorf("unable to load updated session from store: %v", err) + } + + if sessionStored == nil { + return fmt.Errorf("no session available to udpate from store") + } + *session = *sessionStored + + return nil +} + +func (s *storedSessionLoader) waitForPossibleSessionLock(session *sessionsapi.SessionState, req *http.Request) (bool, error) { + var wasLocked bool + isLocked, err := session.PeekLock(req.Context()) + for isLocked { + wasLocked = true + // delay next peek lock + time.Sleep(SessionLockPeekDelay) + isLocked, err = session.PeekLock(req.Context()) + } + + if err != nil { + return false, err + } + + return wasLocked, nil +} + +// checkForConcurrentRefresh returns true if the session is already refreshed via a concurrent request. +func (s *storedSessionLoader) checkForConcurrentRefresh(session *sessionsapi.SessionState, req *http.Request) (bool, error) { + wasLocked, err := s.waitForPossibleSessionLock(session, req) + if err != nil { + return false, err + } + + refreshed := false + if wasLocked { + logger.Printf("Update session from store instead of refreshing") + err = s.updateSessionFromStore(req, session) + if err != nil { + logger.Errorf("Unable to update session from store: %v", err) + return false, err + } + refreshed = true + } + + return refreshed, nil +} + // validateSession checks whether the session has expired and performs // provider validation on the session. // An error implies the session is not longer valid. diff --git a/pkg/middleware/stored_session_test.go b/pkg/middleware/stored_session_test.go index 782390b6..b1f2ef70 100644 --- a/pkg/middleware/stored_session_test.go +++ b/pkg/middleware/stored_session_test.go @@ -6,6 +6,7 @@ import ( "fmt" "net/http" "net/http/httptest" + "sync" "time" middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" @@ -17,9 +18,104 @@ import ( . "github.com/onsi/gomega" ) +type TestLock struct { + Locked bool + WasObtained bool + WasRefreshed bool + WasReleased bool + PeekedCount int + LockedOnPeekCount int + ObtainError error + PeekError error + RefreshError error + ReleaseError error +} + +func (l *TestLock) Obtain(_ context.Context, _ time.Duration) error { + if l.ObtainError != nil { + return l.ObtainError + } + l.Locked = true + l.WasObtained = true + return nil +} + +func (l *TestLock) Peek(_ context.Context) (bool, error) { + if l.PeekError != nil { + return false, l.PeekError + } + locked := l.Locked + l.Locked = false + l.PeekedCount++ + // mainly used to test case when peek initially returns false, + // but when trying to obtain lock, it returns true. + if l.LockedOnPeekCount == l.PeekedCount { + return true, nil + } + return locked, nil +} + +func (l *TestLock) Refresh(_ context.Context, _ time.Duration) error { + if l.RefreshError != nil { + return l.ReleaseError + } + l.WasRefreshed = true + return nil +} + +func (l *TestLock) Release(_ context.Context) error { + if l.ReleaseError != nil { + return l.ReleaseError + } + l.Locked = false + l.WasReleased = true + return nil +} + +type LockConc struct { + mu sync.Mutex + lock bool + disablePeek bool +} + +func (l *LockConc) Obtain(_ context.Context, _ time.Duration) error { + l.mu.Lock() + if l.lock { + l.mu.Unlock() + return sessionsapi.ErrLockNotObtained + } + l.lock = true + l.mu.Unlock() + return nil +} + +func (l *LockConc) Peek(_ context.Context) (bool, error) { + var response bool + l.mu.Lock() + if l.disablePeek { + response = false + } else { + response = l.lock + } + l.mu.Unlock() + return response, nil +} + +func (l *LockConc) Refresh(_ context.Context, _ time.Duration) error { + return nil +} + +func (l *LockConc) Release(_ context.Context) error { + l.mu.Lock() + l.lock = false + l.mu.Unlock() + return nil +} + var _ = Describe("Stored Session Suite", func() { const ( refresh = "Refresh" + refreshed = "Refreshed" noRefresh = "NoRefresh" notImplemented = "NotImplemented" ) @@ -34,7 +130,7 @@ var _ = Describe("Stored Session Suite", func() { var defaultRefreshFunc = func(_ context.Context, ss *sessionsapi.SessionState) (bool, error) { switch ss.RefreshToken { case refresh: - ss.RefreshToken = "Refreshed" + ss.RefreshToken = refreshed return true, nil case noRefresh: return false, nil @@ -181,6 +277,7 @@ var _ = Describe("Stored Session Suite", func() { RefreshToken: noRefresh, CreatedAt: &createdPast, ExpiresOn: &createdFuture, + Lock: &sessionsapi.NoOpLock{}, }, store: defaultSessionStore, refreshPeriod: 1 * time.Minute, @@ -222,6 +319,7 @@ var _ = Describe("Stored Session Suite", func() { RefreshToken: "Refreshed", CreatedAt: &now, ExpiresOn: &createdFuture, + Lock: &sessionsapi.NoOpLock{}, }, store: defaultSessionStore, refreshPeriod: 1 * time.Minute, @@ -237,6 +335,7 @@ var _ = Describe("Stored Session Suite", func() { RefreshToken: "RefreshError", CreatedAt: &createdPast, ExpiresOn: &createdFuture, + Lock: &sessionsapi.NoOpLock{}, }, store: defaultSessionStore, refreshPeriod: 1 * time.Minute, @@ -266,15 +365,109 @@ var _ = Describe("Stored Session Suite", func() { validateSession: defaultValidateFunc, }), ) + + type storedSessionLoaderConcurrentTableInput struct { + existingSession *sessionsapi.SessionState + refreshPeriod time.Duration + numConcReqs int + } + + DescribeTable("when serving concurrent requests", + func(in storedSessionLoaderConcurrentTableInput) { + lockConc := &LockConc{} + + refreshedChan := make(chan bool, in.numConcReqs) + for i := 0; i < in.numConcReqs; i++ { + go func(refreshedChan chan bool, lockConc sessionsapi.Lock) { + existingSession := *in.existingSession // deep copy existingSession state + existingSession.Lock = lockConc + store := &fakeSessionStore{ + LoadFunc: func(req *http.Request) (*sessionsapi.SessionState, error) { + return &existingSession, nil + }, + SaveFunc: func(http.ResponseWriter, *http.Request, *sessionsapi.SessionState) error { + return nil + }, + } + + scope := &middlewareapi.RequestScope{ + Session: nil, + } + + // Set up the request with the request header and a request scope + req := httptest.NewRequest("", "/", nil) + req = middlewareapi.AddRequestScope(req, scope) + + rw := httptest.NewRecorder() + + sessionRefreshed := false + opts := &StoredSessionLoaderOptions{ + SessionStore: store, + RefreshPeriod: in.refreshPeriod, + RefreshSession: func(ctx context.Context, s *sessionsapi.SessionState) (bool, error) { + time.Sleep(10 * time.Millisecond) + sessionRefreshed = true + return true, nil + }, + ValidateSession: func(context.Context, *sessionsapi.SessionState) bool { + return true + }, + } + + handler := NewStoredSessionLoader(opts)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) + handler.ServeHTTP(rw, req) + + refreshedChan <- sessionRefreshed + }(refreshedChan, lockConc) + } + var refreshedSlice []bool + for i := 0; i < in.numConcReqs; i++ { + refreshedSlice = append(refreshedSlice, <-refreshedChan) + } + sessionRefreshedCount := 0 + for _, sessionRefreshed := range refreshedSlice { + if sessionRefreshed { + sessionRefreshedCount++ + } + } + Expect(sessionRefreshedCount).To(Equal(1)) + }, + Entry("with two concurrent requests", storedSessionLoaderConcurrentTableInput{ + existingSession: &sessionsapi.SessionState{ + RefreshToken: refresh, + CreatedAt: &createdPast, + }, + numConcReqs: 2, + refreshPeriod: 1 * time.Minute, + }), + Entry("with 5 concurrent requests", storedSessionLoaderConcurrentTableInput{ + existingSession: &sessionsapi.SessionState{ + RefreshToken: refresh, + CreatedAt: &createdPast, + }, + numConcReqs: 5, + refreshPeriod: 1 * time.Minute, + }), + Entry("with one request", storedSessionLoaderConcurrentTableInput{ + existingSession: &sessionsapi.SessionState{ + RefreshToken: refresh, + CreatedAt: &createdPast, + }, + numConcReqs: 1, + refreshPeriod: 1 * time.Minute, + }), + ) }) Context("refreshSessionIfNeeded", func() { type refreshSessionIfNeededTableInput struct { - refreshPeriod time.Duration - session *sessionsapi.SessionState - expectedErr error - expectRefreshed bool - expectValidated bool + refreshPeriod time.Duration + sessionStored bool + session *sessionsapi.SessionState + expectedErr error + expectRefreshed bool + expectValidated bool + expectedLockState TestLock } createdPast := time.Now().Add(-5 * time.Minute) @@ -285,9 +478,18 @@ var _ = Describe("Stored Session Suite", func() { refreshed := false validated := false + store := &fakeSessionStore{} + if in.sessionStored { + store = &fakeSessionStore{ + LoadFunc: func(req *http.Request) (*sessionsapi.SessionState, error) { + return in.session, nil + }, + } + } + s := &storedSessionLoader{ refreshPeriod: in.refreshPeriod, - store: &fakeSessionStore{}, + store: store, sessionRefresher: func(_ context.Context, ss *sessionsapi.SessionState) (bool, error) { refreshed = true switch ss.RefreshToken { @@ -316,46 +518,117 @@ var _ = Describe("Stored Session Suite", func() { } Expect(refreshed).To(Equal(in.expectRefreshed)) Expect(validated).To(Equal(in.expectValidated)) + testLock, ok := in.session.Lock.(*TestLock) + Expect(ok).To(Equal(true)) + + Expect(testLock).To(Equal(&in.expectedLockState)) }, Entry("when the refresh period is 0, and the session does not need refreshing", refreshSessionIfNeededTableInput{ refreshPeriod: time.Duration(0), session: &sessionsapi.SessionState{ RefreshToken: refresh, CreatedAt: &createdFuture, + Lock: &TestLock{}, }, - expectedErr: nil, - expectRefreshed: false, - expectValidated: false, + expectedErr: nil, + expectRefreshed: false, + expectValidated: false, + expectedLockState: TestLock{}, }), Entry("when the refresh period is 0, and the session needs refreshing", refreshSessionIfNeededTableInput{ refreshPeriod: time.Duration(0), session: &sessionsapi.SessionState{ RefreshToken: refresh, CreatedAt: &createdPast, + Lock: &TestLock{}, }, - expectedErr: nil, - expectRefreshed: false, - expectValidated: false, + expectedErr: nil, + expectRefreshed: false, + expectValidated: false, + expectedLockState: TestLock{}, }), Entry("when the session does not need refreshing", refreshSessionIfNeededTableInput{ refreshPeriod: 1 * time.Minute, session: &sessionsapi.SessionState{ RefreshToken: refresh, CreatedAt: &createdFuture, + Lock: &TestLock{}, }, - expectedErr: nil, - expectRefreshed: false, - expectValidated: false, + expectedErr: nil, + expectRefreshed: false, + expectValidated: false, + expectedLockState: TestLock{}, }), Entry("when the session is refreshed by the provider", refreshSessionIfNeededTableInput{ refreshPeriod: 1 * time.Minute, session: &sessionsapi.SessionState{ RefreshToken: refresh, CreatedAt: &createdPast, + Lock: &TestLock{}, }, expectedErr: nil, expectRefreshed: true, expectValidated: true, + expectedLockState: TestLock{ + Locked: false, + WasObtained: true, + WasReleased: true, + PeekedCount: 1, + }, + }), + Entry("when the session is locked and instead loaded from storage", refreshSessionIfNeededTableInput{ + refreshPeriod: 1 * time.Minute, + session: &sessionsapi.SessionState{ + RefreshToken: noRefresh, + CreatedAt: &createdPast, + Lock: &TestLock{ + Locked: true, + }, + }, + sessionStored: true, + expectedErr: nil, + expectRefreshed: false, + expectValidated: true, + expectedLockState: TestLock{ + Locked: false, + PeekedCount: 2, + }, + }), + Entry("when obtaining lock failed, but concurrent request refreshed", refreshSessionIfNeededTableInput{ + refreshPeriod: 1 * time.Minute, + session: &sessionsapi.SessionState{ + RefreshToken: noRefresh, + CreatedAt: &createdPast, + Lock: &TestLock{ + ObtainError: errors.New("not able to obtain lock"), + LockedOnPeekCount: 2, + }, + }, + expectedErr: nil, + expectRefreshed: false, + expectValidated: true, + expectedLockState: TestLock{ + PeekedCount: 3, + LockedOnPeekCount: 2, + ObtainError: errors.New("not able to obtain lock"), + }, + }), + Entry("when obtaining lock failed", refreshSessionIfNeededTableInput{ + refreshPeriod: 1 * time.Minute, + session: &sessionsapi.SessionState{ + RefreshToken: noRefresh, + CreatedAt: &createdPast, + Lock: &TestLock{ + ObtainError: errors.New("not able to obtain lock"), + }, + }, + expectedErr: nil, + expectRefreshed: false, + expectValidated: true, + expectedLockState: TestLock{ + PeekedCount: 2, + ObtainError: errors.New("not able to obtain lock"), + }, }), Entry("when the session is not refreshed by the provider", refreshSessionIfNeededTableInput{ refreshPeriod: 1 * time.Minute, @@ -363,42 +636,53 @@ var _ = Describe("Stored Session Suite", func() { RefreshToken: noRefresh, CreatedAt: &createdPast, ExpiresOn: &createdFuture, + Lock: &TestLock{}, }, expectedErr: nil, expectRefreshed: true, expectValidated: true, + expectedLockState: TestLock{ + Locked: false, + WasObtained: true, + WasReleased: true, + PeekedCount: 1, + }, }), - Entry("when the provider doesn't implement refresh but validation succeeds", refreshSessionIfNeededTableInput{ + Entry("when the provider doesn't implement refresh", refreshSessionIfNeededTableInput{ refreshPeriod: 1 * time.Minute, session: &sessionsapi.SessionState{ RefreshToken: notImplemented, CreatedAt: &createdPast, + Lock: &TestLock{}, }, expectedErr: nil, expectRefreshed: true, expectValidated: true, - }), - Entry("when the provider refresh fails but validation succeeds", refreshSessionIfNeededTableInput{ - refreshPeriod: 1 * time.Minute, - session: &sessionsapi.SessionState{ - RefreshToken: "RefreshError", - CreatedAt: &createdPast, + expectedLockState: TestLock{ + Locked: false, + WasObtained: true, + WasReleased: true, + PeekedCount: 1, }, - expectedErr: nil, - expectRefreshed: true, - expectValidated: true, }), - Entry("when the session is not refreshed by the provider and validation fails", refreshSessionIfNeededTableInput{ + Entry("when the session is not refreshed by the provider", refreshSessionIfNeededTableInput{ refreshPeriod: 1 * time.Minute, session: &sessionsapi.SessionState{ AccessToken: "Invalid", RefreshToken: noRefresh, CreatedAt: &createdPast, ExpiresOn: &createdFuture, + Lock: &TestLock{}, }, expectedErr: errors.New("session is invalid"), expectRefreshed: true, expectValidated: true, + expectedLockState: TestLock{ + Locked: false, + WasObtained: true, + WasReleased: true, + PeekedCount: 1, + }, }), ) })