From 69d6fc8a0877e1e1c0624406635ceb67cc05b9bc Mon Sep 17 00:00:00 2001
From: Kevin Kreitner <kevin.kreitner@real-digital.de>
Date: Mon, 22 Feb 2021 08:33:53 +0100
Subject: [PATCH] Split RefreshSessionIfNeeded in two methods and use Redis
 lock

---
 go.mod                                 |   1 +
 go.sum                                 |   6 +
 oauthproxy.go                          |   9 +-
 pkg/apis/sessions/interfaces.go        |   4 +-
 pkg/middleware/stored_session.go       |  89 +++++++------
 pkg/middleware/stored_session_test.go  | 175 ++++++++++++++++---------
 pkg/sessions/cookie/session_store.go   |   9 +-
 pkg/sessions/persistence/interfaces.go |   3 +-
 pkg/sessions/persistence/manager.go    |  20 ++-
 pkg/sessions/persistence/ticket.go     |   8 +-
 pkg/sessions/redis/client.go           |  42 ++++--
 pkg/sessions/redis/redis_store.go      |  22 +++-
 providers/azure.go                     |   9 +-
 providers/azure_test.go                |   4 +-
 providers/gitlab.go                    |   7 +-
 providers/google.go                    |   9 +-
 providers/oidc.go                      |   9 +-
 providers/oidc_test.go                 |   4 +-
 providers/provider_default.go          |   7 +-
 providers/provider_default_test.go     |   2 +-
 providers/providers.go                 |   3 +-
 21 files changed, 297 insertions(+), 145 deletions(-)

diff --git a/go.mod b/go.mod
index 7411accd..2a81284e 100644
--- a/go.mod
+++ b/go.mod
@@ -22,6 +22,7 @@ require (
 	github.com/onsi/gomega v1.10.2
 	github.com/pierrec/lz4 v2.5.2+incompatible
 	github.com/pquerna/cachecontrol v0.0.0-20180517163645-1555304b9b35 // indirect
+	github.com/prometheus/client_golang v0.9.3
 	github.com/spf13/pflag v1.0.5
 	github.com/spf13/viper v1.6.3
 	github.com/stretchr/testify v1.6.1
diff --git a/go.sum b/go.sum
index 7c0170dd..9e3b52dc 100644
--- a/go.sum
+++ b/go.sum
@@ -24,6 +24,7 @@ github.com/alicebob/miniredis/v2 v2.13.0 h1:QPosMaxm+r6Qs+YcCtL2Z2a2RSdC9VfXJLpd
 github.com/alicebob/miniredis/v2 v2.13.0/go.mod h1:0UIBNuf97uxrWhdVBpJvPtafKyGpL2NS2pYe0tYM97k=
 github.com/armon/consul-api v0.0.0-20180202201655-eb2c6b5be1b6/go.mod h1:grANhF5doyWs3UAsr3K4I6qtAmlQcZDesFNEHPZAzj8=
 github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q=
+github.com/beorn7/perks v1.0.0 h1:HWo1m869IqiPhD389kmkxeTalrjNbbJTC8LXupb+sl0=
 github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8=
 github.com/bitly/go-simplejson v0.5.0 h1:6IH+V8/tVMab511d5bn4M7EwGXZf9Hj6i2xSwkNEM+Y=
 github.com/bitly/go-simplejson v0.5.0/go.mod h1:cXHtHw4XUPsvGaxgjIAn8PhEWG9NfngEKAMDJEczWVA=
@@ -168,6 +169,7 @@ github.com/magiconair/properties v1.8.1/go.mod h1:PppfXfuXeibc/6YijjN8zIbojt8czP
 github.com/mailru/easyjson v0.0.0-20160728113105-d5b7844b561a/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc=
 github.com/matryer/is v1.2.0 h1:92UTHpy8CDwaJ08GqLDzhhuixiBUUD1p3AU6PHddz4A=
 github.com/matryer/is v1.2.0/go.mod h1:2fLPjFQM9rhQ15aVEtbuwhJinnOqrmgXPNdZsdwlWXA=
+github.com/matttproud/golang_protobuf_extensions v1.0.1 h1:4hp9jkHxhMHkqkrB3Ix0jegS5sx/RkqARlsWZ6pIwiU=
 github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0=
 github.com/mbland/hmacauth v0.0.0-20170912233209-44256dfd4bfa h1:hI1uC2A3vJFjwvBn0G0a7QBRdBUp6Y048BtLAHRTKPo=
 github.com/mbland/hmacauth v0.0.0-20170912233209-44256dfd4bfa/go.mod h1:8vxFeeg++MqgCHwehSuwTlYCF0ALyDJbYJ1JsKi7v6s=
@@ -208,13 +210,17 @@ github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZN
 github.com/pquerna/cachecontrol v0.0.0-20180517163645-1555304b9b35 h1:J9b7z+QKAmPf4YLrFg6oQUotqHQeUNWwkvo7jZp1GLU=
 github.com/pquerna/cachecontrol v0.0.0-20180517163645-1555304b9b35/go.mod h1:prYjPmNq4d1NPVmpShWobRqXY3q7Vp+80DqgxxUrUIA=
 github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw=
+github.com/prometheus/client_golang v0.9.3 h1:9iH4JKXLzFbOAdtqv/a+j8aewx2Y8lAjAydhbaScPF8=
 github.com/prometheus/client_golang v0.9.3/go.mod h1:/TN21ttK/J9q6uSwhBd54HahCDft0ttaMvbicHlPoso=
 github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo=
 github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
+github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4 h1:gQz4mCbXsO+nc9n1hCxHcGA3Zx3Eo+UHZoInFGUIXNM=
 github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
 github.com/prometheus/common v0.0.0-20181113130724-41aa239b4cce/go.mod h1:daVV7qP5qjZbuso7PdcryaAu0sAZbrN9i7WWcTMWvro=
+github.com/prometheus/common v0.4.0 h1:7etb9YClo3a6HjLzfl6rIQaU+FDfi0VSX39io3aQ+DM=
 github.com/prometheus/common v0.4.0/go.mod h1:TNfzLD0ON7rHzMJeJkieUDPYmFC7Snx/y86RQel1bk4=
 github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk=
+github.com/prometheus/procfs v0.0.0-20190507164030-5867b95ac084 h1:sofwID9zm4tzrgykg80hfFph1mryUeLRsUfoocVVmRY=
 github.com/prometheus/procfs v0.0.0-20190507164030-5867b95ac084/go.mod h1:TjEm7ze935MbeOT/UhFTIMYKhuLP4wbCsTZCD3I8kEA=
 github.com/prometheus/tsdb v0.7.1/go.mod h1:qhTCs0VvXwvX/y3TZrWD7rabWM+ijKTux40TwIPHuXU=
 github.com/rogpeppe/fastuuid v0.0.0-20150106093220-6724a57986af/go.mod h1:XWv6SoW27p1b0cqNHllgS5HIMJraePCO15w5zCzIWYg=
diff --git a/oauthproxy.go b/oauthproxy.go
index 0cfa1f93..68aba61a 100644
--- a/oauthproxy.go
+++ b/oauthproxy.go
@@ -276,10 +276,11 @@ func buildSessionChain(opts *options.Options, sessionStore sessionsapi.SessionSt
 	}
 
 	chain = chain.Append(middleware.NewStoredSessionLoader(&middleware.StoredSessionLoaderOptions{
-		SessionStore:           sessionStore,
-		RefreshPeriod:          opts.Cookie.Refresh,
-		RefreshSessionIfNeeded: opts.GetProvider().RefreshSessionIfNeeded,
-		ValidateSessionState:   opts.GetProvider().ValidateSession,
+		SessionStore:         sessionStore,
+		RefreshPeriod:        opts.Cookie.Refresh,
+		RefreshSession:       opts.GetProvider().RefreshSession,
+		IsRefreshNeeded:      opts.GetProvider().IsRefreshNeeded,
+		ValidateSessionState: opts.GetProvider().ValidateSession,
 	}))
 
 	return chain
diff --git a/pkg/apis/sessions/interfaces.go b/pkg/apis/sessions/interfaces.go
index bd02eaf2..9980d89d 100644
--- a/pkg/apis/sessions/interfaces.go
+++ b/pkg/apis/sessions/interfaces.go
@@ -2,13 +2,13 @@ package sessions
 
 import (
 	"net/http"
-	"time"
 )
 
 // SessionStore is an interface to storing user sessions in the proxy
 type SessionStore interface {
 	Save(rw http.ResponseWriter, req *http.Request, s *SessionState) error
 	Load(req *http.Request) (*SessionState, error)
-	Lock(req *http.Request, expiration time.Duration) error
+	LoadWithLock(req *http.Request) (*SessionState, error)
+	ReleaseLock(req *http.Request) error
 	Clear(rw http.ResponseWriter, req *http.Request) error
 }
diff --git a/pkg/middleware/stored_session.go b/pkg/middleware/stored_session.go
index 1bd0a9a4..b5abc1fd 100644
--- a/pkg/middleware/stored_session.go
+++ b/pkg/middleware/stored_session.go
@@ -24,10 +24,13 @@ type StoredSessionLoaderOptions struct {
 	RefreshPeriod time.Duration
 
 	// Provider based sesssion refreshing
-	RefreshSessionIfNeeded func(context.Context, *sessionsapi.SessionState) (bool, error)
+	RefreshSession func(context.Context, *sessionsapi.SessionState) (bool, error)
+
+	// Provider based session refresh check
+	IsRefreshNeeded func(*sessionsapi.SessionState) bool
 
 	// Provider based session validation.
-	// If the sesssion is older than `RefreshPeriod` but the provider doesn't
+	// If the session is older than `RefreshPeriod` but the provider doesn't
 	// refresh it, we must re-validate using this validation.
 	ValidateSessionState func(context.Context, *sessionsapi.SessionState) bool
 }
@@ -38,10 +41,11 @@ type StoredSessionLoaderOptions struct {
 // If a session was loader by a previous handler, it will not be replaced.
 func NewStoredSessionLoader(opts *StoredSessionLoaderOptions) alice.Constructor {
 	ss := &storedSessionLoader{
-		store:                              opts.SessionStore,
-		refreshPeriod:                      opts.RefreshPeriod,
-		refreshSessionWithProviderIfNeeded: opts.RefreshSessionIfNeeded,
-		validateSessionState:               opts.ValidateSessionState,
+		store:                       opts.SessionStore,
+		refreshPeriod:               opts.RefreshPeriod,
+		refreshSessionWithProvider:  opts.RefreshSession,
+		isRefreshNeededWithProvider: opts.IsRefreshNeeded,
+		validateSessionState:        opts.ValidateSessionState,
 	}
 	return ss.loadSession
 }
@@ -49,10 +53,11 @@ func NewStoredSessionLoader(opts *StoredSessionLoaderOptions) alice.Constructor
 // storedSessionLoader is responsible for loading sessions from cookie
 // identified sessions in the session store.
 type storedSessionLoader struct {
-	store                              sessionsapi.SessionStore
-	refreshPeriod                      time.Duration
-	refreshSessionWithProviderIfNeeded func(context.Context, *sessionsapi.SessionState) (bool, error)
-	validateSessionState               func(context.Context, *sessionsapi.SessionState) bool
+	store                       sessionsapi.SessionStore
+	refreshPeriod               time.Duration
+	refreshSessionWithProvider  func(context.Context, *sessionsapi.SessionState) (bool, error)
+	isRefreshNeededWithProvider func(*sessionsapi.SessionState) bool
+	validateSessionState        func(context.Context, *sessionsapi.SessionState) bool
 }
 
 // loadSession attempts to load a session as identified by the request cookies.
@@ -98,44 +103,54 @@ func (s *storedSessionLoader) getValidatedSession(rw http.ResponseWriter, req *h
 		return nil, nil
 	}
 
-	err = s.refreshSessionIfNeeded(rw, req, session)
+	if !s.isSessionRefreshNeeded(session) {
+		return session, nil
+	}
+
+	session, err = s.store.LoadWithLock(req)
+	if err != nil {
+		return nil, err
+	}
+	if session == nil {
+		// No session was found in the storage, nothing more to do
+		return nil, nil
+	}
+
+	if !s.isSessionRefreshNeeded(session) {
+		_ = s.store.ReleaseLock(req)
+		return session, nil
+	}
+	logger.Printf("Refreshing %s old session cookie for %s (refresh after %s)", session.Age(), session, s.refreshPeriod)
+	refreshed, err := s.refreshSession(rw, req, session)
+	_ = s.store.ReleaseLock(req)
 	if err != nil {
 		return nil, fmt.Errorf("error refreshing access token for session (%s): %v", session, err)
 	}
 
+	if refreshed {
+		return session, nil
+	}
+
+	// Session wasn't refreshed, so make sure it's still valid
+	err = s.validateSession(req.Context(), session)
+	if err != nil {
+		return nil, err
+	}
 	return session, nil
 }
 
-// refreshSessionIfNeeded will attempt to refresh a session if the session
-// is older than the refresh period.
-// It is assumed that if the provider refreshes the session, the session is now
-// valid.
-// If the session requires refreshing but the provider does not refresh it,
-// we must validate the session to ensure that the returned session is still
-// valid.
-func (s *storedSessionLoader) refreshSessionIfNeeded(rw http.ResponseWriter, req *http.Request, session *sessionsapi.SessionState) error {
-	if s.refreshPeriod <= time.Duration(0) || session.Age() < s.refreshPeriod {
-		// Refresh is disabled or the session is not old enough, do nothing
-		return nil
+// isSessionRefreshNeeded will check if the session need to be refreshed
+func (s *storedSessionLoader) isSessionRefreshNeeded(session *sessionsapi.SessionState) bool {
+	if s.refreshPeriod > time.Duration(0) && session.Age() >= s.refreshPeriod {
+		return s.isRefreshNeededWithProvider(session)
 	}
-
-	logger.Printf("Refreshing %s old session cookie for %s (refresh after %s)", session.Age(), session, s.refreshPeriod)
-	refreshed, err := s.refreshSessionWithProvider(rw, req, session)
-	if err != nil {
-		return err
-	}
-
-	if !refreshed {
-		// Session wasn't refreshed, so make sure it's still valid
-		return s.validateSession(req.Context(), session)
-	}
-	return nil
+	return false
 }
 
-// refreshSessionWithProvider attempts to refresh the sessinon with the provider
+// refreshSession attempts to refresh the sessinon with the provider
 // and will save the session if it was updated.
-func (s *storedSessionLoader) refreshSessionWithProvider(rw http.ResponseWriter, req *http.Request, session *sessionsapi.SessionState) (bool, error) {
-	refreshed, err := s.refreshSessionWithProviderIfNeeded(req.Context(), session)
+func (s *storedSessionLoader) refreshSession(rw http.ResponseWriter, req *http.Request, session *sessionsapi.SessionState) (bool, error) {
+	refreshed, err := s.refreshSessionWithProvider(req.Context(), session)
 	if err != nil {
 		return false, fmt.Errorf("error refreshing access token: %v", err)
 	}
diff --git a/pkg/middleware/stored_session_test.go b/pkg/middleware/stored_session_test.go
index 4a8fd9da..10a612f1 100644
--- a/pkg/middleware/stored_session_test.go
+++ b/pkg/middleware/stored_session_test.go
@@ -39,6 +39,17 @@ var _ = Describe("Stored Session Suite", func() {
 			}
 		}
 
+		var defaultIsRefreshNeededFunc = func(ss *sessionsapi.SessionState) bool {
+			switch ss.RefreshToken {
+			case refresh:
+				return true
+			case noRefresh:
+				return false
+			default:
+				return false
+			}
+		}
+
 		var defaultValidateFunc = func(_ context.Context, ss *sessionsapi.SessionState) bool {
 			return ss.AccessToken != "Invalid"
 		}
@@ -86,13 +97,14 @@ var _ = Describe("Stored Session Suite", func() {
 		}
 
 		type storedSessionLoaderTableInput struct {
-			requestHeaders  http.Header
-			existingSession *sessionsapi.SessionState
-			expectedSession *sessionsapi.SessionState
-			store           sessionsapi.SessionStore
-			refreshPeriod   time.Duration
-			refreshSession  func(context.Context, *sessionsapi.SessionState) (bool, error)
-			validateSession func(context.Context, *sessionsapi.SessionState) bool
+			requestHeaders         http.Header
+			existingSession        *sessionsapi.SessionState
+			expectedSession        *sessionsapi.SessionState
+			store                  sessionsapi.SessionStore
+			refreshPeriod          time.Duration
+			refreshSession         func(context.Context, *sessionsapi.SessionState) (bool, error)
+			isRefreshSessionNeeded func(*sessionsapi.SessionState) bool
+			validateSession        func(context.Context, *sessionsapi.SessionState) bool
 		}
 
 		DescribeTable("when serving a request",
@@ -109,10 +121,11 @@ var _ = Describe("Stored Session Suite", func() {
 				rw := httptest.NewRecorder()
 
 				opts := &StoredSessionLoaderOptions{
-					SessionStore:           in.store,
-					RefreshPeriod:          in.refreshPeriod,
-					RefreshSessionIfNeeded: in.refreshSession,
-					ValidateSessionState:   in.validateSession,
+					SessionStore:         in.store,
+					RefreshPeriod:        in.refreshPeriod,
+					RefreshSession:       in.refreshSession,
+					IsRefreshNeeded:      in.isRefreshSessionNeeded,
+					ValidateSessionState: in.validateSession,
 				}
 
 				// Create the handler with a next handler that will capture the session
@@ -126,24 +139,26 @@ var _ = Describe("Stored Session Suite", func() {
 				Expect(gotSession).To(Equal(in.expectedSession))
 			},
 			Entry("with no cookie", storedSessionLoaderTableInput{
-				requestHeaders:  http.Header{},
-				existingSession: nil,
-				expectedSession: nil,
-				store:           defaultSessionStore,
-				refreshPeriod:   1 * time.Minute,
-				refreshSession:  defaultRefreshFunc,
-				validateSession: defaultValidateFunc,
+				requestHeaders:         http.Header{},
+				existingSession:        nil,
+				expectedSession:        nil,
+				store:                  defaultSessionStore,
+				refreshPeriod:          1 * time.Minute,
+				refreshSession:         defaultRefreshFunc,
+				isRefreshSessionNeeded: defaultIsRefreshNeededFunc,
+				validateSession:        defaultValidateFunc,
 			}),
 			Entry("with an invalid cookie", storedSessionLoaderTableInput{
 				requestHeaders: http.Header{
 					"Cookie": []string{"_oauth2_proxy=NonExistent"},
 				},
-				existingSession: nil,
-				expectedSession: nil,
-				store:           defaultSessionStore,
-				refreshPeriod:   1 * time.Minute,
-				refreshSession:  defaultRefreshFunc,
-				validateSession: defaultValidateFunc,
+				existingSession:        nil,
+				expectedSession:        nil,
+				store:                  defaultSessionStore,
+				refreshPeriod:          1 * time.Minute,
+				refreshSession:         defaultRefreshFunc,
+				isRefreshSessionNeeded: defaultIsRefreshNeededFunc,
+				validateSession:        defaultValidateFunc,
 			}),
 			Entry("with an existing session", storedSessionLoaderTableInput{
 				requestHeaders: http.Header{
@@ -155,10 +170,11 @@ var _ = Describe("Stored Session Suite", func() {
 				expectedSession: &sessionsapi.SessionState{
 					RefreshToken: "Existing",
 				},
-				store:           defaultSessionStore,
-				refreshPeriod:   1 * time.Minute,
-				refreshSession:  defaultRefreshFunc,
-				validateSession: defaultValidateFunc,
+				store:                  defaultSessionStore,
+				refreshPeriod:          1 * time.Minute,
+				refreshSession:         defaultRefreshFunc,
+				isRefreshSessionNeeded: defaultIsRefreshNeededFunc,
+				validateSession:        defaultValidateFunc,
 			}),
 			Entry("with a session that has not expired", storedSessionLoaderTableInput{
 				requestHeaders: http.Header{
@@ -170,21 +186,23 @@ var _ = Describe("Stored Session Suite", func() {
 					CreatedAt:    &createdPast,
 					ExpiresOn:    &createdFuture,
 				},
-				store:           defaultSessionStore,
-				refreshPeriod:   1 * time.Minute,
-				refreshSession:  defaultRefreshFunc,
-				validateSession: defaultValidateFunc,
+				store:                  defaultSessionStore,
+				refreshPeriod:          1 * time.Minute,
+				refreshSession:         defaultRefreshFunc,
+				isRefreshSessionNeeded: defaultIsRefreshNeededFunc,
+				validateSession:        defaultValidateFunc,
 			}),
 			Entry("with a session that cannot refresh and has expired", storedSessionLoaderTableInput{
 				requestHeaders: http.Header{
 					"Cookie": []string{"_oauth2_proxy=ExpiredNoRefreshSession"},
 				},
-				existingSession: nil,
-				expectedSession: nil,
-				store:           defaultSessionStore,
-				refreshPeriod:   1 * time.Minute,
-				refreshSession:  defaultRefreshFunc,
-				validateSession: defaultValidateFunc,
+				existingSession:        nil,
+				expectedSession:        nil,
+				store:                  defaultSessionStore,
+				refreshPeriod:          1 * time.Minute,
+				refreshSession:         defaultRefreshFunc,
+				isRefreshSessionNeeded: defaultIsRefreshNeededFunc,
+				validateSession:        defaultValidateFunc,
 			}),
 			Entry("with a session that can refresh, but is younger than refresh period", storedSessionLoaderTableInput{
 				requestHeaders: http.Header{
@@ -196,10 +214,11 @@ var _ = Describe("Stored Session Suite", func() {
 					CreatedAt:    &createdPast,
 					ExpiresOn:    &createdFuture,
 				},
-				store:           defaultSessionStore,
-				refreshPeriod:   10 * time.Minute,
-				refreshSession:  defaultRefreshFunc,
-				validateSession: defaultValidateFunc,
+				store:                  defaultSessionStore,
+				refreshPeriod:          10 * time.Minute,
+				refreshSession:         defaultRefreshFunc,
+				isRefreshSessionNeeded: defaultIsRefreshNeededFunc,
+				validateSession:        defaultValidateFunc,
 			}),
 			Entry("with a session that can refresh and is older than the refresh period", storedSessionLoaderTableInput{
 				requestHeaders: http.Header{
@@ -211,37 +230,40 @@ var _ = Describe("Stored Session Suite", func() {
 					CreatedAt:    &createdPast,
 					ExpiresOn:    &createdFuture,
 				},
-				store:           defaultSessionStore,
-				refreshPeriod:   1 * time.Minute,
-				refreshSession:  defaultRefreshFunc,
-				validateSession: defaultValidateFunc,
+				store:                  defaultSessionStore,
+				refreshPeriod:          1 * time.Minute,
+				refreshSession:         defaultRefreshFunc,
+				isRefreshSessionNeeded: defaultIsRefreshNeededFunc,
+				validateSession:        defaultValidateFunc,
 			}),
 			Entry("when the provider refresh fails", storedSessionLoaderTableInput{
 				requestHeaders: http.Header{
 					"Cookie": []string{"_oauth2_proxy=RefreshError"},
 				},
-				existingSession: nil,
-				expectedSession: nil,
-				store:           defaultSessionStore,
-				refreshPeriod:   1 * time.Minute,
-				refreshSession:  defaultRefreshFunc,
-				validateSession: defaultValidateFunc,
+				existingSession:        nil,
+				expectedSession:        nil,
+				store:                  defaultSessionStore,
+				refreshPeriod:          1 * time.Minute,
+				refreshSession:         defaultRefreshFunc,
+				isRefreshSessionNeeded: defaultIsRefreshNeededFunc,
+				validateSession:        defaultValidateFunc,
 			}),
 			Entry("when the session is not refreshed and is no longer valid", storedSessionLoaderTableInput{
 				requestHeaders: http.Header{
 					"Cookie": []string{"_oauth2_proxy=InvalidNoRefreshSession"},
 				},
-				existingSession: nil,
-				expectedSession: nil,
-				store:           defaultSessionStore,
-				refreshPeriod:   1 * time.Minute,
-				refreshSession:  defaultRefreshFunc,
-				validateSession: defaultValidateFunc,
+				existingSession:        nil,
+				expectedSession:        nil,
+				store:                  defaultSessionStore,
+				refreshPeriod:          1 * time.Minute,
+				refreshSession:         defaultRefreshFunc,
+				isRefreshSessionNeeded: defaultIsRefreshNeededFunc,
+				validateSession:        defaultValidateFunc,
 			}),
 		)
 	})
 
-	Context("refreshSessionIfNeeded", func() {
+	Context("isSessionRefreshNeeded", func() {
 		type refreshSessionIfNeededTableInput struct {
 			refreshPeriod   time.Duration
 			session         *sessionsapi.SessionState
@@ -261,7 +283,7 @@ var _ = Describe("Stored Session Suite", func() {
 				s := &storedSessionLoader{
 					refreshPeriod: in.refreshPeriod,
 					store:         &fakeSessionStore{},
-					refreshSessionWithProviderIfNeeded: func(_ context.Context, ss *sessionsapi.SessionState) (bool, error) {
+					refreshSessionWithProvider: func(_ context.Context, ss *sessionsapi.SessionState) (bool, error) {
 						refreshed = true
 						switch ss.RefreshToken {
 						case refresh:
@@ -272,6 +294,17 @@ var _ = Describe("Stored Session Suite", func() {
 							return false, errors.New("error refreshing session")
 						}
 					},
+					isRefreshNeededWithProvider: func(ss *sessionsapi.SessionState) bool {
+						refreshed = true
+						switch ss.RefreshToken {
+						case refresh:
+							return true
+						case noRefresh:
+							return false
+						default:
+							return false
+						}
+					},
 					validateSessionState: func(_ context.Context, ss *sessionsapi.SessionState) bool {
 						validated = true
 						return ss.AccessToken != "Invalid"
@@ -279,7 +312,10 @@ var _ = Describe("Stored Session Suite", func() {
 				}
 
 				req := httptest.NewRequest("", "/", nil)
-				err := s.refreshSessionIfNeeded(nil, req, in.session)
+				var err error
+				if s.isSessionRefreshNeeded(in.session) {
+					refreshed, err = s.refreshSession(nil, req, in.session)
+				}
 				if in.expectedErr != nil {
 					Expect(err).To(MatchError(in.expectedErr))
 				} else {
@@ -364,7 +400,7 @@ var _ = Describe("Stored Session Suite", func() {
 		)
 	})
 
-	Context("refreshSessionWithProvider", func() {
+	Context("refreshSession", func() {
 		type refreshSessionWithProviderTableInput struct {
 			session         *sessionsapi.SessionState
 			expectedErr     error
@@ -388,7 +424,7 @@ var _ = Describe("Stored Session Suite", func() {
 							return nil
 						},
 					},
-					refreshSessionWithProviderIfNeeded: func(_ context.Context, ss *sessionsapi.SessionState) (bool, error) {
+					refreshSessionWithProvider: func(_ context.Context, ss *sessionsapi.SessionState) (bool, error) {
 						switch ss.RefreshToken {
 						case refresh:
 							return true, nil
@@ -401,7 +437,7 @@ var _ = Describe("Stored Session Suite", func() {
 				}
 
 				req := httptest.NewRequest("", "/", nil)
-				refreshed, err := s.refreshSessionWithProvider(nil, req, in.session)
+				refreshed, err := s.refreshSession(nil, req, in.session)
 				if in.expectedErr != nil {
 					Expect(err).To(MatchError(in.expectedErr))
 				} else {
@@ -515,6 +551,17 @@ func (f *fakeSessionStore) Load(req *http.Request) (*sessionsapi.SessionState, e
 	return nil, nil
 }
 
+func (f *fakeSessionStore) LoadWithLock(req *http.Request) (*sessionsapi.SessionState, error) {
+	if f.LoadFunc != nil {
+		return f.LoadFunc(req)
+	}
+	return nil, nil
+}
+
+func (f *fakeSessionStore) ReleaseLock(req *http.Request) error {
+	return nil
+}
+
 func (f *fakeSessionStore) Clear(rw http.ResponseWriter, req *http.Request) error {
 	if f.ClearFunc != nil {
 		return f.ClearFunc(rw, req)
diff --git a/pkg/sessions/cookie/session_store.go b/pkg/sessions/cookie/session_store.go
index cf1a079f..4ec7920e 100644
--- a/pkg/sessions/cookie/session_store.go
+++ b/pkg/sessions/cookie/session_store.go
@@ -66,7 +66,14 @@ func (s *SessionStore) Load(req *http.Request) (*sessions.SessionState, error) {
 	return session, nil
 }
 
-func (s *SessionStore) Lock(req *http.Request, expirationTime time.Duration) error {
+// Load reads sessions.SessionState information from Cookies within the
+// HTTP request object
+func (s *SessionStore) LoadWithLock(req *http.Request) (*sessions.SessionState, error) {
+	return s.Load(req)
+}
+
+// Release the session lock
+func (s *SessionStore) ReleaseLock(req *http.Request) error {
 	return nil
 }
 
diff --git a/pkg/sessions/persistence/interfaces.go b/pkg/sessions/persistence/interfaces.go
index 12c94626..dd3aa6b1 100644
--- a/pkg/sessions/persistence/interfaces.go
+++ b/pkg/sessions/persistence/interfaces.go
@@ -11,6 +11,7 @@ import (
 type Store interface {
 	Save(context.Context, string, []byte, time.Duration) error
 	Load(context.Context, string) ([]byte, error)
-	Lock(context.Context, string, time.Duration) error
+	LoadWithLock(context.Context, string) ([]byte, error)
+	ReleaseLock(context.Context, string) error
 	Clear(context.Context, string) error
 }
diff --git a/pkg/sessions/persistence/manager.go b/pkg/sessions/persistence/manager.go
index 54e50163..9bfda742 100644
--- a/pkg/sessions/persistence/manager.go
+++ b/pkg/sessions/persistence/manager.go
@@ -65,16 +65,28 @@ func (m *Manager) Load(req *http.Request) (*sessions.SessionState, error) {
 	})
 }
 
-// Lock reads sessions.SessionState in a session store. It will
+// Load reads sessions.SessionState information from a session store. It will
 // use the session ticket from the http.Request's cookie.
-func (m *Manager) Lock(req *http.Request, expiration time.Duration) error {
+func (m *Manager) LoadWithLock(req *http.Request) (*sessions.SessionState, error) {
+	tckt, err := decodeTicketFromRequest(req, m.Options)
+	if err != nil {
+		return nil, err
+	}
+
+	return tckt.loadSession(func(key string) ([]byte, error) {
+		return m.Store.LoadWithLock(req.Context(), key)
+	})
+}
+
+// Release the session lock
+func (m *Manager) ReleaseLock(req *http.Request) error {
 	tckt, err := decodeTicketFromRequest(req, m.Options)
 	if err != nil {
 		return err
 	}
 
-	return tckt.lockSession(func(key string) error {
-		return m.Store.Lock(req.Context(), key, expiration)
+	return tckt.releaseSession(func(key string) error {
+		return m.Store.ReleaseLock(req.Context(), key)
 	})
 }
 
diff --git a/pkg/sessions/persistence/ticket.go b/pkg/sessions/persistence/ticket.go
index 26aae47b..144aac93 100644
--- a/pkg/sessions/persistence/ticket.go
+++ b/pkg/sessions/persistence/ticket.go
@@ -139,13 +139,11 @@ func (t *ticket) loadSession(loader loadFunc) (*sessions.SessionState, error) {
 	return sessions.DecodeSessionState(ciphertext, c, false)
 }
 
-// lockSession loads a session from the disk store via the passed loadFunc
-// using the ticket.id as the key. It then decodes the SessionState using
-// ticket.secret to make the AES-GCM cipher.
-func (t *ticket) lockSession(loader lockFunc) error {
+// releaseSession releases a potential locked session
+func (t *ticket) releaseSession(loader lockFunc) error {
 	err := loader(t.id)
 	if err != nil {
-		return fmt.Errorf("failed to lock the session state with the ticket: %v", err)
+		return fmt.Errorf("failed to release session state with the ticket: %v", err)
 	}
 	return nil
 }
diff --git a/pkg/sessions/redis/client.go b/pkg/sessions/redis/client.go
index 1b43761b..2510dbbb 100644
--- a/pkg/sessions/redis/client.go
+++ b/pkg/sessions/redis/client.go
@@ -2,6 +2,7 @@ package redis
 
 import (
 	"context"
+	"fmt"
 	"time"
 
 	"github.com/bsm/redislock"
@@ -12,6 +13,7 @@ import (
 type Client interface {
 	Get(ctx context.Context, key string) ([]byte, error)
 	Lock(ctx context.Context, key string, expiration time.Duration) error
+	Unlock(ctx context.Context, key string) error
 	Set(ctx context.Context, key string, value []byte, expiration time.Duration) error
 	Del(ctx context.Context, key string) error
 }
@@ -21,20 +23,21 @@ var _ Client = (*client)(nil)
 type client struct {
 	*redis.Client
 	locker *redislock.Client
-	lock   *redislock.Lock
+	locks  map[string]*redislock.Lock
 }
 
 func newClient(c *redis.Client) Client {
 	return &client{
 		Client: c,
 		locker: redislock.New(c),
+		locks:  map[string]*redislock.Lock{},
 	}
 }
 
 func (c *client) Get(ctx context.Context, key string) ([]byte, error) {
-	if c.lock != nil {
+	if c.locks[key] != nil {
 		for {
-			ttl, err := c.lock.TTL(ctx)
+			ttl, err := c.locks[key].TTL(ctx)
 			if err != nil {
 				return nil, err
 			}
@@ -47,27 +50,37 @@ func (c *client) Get(ctx context.Context, key string) ([]byte, error) {
 }
 
 func (c *client) Lock(ctx context.Context, key string, expiration time.Duration) error {
+	if c.locks[key] != nil {
+		return fmt.Errorf("locks for key %s already exists", key)
+	}
 	lock, err := c.locker.Obtain(ctx, key, expiration, nil)
 	if err != nil {
 		return err
 	}
-	c.lock = lock
+	c.locks[key] = lock
 	return nil
 }
 
+func (c *client) Unlock(ctx context.Context, key string) error {
+	if c.locks[key] == nil {
+		return nil
+	}
+	return c.locks[key].Release(ctx)
+}
+
 func (c *client) Set(ctx context.Context, key string, value []byte, expiration time.Duration) error {
 	err := c.Client.Set(ctx, key, value, expiration).Err()
 	if err != nil {
 		return err
 	}
-	if c.lock == nil {
+	if c.locks[key] == nil {
 		return nil
 	}
-	err = c.lock.Release(ctx)
+	err = c.locks[key].Release(ctx)
 	if err != nil {
 		return err
 	}
-	c.lock = nil
+	c.locks = nil
 	return nil
 }
 
@@ -80,13 +93,14 @@ var _ Client = (*clusterClient)(nil)
 type clusterClient struct {
 	*redis.ClusterClient
 	locker *redislock.Client
-	lock   *redislock.Lock
+	locks  map[string]*redislock.Lock
 }
 
 func newClusterClient(c *redis.ClusterClient) Client {
 	return &clusterClient{
 		ClusterClient: c,
 		locker:        redislock.New(c),
+		locks:         map[string]*redislock.Lock{},
 	}
 }
 
@@ -95,14 +109,24 @@ func (c *clusterClient) Get(ctx context.Context, key string) ([]byte, error) {
 }
 
 func (c *clusterClient) Lock(ctx context.Context, key string, expiration time.Duration) error {
+	if c.locks[key] != nil {
+		return fmt.Errorf("locks for key %s already exists", key)
+	}
 	lock, err := c.locker.Obtain(ctx, key, expiration, nil)
 	if err != nil {
 		return err
 	}
-	c.lock = lock
+	c.locks[key] = lock
 	return nil
 }
 
+func (c *clusterClient) Unlock(ctx context.Context, key string) error {
+	if c.locks[key] == nil {
+		return nil
+	}
+	return c.locks[key].Release(ctx)
+}
+
 func (c *clusterClient) Set(ctx context.Context, key string, value []byte, expiration time.Duration) error {
 	return c.ClusterClient.Set(ctx, key, value, expiration).Err()
 }
diff --git a/pkg/sessions/redis/redis_store.go b/pkg/sessions/redis/redis_store.go
index 9294be7b..b5fe662f 100644
--- a/pkg/sessions/redis/redis_store.go
+++ b/pkg/sessions/redis/redis_store.go
@@ -54,11 +54,25 @@ func (store *SessionStore) Load(ctx context.Context, key string) ([]byte, error)
 	return value, nil
 }
 
-// Lock sessions.SessionState information from a persistence
-func (store *SessionStore) Lock(ctx context.Context, key string, expiration time.Duration) error {
-	err := store.Client.Lock(ctx, key, expiration)
+// ReleaseLock sessions.SessionState information from a persistence
+func (store *SessionStore) LoadWithLock(ctx context.Context, key string) ([]byte, error) {
+	value, err := store.Client.Get(ctx, key)
 	if err != nil {
-		return fmt.Errorf("error setting redis lock: %v", err)
+		return nil, fmt.Errorf("error loading redis session: %v", err)
+	}
+
+	err = store.Client.Lock(ctx, key, 200*time.Millisecond)
+	if err != nil {
+		return nil, fmt.Errorf("error setting redis locks: %v", err)
+	}
+	return value, nil
+}
+
+// ReleaseLock sessions.SessionState information from a persistence
+func (store *SessionStore) ReleaseLock(ctx context.Context, key string) error {
+	err := store.Client.Unlock(ctx, key)
+	if err != nil {
+		return fmt.Errorf("error releasing redis locks: %v", err)
 	}
 	return nil
 }
diff --git a/providers/azure.go b/providers/azure.go
index 92974540..3a409f2e 100644
--- a/providers/azure.go
+++ b/providers/azure.go
@@ -160,8 +160,8 @@ func (p *AzureProvider) Redeem(ctx context.Context, redirectURL, code string) (*
 
 // RefreshSessionIfNeeded checks if the session has expired and uses the
 // RefreshToken to fetch a new ID token if required
-func (p *AzureProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error) {
-	if s == nil || s.ExpiresOn.After(time.Now()) || s.RefreshToken == "" {
+func (p *AzureProvider) RefreshSession(ctx context.Context, s *sessions.SessionState) (bool, error) {
+	if s == nil {
 		return false, nil
 	}
 
@@ -176,6 +176,11 @@ func (p *AzureProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions.
 	return true, nil
 }
 
+// IsRefreshNeeded checks if the session has expired
+func (p *AzureProvider) IsRefreshNeeded(s *sessions.SessionState) bool {
+	return !(s == nil || s.ExpiresOn.After(time.Now()) || s.RefreshToken == "")
+}
+
 func (p *AzureProvider) redeemRefreshToken(ctx context.Context, s *sessions.SessionState) (err error) {
 	params := url.Values{}
 	params.Add("client_id", p.ClientID)
diff --git a/providers/azure_test.go b/providers/azure_test.go
index 9e3cabf7..ab63e478 100644
--- a/providers/azure_test.go
+++ b/providers/azure_test.go
@@ -244,7 +244,7 @@ func TestAzureProviderNotRefreshWhenNotExpired(t *testing.T) {
 
 	expires := time.Now().Add(time.Duration(1) * time.Hour)
 	session := &sessions.SessionState{AccessToken: "some_access_token", RefreshToken: "some_refresh_token", IDToken: "some_id_token", ExpiresOn: &expires}
-	refreshNeeded, err := p.RefreshSessionIfNeeded(context.Background(), session)
+	refreshNeeded, err := p.RefreshSession(context.Background(), session)
 	assert.Equal(t, nil, err)
 	assert.False(t, refreshNeeded)
 }
@@ -258,7 +258,7 @@ func TestAzureProviderRefreshWhenExpired(t *testing.T) {
 
 	expires := time.Now().Add(time.Duration(-1) * time.Hour)
 	session := &sessions.SessionState{AccessToken: "some_access_token", RefreshToken: "some_refresh_token", IDToken: "some_id_token", ExpiresOn: &expires}
-	_, err := p.RefreshSessionIfNeeded(context.Background(), session)
+	_, err := p.RefreshSession(context.Background(), session)
 	assert.Equal(t, nil, err)
 	assert.NotEqual(t, session, nil)
 	assert.Equal(t, "new_some_access_token", session.AccessToken)
diff --git a/providers/gitlab.go b/providers/gitlab.go
index f54430fc..370b0182 100644
--- a/providers/gitlab.go
+++ b/providers/gitlab.go
@@ -123,7 +123,7 @@ func (p *GitLabProvider) SetProjectScope() {
 
 // RefreshSessionIfNeeded checks if the session has expired and uses the
 // RefreshToken to fetch a new ID token if required
-func (p *GitLabProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error) {
+func (p *GitLabProvider) RefreshSession(ctx context.Context, s *sessions.SessionState) (bool, error) {
 	if s == nil || (s.ExpiresOn != nil && s.ExpiresOn.After(time.Now())) || s.RefreshToken == "" {
 		return false, nil
 	}
@@ -139,6 +139,11 @@ func (p *GitLabProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions
 	return true, nil
 }
 
+// IsRefreshNeeded checks if the session has expired
+func (p *GitLabProvider) IsRefreshNeeded(s *sessions.SessionState) bool {
+	return !(s == nil || (s.ExpiresOn != nil && s.ExpiresOn.After(time.Now())) || s.RefreshToken == "")
+}
+
 func (p *GitLabProvider) redeemRefreshToken(ctx context.Context, s *sessions.SessionState) (err error) {
 	clientSecret, err := p.GetClientSecret()
 	if err != nil {
diff --git a/providers/google.go b/providers/google.go
index b669156d..74e06114 100644
--- a/providers/google.go
+++ b/providers/google.go
@@ -268,8 +268,8 @@ func userInGroup(service *admin.Service, group string, email string) bool {
 
 // RefreshSessionIfNeeded checks if the session has expired and uses the
 // RefreshToken to fetch a new ID token if required
-func (p *GoogleProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error) {
-	if s == nil || (s.ExpiresOn != nil && s.ExpiresOn.After(time.Now())) || s.RefreshToken == "" {
+func (p *GoogleProvider) RefreshSession(ctx context.Context, s *sessions.SessionState) (bool, error) {
+	if s == nil {
 		return false, nil
 	}
 
@@ -295,6 +295,11 @@ func (p *GoogleProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions
 	return true, nil
 }
 
+// IsRefreshNeeded checks if the session has expired
+func (p *GoogleProvider) IsRefreshNeeded(s *sessions.SessionState) bool {
+	return !(s == nil || (s.ExpiresOn != nil && s.ExpiresOn.After(time.Now())) || s.RefreshToken == "")
+}
+
 func (p *GoogleProvider) redeemRefreshToken(ctx context.Context, refreshToken string) (token string, idToken string, expires time.Duration, err error) {
 	// https://developers.google.com/identity/protocols/OAuth2WebServer#refresh
 	clientSecret, err := p.GetClientSecret()
diff --git a/providers/oidc.go b/providers/oidc.go
index df133f4d..20f62e95 100644
--- a/providers/oidc.go
+++ b/providers/oidc.go
@@ -115,8 +115,8 @@ func (p *OIDCProvider) ValidateSession(ctx context.Context, s *sessions.SessionS
 
 // RefreshSessionIfNeeded checks if the session has expired and uses the
 // RefreshToken to fetch a new Access Token (and optional ID token) if required
-func (p *OIDCProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error) {
-	if s == nil || (s.ExpiresOn != nil && s.ExpiresOn.After(time.Now())) || s.RefreshToken == "" {
+func (p *OIDCProvider) RefreshSession(ctx context.Context, s *sessions.SessionState) (bool, error) {
+	if s == nil {
 		return false, nil
 	}
 
@@ -129,6 +129,11 @@ func (p *OIDCProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions.S
 	return true, nil
 }
 
+// IsRefreshNeeded checks if the session has expired
+func (p *OIDCProvider) IsRefreshNeeded(s *sessions.SessionState) bool {
+	return !(s == nil || (s.ExpiresOn != nil && s.ExpiresOn.After(time.Now())) || s.RefreshToken == "")
+}
+
 // redeemRefreshToken uses a RefreshToken with the RedeemURL to refresh the
 // Access Token and (probably) the ID Token.
 func (p *OIDCProvider) redeemRefreshToken(ctx context.Context, s *sessions.SessionState) error {
diff --git a/providers/oidc_test.go b/providers/oidc_test.go
index 7ac98634..84e19a3d 100644
--- a/providers/oidc_test.go
+++ b/providers/oidc_test.go
@@ -467,7 +467,7 @@ func TestOIDCProviderRefreshSessionIfNeededWithoutIdToken(t *testing.T) {
 		User:         "11223344",
 	}
 
-	refreshed, err := provider.RefreshSessionIfNeeded(context.Background(), existingSession)
+	refreshed, err := provider.RefreshSession(context.Background(), existingSession)
 	assert.Equal(t, nil, err)
 	assert.Equal(t, refreshed, true)
 	assert.Equal(t, "janedoe@example.com", existingSession.Email)
@@ -500,7 +500,7 @@ func TestOIDCProviderRefreshSessionIfNeededWithIdToken(t *testing.T) {
 		Email:        "changeit",
 		User:         "changeit",
 	}
-	refreshed, err := provider.RefreshSessionIfNeeded(context.Background(), existingSession)
+	refreshed, err := provider.RefreshSession(context.Background(), existingSession)
 	assert.Equal(t, nil, err)
 	assert.Equal(t, refreshed, true)
 	assert.Equal(t, defaultIDToken.Email, existingSession.Email)
diff --git a/providers/provider_default.go b/providers/provider_default.go
index d3c6d113..9ce6e511 100644
--- a/providers/provider_default.go
+++ b/providers/provider_default.go
@@ -128,10 +128,15 @@ func (p *ProviderData) ValidateSession(ctx context.Context, s *sessions.SessionS
 
 // RefreshSessionIfNeeded should refresh the user's session if required and
 // do nothing if a refresh is not required
-func (p *ProviderData) RefreshSessionIfNeeded(_ context.Context, _ *sessions.SessionState) (bool, error) {
+func (p *ProviderData) RefreshSession(_ context.Context, _ *sessions.SessionState) (bool, error) {
 	return false, nil
 }
 
+// IsRefreshNeeded should return true if the user's session need to be refreshed
+func (p *ProviderData) IsRefreshNeeded(_ *sessions.SessionState) bool {
+	return false
+}
+
 // CreateSessionFromToken converts Bearer IDTokens into sessions
 func (p *ProviderData) CreateSessionFromToken(ctx context.Context, token string) (*sessions.SessionState, error) {
 	if p.Verifier != nil {
diff --git a/providers/provider_default_test.go b/providers/provider_default_test.go
index df1525cf..e3e65d3b 100644
--- a/providers/provider_default_test.go
+++ b/providers/provider_default_test.go
@@ -15,7 +15,7 @@ func TestRefresh(t *testing.T) {
 	p := &ProviderData{}
 
 	expires := time.Now().Add(time.Duration(-11) * time.Minute)
-	refreshed, err := p.RefreshSessionIfNeeded(context.Background(), &sessions.SessionState{
+	refreshed, err := p.RefreshSession(context.Background(), &sessions.SessionState{
 		ExpiresOn: &expires,
 	})
 	assert.Equal(t, false, refreshed)
diff --git a/providers/providers.go b/providers/providers.go
index 6aeb5426..0d1609d9 100644
--- a/providers/providers.go
+++ b/providers/providers.go
@@ -16,7 +16,8 @@ type Provider interface {
 	Authorize(ctx context.Context, s *sessions.SessionState) (bool, error)
 	ValidateSession(ctx context.Context, s *sessions.SessionState) bool
 	GetLoginURL(redirectURI, finalRedirect string) string
-	RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error)
+	RefreshSession(ctx context.Context, s *sessions.SessionState) (bool, error)
+	IsRefreshNeeded(s *sessions.SessionState) bool
 	CreateSessionFromToken(ctx context.Context, token string) (*sessions.SessionState, error)
 }