package middleware import ( "context" "errors" "fmt" "net/http" "net/http/httptest" "time" middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" sessionsapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" . "github.com/onsi/ginkgo" . "github.com/onsi/ginkgo/extensions/table" . "github.com/onsi/gomega" ) var _ = Describe("Stored Session Suite", func() { const ( refresh = "Refresh" noRefresh = "NoRefresh" ) var ctx = context.Background() Context("StoredSessionLoader", func() { createdPast := time.Now().Add(-5 * time.Minute) createdFuture := time.Now().Add(5 * time.Minute) var defaultRefreshFunc = func(_ context.Context, ss *sessionsapi.SessionState) (bool, error) { switch ss.RefreshToken { case refresh: ss.RefreshToken = "Refreshed" return true, nil case noRefresh: return false, nil default: return false, errors.New("error refreshing session") } } var defaultValidateFunc = func(_ context.Context, ss *sessionsapi.SessionState) bool { return ss.AccessToken != "Invalid" } var defaultSessionStore = &fakeSessionStore{ LoadFunc: func(req *http.Request) (*sessionsapi.SessionState, error) { switch req.Header.Get("Cookie") { case "_oauth2_proxy=NoRefreshSession": return &sessionsapi.SessionState{ RefreshToken: noRefresh, CreatedAt: &createdPast, ExpiresOn: &createdFuture, }, nil case "_oauth2_proxy=InvalidNoRefreshSession": return &sessionsapi.SessionState{ AccessToken: "Invalid", RefreshToken: noRefresh, CreatedAt: &createdPast, ExpiresOn: &createdFuture, }, nil case "_oauth2_proxy=ExpiredNoRefreshSession": return &sessionsapi.SessionState{ RefreshToken: noRefresh, CreatedAt: &createdPast, ExpiresOn: &createdPast, }, nil case "_oauth2_proxy=RefreshSession": return &sessionsapi.SessionState{ RefreshToken: refresh, CreatedAt: &createdPast, ExpiresOn: &createdFuture, }, nil case "_oauth2_proxy=RefreshError": return &sessionsapi.SessionState{ RefreshToken: "RefreshError", CreatedAt: &createdPast, ExpiresOn: &createdFuture, }, nil case "_oauth2_proxy=NonExistent": return nil, fmt.Errorf("invalid cookie") default: return nil, nil } }, } type storedSessionLoaderTableInput struct { requestHeaders http.Header existingSession *sessionsapi.SessionState expectedSession *sessionsapi.SessionState store sessionsapi.SessionStore refreshPeriod time.Duration refreshSession func(context.Context, *sessionsapi.SessionState) (bool, error) validateSession func(context.Context, *sessionsapi.SessionState) bool } DescribeTable("when serving a request", func(in storedSessionLoaderTableInput) { scope := &middlewareapi.RequestScope{ Session: in.existingSession, } // Set up the request with the request header and a request scope req := httptest.NewRequest("", "/", nil) req.Header = in.requestHeaders req = middlewareapi.AddRequestScope(req, scope) rw := httptest.NewRecorder() opts := &StoredSessionLoaderOptions{ SessionStore: in.store, RefreshPeriod: in.refreshPeriod, RefreshSessionIfNeeded: in.refreshSession, ValidateSessionState: in.validateSession, } // Create the handler with a next handler that will capture the session // from the scope var gotSession *sessionsapi.SessionState handler := NewStoredSessionLoader(opts)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { gotSession = middlewareapi.GetRequestScope(r).Session })) handler.ServeHTTP(rw, req) Expect(gotSession).To(Equal(in.expectedSession)) }, Entry("with no cookie", storedSessionLoaderTableInput{ requestHeaders: http.Header{}, existingSession: nil, expectedSession: nil, store: defaultSessionStore, refreshPeriod: 1 * time.Minute, refreshSession: defaultRefreshFunc, validateSession: defaultValidateFunc, }), Entry("with an invalid cookie", storedSessionLoaderTableInput{ requestHeaders: http.Header{ "Cookie": []string{"_oauth2_proxy=NonExistent"}, }, existingSession: nil, expectedSession: nil, store: defaultSessionStore, refreshPeriod: 1 * time.Minute, refreshSession: defaultRefreshFunc, validateSession: defaultValidateFunc, }), Entry("with an existing session", storedSessionLoaderTableInput{ requestHeaders: http.Header{ "Cookie": []string{"_oauth2_proxy=RefreshSession"}, }, existingSession: &sessionsapi.SessionState{ RefreshToken: "Existing", }, expectedSession: &sessionsapi.SessionState{ RefreshToken: "Existing", }, store: defaultSessionStore, refreshPeriod: 1 * time.Minute, refreshSession: defaultRefreshFunc, validateSession: defaultValidateFunc, }), Entry("with a session that has not expired", storedSessionLoaderTableInput{ requestHeaders: http.Header{ "Cookie": []string{"_oauth2_proxy=NoRefreshSession"}, }, existingSession: nil, expectedSession: &sessionsapi.SessionState{ RefreshToken: noRefresh, CreatedAt: &createdPast, ExpiresOn: &createdFuture, }, store: defaultSessionStore, refreshPeriod: 1 * time.Minute, refreshSession: defaultRefreshFunc, validateSession: defaultValidateFunc, }), Entry("with a session that cannot refresh and has expired", storedSessionLoaderTableInput{ requestHeaders: http.Header{ "Cookie": []string{"_oauth2_proxy=ExpiredNoRefreshSession"}, }, existingSession: nil, expectedSession: nil, store: defaultSessionStore, refreshPeriod: 1 * time.Minute, refreshSession: defaultRefreshFunc, validateSession: defaultValidateFunc, }), Entry("with a session that can refresh, but is younger than refresh period", storedSessionLoaderTableInput{ requestHeaders: http.Header{ "Cookie": []string{"_oauth2_proxy=RefreshSession"}, }, existingSession: nil, expectedSession: &sessionsapi.SessionState{ RefreshToken: refresh, CreatedAt: &createdPast, ExpiresOn: &createdFuture, }, store: defaultSessionStore, refreshPeriod: 10 * time.Minute, refreshSession: defaultRefreshFunc, validateSession: defaultValidateFunc, }), Entry("with a session that can refresh and is older than the refresh period", storedSessionLoaderTableInput{ requestHeaders: http.Header{ "Cookie": []string{"_oauth2_proxy=RefreshSession"}, }, existingSession: nil, expectedSession: &sessionsapi.SessionState{ RefreshToken: "Refreshed", CreatedAt: &createdPast, ExpiresOn: &createdFuture, }, store: defaultSessionStore, refreshPeriod: 1 * time.Minute, refreshSession: defaultRefreshFunc, validateSession: defaultValidateFunc, }), Entry("when the provider refresh fails", storedSessionLoaderTableInput{ requestHeaders: http.Header{ "Cookie": []string{"_oauth2_proxy=RefreshError"}, }, existingSession: nil, expectedSession: nil, store: defaultSessionStore, refreshPeriod: 1 * time.Minute, refreshSession: defaultRefreshFunc, validateSession: defaultValidateFunc, }), Entry("when the session is not refreshed and is no longer valid", storedSessionLoaderTableInput{ requestHeaders: http.Header{ "Cookie": []string{"_oauth2_proxy=InvalidNoRefreshSession"}, }, existingSession: nil, expectedSession: nil, store: defaultSessionStore, refreshPeriod: 1 * time.Minute, refreshSession: defaultRefreshFunc, validateSession: defaultValidateFunc, }), ) }) Context("refreshSessionIfNeeded", func() { type refreshSessionIfNeededTableInput struct { refreshPeriod time.Duration session *sessionsapi.SessionState expectedErr error expectRefreshed bool expectValidated bool } createdPast := time.Now().Add(-5 * time.Minute) createdFuture := time.Now().Add(5 * time.Minute) DescribeTable("with a session", func(in refreshSessionIfNeededTableInput) { refreshed := false validated := false s := &storedSessionLoader{ refreshPeriod: in.refreshPeriod, store: &fakeSessionStore{}, refreshSessionWithProviderIfNeeded: func(_ context.Context, ss *sessionsapi.SessionState) (bool, error) { refreshed = true switch ss.RefreshToken { case refresh: return true, nil case noRefresh: return false, nil default: return false, errors.New("error refreshing session") } }, validateSessionState: func(_ context.Context, ss *sessionsapi.SessionState) bool { validated = true return ss.AccessToken != "Invalid" }, } req := httptest.NewRequest("", "/", nil) err := s.refreshSessionIfNeeded(nil, req, in.session) if in.expectedErr != nil { Expect(err).To(MatchError(in.expectedErr)) } else { Expect(err).ToNot(HaveOccurred()) } Expect(refreshed).To(Equal(in.expectRefreshed)) Expect(validated).To(Equal(in.expectValidated)) }, Entry("when the refresh period is 0, and the session does not need refreshing", refreshSessionIfNeededTableInput{ refreshPeriod: time.Duration(0), session: &sessionsapi.SessionState{ RefreshToken: refresh, CreatedAt: &createdFuture, }, expectedErr: nil, expectRefreshed: false, expectValidated: false, }), Entry("when the refresh period is 0, and the session needs refreshing", refreshSessionIfNeededTableInput{ refreshPeriod: time.Duration(0), session: &sessionsapi.SessionState{ RefreshToken: refresh, CreatedAt: &createdPast, }, expectedErr: nil, expectRefreshed: false, expectValidated: false, }), Entry("when the session does not need refreshing", refreshSessionIfNeededTableInput{ refreshPeriod: 1 * time.Minute, session: &sessionsapi.SessionState{ RefreshToken: refresh, CreatedAt: &createdFuture, }, expectedErr: nil, expectRefreshed: false, expectValidated: false, }), Entry("when the session is refreshed by the provider", refreshSessionIfNeededTableInput{ refreshPeriod: 1 * time.Minute, session: &sessionsapi.SessionState{ RefreshToken: refresh, CreatedAt: &createdPast, }, expectedErr: nil, expectRefreshed: true, expectValidated: false, }), Entry("when the session is not refreshed by the provider", refreshSessionIfNeededTableInput{ refreshPeriod: 1 * time.Minute, session: &sessionsapi.SessionState{ RefreshToken: noRefresh, CreatedAt: &createdPast, ExpiresOn: &createdFuture, }, expectedErr: nil, expectRefreshed: true, expectValidated: true, }), Entry("when the provider refresh fails", refreshSessionIfNeededTableInput{ refreshPeriod: 1 * time.Minute, session: &sessionsapi.SessionState{ RefreshToken: "RefreshError", CreatedAt: &createdPast, }, expectedErr: errors.New("error refreshing access token: error refreshing session"), expectRefreshed: true, expectValidated: false, }), Entry("when the session is not refreshed by the provider and validation fails", refreshSessionIfNeededTableInput{ refreshPeriod: 1 * time.Minute, session: &sessionsapi.SessionState{ AccessToken: "Invalid", RefreshToken: noRefresh, CreatedAt: &createdPast, ExpiresOn: &createdFuture, }, expectedErr: errors.New("session is invalid"), expectRefreshed: true, expectValidated: true, }), ) }) Context("refreshSessionWithProvider", func() { type refreshSessionWithProviderTableInput struct { session *sessionsapi.SessionState expectedErr error expectRefreshed bool expectSaved bool } now := time.Now() DescribeTable("when refreshing with the provider", func(in refreshSessionWithProviderTableInput) { saved := false s := &storedSessionLoader{ store: &fakeSessionStore{ SaveFunc: func(_ http.ResponseWriter, _ *http.Request, ss *sessionsapi.SessionState) error { saved = true if ss.AccessToken == "NoSave" { return errors.New("unable to save session") } return nil }, }, refreshSessionWithProviderIfNeeded: func(_ context.Context, ss *sessionsapi.SessionState) (bool, error) { switch ss.RefreshToken { case refresh: return true, nil case noRefresh: return false, nil default: return false, errors.New("error refreshing session") } }, } req := httptest.NewRequest("", "/", nil) req = middlewareapi.AddRequestScope(req, &middlewareapi.RequestScope{}) refreshed, err := s.refreshSessionWithProvider(nil, req, in.session) if in.expectedErr != nil { Expect(err).To(MatchError(in.expectedErr)) } else { Expect(err).ToNot(HaveOccurred()) } Expect(refreshed).To(Equal(in.expectRefreshed)) Expect(saved).To(Equal(in.expectSaved)) }, Entry("when the provider does not refresh the session", refreshSessionWithProviderTableInput{ session: &sessionsapi.SessionState{ RefreshToken: noRefresh, }, expectedErr: nil, expectRefreshed: false, expectSaved: false, }), Entry("when the provider refreshes the session", refreshSessionWithProviderTableInput{ session: &sessionsapi.SessionState{ RefreshToken: refresh, }, expectedErr: nil, expectRefreshed: true, expectSaved: true, }), Entry("when the provider returns an error", refreshSessionWithProviderTableInput{ session: &sessionsapi.SessionState{ RefreshToken: "RefreshError", CreatedAt: &now, ExpiresOn: &now, }, expectedErr: errors.New("error refreshing access token: error refreshing session"), expectRefreshed: false, expectSaved: false, }), Entry("when the saving the session returns an error", refreshSessionWithProviderTableInput{ session: &sessionsapi.SessionState{ RefreshToken: refresh, AccessToken: "NoSave", }, expectedErr: errors.New("error saving session: unable to save session"), expectRefreshed: false, expectSaved: true, }), ) }) Context("validateSession", func() { var s *storedSessionLoader BeforeEach(func() { s = &storedSessionLoader{ validateSessionState: func(_ context.Context, ss *sessionsapi.SessionState) bool { return ss.AccessToken == "Valid" }, } }) Context("with a valid session", func() { It("does not return an error", func() { expires := time.Now().Add(1 * time.Minute) session := &sessionsapi.SessionState{ AccessToken: "Valid", ExpiresOn: &expires, } Expect(s.validateSession(ctx, session)).To(Succeed()) }) }) Context("with an expired session", func() { It("returns an error", func() { created := time.Now().Add(-5 * time.Minute) expires := time.Now().Add(-1 * time.Minute) session := &sessionsapi.SessionState{ AccessToken: "Valid", CreatedAt: &created, ExpiresOn: &expires, } Expect(s.validateSession(ctx, session)).To(MatchError("session is expired")) }) }) Context("with an invalid session", func() { It("returns an error", func() { expires := time.Now().Add(1 * time.Minute) session := &sessionsapi.SessionState{ AccessToken: "Invalid", ExpiresOn: &expires, } Expect(s.validateSession(ctx, session)).To(MatchError("session is invalid")) }) }) }) }) type fakeSessionStore struct { SaveFunc func(http.ResponseWriter, *http.Request, *sessionsapi.SessionState) error LoadFunc func(req *http.Request) (*sessionsapi.SessionState, error) ClearFunc func(rw http.ResponseWriter, req *http.Request) error } func (f *fakeSessionStore) Save(rw http.ResponseWriter, req *http.Request, s *sessionsapi.SessionState) error { if f.SaveFunc != nil { return f.SaveFunc(rw, req, s) } return nil } func (f *fakeSessionStore) Load(req *http.Request) (*sessionsapi.SessionState, error) { if f.LoadFunc != nil { return f.LoadFunc(req) } return nil, nil } func (f *fakeSessionStore) Clear(rw http.ResponseWriter, req *http.Request) error { if f.ClearFunc != nil { return f.ClearFunc(rw, req) } return nil }