1
0
mirror of https://github.com/oauth2-proxy/oauth2-proxy.git synced 2025-08-10 22:51:31 +02:00

Try to wait for lock, when obtaining lock failed

This commit is contained in:
Kevin Kreitner
2021-10-18 09:21:21 +02:00
committed by Joel Speed
parent 360c753d6f
commit 2781ea1c95
2 changed files with 69 additions and 21 deletions

View File

@@ -118,20 +118,14 @@ func (s *storedSessionLoader) refreshSessionIfNeeded(rw http.ResponseWriter, req
return nil return nil
} }
wasLocked, err := s.waitForPossibleSessionLock(session, req) wasRefreshed, err := s.checkForConcurrentRefresh(session, req)
if err != nil { if err != nil {
return err return err
} }
// If session was locked, fetch current state, because // If session was already refreshed via a concurrent request locked skip refreshing,
// it should be updated after lock is released. // because the refreshed session is already loaded from storage
if wasLocked { if !wasRefreshed {
logger.Printf("Update session from store instead of refreshing")
err = s.updateSessionFromStore(req, session)
if err != nil {
logger.Errorf("Unable to update session from store: %v", err)
}
} else {
logger.Printf("Refreshing session - User: %s; SessionAge: %s", session.User, session.Age()) logger.Printf("Refreshing session - User: %s; SessionAge: %s", session.User, session.Age())
err = s.refreshSession(rw, req, session) err = s.refreshSession(rw, req, session)
if err != nil { if err != nil {
@@ -150,7 +144,15 @@ func (s *storedSessionLoader) refreshSessionIfNeeded(rw http.ResponseWriter, req
func (s *storedSessionLoader) refreshSession(rw http.ResponseWriter, req *http.Request, session *sessionsapi.SessionState) error { func (s *storedSessionLoader) refreshSession(rw http.ResponseWriter, req *http.Request, session *sessionsapi.SessionState) error {
err := session.ObtainLock(req.Context(), SessionLockExpireTime) err := session.ObtainLock(req.Context(), SessionLockExpireTime)
if err != nil { if err != nil {
logger.Errorf("unable to obtain lock (skipping refresh): %v", err) logger.Errorf("Unable to obtain lock: %v", err)
wasRefreshed, err := s.checkForConcurrentRefresh(session, req)
if err != nil {
logger.Errorf("Unable to wait for obtained lock: %v", err)
return err
}
if !wasRefreshed {
return errors.New("unable to obtain lock and session was also not refreshed via concurrent request")
}
return nil return nil
} }
defer func() { defer func() {
@@ -224,6 +226,27 @@ func (s *storedSessionLoader) waitForPossibleSessionLock(session *sessionsapi.Se
return wasLocked, nil return wasLocked, nil
} }
// checkForConcurrentRefresh returns true if the session is already refreshed via a concurrent request.
func (s *storedSessionLoader) checkForConcurrentRefresh(session *sessionsapi.SessionState, req *http.Request) (bool, error) {
wasLocked, err := s.waitForPossibleSessionLock(session, req)
if err != nil {
return false, err
}
refreshed := false
if wasLocked {
logger.Printf("Update session from store instead of refreshing")
err = s.updateSessionFromStore(req, session)
if err != nil {
logger.Errorf("Unable to update session from store: %v", err)
return false, err
}
refreshed = true
}
return refreshed, nil
}
// validateSession checks whether the session has expired and performs // validateSession checks whether the session has expired and performs
// provider validation on the session. // provider validation on the session.
// An error implies the session is not longer valid. // An error implies the session is not longer valid.

View File

@@ -19,15 +19,16 @@ import (
) )
type TestLock struct { type TestLock struct {
Locked bool Locked bool
WasObtained bool WasObtained bool
WasRefreshed bool WasRefreshed bool
WasReleased bool WasReleased bool
PeekedCount int PeekedCount int
ObtainError error LockedOnPeekCount int
PeekError error ObtainError error
RefreshError error PeekError error
ReleaseError error RefreshError error
ReleaseError error
} }
func (l *TestLock) Obtain(_ context.Context, _ time.Duration) error { func (l *TestLock) Obtain(_ context.Context, _ time.Duration) error {
@@ -46,6 +47,11 @@ func (l *TestLock) Peek(_ context.Context) (bool, error) {
locked := l.Locked locked := l.Locked
l.Locked = false l.Locked = false
l.PeekedCount++ l.PeekedCount++
// mainly used to test case when peek initially returns false,
// but when trying to obtain lock, it returns true.
if l.LockedOnPeekCount == l.PeekedCount {
return true, nil
}
return locked, nil return locked, nil
} }
@@ -588,6 +594,25 @@ var _ = Describe("Stored Session Suite", func() {
PeekedCount: 2, PeekedCount: 2,
}, },
}), }),
Entry("when obtaining lock failed, but concurrent request refreshed", refreshSessionIfNeededTableInput{
refreshPeriod: 1 * time.Minute,
session: &sessionsapi.SessionState{
RefreshToken: noRefresh,
CreatedAt: &createdPast,
Lock: &TestLock{
ObtainError: errors.New("not able to obtain lock"),
LockedOnPeekCount: 2,
},
},
expectedErr: nil,
expectRefreshed: false,
expectValidated: true,
expectedLockState: TestLock{
PeekedCount: 3,
LockedOnPeekCount: 2,
ObtainError: errors.New("not able to obtain lock"),
},
}),
Entry("when obtaining lock failed", refreshSessionIfNeededTableInput{ Entry("when obtaining lock failed", refreshSessionIfNeededTableInput{
refreshPeriod: 1 * time.Minute, refreshPeriod: 1 * time.Minute,
session: &sessionsapi.SessionState{ session: &sessionsapi.SessionState{
@@ -601,7 +626,7 @@ var _ = Describe("Stored Session Suite", func() {
expectRefreshed: false, expectRefreshed: false,
expectValidated: true, expectValidated: true,
expectedLockState: TestLock{ expectedLockState: TestLock{
PeekedCount: 1, PeekedCount: 2,
ObtainError: errors.New("not able to obtain lock"), ObtainError: errors.New("not able to obtain lock"),
}, },
}), }),