From 034f057b6032731f701a724e1037e83b699d7296 Mon Sep 17 00:00:00 2001 From: Joel Speed Date: Fri, 17 Jul 2020 11:47:26 +0100 Subject: [PATCH] Add session loader from session storage --- pkg/middleware/stored_session.go | 165 ++++++++ pkg/middleware/stored_session_test.go | 524 ++++++++++++++++++++++++++ 2 files changed, 689 insertions(+) create mode 100644 pkg/middleware/stored_session.go create mode 100644 pkg/middleware/stored_session_test.go diff --git a/pkg/middleware/stored_session.go b/pkg/middleware/stored_session.go new file mode 100644 index 00000000..dd4d1405 --- /dev/null +++ b/pkg/middleware/stored_session.go @@ -0,0 +1,165 @@ +package middleware + +import ( + "context" + "errors" + "fmt" + "net/http" + "time" + + "github.com/justinas/alice" + sessionsapi "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions" + "github.com/oauth2-proxy/oauth2-proxy/pkg/logger" +) + +// StoredSessionLoaderOptions cotnains all of the requirements to construct +// a stored session loader. +// All options must be provided. +type StoredSessionLoaderOptions struct { + // Session storage basckend + SessionStore sessionsapi.SessionStore + + // How often should sessions be refreshed + RefreshPeriod time.Duration + + // Provider based sesssion refreshing + RefreshSessionIfNeeded func(context.Context, *sessionsapi.SessionState) (bool, error) + + // Provider based session validation. + // If the sesssion is older than `RefreshPeriod` but the provider doesn't + // refresh it, we must re-validate using this validation. + ValidateSessionState func(context.Context, *sessionsapi.SessionState) bool +} + +// NewStoredSessionLoader creates a new storedSessionLoader which loads +// sessions from the session store. +// If no session is found, the request will be passed to the nex handler. +// If a session was loader by a previous handler, it will not be replaced. +func NewStoredSessionLoader(opts *StoredSessionLoaderOptions) alice.Constructor { + ss := &storedSessionLoader{ + store: opts.SessionStore, + refreshPeriod: opts.RefreshPeriod, + refreshSessionWithProviderIfNeeded: opts.RefreshSessionIfNeeded, + validateSessionState: opts.ValidateSessionState, + } + return ss.loadSession +} + +// storedSessionLoader is responsible for loading sessions from cookie +// identified sessions in the session store. +type storedSessionLoader struct { + store sessionsapi.SessionStore + refreshPeriod time.Duration + refreshSessionWithProviderIfNeeded func(context.Context, *sessionsapi.SessionState) (bool, error) + validateSessionState func(context.Context, *sessionsapi.SessionState) bool +} + +// loadSession attempts to load a session as identified by the request cookies. +// If no session is found, the request will be passed to the nex handler. +// If a session was loader by a previous handler, it will not be replaced. +func (s *storedSessionLoader) loadSession(next http.Handler) http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + scope := GetRequestScope(req) + // If scope is nil, this will panic. + // A scope should always be injected before this handler is called. + if scope.Session != nil { + // The session was already loaded, pass to the next handler + next.ServeHTTP(rw, req) + return + } + + session, err := s.getValidatedSession(rw, req) + if err != nil { + // In the case when there was an error loading the session, + // we should clear the session + logger.Printf("Error loading cookied session: %v, removing session", err) + s.store.Clear(rw, req) + } + + // Add the session to the scope if it was found + scope.Session = session + next.ServeHTTP(rw, req) + }) +} + +// getValidatedSession is responsible for loading a session and making sure +// that is is valid. +func (s *storedSessionLoader) getValidatedSession(rw http.ResponseWriter, req *http.Request) (*sessionsapi.SessionState, error) { + session, err := s.store.Load(req) + if err != nil { + return nil, err + } + if session == nil { + // No session was found in the storage, nothing more to do + return nil, nil + } + + err = s.refreshSessionIfNeeded(rw, req, session) + if err != nil { + return nil, fmt.Errorf("error refreshing access token for session (%s): %v", session, err) + } + + return session, nil +} + +// refreshSessionIfNeeded will attempt to refresh a session if the session +// is older than the refresh period. +// It is assumed that if the provider refreshes the session, the session is now +// valid. +// If the session requires refreshing but the provider does not refresh it, +// we must validate the session to ensure that the returned session is still +// valid. +func (s *storedSessionLoader) refreshSessionIfNeeded(rw http.ResponseWriter, req *http.Request, session *sessionsapi.SessionState) error { + if s.refreshPeriod <= time.Duration(0) || session.Age() < s.refreshPeriod { + // Refresh is disabled or the session is not old enough, do nothing + return nil + } + + logger.Printf("Refreshing %s old session cookie for %s (refresh after %s)", session.Age(), session, s.refreshPeriod) + refreshed, err := s.refreshSessionWithProvider(rw, req, session) + if err != nil { + return err + } + + if !refreshed { + // Session wasn't refreshed, so make sure it's still valid + return s.validateSession(req.Context(), session) + } + return nil +} + +// refreshSessionWithProvider attempts to refresh the sessinon with the provider +// and will save the session if it was updated. +func (s *storedSessionLoader) refreshSessionWithProvider(rw http.ResponseWriter, req *http.Request, session *sessionsapi.SessionState) (bool, error) { + refreshed, err := s.refreshSessionWithProviderIfNeeded(req.Context(), session) + if err != nil { + return false, fmt.Errorf("error refreshing access token: %v", err) + } + + if !refreshed { + return false, nil + } + + // Because the session was refreshed, make sure to save it + err = s.store.Save(rw, req, session) + if err != nil { + logger.PrintAuthf(session.Email, req, logger.AuthError, "error saving session: %v", err) + return false, fmt.Errorf("error saving session: %v", err) + } + return true, nil +} + +// validateSession checks whether the session has expired and performs +// provider validation on the session. +// An error implies the session is not longer valid. +func (s *storedSessionLoader) validateSession(ctx context.Context, session *sessionsapi.SessionState) error { + if session.IsExpired() { + return errors.New("session is expired") + } + + if !s.validateSessionState(ctx, session) { + return errors.New("session is invalid") + } + + return nil +} diff --git a/pkg/middleware/stored_session_test.go b/pkg/middleware/stored_session_test.go new file mode 100644 index 00000000..1721b309 --- /dev/null +++ b/pkg/middleware/stored_session_test.go @@ -0,0 +1,524 @@ +package middleware + +import ( + "context" + "errors" + "fmt" + "net/http" + "net/http/httptest" + "time" + + middlewareapi "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/middleware" + sessionsapi "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions" + . "github.com/onsi/ginkgo" + . "github.com/onsi/ginkgo/extensions/table" + . "github.com/onsi/gomega" +) + +var _ = Describe("Stored Session Suite", func() { + const ( + refresh = "Refresh" + noRefresh = "NoRefresh" + ) + + var ctx = context.Background() + + Context("StoredSessionLoader", func() { + createdPast := time.Now().Add(-5 * time.Minute) + createdFuture := time.Now().Add(5 * time.Minute) + + var defaultRefreshFunc = func(_ context.Context, ss *sessionsapi.SessionState) (bool, error) { + switch ss.RefreshToken { + case refresh: + ss.RefreshToken = "Refreshed" + return true, nil + case noRefresh: + return false, nil + default: + return false, errors.New("error refreshing session") + } + } + + var defaultValidateFunc = func(_ context.Context, ss *sessionsapi.SessionState) bool { + return ss.AccessToken != "Invalid" + } + + var defaultSessionStore = &fakeSessionStore{ + LoadFunc: func(req *http.Request) (*sessionsapi.SessionState, error) { + switch req.Header.Get("Cookie") { + case "_oauth2_proxy=NoRefreshSession": + return &sessionsapi.SessionState{ + RefreshToken: noRefresh, + CreatedAt: &createdPast, + ExpiresOn: &createdFuture, + }, nil + case "_oauth2_proxy=InvalidNoRefreshSession": + return &sessionsapi.SessionState{ + AccessToken: "Invalid", + RefreshToken: noRefresh, + CreatedAt: &createdPast, + ExpiresOn: &createdFuture, + }, nil + case "_oauth2_proxy=ExpiredNoRefreshSession": + return &sessionsapi.SessionState{ + RefreshToken: noRefresh, + CreatedAt: &createdPast, + ExpiresOn: &createdPast, + }, nil + case "_oauth2_proxy=RefreshSession": + return &sessionsapi.SessionState{ + RefreshToken: refresh, + CreatedAt: &createdPast, + ExpiresOn: &createdFuture, + }, nil + case "_oauth2_proxy=RefreshError": + return &sessionsapi.SessionState{ + RefreshToken: "RefreshError", + CreatedAt: &createdPast, + ExpiresOn: &createdFuture, + }, nil + case "_oauth2_proxy=NonExistent": + return nil, fmt.Errorf("invalid cookie") + default: + return nil, nil + } + }, + } + + type storedSessionLoaderTableInput struct { + requestHeaders http.Header + existingSession *sessionsapi.SessionState + expectedSession *sessionsapi.SessionState + store sessionsapi.SessionStore + refreshPeriod time.Duration + refreshSession func(context.Context, *sessionsapi.SessionState) (bool, error) + validateSession func(context.Context, *sessionsapi.SessionState) bool + } + + DescribeTable("when serving a request", + func(in storedSessionLoaderTableInput) { + scope := &middlewareapi.RequestScope{ + Session: in.existingSession, + } + + // Set up the request with the request headesr and a request scope + req := httptest.NewRequest("", "/", nil) + req.Header = in.requestHeaders + contextWithScope := context.WithValue(req.Context(), requestScopeKey, scope) + req = req.WithContext(contextWithScope) + + rw := httptest.NewRecorder() + + opts := &StoredSessionLoaderOptions{ + SessionStore: in.store, + RefreshPeriod: in.refreshPeriod, + RefreshSessionIfNeeded: in.refreshSession, + ValidateSessionState: in.validateSession, + } + + // Create the handler with a next handler that will capture the session + // from the scope + var gotSession *sessionsapi.SessionState + handler := NewStoredSessionLoader(opts)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotSession = r.Context().Value(requestScopeKey).(*middlewareapi.RequestScope).Session + })) + handler.ServeHTTP(rw, req) + + Expect(gotSession).To(Equal(in.expectedSession)) + }, + Entry("with no cookie", storedSessionLoaderTableInput{ + requestHeaders: http.Header{}, + existingSession: nil, + expectedSession: nil, + store: defaultSessionStore, + refreshPeriod: 1 * time.Minute, + refreshSession: defaultRefreshFunc, + validateSession: defaultValidateFunc, + }), + Entry("with an invalid cookie", storedSessionLoaderTableInput{ + requestHeaders: http.Header{ + "Cookie": []string{"_oauth2_proxy=NonExistent"}, + }, + existingSession: nil, + expectedSession: nil, + store: defaultSessionStore, + refreshPeriod: 1 * time.Minute, + refreshSession: defaultRefreshFunc, + validateSession: defaultValidateFunc, + }), + Entry("with an existing session", storedSessionLoaderTableInput{ + requestHeaders: http.Header{ + "Cookie": []string{"_oauth2_proxy=RefreshSession"}, + }, + existingSession: &sessionsapi.SessionState{ + RefreshToken: "Existing", + }, + expectedSession: &sessionsapi.SessionState{ + RefreshToken: "Existing", + }, + store: defaultSessionStore, + refreshPeriod: 1 * time.Minute, + refreshSession: defaultRefreshFunc, + validateSession: defaultValidateFunc, + }), + Entry("with a session that has not expired", storedSessionLoaderTableInput{ + requestHeaders: http.Header{ + "Cookie": []string{"_oauth2_proxy=NoRefreshSession"}, + }, + existingSession: nil, + expectedSession: &sessionsapi.SessionState{ + RefreshToken: noRefresh, + CreatedAt: &createdPast, + ExpiresOn: &createdFuture, + }, + store: defaultSessionStore, + refreshPeriod: 1 * time.Minute, + refreshSession: defaultRefreshFunc, + validateSession: defaultValidateFunc, + }), + Entry("with a session that cannot refresh and has expired", storedSessionLoaderTableInput{ + requestHeaders: http.Header{ + "Cookie": []string{"_oauth2_proxy=ExpiredNoRefreshSession"}, + }, + existingSession: nil, + expectedSession: nil, + store: defaultSessionStore, + refreshPeriod: 1 * time.Minute, + refreshSession: defaultRefreshFunc, + validateSession: defaultValidateFunc, + }), + Entry("with a session that can refresh, but is younger than refresh period", storedSessionLoaderTableInput{ + requestHeaders: http.Header{ + "Cookie": []string{"_oauth2_proxy=RefreshSession"}, + }, + existingSession: nil, + expectedSession: &sessionsapi.SessionState{ + RefreshToken: refresh, + CreatedAt: &createdPast, + ExpiresOn: &createdFuture, + }, + store: defaultSessionStore, + refreshPeriod: 10 * time.Minute, + refreshSession: defaultRefreshFunc, + validateSession: defaultValidateFunc, + }), + Entry("with a session that can refresh and is older than the refresh period", storedSessionLoaderTableInput{ + requestHeaders: http.Header{ + "Cookie": []string{"_oauth2_proxy=RefreshSession"}, + }, + existingSession: nil, + expectedSession: &sessionsapi.SessionState{ + RefreshToken: "Refreshed", + CreatedAt: &createdPast, + ExpiresOn: &createdFuture, + }, + store: defaultSessionStore, + refreshPeriod: 1 * time.Minute, + refreshSession: defaultRefreshFunc, + validateSession: defaultValidateFunc, + }), + Entry("when the provider refresh fails", storedSessionLoaderTableInput{ + requestHeaders: http.Header{ + "Cookie": []string{"_oauth2_proxy=RefreshError"}, + }, + existingSession: nil, + expectedSession: nil, + store: defaultSessionStore, + refreshPeriod: 1 * time.Minute, + refreshSession: defaultRefreshFunc, + validateSession: defaultValidateFunc, + }), + Entry("when the session is not refreshed and is no longer valid", storedSessionLoaderTableInput{ + requestHeaders: http.Header{ + "Cookie": []string{"_oauth2_proxy=InvalidNoRefreshSession"}, + }, + existingSession: nil, + expectedSession: nil, + store: defaultSessionStore, + refreshPeriod: 1 * time.Minute, + refreshSession: defaultRefreshFunc, + validateSession: defaultValidateFunc, + }), + ) + }) + + Context("refreshSessionIfNeeded", func() { + type refreshSessionIfNeededTableInput struct { + refreshPeriod time.Duration + session *sessionsapi.SessionState + expectedErr error + expectRefreshed bool + expectValidated bool + } + + createdPast := time.Now().Add(-5 * time.Minute) + createdFuture := time.Now().Add(5 * time.Minute) + + DescribeTable("with a session", + func(in refreshSessionIfNeededTableInput) { + refreshed := false + validated := false + + s := &storedSessionLoader{ + refreshPeriod: in.refreshPeriod, + store: &fakeSessionStore{}, + refreshSessionWithProviderIfNeeded: func(_ context.Context, ss *sessionsapi.SessionState) (bool, error) { + refreshed = true + switch ss.RefreshToken { + case refresh: + return true, nil + case noRefresh: + return false, nil + default: + return false, errors.New("error refreshing session") + } + }, + validateSessionState: func(_ context.Context, ss *sessionsapi.SessionState) bool { + validated = true + return ss.AccessToken != "Invalid" + }, + } + + req := httptest.NewRequest("", "/", nil) + err := s.refreshSessionIfNeeded(nil, req, in.session) + if in.expectedErr != nil { + Expect(err).To(MatchError(in.expectedErr)) + } else { + Expect(err).ToNot(HaveOccurred()) + } + Expect(refreshed).To(Equal(in.expectRefreshed)) + Expect(validated).To(Equal(in.expectValidated)) + }, + Entry("when the refresh period is 0, and the session does not need refreshing", refreshSessionIfNeededTableInput{ + refreshPeriod: time.Duration(0), + session: &sessionsapi.SessionState{ + RefreshToken: refresh, + CreatedAt: &createdFuture, + }, + expectedErr: nil, + expectRefreshed: false, + expectValidated: false, + }), + Entry("when the refresh period is 0, and the session needs refreshing", refreshSessionIfNeededTableInput{ + refreshPeriod: time.Duration(0), + session: &sessionsapi.SessionState{ + RefreshToken: refresh, + CreatedAt: &createdPast, + }, + expectedErr: nil, + expectRefreshed: false, + expectValidated: false, + }), + Entry("when the session does not need refreshing", refreshSessionIfNeededTableInput{ + refreshPeriod: 1 * time.Minute, + session: &sessionsapi.SessionState{ + RefreshToken: refresh, + CreatedAt: &createdFuture, + }, + expectedErr: nil, + expectRefreshed: false, + expectValidated: false, + }), + Entry("when the session is refreshed by the provider", refreshSessionIfNeededTableInput{ + refreshPeriod: 1 * time.Minute, + session: &sessionsapi.SessionState{ + RefreshToken: refresh, + CreatedAt: &createdPast, + }, + expectedErr: nil, + expectRefreshed: true, + expectValidated: false, + }), + Entry("when the session is not refreshed by the provider", refreshSessionIfNeededTableInput{ + refreshPeriod: 1 * time.Minute, + session: &sessionsapi.SessionState{ + RefreshToken: noRefresh, + CreatedAt: &createdPast, + ExpiresOn: &createdFuture, + }, + expectedErr: nil, + expectRefreshed: true, + expectValidated: true, + }), + Entry("when the provider refresh fails", refreshSessionIfNeededTableInput{ + refreshPeriod: 1 * time.Minute, + session: &sessionsapi.SessionState{ + RefreshToken: "RefreshError", + CreatedAt: &createdPast, + }, + expectedErr: errors.New("error refreshing access token: error refreshing session"), + expectRefreshed: true, + expectValidated: false, + }), + Entry("when the session is not refreshed by the provider and validation fails", refreshSessionIfNeededTableInput{ + refreshPeriod: 1 * time.Minute, + session: &sessionsapi.SessionState{ + AccessToken: "Invalid", + RefreshToken: noRefresh, + CreatedAt: &createdPast, + ExpiresOn: &createdFuture, + }, + expectedErr: errors.New("session is invalid"), + expectRefreshed: true, + expectValidated: true, + }), + ) + }) + + Context("refreshSessionWithProvider", func() { + type refreshSessionWithProviderTableInput struct { + session *sessionsapi.SessionState + expectedErr error + expectRefreshed bool + expectSaved bool + } + + now := time.Now() + + DescribeTable("when refreshing with the provider", + func(in refreshSessionWithProviderTableInput) { + saved := false + + s := &storedSessionLoader{ + store: &fakeSessionStore{ + SaveFunc: func(_ http.ResponseWriter, _ *http.Request, ss *sessionsapi.SessionState) error { + saved = true + if ss.AccessToken == "NoSave" { + return errors.New("unable to save session") + } + return nil + }, + }, + refreshSessionWithProviderIfNeeded: func(_ context.Context, ss *sessionsapi.SessionState) (bool, error) { + switch ss.RefreshToken { + case refresh: + return true, nil + case noRefresh: + return false, nil + default: + return false, errors.New("error refreshing session") + } + }, + } + + req := httptest.NewRequest("", "/", nil) + refreshed, err := s.refreshSessionWithProvider(nil, req, in.session) + if in.expectedErr != nil { + Expect(err).To(MatchError(in.expectedErr)) + } else { + Expect(err).ToNot(HaveOccurred()) + } + Expect(refreshed).To(Equal(in.expectRefreshed)) + Expect(saved).To(Equal(in.expectSaved)) + }, + Entry("when the provider does not refresh the session", refreshSessionWithProviderTableInput{ + session: &sessionsapi.SessionState{ + RefreshToken: noRefresh, + }, + expectedErr: nil, + expectRefreshed: false, + expectSaved: false, + }), + Entry("when the provider refreshes the session", refreshSessionWithProviderTableInput{ + session: &sessionsapi.SessionState{ + RefreshToken: refresh, + }, + expectedErr: nil, + expectRefreshed: true, + expectSaved: true, + }), + Entry("when the provider returns an error", refreshSessionWithProviderTableInput{ + session: &sessionsapi.SessionState{ + RefreshToken: "RefreshError", + CreatedAt: &now, + ExpiresOn: &now, + }, + expectedErr: errors.New("error refreshing access token: error refreshing session"), + expectRefreshed: false, + expectSaved: false, + }), + Entry("when the saving the session returns an error", refreshSessionWithProviderTableInput{ + session: &sessionsapi.SessionState{ + RefreshToken: refresh, + AccessToken: "NoSave", + }, + expectedErr: errors.New("error saving session: unable to save session"), + expectRefreshed: false, + expectSaved: true, + }), + ) + }) + + Context("validateSession", func() { + var s *storedSessionLoader + + BeforeEach(func() { + s = &storedSessionLoader{ + validateSessionState: func(_ context.Context, ss *sessionsapi.SessionState) bool { + return ss.AccessToken == "Valid" + }, + } + }) + + Context("with a valid session", func() { + It("does not return an error", func() { + expires := time.Now().Add(1 * time.Minute) + session := &sessionsapi.SessionState{ + AccessToken: "Valid", + ExpiresOn: &expires, + } + Expect(s.validateSession(ctx, session)).To(Succeed()) + }) + }) + + Context("with an expired session", func() { + It("returns an error", func() { + created := time.Now().Add(-5 * time.Minute) + expires := time.Now().Add(-1 * time.Minute) + session := &sessionsapi.SessionState{ + AccessToken: "Valid", + CreatedAt: &created, + ExpiresOn: &expires, + } + Expect(s.validateSession(ctx, session)).To(MatchError("session is expired")) + }) + }) + + Context("with an invalid session", func() { + It("returns an error", func() { + expires := time.Now().Add(1 * time.Minute) + session := &sessionsapi.SessionState{ + AccessToken: "Invalid", + ExpiresOn: &expires, + } + Expect(s.validateSession(ctx, session)).To(MatchError("session is invalid")) + }) + }) + }) +}) + +type fakeSessionStore struct { + SaveFunc func(http.ResponseWriter, *http.Request, *sessionsapi.SessionState) error + LoadFunc func(req *http.Request) (*sessionsapi.SessionState, error) + ClearFunc func(rw http.ResponseWriter, req *http.Request) error +} + +func (f *fakeSessionStore) Save(rw http.ResponseWriter, req *http.Request, s *sessionsapi.SessionState) error { + if f.SaveFunc != nil { + return f.SaveFunc(rw, req, s) + } + return nil +} +func (f *fakeSessionStore) Load(req *http.Request) (*sessionsapi.SessionState, error) { + if f.LoadFunc != nil { + return f.LoadFunc(req) + } + return nil, nil +} + +func (f *fakeSessionStore) Clear(rw http.ResponseWriter, req *http.Request) error { + if f.ClearFunc != nil { + return f.ClearFunc(rw, req) + } + return nil +}