diff --git a/pkg/apis/middleware/scope.go b/pkg/apis/middleware/scope.go new file mode 100644 index 00000000..c8153d1a --- /dev/null +++ b/pkg/apis/middleware/scope.go @@ -0,0 +1,24 @@ +package middleware + +import ( + "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions" +) + +// RequestScope contains information regarding the request that is being made. +// The RequestScope is used to pass information between different middlewares +// within the chain. +type RequestScope struct { + // Session details the authenticated users information (if it exists). + Session *sessions.SessionState + + // SaveSession indicates whether the session storage should attempt to save + // the session or not. + SaveSession bool + + // ClearSession indicates whether the user should be logged out or not. + ClearSession bool + + // SessionRevalidated indicates whether the session has been revalidated since + // it was loaded or not. + SessionRevalidated bool +} diff --git a/pkg/middleware/scope.go b/pkg/middleware/scope.go new file mode 100644 index 00000000..d5925ad4 --- /dev/null +++ b/pkg/middleware/scope.go @@ -0,0 +1,39 @@ +package middleware + +import ( + "context" + "net/http" + + "github.com/justinas/alice" + middlewareapi "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/middleware" +) + +type scopeKey string + +// requestScopeKey uses a typed string to reduce likelihood of clasing +// with other context keys +const requestScopeKey scopeKey = "request-scope" + +func NewScope() alice.Constructor { + return addScope +} + +// addScope injects a new request scope into the request context. +func addScope(next http.Handler) http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + scope := &middlewareapi.RequestScope{} + contextWithScope := context.WithValue(req.Context(), requestScopeKey, scope) + requestWithScope := req.WithContext(contextWithScope) + next.ServeHTTP(rw, requestWithScope) + }) +} + +// GetRequestScope returns the current request scope from the given request +func GetRequestScope(req *http.Request) *middlewareapi.RequestScope { + scope := req.Context().Value(requestScopeKey) + if scope == nil { + return nil + } + + return scope.(*middlewareapi.RequestScope) +} diff --git a/pkg/middleware/scope_test.go b/pkg/middleware/scope_test.go new file mode 100644 index 00000000..5a998bb0 --- /dev/null +++ b/pkg/middleware/scope_test.go @@ -0,0 +1,94 @@ +package middleware + +import ( + "context" + "net/http" + "net/http/httptest" + + middlewareapi "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/middleware" + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Scope Suite", func() { + Context("NewScope", func() { + var request, nextRequest *http.Request + var rw http.ResponseWriter + + BeforeEach(func() { + var err error + request, err = http.NewRequest("", "http://127.0.0.1/", nil) + Expect(err).ToNot(HaveOccurred()) + + rw = httptest.NewRecorder() + + handler := NewScope()(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + nextRequest = r + w.WriteHeader(200) + })) + handler.ServeHTTP(rw, request) + }) + + It("does not add a scope to the original request", func() { + Expect(request.Context().Value(requestScopeKey)).To(BeNil()) + }) + + It("cannot load a scope from the original request using GetRequestScope", func() { + Expect(GetRequestScope(request)).To(BeNil()) + }) + + It("adds a scope to the request for the next handler", func() { + Expect(nextRequest.Context().Value(requestScopeKey)).ToNot(BeNil()) + }) + + It("can load a scope from the next handler's request using GetRequestScope", func() { + Expect(GetRequestScope(nextRequest)).ToNot(BeNil()) + }) + }) + + Context("GetRequestScope", func() { + var request *http.Request + + BeforeEach(func() { + var err error + request, err = http.NewRequest("", "http://127.0.0.1/", nil) + Expect(err).ToNot(HaveOccurred()) + }) + + Context("with a scope", func() { + var scope *middlewareapi.RequestScope + + BeforeEach(func() { + scope = &middlewareapi.RequestScope{} + contextWithScope := context.WithValue(request.Context(), requestScopeKey, scope) + request = request.WithContext(contextWithScope) + }) + + It("returns the scope", func() { + s := GetRequestScope(request) + Expect(s).ToNot(BeNil()) + Expect(s).To(Equal(scope)) + }) + + Context("if the scope is then modified", func() { + BeforeEach(func() { + Expect(scope.SaveSession).To(BeFalse()) + scope.SaveSession = true + }) + + It("returns the updated session", func() { + s := GetRequestScope(request) + Expect(s).ToNot(BeNil()) + Expect(s).To(Equal(scope)) + Expect(s.SaveSession).To(BeTrue()) + }) + }) + }) + + Context("without a scope", func() { + It("returns nil", func() { + Expect(GetRequestScope(request)).To(BeNil()) + }) + }) + }) +})