mirror of
https://github.com/oauth2-proxy/oauth2-proxy.git
synced 2025-06-04 23:37:29 +02:00
Use session to lock to protect concurrent refreshes
This commit is contained in:
parent
dc5d2a5cd7
commit
e2c7ff6ddd
@ -14,6 +14,11 @@ import (
|
|||||||
"github.com/oauth2-proxy/oauth2-proxy/v7/providers"
|
"github.com/oauth2-proxy/oauth2-proxy/v7/providers"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
SessionLockExpireTime = 5 * time.Second
|
||||||
|
SessionLockPeekDelay = 50 * time.Millisecond
|
||||||
|
)
|
||||||
|
|
||||||
// StoredSessionLoaderOptions contains all of the requirements to construct
|
// StoredSessionLoaderOptions contains all of the requirements to construct
|
||||||
// a stored session loader.
|
// a stored session loader.
|
||||||
// All options must be provided.
|
// All options must be provided.
|
||||||
@ -91,13 +96,10 @@ func (s *storedSessionLoader) loadSession(next http.Handler) http.Handler {
|
|||||||
// that is is valid.
|
// that is is valid.
|
||||||
func (s *storedSessionLoader) getValidatedSession(rw http.ResponseWriter, req *http.Request) (*sessionsapi.SessionState, error) {
|
func (s *storedSessionLoader) getValidatedSession(rw http.ResponseWriter, req *http.Request) (*sessionsapi.SessionState, error) {
|
||||||
session, err := s.store.Load(req)
|
session, err := s.store.Load(req)
|
||||||
if err != nil {
|
if err != nil || session == nil {
|
||||||
|
// No session was found in the storage or error occurred, nothing more to do
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if session == nil {
|
|
||||||
// No session was found in the storage, nothing more to do
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
err = s.refreshSessionIfNeeded(rw, req, session)
|
err = s.refreshSessionIfNeeded(rw, req, session)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -116,13 +118,22 @@ func (s *storedSessionLoader) refreshSessionIfNeeded(rw http.ResponseWriter, req
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
wasRefreshed, err := s.checkForConcurrentRefresh(session, req)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// If session was already refreshed via a concurrent request locked skip refreshing,
|
||||||
|
// because the refreshed session is already loaded from storage
|
||||||
|
if !wasRefreshed {
|
||||||
logger.Printf("Refreshing session - User: %s; SessionAge: %s", session.User, session.Age())
|
logger.Printf("Refreshing session - User: %s; SessionAge: %s", session.User, session.Age())
|
||||||
err := s.refreshSession(rw, req, session)
|
err = s.refreshSession(rw, req, session)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// If a preemptive refresh fails, we still keep the session
|
// If a preemptive refresh fails, we still keep the session
|
||||||
// if validateSession succeeds.
|
// if validateSession succeeds.
|
||||||
logger.Errorf("Unable to refresh session: %v", err)
|
logger.Errorf("Unable to refresh session: %v", err)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Validate all sessions after any Redeem/Refresh operation (fail or success)
|
// Validate all sessions after any Redeem/Refresh operation (fail or success)
|
||||||
return s.validateSession(req.Context(), session)
|
return s.validateSession(req.Context(), session)
|
||||||
@ -131,6 +142,18 @@ func (s *storedSessionLoader) refreshSessionIfNeeded(rw http.ResponseWriter, req
|
|||||||
// refreshSession attempts to refresh the session with the provider
|
// refreshSession attempts to refresh the session with the provider
|
||||||
// and will save the session if it was updated.
|
// and will save the session if it was updated.
|
||||||
func (s *storedSessionLoader) refreshSession(rw http.ResponseWriter, req *http.Request, session *sessionsapi.SessionState) error {
|
func (s *storedSessionLoader) refreshSession(rw http.ResponseWriter, req *http.Request, session *sessionsapi.SessionState) error {
|
||||||
|
err := session.ObtainLock(req.Context(), SessionLockExpireTime)
|
||||||
|
if err != nil {
|
||||||
|
logger.Errorf("Unable to obtain lock: %v", err)
|
||||||
|
return s.handleObtainLockError(req, session)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
err = session.ReleaseLock(req.Context())
|
||||||
|
if err != nil {
|
||||||
|
logger.Errorf("unable to release lock: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
refreshed, err := s.sessionRefresher(req.Context(), session)
|
refreshed, err := s.sessionRefresher(req.Context(), session)
|
||||||
if err != nil && !errors.Is(err, providers.ErrNotImplemented) {
|
if err != nil && !errors.Is(err, providers.ErrNotImplemented) {
|
||||||
return fmt.Errorf("error refreshing tokens: %v", err)
|
return fmt.Errorf("error refreshing tokens: %v", err)
|
||||||
@ -159,11 +182,75 @@ func (s *storedSessionLoader) refreshSession(rw http.ResponseWriter, req *http.R
|
|||||||
err = s.store.Save(rw, req, session)
|
err = s.store.Save(rw, req, session)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.PrintAuthf(session.Email, req, logger.AuthError, "error saving session: %v", err)
|
logger.PrintAuthf(session.Email, req, logger.AuthError, "error saving session: %v", err)
|
||||||
return fmt.Errorf("error saving session: %v", err)
|
err = fmt.Errorf("error saving session: %v", err)
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *storedSessionLoader) handleObtainLockError(req *http.Request, session *sessionsapi.SessionState) error {
|
||||||
|
wasRefreshed, err := s.checkForConcurrentRefresh(session, req)
|
||||||
|
if err != nil {
|
||||||
|
logger.Errorf("Unable to wait for obtained lock: %v", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if !wasRefreshed {
|
||||||
|
return errors.New("unable to obtain lock and session was also not refreshed via concurrent request")
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *storedSessionLoader) updateSessionFromStore(req *http.Request, session *sessionsapi.SessionState) error {
|
||||||
|
sessionStored, err := s.store.Load(req)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("unable to load updated session from store: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if sessionStored == nil {
|
||||||
|
return fmt.Errorf("no session available to udpate from store")
|
||||||
|
}
|
||||||
|
*session = *sessionStored
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *storedSessionLoader) waitForPossibleSessionLock(session *sessionsapi.SessionState, req *http.Request) (bool, error) {
|
||||||
|
var wasLocked bool
|
||||||
|
isLocked, err := session.PeekLock(req.Context())
|
||||||
|
for isLocked {
|
||||||
|
wasLocked = true
|
||||||
|
// delay next peek lock
|
||||||
|
time.Sleep(SessionLockPeekDelay)
|
||||||
|
isLocked, err = session.PeekLock(req.Context())
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return wasLocked, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// checkForConcurrentRefresh returns true if the session is already refreshed via a concurrent request.
|
||||||
|
func (s *storedSessionLoader) checkForConcurrentRefresh(session *sessionsapi.SessionState, req *http.Request) (bool, error) {
|
||||||
|
wasLocked, err := s.waitForPossibleSessionLock(session, req)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
refreshed := false
|
||||||
|
if wasLocked {
|
||||||
|
logger.Printf("Update session from store instead of refreshing")
|
||||||
|
err = s.updateSessionFromStore(req, session)
|
||||||
|
if err != nil {
|
||||||
|
logger.Errorf("Unable to update session from store: %v", err)
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
refreshed = true
|
||||||
|
}
|
||||||
|
|
||||||
|
return refreshed, nil
|
||||||
|
}
|
||||||
|
|
||||||
// validateSession checks whether the session has expired and performs
|
// validateSession checks whether the session has expired and performs
|
||||||
// provider validation on the session.
|
// provider validation on the session.
|
||||||
// An error implies the session is not longer valid.
|
// An error implies the session is not longer valid.
|
||||||
|
@ -6,6 +6,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware"
|
middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware"
|
||||||
@ -17,9 +18,104 @@ import (
|
|||||||
. "github.com/onsi/gomega"
|
. "github.com/onsi/gomega"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type TestLock struct {
|
||||||
|
Locked bool
|
||||||
|
WasObtained bool
|
||||||
|
WasRefreshed bool
|
||||||
|
WasReleased bool
|
||||||
|
PeekedCount int
|
||||||
|
LockedOnPeekCount int
|
||||||
|
ObtainError error
|
||||||
|
PeekError error
|
||||||
|
RefreshError error
|
||||||
|
ReleaseError error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *TestLock) Obtain(_ context.Context, _ time.Duration) error {
|
||||||
|
if l.ObtainError != nil {
|
||||||
|
return l.ObtainError
|
||||||
|
}
|
||||||
|
l.Locked = true
|
||||||
|
l.WasObtained = true
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *TestLock) Peek(_ context.Context) (bool, error) {
|
||||||
|
if l.PeekError != nil {
|
||||||
|
return false, l.PeekError
|
||||||
|
}
|
||||||
|
locked := l.Locked
|
||||||
|
l.Locked = false
|
||||||
|
l.PeekedCount++
|
||||||
|
// mainly used to test case when peek initially returns false,
|
||||||
|
// but when trying to obtain lock, it returns true.
|
||||||
|
if l.LockedOnPeekCount == l.PeekedCount {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
return locked, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *TestLock) Refresh(_ context.Context, _ time.Duration) error {
|
||||||
|
if l.RefreshError != nil {
|
||||||
|
return l.ReleaseError
|
||||||
|
}
|
||||||
|
l.WasRefreshed = true
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *TestLock) Release(_ context.Context) error {
|
||||||
|
if l.ReleaseError != nil {
|
||||||
|
return l.ReleaseError
|
||||||
|
}
|
||||||
|
l.Locked = false
|
||||||
|
l.WasReleased = true
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type LockConc struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
lock bool
|
||||||
|
disablePeek bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *LockConc) Obtain(_ context.Context, _ time.Duration) error {
|
||||||
|
l.mu.Lock()
|
||||||
|
if l.lock {
|
||||||
|
l.mu.Unlock()
|
||||||
|
return sessionsapi.ErrLockNotObtained
|
||||||
|
}
|
||||||
|
l.lock = true
|
||||||
|
l.mu.Unlock()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *LockConc) Peek(_ context.Context) (bool, error) {
|
||||||
|
var response bool
|
||||||
|
l.mu.Lock()
|
||||||
|
if l.disablePeek {
|
||||||
|
response = false
|
||||||
|
} else {
|
||||||
|
response = l.lock
|
||||||
|
}
|
||||||
|
l.mu.Unlock()
|
||||||
|
return response, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *LockConc) Refresh(_ context.Context, _ time.Duration) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *LockConc) Release(_ context.Context) error {
|
||||||
|
l.mu.Lock()
|
||||||
|
l.lock = false
|
||||||
|
l.mu.Unlock()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
var _ = Describe("Stored Session Suite", func() {
|
var _ = Describe("Stored Session Suite", func() {
|
||||||
const (
|
const (
|
||||||
refresh = "Refresh"
|
refresh = "Refresh"
|
||||||
|
refreshed = "Refreshed"
|
||||||
noRefresh = "NoRefresh"
|
noRefresh = "NoRefresh"
|
||||||
notImplemented = "NotImplemented"
|
notImplemented = "NotImplemented"
|
||||||
)
|
)
|
||||||
@ -34,7 +130,7 @@ var _ = Describe("Stored Session Suite", func() {
|
|||||||
var defaultRefreshFunc = func(_ context.Context, ss *sessionsapi.SessionState) (bool, error) {
|
var defaultRefreshFunc = func(_ context.Context, ss *sessionsapi.SessionState) (bool, error) {
|
||||||
switch ss.RefreshToken {
|
switch ss.RefreshToken {
|
||||||
case refresh:
|
case refresh:
|
||||||
ss.RefreshToken = "Refreshed"
|
ss.RefreshToken = refreshed
|
||||||
return true, nil
|
return true, nil
|
||||||
case noRefresh:
|
case noRefresh:
|
||||||
return false, nil
|
return false, nil
|
||||||
@ -181,6 +277,7 @@ var _ = Describe("Stored Session Suite", func() {
|
|||||||
RefreshToken: noRefresh,
|
RefreshToken: noRefresh,
|
||||||
CreatedAt: &createdPast,
|
CreatedAt: &createdPast,
|
||||||
ExpiresOn: &createdFuture,
|
ExpiresOn: &createdFuture,
|
||||||
|
Lock: &sessionsapi.NoOpLock{},
|
||||||
},
|
},
|
||||||
store: defaultSessionStore,
|
store: defaultSessionStore,
|
||||||
refreshPeriod: 1 * time.Minute,
|
refreshPeriod: 1 * time.Minute,
|
||||||
@ -222,6 +319,7 @@ var _ = Describe("Stored Session Suite", func() {
|
|||||||
RefreshToken: "Refreshed",
|
RefreshToken: "Refreshed",
|
||||||
CreatedAt: &now,
|
CreatedAt: &now,
|
||||||
ExpiresOn: &createdFuture,
|
ExpiresOn: &createdFuture,
|
||||||
|
Lock: &sessionsapi.NoOpLock{},
|
||||||
},
|
},
|
||||||
store: defaultSessionStore,
|
store: defaultSessionStore,
|
||||||
refreshPeriod: 1 * time.Minute,
|
refreshPeriod: 1 * time.Minute,
|
||||||
@ -237,6 +335,7 @@ var _ = Describe("Stored Session Suite", func() {
|
|||||||
RefreshToken: "RefreshError",
|
RefreshToken: "RefreshError",
|
||||||
CreatedAt: &createdPast,
|
CreatedAt: &createdPast,
|
||||||
ExpiresOn: &createdFuture,
|
ExpiresOn: &createdFuture,
|
||||||
|
Lock: &sessionsapi.NoOpLock{},
|
||||||
},
|
},
|
||||||
store: defaultSessionStore,
|
store: defaultSessionStore,
|
||||||
refreshPeriod: 1 * time.Minute,
|
refreshPeriod: 1 * time.Minute,
|
||||||
@ -266,15 +365,109 @@ var _ = Describe("Stored Session Suite", func() {
|
|||||||
validateSession: defaultValidateFunc,
|
validateSession: defaultValidateFunc,
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type storedSessionLoaderConcurrentTableInput struct {
|
||||||
|
existingSession *sessionsapi.SessionState
|
||||||
|
refreshPeriod time.Duration
|
||||||
|
numConcReqs int
|
||||||
|
}
|
||||||
|
|
||||||
|
DescribeTable("when serving concurrent requests",
|
||||||
|
func(in storedSessionLoaderConcurrentTableInput) {
|
||||||
|
lockConc := &LockConc{}
|
||||||
|
|
||||||
|
refreshedChan := make(chan bool, in.numConcReqs)
|
||||||
|
for i := 0; i < in.numConcReqs; i++ {
|
||||||
|
go func(refreshedChan chan bool, lockConc sessionsapi.Lock) {
|
||||||
|
existingSession := *in.existingSession // deep copy existingSession state
|
||||||
|
existingSession.Lock = lockConc
|
||||||
|
store := &fakeSessionStore{
|
||||||
|
LoadFunc: func(req *http.Request) (*sessionsapi.SessionState, error) {
|
||||||
|
return &existingSession, nil
|
||||||
|
},
|
||||||
|
SaveFunc: func(http.ResponseWriter, *http.Request, *sessionsapi.SessionState) error {
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
scope := &middlewareapi.RequestScope{
|
||||||
|
Session: nil,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set up the request with the request header and a request scope
|
||||||
|
req := httptest.NewRequest("", "/", nil)
|
||||||
|
req = middlewareapi.AddRequestScope(req, scope)
|
||||||
|
|
||||||
|
rw := httptest.NewRecorder()
|
||||||
|
|
||||||
|
sessionRefreshed := false
|
||||||
|
opts := &StoredSessionLoaderOptions{
|
||||||
|
SessionStore: store,
|
||||||
|
RefreshPeriod: in.refreshPeriod,
|
||||||
|
RefreshSession: func(ctx context.Context, s *sessionsapi.SessionState) (bool, error) {
|
||||||
|
time.Sleep(10 * time.Millisecond)
|
||||||
|
sessionRefreshed = true
|
||||||
|
return true, nil
|
||||||
|
},
|
||||||
|
ValidateSession: func(context.Context, *sessionsapi.SessionState) bool {
|
||||||
|
return true
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
handler := NewStoredSessionLoader(opts)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
|
||||||
|
handler.ServeHTTP(rw, req)
|
||||||
|
|
||||||
|
refreshedChan <- sessionRefreshed
|
||||||
|
}(refreshedChan, lockConc)
|
||||||
|
}
|
||||||
|
var refreshedSlice []bool
|
||||||
|
for i := 0; i < in.numConcReqs; i++ {
|
||||||
|
refreshedSlice = append(refreshedSlice, <-refreshedChan)
|
||||||
|
}
|
||||||
|
sessionRefreshedCount := 0
|
||||||
|
for _, sessionRefreshed := range refreshedSlice {
|
||||||
|
if sessionRefreshed {
|
||||||
|
sessionRefreshedCount++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Expect(sessionRefreshedCount).To(Equal(1))
|
||||||
|
},
|
||||||
|
Entry("with two concurrent requests", storedSessionLoaderConcurrentTableInput{
|
||||||
|
existingSession: &sessionsapi.SessionState{
|
||||||
|
RefreshToken: refresh,
|
||||||
|
CreatedAt: &createdPast,
|
||||||
|
},
|
||||||
|
numConcReqs: 2,
|
||||||
|
refreshPeriod: 1 * time.Minute,
|
||||||
|
}),
|
||||||
|
Entry("with 5 concurrent requests", storedSessionLoaderConcurrentTableInput{
|
||||||
|
existingSession: &sessionsapi.SessionState{
|
||||||
|
RefreshToken: refresh,
|
||||||
|
CreatedAt: &createdPast,
|
||||||
|
},
|
||||||
|
numConcReqs: 5,
|
||||||
|
refreshPeriod: 1 * time.Minute,
|
||||||
|
}),
|
||||||
|
Entry("with one request", storedSessionLoaderConcurrentTableInput{
|
||||||
|
existingSession: &sessionsapi.SessionState{
|
||||||
|
RefreshToken: refresh,
|
||||||
|
CreatedAt: &createdPast,
|
||||||
|
},
|
||||||
|
numConcReqs: 1,
|
||||||
|
refreshPeriod: 1 * time.Minute,
|
||||||
|
}),
|
||||||
|
)
|
||||||
})
|
})
|
||||||
|
|
||||||
Context("refreshSessionIfNeeded", func() {
|
Context("refreshSessionIfNeeded", func() {
|
||||||
type refreshSessionIfNeededTableInput struct {
|
type refreshSessionIfNeededTableInput struct {
|
||||||
refreshPeriod time.Duration
|
refreshPeriod time.Duration
|
||||||
|
sessionStored bool
|
||||||
session *sessionsapi.SessionState
|
session *sessionsapi.SessionState
|
||||||
expectedErr error
|
expectedErr error
|
||||||
expectRefreshed bool
|
expectRefreshed bool
|
||||||
expectValidated bool
|
expectValidated bool
|
||||||
|
expectedLockState TestLock
|
||||||
}
|
}
|
||||||
|
|
||||||
createdPast := time.Now().Add(-5 * time.Minute)
|
createdPast := time.Now().Add(-5 * time.Minute)
|
||||||
@ -285,9 +478,18 @@ var _ = Describe("Stored Session Suite", func() {
|
|||||||
refreshed := false
|
refreshed := false
|
||||||
validated := false
|
validated := false
|
||||||
|
|
||||||
|
store := &fakeSessionStore{}
|
||||||
|
if in.sessionStored {
|
||||||
|
store = &fakeSessionStore{
|
||||||
|
LoadFunc: func(req *http.Request) (*sessionsapi.SessionState, error) {
|
||||||
|
return in.session, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
s := &storedSessionLoader{
|
s := &storedSessionLoader{
|
||||||
refreshPeriod: in.refreshPeriod,
|
refreshPeriod: in.refreshPeriod,
|
||||||
store: &fakeSessionStore{},
|
store: store,
|
||||||
sessionRefresher: func(_ context.Context, ss *sessionsapi.SessionState) (bool, error) {
|
sessionRefresher: func(_ context.Context, ss *sessionsapi.SessionState) (bool, error) {
|
||||||
refreshed = true
|
refreshed = true
|
||||||
switch ss.RefreshToken {
|
switch ss.RefreshToken {
|
||||||
@ -316,46 +518,117 @@ var _ = Describe("Stored Session Suite", func() {
|
|||||||
}
|
}
|
||||||
Expect(refreshed).To(Equal(in.expectRefreshed))
|
Expect(refreshed).To(Equal(in.expectRefreshed))
|
||||||
Expect(validated).To(Equal(in.expectValidated))
|
Expect(validated).To(Equal(in.expectValidated))
|
||||||
|
testLock, ok := in.session.Lock.(*TestLock)
|
||||||
|
Expect(ok).To(Equal(true))
|
||||||
|
|
||||||
|
Expect(testLock).To(Equal(&in.expectedLockState))
|
||||||
},
|
},
|
||||||
Entry("when the refresh period is 0, and the session does not need refreshing", refreshSessionIfNeededTableInput{
|
Entry("when the refresh period is 0, and the session does not need refreshing", refreshSessionIfNeededTableInput{
|
||||||
refreshPeriod: time.Duration(0),
|
refreshPeriod: time.Duration(0),
|
||||||
session: &sessionsapi.SessionState{
|
session: &sessionsapi.SessionState{
|
||||||
RefreshToken: refresh,
|
RefreshToken: refresh,
|
||||||
CreatedAt: &createdFuture,
|
CreatedAt: &createdFuture,
|
||||||
|
Lock: &TestLock{},
|
||||||
},
|
},
|
||||||
expectedErr: nil,
|
expectedErr: nil,
|
||||||
expectRefreshed: false,
|
expectRefreshed: false,
|
||||||
expectValidated: false,
|
expectValidated: false,
|
||||||
|
expectedLockState: TestLock{},
|
||||||
}),
|
}),
|
||||||
Entry("when the refresh period is 0, and the session needs refreshing", refreshSessionIfNeededTableInput{
|
Entry("when the refresh period is 0, and the session needs refreshing", refreshSessionIfNeededTableInput{
|
||||||
refreshPeriod: time.Duration(0),
|
refreshPeriod: time.Duration(0),
|
||||||
session: &sessionsapi.SessionState{
|
session: &sessionsapi.SessionState{
|
||||||
RefreshToken: refresh,
|
RefreshToken: refresh,
|
||||||
CreatedAt: &createdPast,
|
CreatedAt: &createdPast,
|
||||||
|
Lock: &TestLock{},
|
||||||
},
|
},
|
||||||
expectedErr: nil,
|
expectedErr: nil,
|
||||||
expectRefreshed: false,
|
expectRefreshed: false,
|
||||||
expectValidated: false,
|
expectValidated: false,
|
||||||
|
expectedLockState: TestLock{},
|
||||||
}),
|
}),
|
||||||
Entry("when the session does not need refreshing", refreshSessionIfNeededTableInput{
|
Entry("when the session does not need refreshing", refreshSessionIfNeededTableInput{
|
||||||
refreshPeriod: 1 * time.Minute,
|
refreshPeriod: 1 * time.Minute,
|
||||||
session: &sessionsapi.SessionState{
|
session: &sessionsapi.SessionState{
|
||||||
RefreshToken: refresh,
|
RefreshToken: refresh,
|
||||||
CreatedAt: &createdFuture,
|
CreatedAt: &createdFuture,
|
||||||
|
Lock: &TestLock{},
|
||||||
},
|
},
|
||||||
expectedErr: nil,
|
expectedErr: nil,
|
||||||
expectRefreshed: false,
|
expectRefreshed: false,
|
||||||
expectValidated: false,
|
expectValidated: false,
|
||||||
|
expectedLockState: TestLock{},
|
||||||
}),
|
}),
|
||||||
Entry("when the session is refreshed by the provider", refreshSessionIfNeededTableInput{
|
Entry("when the session is refreshed by the provider", refreshSessionIfNeededTableInput{
|
||||||
refreshPeriod: 1 * time.Minute,
|
refreshPeriod: 1 * time.Minute,
|
||||||
session: &sessionsapi.SessionState{
|
session: &sessionsapi.SessionState{
|
||||||
RefreshToken: refresh,
|
RefreshToken: refresh,
|
||||||
CreatedAt: &createdPast,
|
CreatedAt: &createdPast,
|
||||||
|
Lock: &TestLock{},
|
||||||
},
|
},
|
||||||
expectedErr: nil,
|
expectedErr: nil,
|
||||||
expectRefreshed: true,
|
expectRefreshed: true,
|
||||||
expectValidated: true,
|
expectValidated: true,
|
||||||
|
expectedLockState: TestLock{
|
||||||
|
Locked: false,
|
||||||
|
WasObtained: true,
|
||||||
|
WasReleased: true,
|
||||||
|
PeekedCount: 1,
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
Entry("when the session is locked and instead loaded from storage", refreshSessionIfNeededTableInput{
|
||||||
|
refreshPeriod: 1 * time.Minute,
|
||||||
|
session: &sessionsapi.SessionState{
|
||||||
|
RefreshToken: noRefresh,
|
||||||
|
CreatedAt: &createdPast,
|
||||||
|
Lock: &TestLock{
|
||||||
|
Locked: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
sessionStored: true,
|
||||||
|
expectedErr: nil,
|
||||||
|
expectRefreshed: false,
|
||||||
|
expectValidated: true,
|
||||||
|
expectedLockState: TestLock{
|
||||||
|
Locked: false,
|
||||||
|
PeekedCount: 2,
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
Entry("when obtaining lock failed, but concurrent request refreshed", refreshSessionIfNeededTableInput{
|
||||||
|
refreshPeriod: 1 * time.Minute,
|
||||||
|
session: &sessionsapi.SessionState{
|
||||||
|
RefreshToken: noRefresh,
|
||||||
|
CreatedAt: &createdPast,
|
||||||
|
Lock: &TestLock{
|
||||||
|
ObtainError: errors.New("not able to obtain lock"),
|
||||||
|
LockedOnPeekCount: 2,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedErr: nil,
|
||||||
|
expectRefreshed: false,
|
||||||
|
expectValidated: true,
|
||||||
|
expectedLockState: TestLock{
|
||||||
|
PeekedCount: 3,
|
||||||
|
LockedOnPeekCount: 2,
|
||||||
|
ObtainError: errors.New("not able to obtain lock"),
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
Entry("when obtaining lock failed", refreshSessionIfNeededTableInput{
|
||||||
|
refreshPeriod: 1 * time.Minute,
|
||||||
|
session: &sessionsapi.SessionState{
|
||||||
|
RefreshToken: noRefresh,
|
||||||
|
CreatedAt: &createdPast,
|
||||||
|
Lock: &TestLock{
|
||||||
|
ObtainError: errors.New("not able to obtain lock"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedErr: nil,
|
||||||
|
expectRefreshed: false,
|
||||||
|
expectValidated: true,
|
||||||
|
expectedLockState: TestLock{
|
||||||
|
PeekedCount: 2,
|
||||||
|
ObtainError: errors.New("not able to obtain lock"),
|
||||||
|
},
|
||||||
}),
|
}),
|
||||||
Entry("when the session is not refreshed by the provider", refreshSessionIfNeededTableInput{
|
Entry("when the session is not refreshed by the provider", refreshSessionIfNeededTableInput{
|
||||||
refreshPeriod: 1 * time.Minute,
|
refreshPeriod: 1 * time.Minute,
|
||||||
@ -363,42 +636,53 @@ var _ = Describe("Stored Session Suite", func() {
|
|||||||
RefreshToken: noRefresh,
|
RefreshToken: noRefresh,
|
||||||
CreatedAt: &createdPast,
|
CreatedAt: &createdPast,
|
||||||
ExpiresOn: &createdFuture,
|
ExpiresOn: &createdFuture,
|
||||||
|
Lock: &TestLock{},
|
||||||
},
|
},
|
||||||
expectedErr: nil,
|
expectedErr: nil,
|
||||||
expectRefreshed: true,
|
expectRefreshed: true,
|
||||||
expectValidated: true,
|
expectValidated: true,
|
||||||
|
expectedLockState: TestLock{
|
||||||
|
Locked: false,
|
||||||
|
WasObtained: true,
|
||||||
|
WasReleased: true,
|
||||||
|
PeekedCount: 1,
|
||||||
|
},
|
||||||
}),
|
}),
|
||||||
Entry("when the provider doesn't implement refresh but validation succeeds", refreshSessionIfNeededTableInput{
|
Entry("when the provider doesn't implement refresh", refreshSessionIfNeededTableInput{
|
||||||
refreshPeriod: 1 * time.Minute,
|
refreshPeriod: 1 * time.Minute,
|
||||||
session: &sessionsapi.SessionState{
|
session: &sessionsapi.SessionState{
|
||||||
RefreshToken: notImplemented,
|
RefreshToken: notImplemented,
|
||||||
CreatedAt: &createdPast,
|
CreatedAt: &createdPast,
|
||||||
|
Lock: &TestLock{},
|
||||||
},
|
},
|
||||||
expectedErr: nil,
|
expectedErr: nil,
|
||||||
expectRefreshed: true,
|
expectRefreshed: true,
|
||||||
expectValidated: true,
|
expectValidated: true,
|
||||||
}),
|
expectedLockState: TestLock{
|
||||||
Entry("when the provider refresh fails but validation succeeds", refreshSessionIfNeededTableInput{
|
Locked: false,
|
||||||
refreshPeriod: 1 * time.Minute,
|
WasObtained: true,
|
||||||
session: &sessionsapi.SessionState{
|
WasReleased: true,
|
||||||
RefreshToken: "RefreshError",
|
PeekedCount: 1,
|
||||||
CreatedAt: &createdPast,
|
|
||||||
},
|
},
|
||||||
expectedErr: nil,
|
|
||||||
expectRefreshed: true,
|
|
||||||
expectValidated: true,
|
|
||||||
}),
|
}),
|
||||||
Entry("when the session is not refreshed by the provider and validation fails", refreshSessionIfNeededTableInput{
|
Entry("when the session is not refreshed by the provider", refreshSessionIfNeededTableInput{
|
||||||
refreshPeriod: 1 * time.Minute,
|
refreshPeriod: 1 * time.Minute,
|
||||||
session: &sessionsapi.SessionState{
|
session: &sessionsapi.SessionState{
|
||||||
AccessToken: "Invalid",
|
AccessToken: "Invalid",
|
||||||
RefreshToken: noRefresh,
|
RefreshToken: noRefresh,
|
||||||
CreatedAt: &createdPast,
|
CreatedAt: &createdPast,
|
||||||
ExpiresOn: &createdFuture,
|
ExpiresOn: &createdFuture,
|
||||||
|
Lock: &TestLock{},
|
||||||
},
|
},
|
||||||
expectedErr: errors.New("session is invalid"),
|
expectedErr: errors.New("session is invalid"),
|
||||||
expectRefreshed: true,
|
expectRefreshed: true,
|
||||||
expectValidated: true,
|
expectValidated: true,
|
||||||
|
expectedLockState: TestLock{
|
||||||
|
Locked: false,
|
||||||
|
WasObtained: true,
|
||||||
|
WasReleased: true,
|
||||||
|
PeekedCount: 1,
|
||||||
|
},
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
})
|
})
|
||||||
|
Loading…
x
Reference in New Issue
Block a user