diff --git a/go.mod b/go.mod index 7411accd..2a81284e 100644 --- a/go.mod +++ b/go.mod @@ -22,6 +22,7 @@ require ( github.com/onsi/gomega v1.10.2 github.com/pierrec/lz4 v2.5.2+incompatible github.com/pquerna/cachecontrol v0.0.0-20180517163645-1555304b9b35 // indirect + github.com/prometheus/client_golang v0.9.3 github.com/spf13/pflag v1.0.5 github.com/spf13/viper v1.6.3 github.com/stretchr/testify v1.6.1 diff --git a/go.sum b/go.sum index 7c0170dd..9e3b52dc 100644 --- a/go.sum +++ b/go.sum @@ -24,6 +24,7 @@ github.com/alicebob/miniredis/v2 v2.13.0 h1:QPosMaxm+r6Qs+YcCtL2Z2a2RSdC9VfXJLpd github.com/alicebob/miniredis/v2 v2.13.0/go.mod h1:0UIBNuf97uxrWhdVBpJvPtafKyGpL2NS2pYe0tYM97k= github.com/armon/consul-api v0.0.0-20180202201655-eb2c6b5be1b6/go.mod h1:grANhF5doyWs3UAsr3K4I6qtAmlQcZDesFNEHPZAzj8= github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q= +github.com/beorn7/perks v1.0.0 h1:HWo1m869IqiPhD389kmkxeTalrjNbbJTC8LXupb+sl0= github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8= github.com/bitly/go-simplejson v0.5.0 h1:6IH+V8/tVMab511d5bn4M7EwGXZf9Hj6i2xSwkNEM+Y= github.com/bitly/go-simplejson v0.5.0/go.mod h1:cXHtHw4XUPsvGaxgjIAn8PhEWG9NfngEKAMDJEczWVA= @@ -168,6 +169,7 @@ github.com/magiconair/properties v1.8.1/go.mod h1:PppfXfuXeibc/6YijjN8zIbojt8czP github.com/mailru/easyjson v0.0.0-20160728113105-d5b7844b561a/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc= github.com/matryer/is v1.2.0 h1:92UTHpy8CDwaJ08GqLDzhhuixiBUUD1p3AU6PHddz4A= github.com/matryer/is v1.2.0/go.mod h1:2fLPjFQM9rhQ15aVEtbuwhJinnOqrmgXPNdZsdwlWXA= +github.com/matttproud/golang_protobuf_extensions v1.0.1 h1:4hp9jkHxhMHkqkrB3Ix0jegS5sx/RkqARlsWZ6pIwiU= github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= github.com/mbland/hmacauth v0.0.0-20170912233209-44256dfd4bfa h1:hI1uC2A3vJFjwvBn0G0a7QBRdBUp6Y048BtLAHRTKPo= github.com/mbland/hmacauth v0.0.0-20170912233209-44256dfd4bfa/go.mod h1:8vxFeeg++MqgCHwehSuwTlYCF0ALyDJbYJ1JsKi7v6s= @@ -208,13 +210,17 @@ github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZN github.com/pquerna/cachecontrol v0.0.0-20180517163645-1555304b9b35 h1:J9b7z+QKAmPf4YLrFg6oQUotqHQeUNWwkvo7jZp1GLU= github.com/pquerna/cachecontrol v0.0.0-20180517163645-1555304b9b35/go.mod h1:prYjPmNq4d1NPVmpShWobRqXY3q7Vp+80DqgxxUrUIA= github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw= +github.com/prometheus/client_golang v0.9.3 h1:9iH4JKXLzFbOAdtqv/a+j8aewx2Y8lAjAydhbaScPF8= github.com/prometheus/client_golang v0.9.3/go.mod h1:/TN21ttK/J9q6uSwhBd54HahCDft0ttaMvbicHlPoso= github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo= github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= +github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4 h1:gQz4mCbXsO+nc9n1hCxHcGA3Zx3Eo+UHZoInFGUIXNM= github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/prometheus/common v0.0.0-20181113130724-41aa239b4cce/go.mod h1:daVV7qP5qjZbuso7PdcryaAu0sAZbrN9i7WWcTMWvro= +github.com/prometheus/common v0.4.0 h1:7etb9YClo3a6HjLzfl6rIQaU+FDfi0VSX39io3aQ+DM= github.com/prometheus/common v0.4.0/go.mod h1:TNfzLD0ON7rHzMJeJkieUDPYmFC7Snx/y86RQel1bk4= github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk= +github.com/prometheus/procfs v0.0.0-20190507164030-5867b95ac084 h1:sofwID9zm4tzrgykg80hfFph1mryUeLRsUfoocVVmRY= github.com/prometheus/procfs v0.0.0-20190507164030-5867b95ac084/go.mod h1:TjEm7ze935MbeOT/UhFTIMYKhuLP4wbCsTZCD3I8kEA= github.com/prometheus/tsdb v0.7.1/go.mod h1:qhTCs0VvXwvX/y3TZrWD7rabWM+ijKTux40TwIPHuXU= github.com/rogpeppe/fastuuid v0.0.0-20150106093220-6724a57986af/go.mod h1:XWv6SoW27p1b0cqNHllgS5HIMJraePCO15w5zCzIWYg= diff --git a/oauthproxy.go b/oauthproxy.go index 0cfa1f93..68aba61a 100644 --- a/oauthproxy.go +++ b/oauthproxy.go @@ -276,10 +276,11 @@ func buildSessionChain(opts *options.Options, sessionStore sessionsapi.SessionSt } chain = chain.Append(middleware.NewStoredSessionLoader(&middleware.StoredSessionLoaderOptions{ - SessionStore: sessionStore, - RefreshPeriod: opts.Cookie.Refresh, - RefreshSessionIfNeeded: opts.GetProvider().RefreshSessionIfNeeded, - ValidateSessionState: opts.GetProvider().ValidateSession, + SessionStore: sessionStore, + RefreshPeriod: opts.Cookie.Refresh, + RefreshSession: opts.GetProvider().RefreshSession, + IsRefreshNeeded: opts.GetProvider().IsRefreshNeeded, + ValidateSessionState: opts.GetProvider().ValidateSession, })) return chain diff --git a/pkg/apis/sessions/interfaces.go b/pkg/apis/sessions/interfaces.go index bd02eaf2..9980d89d 100644 --- a/pkg/apis/sessions/interfaces.go +++ b/pkg/apis/sessions/interfaces.go @@ -2,13 +2,13 @@ package sessions import ( "net/http" - "time" ) // SessionStore is an interface to storing user sessions in the proxy type SessionStore interface { Save(rw http.ResponseWriter, req *http.Request, s *SessionState) error Load(req *http.Request) (*SessionState, error) - Lock(req *http.Request, expiration time.Duration) error + LoadWithLock(req *http.Request) (*SessionState, error) + ReleaseLock(req *http.Request) error Clear(rw http.ResponseWriter, req *http.Request) error } diff --git a/pkg/middleware/stored_session.go b/pkg/middleware/stored_session.go index 1bd0a9a4..b5abc1fd 100644 --- a/pkg/middleware/stored_session.go +++ b/pkg/middleware/stored_session.go @@ -24,10 +24,13 @@ type StoredSessionLoaderOptions struct { RefreshPeriod time.Duration // Provider based sesssion refreshing - RefreshSessionIfNeeded func(context.Context, *sessionsapi.SessionState) (bool, error) + RefreshSession func(context.Context, *sessionsapi.SessionState) (bool, error) + + // Provider based session refresh check + IsRefreshNeeded func(*sessionsapi.SessionState) bool // Provider based session validation. - // If the sesssion is older than `RefreshPeriod` but the provider doesn't + // If the session is older than `RefreshPeriod` but the provider doesn't // refresh it, we must re-validate using this validation. ValidateSessionState func(context.Context, *sessionsapi.SessionState) bool } @@ -38,10 +41,11 @@ type StoredSessionLoaderOptions struct { // If a session was loader by a previous handler, it will not be replaced. func NewStoredSessionLoader(opts *StoredSessionLoaderOptions) alice.Constructor { ss := &storedSessionLoader{ - store: opts.SessionStore, - refreshPeriod: opts.RefreshPeriod, - refreshSessionWithProviderIfNeeded: opts.RefreshSessionIfNeeded, - validateSessionState: opts.ValidateSessionState, + store: opts.SessionStore, + refreshPeriod: opts.RefreshPeriod, + refreshSessionWithProvider: opts.RefreshSession, + isRefreshNeededWithProvider: opts.IsRefreshNeeded, + validateSessionState: opts.ValidateSessionState, } return ss.loadSession } @@ -49,10 +53,11 @@ func NewStoredSessionLoader(opts *StoredSessionLoaderOptions) alice.Constructor // storedSessionLoader is responsible for loading sessions from cookie // identified sessions in the session store. type storedSessionLoader struct { - store sessionsapi.SessionStore - refreshPeriod time.Duration - refreshSessionWithProviderIfNeeded func(context.Context, *sessionsapi.SessionState) (bool, error) - validateSessionState func(context.Context, *sessionsapi.SessionState) bool + store sessionsapi.SessionStore + refreshPeriod time.Duration + refreshSessionWithProvider func(context.Context, *sessionsapi.SessionState) (bool, error) + isRefreshNeededWithProvider func(*sessionsapi.SessionState) bool + validateSessionState func(context.Context, *sessionsapi.SessionState) bool } // loadSession attempts to load a session as identified by the request cookies. @@ -98,44 +103,54 @@ func (s *storedSessionLoader) getValidatedSession(rw http.ResponseWriter, req *h return nil, nil } - err = s.refreshSessionIfNeeded(rw, req, session) + if !s.isSessionRefreshNeeded(session) { + return session, nil + } + + session, err = s.store.LoadWithLock(req) + if err != nil { + return nil, err + } + if session == nil { + // No session was found in the storage, nothing more to do + return nil, nil + } + + if !s.isSessionRefreshNeeded(session) { + _ = s.store.ReleaseLock(req) + return session, nil + } + logger.Printf("Refreshing %s old session cookie for %s (refresh after %s)", session.Age(), session, s.refreshPeriod) + refreshed, err := s.refreshSession(rw, req, session) + _ = s.store.ReleaseLock(req) if err != nil { return nil, fmt.Errorf("error refreshing access token for session (%s): %v", session, err) } + if refreshed { + return session, nil + } + + // Session wasn't refreshed, so make sure it's still valid + err = s.validateSession(req.Context(), session) + if err != nil { + return nil, err + } return session, nil } -// refreshSessionIfNeeded will attempt to refresh a session if the session -// is older than the refresh period. -// It is assumed that if the provider refreshes the session, the session is now -// valid. -// If the session requires refreshing but the provider does not refresh it, -// we must validate the session to ensure that the returned session is still -// valid. -func (s *storedSessionLoader) refreshSessionIfNeeded(rw http.ResponseWriter, req *http.Request, session *sessionsapi.SessionState) error { - if s.refreshPeriod <= time.Duration(0) || session.Age() < s.refreshPeriod { - // Refresh is disabled or the session is not old enough, do nothing - return nil +// isSessionRefreshNeeded will check if the session need to be refreshed +func (s *storedSessionLoader) isSessionRefreshNeeded(session *sessionsapi.SessionState) bool { + if s.refreshPeriod > time.Duration(0) && session.Age() >= s.refreshPeriod { + return s.isRefreshNeededWithProvider(session) } - - logger.Printf("Refreshing %s old session cookie for %s (refresh after %s)", session.Age(), session, s.refreshPeriod) - refreshed, err := s.refreshSessionWithProvider(rw, req, session) - if err != nil { - return err - } - - if !refreshed { - // Session wasn't refreshed, so make sure it's still valid - return s.validateSession(req.Context(), session) - } - return nil + return false } -// refreshSessionWithProvider attempts to refresh the sessinon with the provider +// refreshSession attempts to refresh the sessinon with the provider // and will save the session if it was updated. -func (s *storedSessionLoader) refreshSessionWithProvider(rw http.ResponseWriter, req *http.Request, session *sessionsapi.SessionState) (bool, error) { - refreshed, err := s.refreshSessionWithProviderIfNeeded(req.Context(), session) +func (s *storedSessionLoader) refreshSession(rw http.ResponseWriter, req *http.Request, session *sessionsapi.SessionState) (bool, error) { + refreshed, err := s.refreshSessionWithProvider(req.Context(), session) if err != nil { return false, fmt.Errorf("error refreshing access token: %v", err) } diff --git a/pkg/middleware/stored_session_test.go b/pkg/middleware/stored_session_test.go index 4a8fd9da..10a612f1 100644 --- a/pkg/middleware/stored_session_test.go +++ b/pkg/middleware/stored_session_test.go @@ -39,6 +39,17 @@ var _ = Describe("Stored Session Suite", func() { } } + var defaultIsRefreshNeededFunc = func(ss *sessionsapi.SessionState) bool { + switch ss.RefreshToken { + case refresh: + return true + case noRefresh: + return false + default: + return false + } + } + var defaultValidateFunc = func(_ context.Context, ss *sessionsapi.SessionState) bool { return ss.AccessToken != "Invalid" } @@ -86,13 +97,14 @@ var _ = Describe("Stored Session Suite", func() { } type storedSessionLoaderTableInput struct { - requestHeaders http.Header - existingSession *sessionsapi.SessionState - expectedSession *sessionsapi.SessionState - store sessionsapi.SessionStore - refreshPeriod time.Duration - refreshSession func(context.Context, *sessionsapi.SessionState) (bool, error) - validateSession func(context.Context, *sessionsapi.SessionState) bool + requestHeaders http.Header + existingSession *sessionsapi.SessionState + expectedSession *sessionsapi.SessionState + store sessionsapi.SessionStore + refreshPeriod time.Duration + refreshSession func(context.Context, *sessionsapi.SessionState) (bool, error) + isRefreshSessionNeeded func(*sessionsapi.SessionState) bool + validateSession func(context.Context, *sessionsapi.SessionState) bool } DescribeTable("when serving a request", @@ -109,10 +121,11 @@ var _ = Describe("Stored Session Suite", func() { rw := httptest.NewRecorder() opts := &StoredSessionLoaderOptions{ - SessionStore: in.store, - RefreshPeriod: in.refreshPeriod, - RefreshSessionIfNeeded: in.refreshSession, - ValidateSessionState: in.validateSession, + SessionStore: in.store, + RefreshPeriod: in.refreshPeriod, + RefreshSession: in.refreshSession, + IsRefreshNeeded: in.isRefreshSessionNeeded, + ValidateSessionState: in.validateSession, } // Create the handler with a next handler that will capture the session @@ -126,24 +139,26 @@ var _ = Describe("Stored Session Suite", func() { Expect(gotSession).To(Equal(in.expectedSession)) }, Entry("with no cookie", storedSessionLoaderTableInput{ - requestHeaders: http.Header{}, - existingSession: nil, - expectedSession: nil, - store: defaultSessionStore, - refreshPeriod: 1 * time.Minute, - refreshSession: defaultRefreshFunc, - validateSession: defaultValidateFunc, + requestHeaders: http.Header{}, + existingSession: nil, + expectedSession: nil, + store: defaultSessionStore, + refreshPeriod: 1 * time.Minute, + refreshSession: defaultRefreshFunc, + isRefreshSessionNeeded: defaultIsRefreshNeededFunc, + validateSession: defaultValidateFunc, }), Entry("with an invalid cookie", storedSessionLoaderTableInput{ requestHeaders: http.Header{ "Cookie": []string{"_oauth2_proxy=NonExistent"}, }, - existingSession: nil, - expectedSession: nil, - store: defaultSessionStore, - refreshPeriod: 1 * time.Minute, - refreshSession: defaultRefreshFunc, - validateSession: defaultValidateFunc, + existingSession: nil, + expectedSession: nil, + store: defaultSessionStore, + refreshPeriod: 1 * time.Minute, + refreshSession: defaultRefreshFunc, + isRefreshSessionNeeded: defaultIsRefreshNeededFunc, + validateSession: defaultValidateFunc, }), Entry("with an existing session", storedSessionLoaderTableInput{ requestHeaders: http.Header{ @@ -155,10 +170,11 @@ var _ = Describe("Stored Session Suite", func() { expectedSession: &sessionsapi.SessionState{ RefreshToken: "Existing", }, - store: defaultSessionStore, - refreshPeriod: 1 * time.Minute, - refreshSession: defaultRefreshFunc, - validateSession: defaultValidateFunc, + store: defaultSessionStore, + refreshPeriod: 1 * time.Minute, + refreshSession: defaultRefreshFunc, + isRefreshSessionNeeded: defaultIsRefreshNeededFunc, + validateSession: defaultValidateFunc, }), Entry("with a session that has not expired", storedSessionLoaderTableInput{ requestHeaders: http.Header{ @@ -170,21 +186,23 @@ var _ = Describe("Stored Session Suite", func() { CreatedAt: &createdPast, ExpiresOn: &createdFuture, }, - store: defaultSessionStore, - refreshPeriod: 1 * time.Minute, - refreshSession: defaultRefreshFunc, - validateSession: defaultValidateFunc, + store: defaultSessionStore, + refreshPeriod: 1 * time.Minute, + refreshSession: defaultRefreshFunc, + isRefreshSessionNeeded: defaultIsRefreshNeededFunc, + validateSession: defaultValidateFunc, }), Entry("with a session that cannot refresh and has expired", storedSessionLoaderTableInput{ requestHeaders: http.Header{ "Cookie": []string{"_oauth2_proxy=ExpiredNoRefreshSession"}, }, - existingSession: nil, - expectedSession: nil, - store: defaultSessionStore, - refreshPeriod: 1 * time.Minute, - refreshSession: defaultRefreshFunc, - validateSession: defaultValidateFunc, + existingSession: nil, + expectedSession: nil, + store: defaultSessionStore, + refreshPeriod: 1 * time.Minute, + refreshSession: defaultRefreshFunc, + isRefreshSessionNeeded: defaultIsRefreshNeededFunc, + validateSession: defaultValidateFunc, }), Entry("with a session that can refresh, but is younger than refresh period", storedSessionLoaderTableInput{ requestHeaders: http.Header{ @@ -196,10 +214,11 @@ var _ = Describe("Stored Session Suite", func() { CreatedAt: &createdPast, ExpiresOn: &createdFuture, }, - store: defaultSessionStore, - refreshPeriod: 10 * time.Minute, - refreshSession: defaultRefreshFunc, - validateSession: defaultValidateFunc, + store: defaultSessionStore, + refreshPeriod: 10 * time.Minute, + refreshSession: defaultRefreshFunc, + isRefreshSessionNeeded: defaultIsRefreshNeededFunc, + validateSession: defaultValidateFunc, }), Entry("with a session that can refresh and is older than the refresh period", storedSessionLoaderTableInput{ requestHeaders: http.Header{ @@ -211,37 +230,40 @@ var _ = Describe("Stored Session Suite", func() { CreatedAt: &createdPast, ExpiresOn: &createdFuture, }, - store: defaultSessionStore, - refreshPeriod: 1 * time.Minute, - refreshSession: defaultRefreshFunc, - validateSession: defaultValidateFunc, + store: defaultSessionStore, + refreshPeriod: 1 * time.Minute, + refreshSession: defaultRefreshFunc, + isRefreshSessionNeeded: defaultIsRefreshNeededFunc, + validateSession: defaultValidateFunc, }), Entry("when the provider refresh fails", storedSessionLoaderTableInput{ requestHeaders: http.Header{ "Cookie": []string{"_oauth2_proxy=RefreshError"}, }, - existingSession: nil, - expectedSession: nil, - store: defaultSessionStore, - refreshPeriod: 1 * time.Minute, - refreshSession: defaultRefreshFunc, - validateSession: defaultValidateFunc, + existingSession: nil, + expectedSession: nil, + store: defaultSessionStore, + refreshPeriod: 1 * time.Minute, + refreshSession: defaultRefreshFunc, + isRefreshSessionNeeded: defaultIsRefreshNeededFunc, + validateSession: defaultValidateFunc, }), Entry("when the session is not refreshed and is no longer valid", storedSessionLoaderTableInput{ requestHeaders: http.Header{ "Cookie": []string{"_oauth2_proxy=InvalidNoRefreshSession"}, }, - existingSession: nil, - expectedSession: nil, - store: defaultSessionStore, - refreshPeriod: 1 * time.Minute, - refreshSession: defaultRefreshFunc, - validateSession: defaultValidateFunc, + existingSession: nil, + expectedSession: nil, + store: defaultSessionStore, + refreshPeriod: 1 * time.Minute, + refreshSession: defaultRefreshFunc, + isRefreshSessionNeeded: defaultIsRefreshNeededFunc, + validateSession: defaultValidateFunc, }), ) }) - Context("refreshSessionIfNeeded", func() { + Context("isSessionRefreshNeeded", func() { type refreshSessionIfNeededTableInput struct { refreshPeriod time.Duration session *sessionsapi.SessionState @@ -261,7 +283,7 @@ var _ = Describe("Stored Session Suite", func() { s := &storedSessionLoader{ refreshPeriod: in.refreshPeriod, store: &fakeSessionStore{}, - refreshSessionWithProviderIfNeeded: func(_ context.Context, ss *sessionsapi.SessionState) (bool, error) { + refreshSessionWithProvider: func(_ context.Context, ss *sessionsapi.SessionState) (bool, error) { refreshed = true switch ss.RefreshToken { case refresh: @@ -272,6 +294,17 @@ var _ = Describe("Stored Session Suite", func() { return false, errors.New("error refreshing session") } }, + isRefreshNeededWithProvider: func(ss *sessionsapi.SessionState) bool { + refreshed = true + switch ss.RefreshToken { + case refresh: + return true + case noRefresh: + return false + default: + return false + } + }, validateSessionState: func(_ context.Context, ss *sessionsapi.SessionState) bool { validated = true return ss.AccessToken != "Invalid" @@ -279,7 +312,10 @@ var _ = Describe("Stored Session Suite", func() { } req := httptest.NewRequest("", "/", nil) - err := s.refreshSessionIfNeeded(nil, req, in.session) + var err error + if s.isSessionRefreshNeeded(in.session) { + refreshed, err = s.refreshSession(nil, req, in.session) + } if in.expectedErr != nil { Expect(err).To(MatchError(in.expectedErr)) } else { @@ -364,7 +400,7 @@ var _ = Describe("Stored Session Suite", func() { ) }) - Context("refreshSessionWithProvider", func() { + Context("refreshSession", func() { type refreshSessionWithProviderTableInput struct { session *sessionsapi.SessionState expectedErr error @@ -388,7 +424,7 @@ var _ = Describe("Stored Session Suite", func() { return nil }, }, - refreshSessionWithProviderIfNeeded: func(_ context.Context, ss *sessionsapi.SessionState) (bool, error) { + refreshSessionWithProvider: func(_ context.Context, ss *sessionsapi.SessionState) (bool, error) { switch ss.RefreshToken { case refresh: return true, nil @@ -401,7 +437,7 @@ var _ = Describe("Stored Session Suite", func() { } req := httptest.NewRequest("", "/", nil) - refreshed, err := s.refreshSessionWithProvider(nil, req, in.session) + refreshed, err := s.refreshSession(nil, req, in.session) if in.expectedErr != nil { Expect(err).To(MatchError(in.expectedErr)) } else { @@ -515,6 +551,17 @@ func (f *fakeSessionStore) Load(req *http.Request) (*sessionsapi.SessionState, e return nil, nil } +func (f *fakeSessionStore) LoadWithLock(req *http.Request) (*sessionsapi.SessionState, error) { + if f.LoadFunc != nil { + return f.LoadFunc(req) + } + return nil, nil +} + +func (f *fakeSessionStore) ReleaseLock(req *http.Request) error { + return nil +} + func (f *fakeSessionStore) Clear(rw http.ResponseWriter, req *http.Request) error { if f.ClearFunc != nil { return f.ClearFunc(rw, req) diff --git a/pkg/sessions/cookie/session_store.go b/pkg/sessions/cookie/session_store.go index cf1a079f..4ec7920e 100644 --- a/pkg/sessions/cookie/session_store.go +++ b/pkg/sessions/cookie/session_store.go @@ -66,7 +66,14 @@ func (s *SessionStore) Load(req *http.Request) (*sessions.SessionState, error) { return session, nil } -func (s *SessionStore) Lock(req *http.Request, expirationTime time.Duration) error { +// Load reads sessions.SessionState information from Cookies within the +// HTTP request object +func (s *SessionStore) LoadWithLock(req *http.Request) (*sessions.SessionState, error) { + return s.Load(req) +} + +// Release the session lock +func (s *SessionStore) ReleaseLock(req *http.Request) error { return nil } diff --git a/pkg/sessions/persistence/interfaces.go b/pkg/sessions/persistence/interfaces.go index 12c94626..dd3aa6b1 100644 --- a/pkg/sessions/persistence/interfaces.go +++ b/pkg/sessions/persistence/interfaces.go @@ -11,6 +11,7 @@ import ( type Store interface { Save(context.Context, string, []byte, time.Duration) error Load(context.Context, string) ([]byte, error) - Lock(context.Context, string, time.Duration) error + LoadWithLock(context.Context, string) ([]byte, error) + ReleaseLock(context.Context, string) error Clear(context.Context, string) error } diff --git a/pkg/sessions/persistence/manager.go b/pkg/sessions/persistence/manager.go index 54e50163..9bfda742 100644 --- a/pkg/sessions/persistence/manager.go +++ b/pkg/sessions/persistence/manager.go @@ -65,16 +65,28 @@ func (m *Manager) Load(req *http.Request) (*sessions.SessionState, error) { }) } -// Lock reads sessions.SessionState in a session store. It will +// Load reads sessions.SessionState information from a session store. It will // use the session ticket from the http.Request's cookie. -func (m *Manager) Lock(req *http.Request, expiration time.Duration) error { +func (m *Manager) LoadWithLock(req *http.Request) (*sessions.SessionState, error) { + tckt, err := decodeTicketFromRequest(req, m.Options) + if err != nil { + return nil, err + } + + return tckt.loadSession(func(key string) ([]byte, error) { + return m.Store.LoadWithLock(req.Context(), key) + }) +} + +// Release the session lock +func (m *Manager) ReleaseLock(req *http.Request) error { tckt, err := decodeTicketFromRequest(req, m.Options) if err != nil { return err } - return tckt.lockSession(func(key string) error { - return m.Store.Lock(req.Context(), key, expiration) + return tckt.releaseSession(func(key string) error { + return m.Store.ReleaseLock(req.Context(), key) }) } diff --git a/pkg/sessions/persistence/ticket.go b/pkg/sessions/persistence/ticket.go index 26aae47b..144aac93 100644 --- a/pkg/sessions/persistence/ticket.go +++ b/pkg/sessions/persistence/ticket.go @@ -139,13 +139,11 @@ func (t *ticket) loadSession(loader loadFunc) (*sessions.SessionState, error) { return sessions.DecodeSessionState(ciphertext, c, false) } -// lockSession loads a session from the disk store via the passed loadFunc -// using the ticket.id as the key. It then decodes the SessionState using -// ticket.secret to make the AES-GCM cipher. -func (t *ticket) lockSession(loader lockFunc) error { +// releaseSession releases a potential locked session +func (t *ticket) releaseSession(loader lockFunc) error { err := loader(t.id) if err != nil { - return fmt.Errorf("failed to lock the session state with the ticket: %v", err) + return fmt.Errorf("failed to release session state with the ticket: %v", err) } return nil } diff --git a/pkg/sessions/redis/client.go b/pkg/sessions/redis/client.go index 1b43761b..2510dbbb 100644 --- a/pkg/sessions/redis/client.go +++ b/pkg/sessions/redis/client.go @@ -2,6 +2,7 @@ package redis import ( "context" + "fmt" "time" "github.com/bsm/redislock" @@ -12,6 +13,7 @@ import ( type Client interface { Get(ctx context.Context, key string) ([]byte, error) Lock(ctx context.Context, key string, expiration time.Duration) error + Unlock(ctx context.Context, key string) error Set(ctx context.Context, key string, value []byte, expiration time.Duration) error Del(ctx context.Context, key string) error } @@ -21,20 +23,21 @@ var _ Client = (*client)(nil) type client struct { *redis.Client locker *redislock.Client - lock *redislock.Lock + locks map[string]*redislock.Lock } func newClient(c *redis.Client) Client { return &client{ Client: c, locker: redislock.New(c), + locks: map[string]*redislock.Lock{}, } } func (c *client) Get(ctx context.Context, key string) ([]byte, error) { - if c.lock != nil { + if c.locks[key] != nil { for { - ttl, err := c.lock.TTL(ctx) + ttl, err := c.locks[key].TTL(ctx) if err != nil { return nil, err } @@ -47,27 +50,37 @@ func (c *client) Get(ctx context.Context, key string) ([]byte, error) { } func (c *client) Lock(ctx context.Context, key string, expiration time.Duration) error { + if c.locks[key] != nil { + return fmt.Errorf("locks for key %s already exists", key) + } lock, err := c.locker.Obtain(ctx, key, expiration, nil) if err != nil { return err } - c.lock = lock + c.locks[key] = lock return nil } +func (c *client) Unlock(ctx context.Context, key string) error { + if c.locks[key] == nil { + return nil + } + return c.locks[key].Release(ctx) +} + func (c *client) Set(ctx context.Context, key string, value []byte, expiration time.Duration) error { err := c.Client.Set(ctx, key, value, expiration).Err() if err != nil { return err } - if c.lock == nil { + if c.locks[key] == nil { return nil } - err = c.lock.Release(ctx) + err = c.locks[key].Release(ctx) if err != nil { return err } - c.lock = nil + c.locks = nil return nil } @@ -80,13 +93,14 @@ var _ Client = (*clusterClient)(nil) type clusterClient struct { *redis.ClusterClient locker *redislock.Client - lock *redislock.Lock + locks map[string]*redislock.Lock } func newClusterClient(c *redis.ClusterClient) Client { return &clusterClient{ ClusterClient: c, locker: redislock.New(c), + locks: map[string]*redislock.Lock{}, } } @@ -95,14 +109,24 @@ func (c *clusterClient) Get(ctx context.Context, key string) ([]byte, error) { } func (c *clusterClient) Lock(ctx context.Context, key string, expiration time.Duration) error { + if c.locks[key] != nil { + return fmt.Errorf("locks for key %s already exists", key) + } lock, err := c.locker.Obtain(ctx, key, expiration, nil) if err != nil { return err } - c.lock = lock + c.locks[key] = lock return nil } +func (c *clusterClient) Unlock(ctx context.Context, key string) error { + if c.locks[key] == nil { + return nil + } + return c.locks[key].Release(ctx) +} + func (c *clusterClient) Set(ctx context.Context, key string, value []byte, expiration time.Duration) error { return c.ClusterClient.Set(ctx, key, value, expiration).Err() } diff --git a/pkg/sessions/redis/redis_store.go b/pkg/sessions/redis/redis_store.go index 9294be7b..b5fe662f 100644 --- a/pkg/sessions/redis/redis_store.go +++ b/pkg/sessions/redis/redis_store.go @@ -54,11 +54,25 @@ func (store *SessionStore) Load(ctx context.Context, key string) ([]byte, error) return value, nil } -// Lock sessions.SessionState information from a persistence -func (store *SessionStore) Lock(ctx context.Context, key string, expiration time.Duration) error { - err := store.Client.Lock(ctx, key, expiration) +// ReleaseLock sessions.SessionState information from a persistence +func (store *SessionStore) LoadWithLock(ctx context.Context, key string) ([]byte, error) { + value, err := store.Client.Get(ctx, key) if err != nil { - return fmt.Errorf("error setting redis lock: %v", err) + return nil, fmt.Errorf("error loading redis session: %v", err) + } + + err = store.Client.Lock(ctx, key, 200*time.Millisecond) + if err != nil { + return nil, fmt.Errorf("error setting redis locks: %v", err) + } + return value, nil +} + +// ReleaseLock sessions.SessionState information from a persistence +func (store *SessionStore) ReleaseLock(ctx context.Context, key string) error { + err := store.Client.Unlock(ctx, key) + if err != nil { + return fmt.Errorf("error releasing redis locks: %v", err) } return nil } diff --git a/providers/azure.go b/providers/azure.go index 92974540..3a409f2e 100644 --- a/providers/azure.go +++ b/providers/azure.go @@ -160,8 +160,8 @@ func (p *AzureProvider) Redeem(ctx context.Context, redirectURL, code string) (* // RefreshSessionIfNeeded checks if the session has expired and uses the // RefreshToken to fetch a new ID token if required -func (p *AzureProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error) { - if s == nil || s.ExpiresOn.After(time.Now()) || s.RefreshToken == "" { +func (p *AzureProvider) RefreshSession(ctx context.Context, s *sessions.SessionState) (bool, error) { + if s == nil { return false, nil } @@ -176,6 +176,11 @@ func (p *AzureProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions. return true, nil } +// IsRefreshNeeded checks if the session has expired +func (p *AzureProvider) IsRefreshNeeded(s *sessions.SessionState) bool { + return !(s == nil || s.ExpiresOn.After(time.Now()) || s.RefreshToken == "") +} + func (p *AzureProvider) redeemRefreshToken(ctx context.Context, s *sessions.SessionState) (err error) { params := url.Values{} params.Add("client_id", p.ClientID) diff --git a/providers/azure_test.go b/providers/azure_test.go index 9e3cabf7..ab63e478 100644 --- a/providers/azure_test.go +++ b/providers/azure_test.go @@ -244,7 +244,7 @@ func TestAzureProviderNotRefreshWhenNotExpired(t *testing.T) { expires := time.Now().Add(time.Duration(1) * time.Hour) session := &sessions.SessionState{AccessToken: "some_access_token", RefreshToken: "some_refresh_token", IDToken: "some_id_token", ExpiresOn: &expires} - refreshNeeded, err := p.RefreshSessionIfNeeded(context.Background(), session) + refreshNeeded, err := p.RefreshSession(context.Background(), session) assert.Equal(t, nil, err) assert.False(t, refreshNeeded) } @@ -258,7 +258,7 @@ func TestAzureProviderRefreshWhenExpired(t *testing.T) { expires := time.Now().Add(time.Duration(-1) * time.Hour) session := &sessions.SessionState{AccessToken: "some_access_token", RefreshToken: "some_refresh_token", IDToken: "some_id_token", ExpiresOn: &expires} - _, err := p.RefreshSessionIfNeeded(context.Background(), session) + _, err := p.RefreshSession(context.Background(), session) assert.Equal(t, nil, err) assert.NotEqual(t, session, nil) assert.Equal(t, "new_some_access_token", session.AccessToken) diff --git a/providers/gitlab.go b/providers/gitlab.go index f54430fc..370b0182 100644 --- a/providers/gitlab.go +++ b/providers/gitlab.go @@ -123,7 +123,7 @@ func (p *GitLabProvider) SetProjectScope() { // RefreshSessionIfNeeded checks if the session has expired and uses the // RefreshToken to fetch a new ID token if required -func (p *GitLabProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error) { +func (p *GitLabProvider) RefreshSession(ctx context.Context, s *sessions.SessionState) (bool, error) { if s == nil || (s.ExpiresOn != nil && s.ExpiresOn.After(time.Now())) || s.RefreshToken == "" { return false, nil } @@ -139,6 +139,11 @@ func (p *GitLabProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions return true, nil } +// IsRefreshNeeded checks if the session has expired +func (p *GitLabProvider) IsRefreshNeeded(s *sessions.SessionState) bool { + return !(s == nil || (s.ExpiresOn != nil && s.ExpiresOn.After(time.Now())) || s.RefreshToken == "") +} + func (p *GitLabProvider) redeemRefreshToken(ctx context.Context, s *sessions.SessionState) (err error) { clientSecret, err := p.GetClientSecret() if err != nil { diff --git a/providers/google.go b/providers/google.go index b669156d..74e06114 100644 --- a/providers/google.go +++ b/providers/google.go @@ -268,8 +268,8 @@ func userInGroup(service *admin.Service, group string, email string) bool { // RefreshSessionIfNeeded checks if the session has expired and uses the // RefreshToken to fetch a new ID token if required -func (p *GoogleProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error) { - if s == nil || (s.ExpiresOn != nil && s.ExpiresOn.After(time.Now())) || s.RefreshToken == "" { +func (p *GoogleProvider) RefreshSession(ctx context.Context, s *sessions.SessionState) (bool, error) { + if s == nil { return false, nil } @@ -295,6 +295,11 @@ func (p *GoogleProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions return true, nil } +// IsRefreshNeeded checks if the session has expired +func (p *GoogleProvider) IsRefreshNeeded(s *sessions.SessionState) bool { + return !(s == nil || (s.ExpiresOn != nil && s.ExpiresOn.After(time.Now())) || s.RefreshToken == "") +} + func (p *GoogleProvider) redeemRefreshToken(ctx context.Context, refreshToken string) (token string, idToken string, expires time.Duration, err error) { // https://developers.google.com/identity/protocols/OAuth2WebServer#refresh clientSecret, err := p.GetClientSecret() diff --git a/providers/oidc.go b/providers/oidc.go index df133f4d..20f62e95 100644 --- a/providers/oidc.go +++ b/providers/oidc.go @@ -115,8 +115,8 @@ func (p *OIDCProvider) ValidateSession(ctx context.Context, s *sessions.SessionS // RefreshSessionIfNeeded checks if the session has expired and uses the // RefreshToken to fetch a new Access Token (and optional ID token) if required -func (p *OIDCProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error) { - if s == nil || (s.ExpiresOn != nil && s.ExpiresOn.After(time.Now())) || s.RefreshToken == "" { +func (p *OIDCProvider) RefreshSession(ctx context.Context, s *sessions.SessionState) (bool, error) { + if s == nil { return false, nil } @@ -129,6 +129,11 @@ func (p *OIDCProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions.S return true, nil } +// IsRefreshNeeded checks if the session has expired +func (p *OIDCProvider) IsRefreshNeeded(s *sessions.SessionState) bool { + return !(s == nil || (s.ExpiresOn != nil && s.ExpiresOn.After(time.Now())) || s.RefreshToken == "") +} + // redeemRefreshToken uses a RefreshToken with the RedeemURL to refresh the // Access Token and (probably) the ID Token. func (p *OIDCProvider) redeemRefreshToken(ctx context.Context, s *sessions.SessionState) error { diff --git a/providers/oidc_test.go b/providers/oidc_test.go index 7ac98634..84e19a3d 100644 --- a/providers/oidc_test.go +++ b/providers/oidc_test.go @@ -467,7 +467,7 @@ func TestOIDCProviderRefreshSessionIfNeededWithoutIdToken(t *testing.T) { User: "11223344", } - refreshed, err := provider.RefreshSessionIfNeeded(context.Background(), existingSession) + refreshed, err := provider.RefreshSession(context.Background(), existingSession) assert.Equal(t, nil, err) assert.Equal(t, refreshed, true) assert.Equal(t, "janedoe@example.com", existingSession.Email) @@ -500,7 +500,7 @@ func TestOIDCProviderRefreshSessionIfNeededWithIdToken(t *testing.T) { Email: "changeit", User: "changeit", } - refreshed, err := provider.RefreshSessionIfNeeded(context.Background(), existingSession) + refreshed, err := provider.RefreshSession(context.Background(), existingSession) assert.Equal(t, nil, err) assert.Equal(t, refreshed, true) assert.Equal(t, defaultIDToken.Email, existingSession.Email) diff --git a/providers/provider_default.go b/providers/provider_default.go index d3c6d113..9ce6e511 100644 --- a/providers/provider_default.go +++ b/providers/provider_default.go @@ -128,10 +128,15 @@ func (p *ProviderData) ValidateSession(ctx context.Context, s *sessions.SessionS // RefreshSessionIfNeeded should refresh the user's session if required and // do nothing if a refresh is not required -func (p *ProviderData) RefreshSessionIfNeeded(_ context.Context, _ *sessions.SessionState) (bool, error) { +func (p *ProviderData) RefreshSession(_ context.Context, _ *sessions.SessionState) (bool, error) { return false, nil } +// IsRefreshNeeded should return true if the user's session need to be refreshed +func (p *ProviderData) IsRefreshNeeded(_ *sessions.SessionState) bool { + return false +} + // CreateSessionFromToken converts Bearer IDTokens into sessions func (p *ProviderData) CreateSessionFromToken(ctx context.Context, token string) (*sessions.SessionState, error) { if p.Verifier != nil { diff --git a/providers/provider_default_test.go b/providers/provider_default_test.go index df1525cf..e3e65d3b 100644 --- a/providers/provider_default_test.go +++ b/providers/provider_default_test.go @@ -15,7 +15,7 @@ func TestRefresh(t *testing.T) { p := &ProviderData{} expires := time.Now().Add(time.Duration(-11) * time.Minute) - refreshed, err := p.RefreshSessionIfNeeded(context.Background(), &sessions.SessionState{ + refreshed, err := p.RefreshSession(context.Background(), &sessions.SessionState{ ExpiresOn: &expires, }) assert.Equal(t, false, refreshed) diff --git a/providers/providers.go b/providers/providers.go index 6aeb5426..0d1609d9 100644 --- a/providers/providers.go +++ b/providers/providers.go @@ -16,7 +16,8 @@ type Provider interface { Authorize(ctx context.Context, s *sessions.SessionState) (bool, error) ValidateSession(ctx context.Context, s *sessions.SessionState) bool GetLoginURL(redirectURI, finalRedirect string) string - RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error) + RefreshSession(ctx context.Context, s *sessions.SessionState) (bool, error) + IsRefreshNeeded(s *sessions.SessionState) bool CreateSessionFromToken(ctx context.Context, token string) (*sessions.SessionState, error) }