1
0
mirror of https://github.com/oauth2-proxy/oauth2-proxy.git synced 2025-08-10 22:51:31 +02:00

Try to wait for lock, when obtaining lock failed

This commit is contained in:
Kevin Kreitner
2021-10-18 09:21:21 +02:00
committed by Joel Speed
parent 360c753d6f
commit 2781ea1c95
2 changed files with 69 additions and 21 deletions

View File

@@ -118,20 +118,14 @@ func (s *storedSessionLoader) refreshSessionIfNeeded(rw http.ResponseWriter, req
return nil
}
wasLocked, err := s.waitForPossibleSessionLock(session, req)
wasRefreshed, err := s.checkForConcurrentRefresh(session, req)
if err != nil {
return err
}
// If session was locked, fetch current state, because
// it should be updated after lock is released.
if wasLocked {
logger.Printf("Update session from store instead of refreshing")
err = s.updateSessionFromStore(req, session)
if err != nil {
logger.Errorf("Unable to update session from store: %v", err)
}
} else {
// If session was already refreshed via a concurrent request locked skip refreshing,
// because the refreshed session is already loaded from storage
if !wasRefreshed {
logger.Printf("Refreshing session - User: %s; SessionAge: %s", session.User, session.Age())
err = s.refreshSession(rw, req, session)
if err != nil {
@@ -150,7 +144,15 @@ func (s *storedSessionLoader) refreshSessionIfNeeded(rw http.ResponseWriter, req
func (s *storedSessionLoader) refreshSession(rw http.ResponseWriter, req *http.Request, session *sessionsapi.SessionState) error {
err := session.ObtainLock(req.Context(), SessionLockExpireTime)
if err != nil {
logger.Errorf("unable to obtain lock (skipping refresh): %v", err)
logger.Errorf("Unable to obtain lock: %v", err)
wasRefreshed, err := s.checkForConcurrentRefresh(session, req)
if err != nil {
logger.Errorf("Unable to wait for obtained lock: %v", err)
return err
}
if !wasRefreshed {
return errors.New("unable to obtain lock and session was also not refreshed via concurrent request")
}
return nil
}
defer func() {
@@ -224,6 +226,27 @@ func (s *storedSessionLoader) waitForPossibleSessionLock(session *sessionsapi.Se
return wasLocked, nil
}
// checkForConcurrentRefresh returns true if the session is already refreshed via a concurrent request.
func (s *storedSessionLoader) checkForConcurrentRefresh(session *sessionsapi.SessionState, req *http.Request) (bool, error) {
wasLocked, err := s.waitForPossibleSessionLock(session, req)
if err != nil {
return false, err
}
refreshed := false
if wasLocked {
logger.Printf("Update session from store instead of refreshing")
err = s.updateSessionFromStore(req, session)
if err != nil {
logger.Errorf("Unable to update session from store: %v", err)
return false, err
}
refreshed = true
}
return refreshed, nil
}
// validateSession checks whether the session has expired and performs
// provider validation on the session.
// An error implies the session is not longer valid.

View File

@@ -19,15 +19,16 @@ import (
)
type TestLock struct {
Locked bool
WasObtained bool
WasRefreshed bool
WasReleased bool
PeekedCount int
ObtainError error
PeekError error
RefreshError error
ReleaseError error
Locked bool
WasObtained bool
WasRefreshed bool
WasReleased bool
PeekedCount int
LockedOnPeekCount int
ObtainError error
PeekError error
RefreshError error
ReleaseError error
}
func (l *TestLock) Obtain(_ context.Context, _ time.Duration) error {
@@ -46,6 +47,11 @@ func (l *TestLock) Peek(_ context.Context) (bool, error) {
locked := l.Locked
l.Locked = false
l.PeekedCount++
// mainly used to test case when peek initially returns false,
// but when trying to obtain lock, it returns true.
if l.LockedOnPeekCount == l.PeekedCount {
return true, nil
}
return locked, nil
}
@@ -588,6 +594,25 @@ var _ = Describe("Stored Session Suite", func() {
PeekedCount: 2,
},
}),
Entry("when obtaining lock failed, but concurrent request refreshed", refreshSessionIfNeededTableInput{
refreshPeriod: 1 * time.Minute,
session: &sessionsapi.SessionState{
RefreshToken: noRefresh,
CreatedAt: &createdPast,
Lock: &TestLock{
ObtainError: errors.New("not able to obtain lock"),
LockedOnPeekCount: 2,
},
},
expectedErr: nil,
expectRefreshed: false,
expectValidated: true,
expectedLockState: TestLock{
PeekedCount: 3,
LockedOnPeekCount: 2,
ObtainError: errors.New("not able to obtain lock"),
},
}),
Entry("when obtaining lock failed", refreshSessionIfNeededTableInput{
refreshPeriod: 1 * time.Minute,
session: &sessionsapi.SessionState{
@@ -601,7 +626,7 @@ var _ = Describe("Stored Session Suite", func() {
expectRefreshed: false,
expectValidated: true,
expectedLockState: TestLock{
PeekedCount: 1,
PeekedCount: 2,
ObtainError: errors.New("not able to obtain lock"),
},
}),