From 2781ea1c9523ee21a1280845a79054e8d8b4d138 Mon Sep 17 00:00:00 2001 From: Kevin Kreitner Date: Mon, 18 Oct 2021 09:21:21 +0200 Subject: [PATCH] Try to wait for lock, when obtaining lock failed --- pkg/middleware/stored_session.go | 45 ++++++++++++++++++++------- pkg/middleware/stored_session_test.go | 45 +++++++++++++++++++++------ 2 files changed, 69 insertions(+), 21 deletions(-) diff --git a/pkg/middleware/stored_session.go b/pkg/middleware/stored_session.go index de1591ac..c9770380 100644 --- a/pkg/middleware/stored_session.go +++ b/pkg/middleware/stored_session.go @@ -118,20 +118,14 @@ func (s *storedSessionLoader) refreshSessionIfNeeded(rw http.ResponseWriter, req return nil } - wasLocked, err := s.waitForPossibleSessionLock(session, req) + wasRefreshed, err := s.checkForConcurrentRefresh(session, req) if err != nil { return err } - // If session was locked, fetch current state, because - // it should be updated after lock is released. - if wasLocked { - logger.Printf("Update session from store instead of refreshing") - err = s.updateSessionFromStore(req, session) - if err != nil { - logger.Errorf("Unable to update session from store: %v", err) - } - } else { + // If session was already refreshed via a concurrent request locked skip refreshing, + // because the refreshed session is already loaded from storage + if !wasRefreshed { logger.Printf("Refreshing session - User: %s; SessionAge: %s", session.User, session.Age()) err = s.refreshSession(rw, req, session) if err != nil { @@ -150,7 +144,15 @@ func (s *storedSessionLoader) refreshSessionIfNeeded(rw http.ResponseWriter, req func (s *storedSessionLoader) refreshSession(rw http.ResponseWriter, req *http.Request, session *sessionsapi.SessionState) error { err := session.ObtainLock(req.Context(), SessionLockExpireTime) if err != nil { - logger.Errorf("unable to obtain lock (skipping refresh): %v", err) + logger.Errorf("Unable to obtain lock: %v", err) + wasRefreshed, err := s.checkForConcurrentRefresh(session, req) + if err != nil { + logger.Errorf("Unable to wait for obtained lock: %v", err) + return err + } + if !wasRefreshed { + return errors.New("unable to obtain lock and session was also not refreshed via concurrent request") + } return nil } defer func() { @@ -224,6 +226,27 @@ func (s *storedSessionLoader) waitForPossibleSessionLock(session *sessionsapi.Se return wasLocked, nil } +// checkForConcurrentRefresh returns true if the session is already refreshed via a concurrent request. +func (s *storedSessionLoader) checkForConcurrentRefresh(session *sessionsapi.SessionState, req *http.Request) (bool, error) { + wasLocked, err := s.waitForPossibleSessionLock(session, req) + if err != nil { + return false, err + } + + refreshed := false + if wasLocked { + logger.Printf("Update session from store instead of refreshing") + err = s.updateSessionFromStore(req, session) + if err != nil { + logger.Errorf("Unable to update session from store: %v", err) + return false, err + } + refreshed = true + } + + return refreshed, nil +} + // validateSession checks whether the session has expired and performs // provider validation on the session. // An error implies the session is not longer valid. diff --git a/pkg/middleware/stored_session_test.go b/pkg/middleware/stored_session_test.go index 7c74116a..b1f2ef70 100644 --- a/pkg/middleware/stored_session_test.go +++ b/pkg/middleware/stored_session_test.go @@ -19,15 +19,16 @@ import ( ) type TestLock struct { - Locked bool - WasObtained bool - WasRefreshed bool - WasReleased bool - PeekedCount int - ObtainError error - PeekError error - RefreshError error - ReleaseError error + Locked bool + WasObtained bool + WasRefreshed bool + WasReleased bool + PeekedCount int + LockedOnPeekCount int + ObtainError error + PeekError error + RefreshError error + ReleaseError error } func (l *TestLock) Obtain(_ context.Context, _ time.Duration) error { @@ -46,6 +47,11 @@ func (l *TestLock) Peek(_ context.Context) (bool, error) { locked := l.Locked l.Locked = false l.PeekedCount++ + // mainly used to test case when peek initially returns false, + // but when trying to obtain lock, it returns true. + if l.LockedOnPeekCount == l.PeekedCount { + return true, nil + } return locked, nil } @@ -588,6 +594,25 @@ var _ = Describe("Stored Session Suite", func() { PeekedCount: 2, }, }), + Entry("when obtaining lock failed, but concurrent request refreshed", refreshSessionIfNeededTableInput{ + refreshPeriod: 1 * time.Minute, + session: &sessionsapi.SessionState{ + RefreshToken: noRefresh, + CreatedAt: &createdPast, + Lock: &TestLock{ + ObtainError: errors.New("not able to obtain lock"), + LockedOnPeekCount: 2, + }, + }, + expectedErr: nil, + expectRefreshed: false, + expectValidated: true, + expectedLockState: TestLock{ + PeekedCount: 3, + LockedOnPeekCount: 2, + ObtainError: errors.New("not able to obtain lock"), + }, + }), Entry("when obtaining lock failed", refreshSessionIfNeededTableInput{ refreshPeriod: 1 * time.Minute, session: &sessionsapi.SessionState{ @@ -601,7 +626,7 @@ var _ = Describe("Stored Session Suite", func() { expectRefreshed: false, expectValidated: true, expectedLockState: TestLock{ - PeekedCount: 1, + PeekedCount: 2, ObtainError: errors.New("not able to obtain lock"), }, }),