From ad8ce2f6a48d27eb4ccd11ad63ae6a58ce64087f Mon Sep 17 00:00:00 2001 From: Kevin Kreitner Date: Mon, 11 Oct 2021 15:36:33 +0200 Subject: [PATCH] Add concurrent requests tests --- pkg/middleware/stored_session_test.go | 136 +++++++++++++++++++++++++- 1 file changed, 135 insertions(+), 1 deletion(-) diff --git a/pkg/middleware/stored_session_test.go b/pkg/middleware/stored_session_test.go index 4342d464..6ce816ab 100644 --- a/pkg/middleware/stored_session_test.go +++ b/pkg/middleware/stored_session_test.go @@ -6,6 +6,7 @@ import ( "fmt" "net/http" "net/http/httptest" + "sync" "time" middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" @@ -65,9 +66,50 @@ func (l *TestLock) Release(_ context.Context) error { return nil } +type LockConc struct { + mu sync.Mutex + lock bool + disablePeek bool +} + +func (l *LockConc) Obtain(_ context.Context, _ time.Duration) error { + l.mu.Lock() + if l.lock { + l.mu.Unlock() + return sessionsapi.ErrLockNotObtained + } + l.lock = true + l.mu.Unlock() + return nil +} + +func (l *LockConc) Peek(_ context.Context) (bool, error) { + var response bool + l.mu.Lock() + if l.disablePeek { + response = false + } else { + response = l.lock + } + l.mu.Unlock() + return response, nil +} + +func (l *LockConc) Refresh(_ context.Context, _ time.Duration) error { + return nil +} + +func (l *LockConc) Release(_ context.Context) error { + l.mu.Lock() + l.lock = false + l.mu.Unlock() + return nil +} + var _ = Describe("Stored Session Suite", func() { const ( refresh = "Refresh" + refreshed = "Refreshed" noRefresh = "NoRefresh" notImplemented = "NotImplemented" ) @@ -82,7 +124,7 @@ var _ = Describe("Stored Session Suite", func() { var defaultRefreshFunc = func(_ context.Context, ss *sessionsapi.SessionState) (bool, error) { switch ss.RefreshToken { case refresh: - ss.RefreshToken = "Refreshed" + ss.RefreshToken = refreshed return true, nil case noRefresh: return false, nil @@ -317,6 +359,98 @@ var _ = Describe("Stored Session Suite", func() { validateSession: defaultValidateFunc, }), ) + + type storedSessionLoaderConcurrentTableInput struct { + existingSession *sessionsapi.SessionState + refreshPeriod time.Duration + numConcReqs int + } + + DescribeTable("when serving concurrent requests", + func(in storedSessionLoaderConcurrentTableInput) { + lockConc := &LockConc{} + + refreshedChan := make(chan bool, in.numConcReqs) + for i := 0; i < in.numConcReqs; i++ { + go func(refreshedChan chan bool, lockConc sessionsapi.Lock) { + existingSession := *in.existingSession // deep copy existingSession state + existingSession.Lock = lockConc + store := &fakeSessionStore{ + LoadFunc: func(req *http.Request) (*sessionsapi.SessionState, error) { + return &existingSession, nil + }, + SaveFunc: func(http.ResponseWriter, *http.Request, *sessionsapi.SessionState) error { + return nil + }, + } + + scope := &middlewareapi.RequestScope{ + Session: nil, + } + + // Set up the request with the request header and a request scope + req := httptest.NewRequest("", "/", nil) + req = middlewareapi.AddRequestScope(req, scope) + + rw := httptest.NewRecorder() + + sessionRefreshed := false + opts := &StoredSessionLoaderOptions{ + SessionStore: store, + RefreshPeriod: in.refreshPeriod, + RefreshSession: func(ctx context.Context, s *sessionsapi.SessionState) (bool, error) { + time.Sleep(10 * time.Millisecond) + sessionRefreshed = true + return true, nil + }, + ValidateSession: func(context.Context, *sessionsapi.SessionState) bool { + return true + }, + } + + handler := NewStoredSessionLoader(opts)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) + handler.ServeHTTP(rw, req) + + refreshedChan <- sessionRefreshed + }(refreshedChan, lockConc) + } + var refreshedSlice []bool + for i := 0; i < in.numConcReqs; i++ { + refreshedSlice = append(refreshedSlice, <-refreshedChan) + } + sessionRefreshedCount := 0 + for _, sessionRefreshed := range refreshedSlice { + if sessionRefreshed { + sessionRefreshedCount++ + } + } + Expect(sessionRefreshedCount).To(Equal(1)) + }, + Entry("with two concurrent requests", storedSessionLoaderConcurrentTableInput{ + existingSession: &sessionsapi.SessionState{ + RefreshToken: refresh, + CreatedAt: &createdPast, + }, + numConcReqs: 2, + refreshPeriod: 1 * time.Minute, + }), + Entry("with 5 concurrent requests", storedSessionLoaderConcurrentTableInput{ + existingSession: &sessionsapi.SessionState{ + RefreshToken: refresh, + CreatedAt: &createdPast, + }, + numConcReqs: 5, + refreshPeriod: 1 * time.Minute, + }), + Entry("with one request", storedSessionLoaderConcurrentTableInput{ + existingSession: &sessionsapi.SessionState{ + RefreshToken: refresh, + CreatedAt: &createdPast, + }, + numConcReqs: 1, + refreshPeriod: 1 * time.Minute, + }), + ) }) Context("refreshSessionIfNeeded", func() {