From 6fb3274ca3649d8e5b263e0742879002608f1455 Mon Sep 17 00:00:00 2001 From: Nick Meves Date: Sat, 2 Jan 2021 13:16:01 -0800 Subject: [PATCH] Refactor organization of scope aware request utils Reorganized the structure of the Request Utils due to their widespread use resulting in circular imports issues (mostly because of middleware & logger). --- pkg/apis/middleware/middleware_suite_test.go | 19 +++ pkg/apis/middleware/scope.go | 25 ++++ pkg/apis/middleware/scope_test.go | 56 ++++++++ pkg/cookies/cookies.go | 8 +- pkg/logger/logger.go | 6 +- pkg/middleware/basic_session.go | 3 +- pkg/middleware/basic_session_test.go | 6 +- pkg/middleware/headers.go | 5 +- pkg/middleware/headers_test.go | 7 +- pkg/middleware/jwt_session.go | 2 +- pkg/middleware/jwt_session_test.go | 5 +- pkg/middleware/scope.go | 27 +--- pkg/middleware/scope_test.go | 89 +++++-------- pkg/middleware/stored_session.go | 3 +- pkg/middleware/stored_session_test.go | 5 +- pkg/requests/util/util.go | 48 +++++++ pkg/requests/util/util_suite_test.go | 19 +++ pkg/requests/util/util_test.go | 131 +++++++++++++++++++ pkg/util/util.go | 37 ------ pkg/util/util_test.go | 41 ------ 20 files changed, 357 insertions(+), 185 deletions(-) create mode 100644 pkg/apis/middleware/middleware_suite_test.go create mode 100644 pkg/apis/middleware/scope_test.go create mode 100644 pkg/requests/util/util.go create mode 100644 pkg/requests/util/util_suite_test.go create mode 100644 pkg/requests/util/util_test.go diff --git a/pkg/apis/middleware/middleware_suite_test.go b/pkg/apis/middleware/middleware_suite_test.go new file mode 100644 index 00000000..f2f48cfd --- /dev/null +++ b/pkg/apis/middleware/middleware_suite_test.go @@ -0,0 +1,19 @@ +package middleware_test + +import ( + "testing" + + "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +// TestMiddlewareSuite and related tests are in a *_test package +// to prevent circular imports with the `logger` package which uses +// this functionality +func TestMiddlewareSuite(t *testing.T) { + logger.SetOutput(GinkgoWriter) + + RegisterFailHandler(Fail) + RunSpecs(t, "Middleware API") +} diff --git a/pkg/apis/middleware/scope.go b/pkg/apis/middleware/scope.go index cb6fe4b8..c54a33d1 100644 --- a/pkg/apis/middleware/scope.go +++ b/pkg/apis/middleware/scope.go @@ -1,9 +1,18 @@ package middleware import ( + "context" + "net/http" + "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" ) +type scopeKey string + +// RequestScopeKey uses a typed string to reduce likelihood of clashing +// with other context keys +const RequestScopeKey scopeKey = "request-scope" + // RequestScope contains information regarding the request that is being made. // The RequestScope is used to pass information between different middlewares // within the chain. @@ -26,3 +35,19 @@ type RequestScope struct { // it was loaded or not. SessionRevalidated bool } + +// GetRequestScope returns the current request scope from the given request +func GetRequestScope(req *http.Request) *RequestScope { + scope := req.Context().Value(RequestScopeKey) + if scope == nil { + return nil + } + + return scope.(*RequestScope) +} + +// AddRequestScope adds a RequestScope to a request +func AddRequestScope(req *http.Request, scope *RequestScope) *http.Request { + ctx := context.WithValue(req.Context(), RequestScopeKey, scope) + return req.WithContext(ctx) +} diff --git a/pkg/apis/middleware/scope_test.go b/pkg/apis/middleware/scope_test.go new file mode 100644 index 00000000..355365bf --- /dev/null +++ b/pkg/apis/middleware/scope_test.go @@ -0,0 +1,56 @@ +package middleware_test + +import ( + "net/http" + + "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Scope Suite", func() { + Context("GetRequestScope", func() { + var request *http.Request + + BeforeEach(func() { + var err error + request, err = http.NewRequest("", "http://127.0.0.1/", nil) + Expect(err).ToNot(HaveOccurred()) + }) + + Context("with a scope", func() { + var scope *middleware.RequestScope + + BeforeEach(func() { + scope = &middleware.RequestScope{} + request = middleware.AddRequestScope(request, scope) + }) + + It("returns the scope", func() { + s := middleware.GetRequestScope(request) + Expect(s).ToNot(BeNil()) + Expect(s).To(Equal(scope)) + }) + + Context("if the scope is then modified", func() { + BeforeEach(func() { + Expect(scope.SaveSession).To(BeFalse()) + scope.SaveSession = true + }) + + It("returns the updated session", func() { + s := middleware.GetRequestScope(request) + Expect(s).ToNot(BeNil()) + Expect(s).To(Equal(scope)) + Expect(s.SaveSession).To(BeTrue()) + }) + }) + }) + + Context("without a scope", func() { + It("returns nil", func() { + Expect(middleware.GetRequestScope(request)).To(BeNil()) + }) + }) + }) +}) diff --git a/pkg/cookies/cookies.go b/pkg/cookies/cookies.go index 9b6dc03d..c590de38 100644 --- a/pkg/cookies/cookies.go +++ b/pkg/cookies/cookies.go @@ -9,14 +9,14 @@ import ( "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" - "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/util" + requestutil "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/requests/util" ) // MakeCookie constructs a cookie from the given parameters, // discovering the domain from the request if not specified. func MakeCookie(req *http.Request, name string, value string, path string, domain string, httpOnly bool, secure bool, expiration time.Duration, now time.Time, sameSite http.SameSite) *http.Cookie { if domain != "" { - host := util.GetRequestHost(req) + host := requestutil.GetRequestHost(req) if h, _, err := net.SplitHostPort(host); err == nil { host = h } @@ -48,7 +48,7 @@ func MakeCookieFromOptions(req *http.Request, name string, value string, cookieO // If nothing matches, create the cookie with the shortest domain defaultDomain := "" if len(cookieOpts.Domains) > 0 { - logger.Errorf("Warning: request host %q did not match any of the specific cookie domains of %q", util.GetRequestHost(req), strings.Join(cookieOpts.Domains, ",")) + logger.Errorf("Warning: request host %q did not match any of the specific cookie domains of %q", requestutil.GetRequestHost(req), strings.Join(cookieOpts.Domains, ",")) defaultDomain = cookieOpts.Domains[len(cookieOpts.Domains)-1] } return MakeCookie(req, name, value, cookieOpts.Path, defaultDomain, cookieOpts.HTTPOnly, cookieOpts.Secure, expiration, now, ParseSameSite(cookieOpts.SameSite)) @@ -57,7 +57,7 @@ func MakeCookieFromOptions(req *http.Request, name string, value string, cookieO // GetCookieDomain returns the correct cookie domain given a list of domains // by checking the X-Fowarded-Host and host header of an an http request func GetCookieDomain(req *http.Request, cookieDomains []string) string { - host := util.GetRequestHost(req) + host := requestutil.GetRequestHost(req) for _, domain := range cookieDomains { if strings.HasSuffix(host, domain) { return domain diff --git a/pkg/logger/logger.go b/pkg/logger/logger.go index 23696765..86ad720e 100644 --- a/pkg/logger/logger.go +++ b/pkg/logger/logger.go @@ -12,7 +12,7 @@ import ( "text/template" "time" - "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/util" + requestutil "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/requests/util" ) // AuthStatus defines the different types of auth logging that occur @@ -197,7 +197,7 @@ func (l *Logger) PrintAuthf(username string, req *http.Request, status AuthStatu err := l.authTemplate.Execute(l.writer, authLogMessageData{ Client: client, - Host: util.GetRequestHost(req), + Host: requestutil.GetRequestHost(req), Protocol: req.Proto, RequestMethod: req.Method, Timestamp: FormatTimestamp(now), @@ -251,7 +251,7 @@ func (l *Logger) PrintReq(username, upstream string, req *http.Request, url url. err := l.reqTemplate.Execute(l.writer, reqLogMessageData{ Client: client, - Host: util.GetRequestHost(req), + Host: requestutil.GetRequestHost(req), Protocol: req.Proto, RequestDuration: fmt.Sprintf("%0.3f", duration), RequestMethod: req.Method, diff --git a/pkg/middleware/basic_session.go b/pkg/middleware/basic_session.go index 5a7b77f9..7de1bf2b 100644 --- a/pkg/middleware/basic_session.go +++ b/pkg/middleware/basic_session.go @@ -5,6 +5,7 @@ import ( "net/http" "github.com/justinas/alice" + middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" sessionsapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/authentication/basic" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" @@ -23,7 +24,7 @@ func NewBasicAuthSessionLoader(validator basic.Validator) alice.Constructor { // If a session was loaded by a previous handler, it will not be replaced. func loadBasicAuthSession(validator basic.Validator, next http.Handler) http.Handler { return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { - scope := GetRequestScope(req) + scope := middlewareapi.GetRequestScope(req) // If scope is nil, this will panic. // A scope should always be injected before this handler is called. if scope.Session != nil { diff --git a/pkg/middleware/basic_session_test.go b/pkg/middleware/basic_session_test.go index 35e4f804..14c49c43 100644 --- a/pkg/middleware/basic_session_test.go +++ b/pkg/middleware/basic_session_test.go @@ -1,7 +1,6 @@ package middleware import ( - "context" "fmt" "net/http" "net/http/httptest" @@ -40,8 +39,7 @@ var _ = Describe("Basic Auth Session Suite", func() { // Set up the request with the authorization header and a request scope req := httptest.NewRequest("", "/", nil) req.Header.Set("Authorization", in.authorizationHeader) - contextWithScope := context.WithValue(req.Context(), requestScopeKey, scope) - req = req.WithContext(contextWithScope) + req = middlewareapi.AddRequestScope(req, scope) rw := httptest.NewRecorder() @@ -57,7 +55,7 @@ var _ = Describe("Basic Auth Session Suite", func() { // from the scope var gotSession *sessionsapi.SessionState handler := NewBasicAuthSessionLoader(validator)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - gotSession = r.Context().Value(requestScopeKey).(*middlewareapi.RequestScope).Session + gotSession = middlewareapi.GetRequestScope(r).Session })) handler.ServeHTTP(rw, req) diff --git a/pkg/middleware/headers.go b/pkg/middleware/headers.go index 6786c2eb..b79b547b 100644 --- a/pkg/middleware/headers.go +++ b/pkg/middleware/headers.go @@ -5,6 +5,7 @@ import ( "net/http" "github.com/justinas/alice" + middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/header" ) @@ -61,7 +62,7 @@ func newRequestHeaderInjector(headers []options.Header) (alice.Constructor, erro func injectRequestHeaders(injector header.Injector, next http.Handler) http.Handler { return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { - scope := GetRequestScope(req) + scope := middlewareapi.GetRequestScope(req) // If scope is nil, this will panic. // A scope should always be injected before this handler is called. @@ -92,7 +93,7 @@ func newResponseHeaderInjector(headers []options.Header) (alice.Constructor, err func injectResponseHeaders(injector header.Injector, next http.Handler) http.Handler { return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { - scope := GetRequestScope(req) + scope := middlewareapi.GetRequestScope(req) // If scope is nil, this will panic. // A scope should always be injected before this handler is called. diff --git a/pkg/middleware/headers_test.go b/pkg/middleware/headers_test.go index 15006b1d..a9c6d73e 100644 --- a/pkg/middleware/headers_test.go +++ b/pkg/middleware/headers_test.go @@ -1,7 +1,6 @@ package middleware import ( - "context" "encoding/base64" "net/http" "net/http/httptest" @@ -31,8 +30,7 @@ var _ = Describe("Headers Suite", func() { // Set up the request with a request scope req := httptest.NewRequest("", "/", nil) - contextWithScope := context.WithValue(req.Context(), requestScopeKey, scope) - req = req.WithContext(contextWithScope) + req = middlewareapi.AddRequestScope(req, scope) req.Header = in.initialHeaders.Clone() rw := httptest.NewRecorder() @@ -218,8 +216,7 @@ var _ = Describe("Headers Suite", func() { // Set up the request with a request scope req := httptest.NewRequest("", "/", nil) - contextWithScope := context.WithValue(req.Context(), requestScopeKey, scope) - req = req.WithContext(contextWithScope) + req = middlewareapi.AddRequestScope(req, scope) rw := httptest.NewRecorder() for key, values := range in.initialHeaders { diff --git a/pkg/middleware/jwt_session.go b/pkg/middleware/jwt_session.go index 0510c72a..78ef5400 100644 --- a/pkg/middleware/jwt_session.go +++ b/pkg/middleware/jwt_session.go @@ -37,7 +37,7 @@ type jwtSessionLoader struct { // If a session was loaded by a previous handler, it will not be replaced. func (j *jwtSessionLoader) loadSession(next http.Handler) http.Handler { return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { - scope := GetRequestScope(req) + scope := middlewareapi.GetRequestScope(req) // If scope is nil, this will panic. // A scope should always be injected before this handler is called. if scope.Session != nil { diff --git a/pkg/middleware/jwt_session_test.go b/pkg/middleware/jwt_session_test.go index cd34c5ad..7786d00a 100644 --- a/pkg/middleware/jwt_session_test.go +++ b/pkg/middleware/jwt_session_test.go @@ -103,8 +103,7 @@ Nnc3a3lGVWFCNUMxQnNJcnJMTWxka1dFaHluYmI4Ongtb2F1dGgtYmFzaWM=` // Set up the request with the authorization header and a request scope req := httptest.NewRequest("", "/", nil) req.Header.Set("Authorization", in.authorizationHeader) - contextWithScope := context.WithValue(req.Context(), requestScopeKey, scope) - req = req.WithContext(contextWithScope) + req = middlewareapi.AddRequestScope(req, scope) rw := httptest.NewRecorder() @@ -116,7 +115,7 @@ Nnc3a3lGVWFCNUMxQnNJcnJMTWxka1dFaHluYmI4Ongtb2F1dGgtYmFzaWM=` // from the scope var gotSession *sessionsapi.SessionState handler := NewJwtSessionLoader(sessionLoaders)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - gotSession = r.Context().Value(requestScopeKey).(*middlewareapi.RequestScope).Session + gotSession = middlewareapi.GetRequestScope(r).Session })) handler.ServeHTTP(rw, req) diff --git a/pkg/middleware/scope.go b/pkg/middleware/scope.go index 6485cc4f..9218faa0 100644 --- a/pkg/middleware/scope.go +++ b/pkg/middleware/scope.go @@ -1,39 +1,20 @@ package middleware import ( - "context" "net/http" "github.com/justinas/alice" middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" - "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options" ) -type scopeKey string - -// requestScopeKey uses a typed string to reduce likelihood of clashing -// with other context keys -const requestScopeKey scopeKey = "request-scope" - -func NewScope(opts *options.Options) alice.Constructor { +func NewScope(reverseProxy bool) alice.Constructor { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { scope := &middlewareapi.RequestScope{ - ReverseProxy: opts.ReverseProxy, + ReverseProxy: reverseProxy, } - contextWithScope := context.WithValue(req.Context(), requestScopeKey, scope) - requestWithScope := req.WithContext(contextWithScope) - next.ServeHTTP(rw, requestWithScope) + req = middlewareapi.AddRequestScope(req, scope) + next.ServeHTTP(rw, req) }) } } - -// GetRequestScope returns the current request scope from the given request -func GetRequestScope(req *http.Request) *middlewareapi.RequestScope { - scope := req.Context().Value(requestScopeKey) - if scope == nil { - return nil - } - - return scope.(*middlewareapi.RequestScope) -} diff --git a/pkg/middleware/scope_test.go b/pkg/middleware/scope_test.go index e9533a8d..3432d148 100644 --- a/pkg/middleware/scope_test.go +++ b/pkg/middleware/scope_test.go @@ -1,7 +1,6 @@ package middleware import ( - "context" "net/http" "net/http/httptest" @@ -21,73 +20,49 @@ var _ = Describe("Scope Suite", func() { Expect(err).ToNot(HaveOccurred()) rw = httptest.NewRecorder() - - handler := NewScope()(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - nextRequest = r - w.WriteHeader(200) - })) - handler.ServeHTTP(rw, request) }) - It("does not add a scope to the original request", func() { - Expect(request.Context().Value(requestScopeKey)).To(BeNil()) - }) - - It("cannot load a scope from the original request using GetRequestScope", func() { - Expect(GetRequestScope(request)).To(BeNil()) - }) - - It("adds a scope to the request for the next handler", func() { - Expect(nextRequest.Context().Value(requestScopeKey)).ToNot(BeNil()) - }) - - It("can load a scope from the next handler's request using GetRequestScope", func() { - Expect(GetRequestScope(nextRequest)).ToNot(BeNil()) - }) - }) - - Context("GetRequestScope", func() { - var request *http.Request - - BeforeEach(func() { - var err error - request, err = http.NewRequest("", "http://127.0.0.1/", nil) - Expect(err).ToNot(HaveOccurred()) - }) - - Context("with a scope", func() { - var scope *middlewareapi.RequestScope - + Context("ReverseProxy is false", func() { BeforeEach(func() { - scope = &middlewareapi.RequestScope{} - contextWithScope := context.WithValue(request.Context(), requestScopeKey, scope) - request = request.WithContext(contextWithScope) + handler := NewScope(false)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + nextRequest = r + w.WriteHeader(200) + })) + handler.ServeHTTP(rw, request) }) - It("returns the scope", func() { - s := GetRequestScope(request) - Expect(s).ToNot(BeNil()) - Expect(s).To(Equal(scope)) + It("does not add a scope to the original request", func() { + Expect(request.Context().Value(middlewareapi.RequestScopeKey)).To(BeNil()) }) - Context("if the scope is then modified", func() { - BeforeEach(func() { - Expect(scope.SaveSession).To(BeFalse()) - scope.SaveSession = true - }) + It("cannot load a scope from the original request using GetRequestScope", func() { + Expect(middlewareapi.GetRequestScope(request)).To(BeNil()) + }) - It("returns the updated session", func() { - s := GetRequestScope(request) - Expect(s).ToNot(BeNil()) - Expect(s).To(Equal(scope)) - Expect(s.SaveSession).To(BeTrue()) - }) + It("adds a scope to the request for the next handler", func() { + Expect(nextRequest.Context().Value(middlewareapi.RequestScopeKey)).ToNot(BeNil()) + }) + + It("can load a scope from the next handler's request using GetRequestScope", func() { + scope := middlewareapi.GetRequestScope(nextRequest) + Expect(scope).ToNot(BeNil()) + Expect(scope.ReverseProxy).To(BeFalse()) }) }) - Context("without a scope", func() { - It("returns nil", func() { - Expect(GetRequestScope(request)).To(BeNil()) + Context("ReverseProxy is true", func() { + BeforeEach(func() { + handler := NewScope(true)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + nextRequest = r + w.WriteHeader(200) + })) + handler.ServeHTTP(rw, request) + }) + + It("return a scope where the ReverseProxy field is true", func() { + scope := middlewareapi.GetRequestScope(nextRequest) + Expect(scope).ToNot(BeNil()) + Expect(scope.ReverseProxy).To(BeTrue()) }) }) }) diff --git a/pkg/middleware/stored_session.go b/pkg/middleware/stored_session.go index 6d86e613..1bd0a9a4 100644 --- a/pkg/middleware/stored_session.go +++ b/pkg/middleware/stored_session.go @@ -8,6 +8,7 @@ import ( "time" "github.com/justinas/alice" + middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" sessionsapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" ) @@ -59,7 +60,7 @@ type storedSessionLoader struct { // If a session was loader by a previous handler, it will not be replaced. func (s *storedSessionLoader) loadSession(next http.Handler) http.Handler { return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { - scope := GetRequestScope(req) + scope := middlewareapi.GetRequestScope(req) // If scope is nil, this will panic. // A scope should always be injected before this handler is called. if scope.Session != nil { diff --git a/pkg/middleware/stored_session_test.go b/pkg/middleware/stored_session_test.go index 89eadc5d..4a8fd9da 100644 --- a/pkg/middleware/stored_session_test.go +++ b/pkg/middleware/stored_session_test.go @@ -104,8 +104,7 @@ var _ = Describe("Stored Session Suite", func() { // Set up the request with the request headesr and a request scope req := httptest.NewRequest("", "/", nil) req.Header = in.requestHeaders - contextWithScope := context.WithValue(req.Context(), requestScopeKey, scope) - req = req.WithContext(contextWithScope) + req = middlewareapi.AddRequestScope(req, scope) rw := httptest.NewRecorder() @@ -120,7 +119,7 @@ var _ = Describe("Stored Session Suite", func() { // from the scope var gotSession *sessionsapi.SessionState handler := NewStoredSessionLoader(opts)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - gotSession = r.Context().Value(requestScopeKey).(*middlewareapi.RequestScope).Session + gotSession = middlewareapi.GetRequestScope(r).Session })) handler.ServeHTTP(rw, req) diff --git a/pkg/requests/util/util.go b/pkg/requests/util/util.go new file mode 100644 index 00000000..08c9c2c1 --- /dev/null +++ b/pkg/requests/util/util.go @@ -0,0 +1,48 @@ +package util + +import ( + "net/http" + + middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" +) + +// GetRequestProto returns the request scheme or X-Forwarded-Proto if present +// and the request is proxied. +func GetRequestProto(req *http.Request) string { + proto := req.Header.Get("X-Forwarded-Proto") + if !IsProxied(req) || proto == "" { + proto = req.URL.Scheme + } + return proto +} + +// GetRequestHost returns the request host header or X-Forwarded-Host if +// present and the request is proxied. +func GetRequestHost(req *http.Request) string { + host := req.Header.Get("X-Forwarded-Host") + if !IsProxied(req) || host == "" { + host = req.Host + } + return host +} + +// GetRequestURI return the request URI or X-Forwarded-Uri if present and the +// request is proxied. +func GetRequestURI(req *http.Request) string { + uri := req.Header.Get("X-Forwarded-Uri") + if !IsProxied(req) || uri == "" { + // Use RequestURI to preserve ?query + uri = req.URL.RequestURI() + } + return uri +} + +// IsProxied determines if a request was from a proxy based on the RequestScope +// ReverseProxy tracker. +func IsProxied(req *http.Request) bool { + scope := middlewareapi.GetRequestScope(req) + if scope == nil { + return false + } + return scope.ReverseProxy +} diff --git a/pkg/requests/util/util_suite_test.go b/pkg/requests/util/util_suite_test.go new file mode 100644 index 00000000..a03f943f --- /dev/null +++ b/pkg/requests/util/util_suite_test.go @@ -0,0 +1,19 @@ +package util_test + +import ( + "testing" + + "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +// TestRequestUtilSuite and related tests are in a *_test package +// to prevent circular imports with the `logger` package which uses +// this functionality +func TestRequestUtilSuite(t *testing.T) { + logger.SetOutput(GinkgoWriter) + + RegisterFailHandler(Fail) + RunSpecs(t, "Request Utils") +} diff --git a/pkg/requests/util/util_test.go b/pkg/requests/util/util_test.go new file mode 100644 index 00000000..595f93f6 --- /dev/null +++ b/pkg/requests/util/util_test.go @@ -0,0 +1,131 @@ +package util_test + +import ( + "fmt" + "net/http" + "net/http/httptest" + + "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" + "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/requests/util" + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Util Suite", func() { + const ( + proto = "http" + host = "www.oauth2proxy.test" + uri = "/test/endpoint" + ) + var req *http.Request + + BeforeEach(func() { + req = httptest.NewRequest( + http.MethodGet, + fmt.Sprintf("%s://%s%s", proto, host, uri), + nil, + ) + }) + + Context("GetRequestHost", func() { + Context("IsProxied is false", func() { + BeforeEach(func() { + req = middleware.AddRequestScope(req, &middleware.RequestScope{}) + }) + + It("returns the host", func() { + Expect(util.GetRequestHost(req)).To(Equal(host)) + }) + + It("ignores X-Forwarded-Host and returns the host", func() { + req.Header.Add("X-Forwarded-Host", "external.oauth2proxy.text") + Expect(util.GetRequestHost(req)).To(Equal(host)) + }) + }) + + Context("IsProxied is true", func() { + BeforeEach(func() { + req = middleware.AddRequestScope(req, &middleware.RequestScope{ + ReverseProxy: true, + }) + }) + + It("returns the host if X-Forwarded-Host is not present", func() { + Expect(util.GetRequestHost(req)).To(Equal(host)) + }) + + It("returns the X-Forwarded-Host when present", func() { + req.Header.Add("X-Forwarded-Host", "external.oauth2proxy.text") + Expect(util.GetRequestHost(req)).To(Equal("external.oauth2proxy.text")) + }) + }) + }) + + Context("GetRequestProto", func() { + Context("IsProxied is false", func() { + BeforeEach(func() { + req = middleware.AddRequestScope(req, &middleware.RequestScope{}) + }) + + It("returns the scheme", func() { + Expect(util.GetRequestProto(req)).To(Equal(proto)) + }) + + It("ignores X-Forwarded-Proto and returns the scheme", func() { + req.Header.Add("X-Forwarded-Proto", "https") + Expect(util.GetRequestProto(req)).To(Equal(proto)) + }) + }) + + Context("IsProxied is true", func() { + BeforeEach(func() { + req = middleware.AddRequestScope(req, &middleware.RequestScope{ + ReverseProxy: true, + }) + }) + + It("returns the scheme if X-Forwarded-Proto is not present", func() { + Expect(util.GetRequestProto(req)).To(Equal(proto)) + }) + + It("returns the X-Forwarded-Proto when present", func() { + req.Header.Add("X-Forwarded-Proto", "https") + Expect(util.GetRequestProto(req)).To(Equal("https")) + }) + }) + }) + + Context("GetRequestURI", func() { + Context("IsProxied is false", func() { + BeforeEach(func() { + req = middleware.AddRequestScope(req, &middleware.RequestScope{}) + }) + + It("returns the URI", func() { + Expect(util.GetRequestURI(req)).To(Equal(uri)) + }) + + It("ignores X-Forwarded-Uri and returns the URI", func() { + req.Header.Add("X-Forwarded-Uri", "/some/other/path") + Expect(util.GetRequestURI(req)).To(Equal(uri)) + }) + }) + + Context("IsProxied is true", func() { + BeforeEach(func() { + req = middleware.AddRequestScope(req, &middleware.RequestScope{ + ReverseProxy: true, + }) + }) + + It("returns the URI if X-Forwarded-Uri is not present", func() { + Expect(util.GetRequestURI(req)).To(Equal(uri)) + }) + + It("returns the X-Forwarded-Uri when present", func() { + req.Header.Add("X-Forwarded-Uri", "/some/other/path") + Expect(util.GetRequestURI(req)).To(Equal("/some/other/path")) + }) + }) + }) +}) diff --git a/pkg/util/util.go b/pkg/util/util.go index 452e14f1..4519fdb8 100644 --- a/pkg/util/util.go +++ b/pkg/util/util.go @@ -4,9 +4,6 @@ import ( "crypto/x509" "fmt" "io/ioutil" - "net/http" - - "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/middleware" ) func GetCertPool(paths []string) (*x509.CertPool, error) { @@ -26,37 +23,3 @@ func GetCertPool(paths []string) (*x509.CertPool, error) { } return pool, nil } - -// GetRequestProto return the request host header or X-Forwarded-Proto if present -func GetRequestProto(req *http.Request) string { - proto := req.Header.Get("X-Forwarded-Proto") - if !isProxied(req) || proto == "" { - proto = req.URL.Scheme - } - return proto -} - -// GetRequestHost return the request host header or X-Forwarded-Host if present -// and reverse proxy mode is enabled. -func GetRequestHost(req *http.Request) string { - host := req.Header.Get("X-Forwarded-Host") - if !isProxied(req) || host == "" { - host = req.Host - } - return host -} - -// GetRequestURI return the request host header or X-Forwarded-Uri if present -func GetRequestURI(req *http.Request) string { - uri := req.Header.Get("X-Forwarded-Uri") - if !isProxied(req) || uri == "" { - // Use RequestURI to preserve ?query - uri = req.URL.RequestURI() - } - return uri -} - -func isProxied(req *http.Request) bool { - scope := middleware.GetRequestScope(req) - return scope.ReverseProxy -} diff --git a/pkg/util/util_test.go b/pkg/util/util_test.go index d032025e..347f41bb 100644 --- a/pkg/util/util_test.go +++ b/pkg/util/util_test.go @@ -4,11 +4,9 @@ import ( "crypto/x509/pkix" "encoding/asn1" "io/ioutil" - "net/http/httptest" "os" "testing" - . "github.com/onsi/gomega" "github.com/stretchr/testify/assert" ) @@ -97,42 +95,3 @@ func TestGetCertPool(t *testing.T) { expectedSubjects := []string{testCA1Subj, testCA2Subj} assert.Equal(t, expectedSubjects, got) } - -func TestGetRequestHost(t *testing.T) { - g := NewWithT(t) - - req := httptest.NewRequest("GET", "https://example.com", nil) - host := GetRequestHost(req) - g.Expect(host).To(Equal("example.com")) - - proxyReq := httptest.NewRequest("GET", "http://internal.example.com", nil) - proxyReq.Header.Add("X-Forwarded-Host", "external.example.com") - extHost := GetRequestHost(proxyReq) - g.Expect(extHost).To(Equal("external.example.com")) -} - -func TestGetRequestProto(t *testing.T) { - g := NewWithT(t) - - req := httptest.NewRequest("GET", "https://example.com", nil) - proto := GetRequestProto(req) - g.Expect(proto).To(Equal("https")) - - proxyReq := httptest.NewRequest("GET", "https://internal.example.com", nil) - proxyReq.Header.Add("X-Forwarded-Proto", "http") - extProto := GetRequestProto(proxyReq) - g.Expect(extProto).To(Equal("http")) -} - -func TestGetRequestURI(t *testing.T) { - g := NewWithT(t) - - req := httptest.NewRequest("GET", "https://example.com/ping", nil) - uri := GetRequestURI(req) - g.Expect(uri).To(Equal("/ping")) - - proxyReq := httptest.NewRequest("GET", "http://internal.example.com/bong", nil) - proxyReq.Header.Add("X-Forwarded-Uri", "/ping") - extURI := GetRequestURI(proxyReq) - g.Expect(extURI).To(Equal("/ping")) -}