From 602dac78526336068c03f20b6037daa8f0e3003d Mon Sep 17 00:00:00 2001 From: Nick Meves Date: Sat, 6 Mar 2021 09:27:16 -0800 Subject: [PATCH] Move Logging to Middleware Package (#1070) * Use a specialized ResponseWriter in middleware * Track User & Upstream in RequestScope * Wrap responses in our custom ResponseWriter * Add tests for logging middleware * Inject upstream metadata into request scope * Use custom ResponseWriter only in logging middleware * Assume RequestScope is never nil --- CHANGELOG.md | 1 + logging_handler.go | 108 --------------------- logging_handler_test.go | 116 ----------------------- oauthproxy.go | 10 +- pkg/apis/middleware/scope.go | 3 + pkg/middleware/middleware_suite_test.go | 12 +++ pkg/middleware/request_logger.go | 110 +++++++++++++++++++++ pkg/middleware/request_logger_test.go | 121 ++++++++++++++++++++++++ pkg/upstream/file.go | 8 +- pkg/upstream/file_test.go | 7 +- pkg/upstream/http.go | 8 +- pkg/upstream/http_test.go | 36 ++++--- pkg/upstream/proxy_test.go | 41 ++++---- pkg/upstream/static.go | 14 ++- pkg/upstream/static_test.go | 7 +- pkg/upstream/upstream_suite_test.go | 1 - 16 files changed, 337 insertions(+), 266 deletions(-) delete mode 100644 logging_handler.go delete mode 100644 logging_handler_test.go create mode 100644 pkg/middleware/request_logger.go create mode 100644 pkg/middleware/request_logger_test.go diff --git a/CHANGELOG.md b/CHANGELOG.md index 349127dd..7ba0dd4e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,7 @@ ## Changes since v7.0.1 +- [#1070](https://github.com/oauth2-proxy/oauth2-proxy/pull/1070) Refactor logging middleware to middleware package (@NickMeves) - [#1064](https://github.com/oauth2-proxy/oauth2-proxy/pull/1064) Add support for setting groups on session when using basic auth (@stefansedich) - [#1056](https://github.com/oauth2-proxy/oauth2-proxy/pull/1056) Add option for custom logos on the sign in page (@JoelSpeed) - [#1054](https://github.com/oauth2-proxy/oauth2-proxy/pull/1054) Update to Go 1.16 (@JoelSpeed) diff --git a/logging_handler.go b/logging_handler.go deleted file mode 100644 index 6da38c0a..00000000 --- a/logging_handler.go +++ /dev/null @@ -1,108 +0,0 @@ -// largely adapted from https://github.com/gorilla/handlers/blob/master/handlers.go -// to add logging of request duration as last value (and drop referrer) - -package main - -import ( - "bufio" - "errors" - "net" - "net/http" - "time" - - "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" -) - -// responseLogger is wrapper of http.ResponseWriter that keeps track of its HTTP status -// code and body size -type responseLogger struct { - w http.ResponseWriter - status int - size int - upstream string - authInfo string -} - -// Header returns the ResponseWriter's Header -func (l *responseLogger) Header() http.Header { - return l.w.Header() -} - -// Support Websocket -func (l *responseLogger) Hijack() (rwc net.Conn, buf *bufio.ReadWriter, err error) { - if hj, ok := l.w.(http.Hijacker); ok { - return hj.Hijack() - } - return nil, nil, errors.New("http.Hijacker is not available on writer") -} - -// ExtractGAPMetadata extracts and removes GAP headers from the ResponseWriter's -// Header -func (l *responseLogger) ExtractGAPMetadata() { - upstream := l.w.Header().Get("GAP-Upstream-Address") - if upstream != "" { - l.upstream = upstream - l.w.Header().Del("GAP-Upstream-Address") - } - authInfo := l.w.Header().Get("GAP-Auth") - if authInfo != "" { - l.authInfo = authInfo - l.w.Header().Del("GAP-Auth") - } -} - -// Write writes the response using the ResponseWriter -func (l *responseLogger) Write(b []byte) (int, error) { - if l.status == 0 { - // The status will be StatusOK if WriteHeader has not been called yet - l.status = http.StatusOK - } - l.ExtractGAPMetadata() - size, err := l.w.Write(b) - l.size += size - return size, err -} - -// WriteHeader writes the status code for the Response -func (l *responseLogger) WriteHeader(s int) { - l.ExtractGAPMetadata() - l.w.WriteHeader(s) - l.status = s -} - -// Status returns the response status code -func (l *responseLogger) Status() int { - return l.status -} - -// Size returns the response size -func (l *responseLogger) Size() int { - return l.size -} - -// Flush sends any buffered data to the client -func (l *responseLogger) Flush() { - if flusher, ok := l.w.(http.Flusher); ok { - flusher.Flush() - } -} - -// loggingHandler is the http.Handler implementation for LoggingHandler -type loggingHandler struct { - handler http.Handler -} - -// LoggingHandler provides an http.Handler which logs requests to the HTTP server -func LoggingHandler(h http.Handler) http.Handler { - return loggingHandler{ - handler: h, - } -} - -func (h loggingHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { - t := time.Now() - url := *req.URL - responseLogger := &responseLogger{w: w} - h.handler.ServeHTTP(responseLogger, req) - logger.PrintReq(responseLogger.authInfo, responseLogger.upstream, req, url, t, responseLogger.Status(), responseLogger.Size()) -} diff --git a/logging_handler_test.go b/logging_handler_test.go deleted file mode 100644 index 1938c54b..00000000 --- a/logging_handler_test.go +++ /dev/null @@ -1,116 +0,0 @@ -package main - -import ( - "bytes" - "net/http" - "net/http/httptest" - "testing" - - "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" - "github.com/stretchr/testify/assert" -) - -const RequestLoggingFormatWithoutTime = "{{.Client}} - {{.Username}} [TIMELESS] {{.Host}} {{.RequestMethod}} {{.Upstream}} {{.RequestURI}} {{.Protocol}} {{.UserAgent}} {{.StatusCode}} {{.ResponseSize}} {{.RequestDuration}}" - -func TestLoggingHandler_ServeHTTP(t *testing.T) { - tests := []struct { - Format string - ExpectedLogMessage string - Path string - ExcludePaths []string - }{ - { - Format: RequestLoggingFormatWithoutTime, - ExpectedLogMessage: "127.0.0.1 - - [TIMELESS] test-server GET - \"/foo/bar\" HTTP/1.1 \"\" 200 4 0.000\n", - Path: "/foo/bar", - ExcludePaths: []string{}, - }, - { - Format: RequestLoggingFormatWithoutTime, - ExpectedLogMessage: "127.0.0.1 - - [TIMELESS] test-server GET - \"/foo/bar\" HTTP/1.1 \"\" 200 4 0.000\n", - Path: "/foo/bar", - ExcludePaths: []string{}, - }, - { - Format: RequestLoggingFormatWithoutTime, - ExpectedLogMessage: "127.0.0.1 - - [TIMELESS] test-server GET - \"/foo/bar\" HTTP/1.1 \"\" 200 4 0.000\n", - Path: "/foo/bar", - ExcludePaths: []string{"/ping"}, - }, - { - Format: RequestLoggingFormatWithoutTime, - ExpectedLogMessage: "", - Path: "/foo/bar", - ExcludePaths: []string{"/foo/bar"}, - }, - { - Format: RequestLoggingFormatWithoutTime, - ExpectedLogMessage: "127.0.0.1 - - [TIMELESS] test-server GET - \"/ping\" HTTP/1.1 \"\" 200 4 0.000\n", - Path: "/ping", - ExcludePaths: []string{}, - }, - { - Format: RequestLoggingFormatWithoutTime, - ExpectedLogMessage: "", - Path: "/ping", - ExcludePaths: []string{"/ping"}, - }, - { - Format: RequestLoggingFormatWithoutTime, - ExpectedLogMessage: "", - Path: "/ping", - ExcludePaths: []string{"/foo/bar", "/ping"}, - }, - { - Format: "{{.RequestMethod}}", - ExpectedLogMessage: "GET\n", - Path: "/foo/bar", - ExcludePaths: []string{""}, - }, - { - Format: "{{.RequestMethod}}", - ExpectedLogMessage: "GET\n", - Path: "/foo/bar", - ExcludePaths: []string{"/ping"}, - }, - { - Format: "{{.RequestMethod}}", - ExpectedLogMessage: "GET\n", - Path: "/ping", - ExcludePaths: []string{""}, - }, - { - Format: "{{.RequestMethod}}", - ExpectedLogMessage: "", - Path: "/ping", - ExcludePaths: []string{"/ping"}, - }, - } - - for _, test := range tests { - buf := bytes.NewBuffer(nil) - handler := func(w http.ResponseWriter, req *http.Request) { - _, ok := w.(http.Hijacker) - if !ok { - t.Error("http.Hijacker is not available") - } - - _, err := w.Write([]byte("test")) - assert.NoError(t, err) - } - - logger.SetOutput(buf) - logger.SetReqTemplate(test.Format) - logger.SetExcludePaths(test.ExcludePaths) - h := LoggingHandler(http.HandlerFunc(handler)) - - r, _ := http.NewRequest("GET", test.Path, nil) - r.RemoteAddr = "127.0.0.1" - r.Host = "test-server" - - h.ServeHTTP(httptest.NewRecorder(), r) - - actual := buf.String() - assert.Equal(t, test.ExpectedLogMessage, actual) - } -} diff --git a/oauthproxy.go b/oauthproxy.go index 8041b4f7..43c64525 100644 --- a/oauthproxy.go +++ b/oauthproxy.go @@ -250,9 +250,15 @@ func buildPreAuthChain(opts *options.Options) (alice.Chain, error) { // To silence logging of health checks, register the health check handler before // the logging handler if opts.Logging.SilencePing { - chain = chain.Append(middleware.NewHealthCheck(healthCheckPaths, healthCheckUserAgents), LoggingHandler) + chain = chain.Append( + middleware.NewHealthCheck(healthCheckPaths, healthCheckUserAgents), + middleware.NewRequestLogger(), + ) } else { - chain = chain.Append(LoggingHandler, middleware.NewHealthCheck(healthCheckPaths, healthCheckUserAgents)) + chain = chain.Append( + middleware.NewRequestLogger(), + middleware.NewHealthCheck(healthCheckPaths, healthCheckUserAgents), + ) } chain = chain.Append(middleware.NewRequestMetricsWithDefaultRegistry()) diff --git a/pkg/apis/middleware/scope.go b/pkg/apis/middleware/scope.go index c54a33d1..b3693cc6 100644 --- a/pkg/apis/middleware/scope.go +++ b/pkg/apis/middleware/scope.go @@ -34,6 +34,9 @@ type RequestScope struct { // SessionRevalidated indicates whether the session has been revalidated since // it was loaded or not. SessionRevalidated bool + + // Upstream tracks which upstream was used for this request + Upstream string } // GetRequestScope returns the current request scope from the given request diff --git a/pkg/middleware/middleware_suite_test.go b/pkg/middleware/middleware_suite_test.go index 494c6090..ff9c1ef6 100644 --- a/pkg/middleware/middleware_suite_test.go +++ b/pkg/middleware/middleware_suite_test.go @@ -4,6 +4,7 @@ import ( "net/http" "testing" + middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" @@ -19,6 +20,17 @@ func TestMiddlewareSuite(t *testing.T) { func testHandler() http.Handler { return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + rw.WriteHeader(200) + rw.Write([]byte("test")) + }) +} + +func testUpstreamHandler(upstream string) http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + scope := middlewareapi.GetRequestScope(req) + scope.Upstream = upstream + + rw.WriteHeader(200) rw.Write([]byte("test")) }) } diff --git a/pkg/middleware/request_logger.go b/pkg/middleware/request_logger.go new file mode 100644 index 00000000..e6ed9f21 --- /dev/null +++ b/pkg/middleware/request_logger.go @@ -0,0 +1,110 @@ +package middleware + +import ( + "bufio" + "errors" + "net" + "net/http" + "time" + + "github.com/justinas/alice" + middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" + "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" +) + +// NewRequestLogger returns middleware which logs requests +// It uses a custom ResponseWriter to track status code & response size details +func NewRequestLogger() alice.Constructor { + return requestLogger +} + +func requestLogger(next http.Handler) http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + startTime := time.Now() + url := *req.URL + + responseLogger := &loggingResponse{ResponseWriter: rw} + next.ServeHTTP(responseLogger, req) + + scope := middlewareapi.GetRequestScope(req) + // If scope is nil, this will panic. + // A scope should always be injected before this handler is called. + logger.PrintReq( + getUser(scope), + scope.Upstream, + req, + url, + startTime, + responseLogger.Status(), + responseLogger.Size(), + ) + }) +} + +func getUser(scope *middlewareapi.RequestScope) string { + session := scope.Session + if session != nil { + if session.Email != "" { + return session.Email + } + return session.User + } + return "" +} + +// loggingResponse is a custom http.ResponseWriter that allows tracking certain +// details for request logging. +type loggingResponse struct { + http.ResponseWriter + + status int + size int +} + +// Write writes the response using the ResponseWriter +func (r *loggingResponse) Write(b []byte) (int, error) { + if r.status == 0 { + // The status will be StatusOK if WriteHeader has not been called yet + r.status = http.StatusOK + } + size, err := r.ResponseWriter.Write(b) + r.size += size + return size, err +} + +// WriteHeader writes the status code for the Response +func (r *loggingResponse) WriteHeader(s int) { + r.ResponseWriter.WriteHeader(s) + r.status = s +} + +// Hijack implements the `http.Hijacker` interface that actual ResponseWriters +// implement to support websockets +func (r *loggingResponse) Hijack() (net.Conn, *bufio.ReadWriter, error) { + if hj, ok := r.ResponseWriter.(http.Hijacker); ok { + return hj.Hijack() + } + return nil, nil, errors.New("http.Hijacker is not available on writer") +} + +// Flush sends any buffered data to the client. Implements the `http.Flusher` +// interface +func (r *loggingResponse) Flush() { + if flusher, ok := r.ResponseWriter.(http.Flusher); ok { + if r.status == 0 { + // The status will be StatusOK if WriteHeader has not been called yet + r.status = http.StatusOK + } + flusher.Flush() + } +} + +// Status returns the response status code +func (r *loggingResponse) Status() int { + return r.status +} + +// Size returns the response size +func (r *loggingResponse) Size() int { + return r.size +} diff --git a/pkg/middleware/request_logger_test.go b/pkg/middleware/request_logger_test.go new file mode 100644 index 00000000..0a1e35ff --- /dev/null +++ b/pkg/middleware/request_logger_test.go @@ -0,0 +1,121 @@ +package middleware + +import ( + "bytes" + "net/http" + "net/http/httptest" + + middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" + "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" + "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" + . "github.com/onsi/ginkgo" + . "github.com/onsi/ginkgo/extensions/table" + . "github.com/onsi/gomega" +) + +const RequestLoggingFormatWithoutTime = "{{.Client}} - {{.Username}} [TIMELESS] {{.Host}} {{.RequestMethod}} {{.Upstream}} {{.RequestURI}} {{.Protocol}} {{.UserAgent}} {{.StatusCode}} {{.ResponseSize}} {{.RequestDuration}}" + +var _ = Describe("Request logger suite", func() { + type requestLoggerTableInput struct { + Format string + ExpectedLogMessage string + Path string + ExcludePaths []string + Upstream string + Session *sessions.SessionState + } + + DescribeTable("when service a request", + func(in *requestLoggerTableInput) { + buf := bytes.NewBuffer(nil) + logger.SetOutput(buf) + logger.SetReqTemplate(in.Format) + logger.SetExcludePaths(in.ExcludePaths) + + req, err := http.NewRequest("GET", in.Path, nil) + Expect(err).ToNot(HaveOccurred()) + req.RemoteAddr = "127.0.0.1" + req.Host = "test-server" + + scope := &middlewareapi.RequestScope{Session: in.Session} + req = middlewareapi.AddRequestScope(req, scope) + + handler := NewRequestLogger()(testUpstreamHandler(in.Upstream)) + handler.ServeHTTP(httptest.NewRecorder(), req) + + Expect(buf.String()).To(Equal(in.ExpectedLogMessage)) + }, + Entry("standard request", &requestLoggerTableInput{ + Format: RequestLoggingFormatWithoutTime, + ExpectedLogMessage: "127.0.0.1 - standard.user [TIMELESS] test-server GET standard \"/foo/bar\" HTTP/1.1 \"\" 200 4 0.000\n", + Path: "/foo/bar", + ExcludePaths: []string{}, + Upstream: "standard", + Session: &sessions.SessionState{User: "standard.user"}, + }), + Entry("with unrelated path excluded", &requestLoggerTableInput{ + Format: RequestLoggingFormatWithoutTime, + ExpectedLogMessage: "127.0.0.1 - unrelated.exclusion [TIMELESS] test-server GET unrelated \"/foo/bar\" HTTP/1.1 \"\" 200 4 0.000\n", + Path: "/foo/bar", + ExcludePaths: []string{"/ping"}, + Upstream: "unrelated", + Session: &sessions.SessionState{User: "unrelated.exclusion"}, + }), + Entry("with path as the sole exclusion", &requestLoggerTableInput{ + Format: RequestLoggingFormatWithoutTime, + ExpectedLogMessage: "", + Path: "/foo/bar", + ExcludePaths: []string{"/foo/bar"}, + }), + Entry("ping path", &requestLoggerTableInput{ + Format: RequestLoggingFormatWithoutTime, + ExpectedLogMessage: "127.0.0.1 - mr.ping [TIMELESS] test-server GET - \"/ping\" HTTP/1.1 \"\" 200 4 0.000\n", + Path: "/ping", + ExcludePaths: []string{}, + Upstream: "", + Session: &sessions.SessionState{User: "mr.ping"}, + }), + Entry("ping path but excluded", &requestLoggerTableInput{ + Format: RequestLoggingFormatWithoutTime, + ExpectedLogMessage: "", + Path: "/ping", + ExcludePaths: []string{"/ping"}, + Upstream: "", + Session: &sessions.SessionState{User: "mr.ping"}, + }), + Entry("ping path and excluded in list", &requestLoggerTableInput{ + Format: RequestLoggingFormatWithoutTime, + ExpectedLogMessage: "", + Path: "/ping", + ExcludePaths: []string{"/foo/bar", "/ping"}, + }), + Entry("custom format", &requestLoggerTableInput{ + Format: "{{.RequestMethod}} {{.Username}} {{.Upstream}}", + ExpectedLogMessage: "GET custom.format custom\n", + Path: "/foo/bar", + ExcludePaths: []string{""}, + Upstream: "custom", + Session: &sessions.SessionState{User: "custom.format"}, + }), + Entry("custom format with unrelated exclusion", &requestLoggerTableInput{ + Format: "{{.RequestMethod}} {{.Username}} {{.Upstream}}", + ExpectedLogMessage: "GET custom.format custom\n", + Path: "/foo/bar", + ExcludePaths: []string{"/ping"}, + Upstream: "custom", + Session: &sessions.SessionState{User: "custom.format"}, + }), + Entry("custom format ping path", &requestLoggerTableInput{ + Format: "{{.RequestMethod}}", + ExpectedLogMessage: "GET\n", + Path: "/ping", + ExcludePaths: []string{""}, + }), + Entry("custom format ping path excluded", &requestLoggerTableInput{ + Format: "{{.RequestMethod}}", + ExpectedLogMessage: "", + Path: "/ping", + ExcludePaths: []string{"/ping"}, + }), + ) +}) diff --git a/pkg/upstream/file.go b/pkg/upstream/file.go index 7f67edb0..26b1c0b9 100644 --- a/pkg/upstream/file.go +++ b/pkg/upstream/file.go @@ -4,6 +4,8 @@ import ( "net/http" "runtime" "strings" + + "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" ) const fileScheme = "file" @@ -37,6 +39,10 @@ type fileServer struct { // ServeHTTP proxies requests to the upstream provider while signing the // request headers func (u *fileServer) ServeHTTP(rw http.ResponseWriter, req *http.Request) { - rw.Header().Set("GAP-Upstream-Address", u.upstream) + scope := middleware.GetRequestScope(req) + // If scope is nil, this will panic. + // A scope should always be injected before this handler is called. + scope.Upstream = u.upstream + u.handler.ServeHTTP(rw, req) } diff --git a/pkg/upstream/file_test.go b/pkg/upstream/file_test.go index 2da1f078..feff3261 100644 --- a/pkg/upstream/file_test.go +++ b/pkg/upstream/file_test.go @@ -7,6 +7,7 @@ import ( "net/http/httptest" "os" + middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" . "github.com/onsi/ginkgo" . "github.com/onsi/ginkgo/extensions/table" . "github.com/onsi/gomega" @@ -42,10 +43,14 @@ var _ = Describe("File Server Suite", func() { DescribeTable("fileServer ServeHTTP", func(requestPath string, expectedResponseCode int, expectedBody string) { req := httptest.NewRequest("", requestPath, nil) + req = middlewareapi.AddRequestScope(req, &middlewareapi.RequestScope{}) + rw := httptest.NewRecorder() handler.ServeHTTP(rw, req) - Expect(rw.Header().Get("GAP-Upstream-Address")).To(Equal(id)) + scope := middlewareapi.GetRequestScope(req) + Expect(scope.Upstream).To(Equal(id)) + Expect(rw.Code).To(Equal(expectedResponseCode)) Expect(rw.Body.String()).To(Equal(expectedBody)) }, diff --git a/pkg/upstream/http.go b/pkg/upstream/http.go index a6e948c3..718bcba6 100644 --- a/pkg/upstream/http.go +++ b/pkg/upstream/http.go @@ -8,6 +8,7 @@ import ( "strings" "github.com/mbland/hmacauth" + "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options" "github.com/yhat/wsutil" ) @@ -76,7 +77,12 @@ type httpUpstreamProxy struct { // ServeHTTP proxies requests to the upstream provider while signing the // request headers func (h *httpUpstreamProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { - rw.Header().Set("GAP-Upstream-Address", h.upstream) + scope := middleware.GetRequestScope(req) + // If scope is nil, this will panic. + // A scope should always be injected before this handler is called. + scope.Upstream = h.upstream + + // TODO (@NickMeves) - Deprecate GAP-Signature & remove GAP-Auth if h.auth != nil { req.Header.Set("GAP-Auth", rw.Header().Get("GAP-Auth")) h.auth.SignRequest(req) diff --git a/pkg/upstream/http_test.go b/pkg/upstream/http_test.go index 3ce5bd19..c406e1c7 100644 --- a/pkg/upstream/http_test.go +++ b/pkg/upstream/http_test.go @@ -13,7 +13,9 @@ import ( "strings" "time" + middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options" + "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/middleware" . "github.com/onsi/ginkgo" . "github.com/onsi/ginkgo/extensions/table" . "github.com/onsi/gomega" @@ -36,6 +38,7 @@ var _ = Describe("HTTP Upstream Suite", func() { signatureData *options.SignatureData existingHeaders map[string]string expectedResponse testHTTPResponse + expectedUpstream string errorHandler ProxyErrorHandler } @@ -50,6 +53,7 @@ var _ = Describe("HTTP Upstream Suite", func() { req.Header.Add(key, value) } + req = middlewareapi.AddRequestScope(req, &middlewareapi.RequestScope{}) rw := httptest.NewRecorder() flush := options.Duration(1 * time.Second) @@ -71,6 +75,9 @@ var _ = Describe("HTTP Upstream Suite", func() { Expect(rw.Code).To(Equal(in.expectedResponse.code)) + scope := middlewareapi.GetRequestScope(req) + Expect(scope.Upstream).To(Equal(in.expectedUpstream)) + // Delete extra headers that aren't relevant to tests testSanitizeResponseHeader(rw.Header()) Expect(rw.Header()).To(Equal(in.expectedResponse.header)) @@ -97,7 +104,6 @@ var _ = Describe("HTTP Upstream Suite", func() { expectedResponse: testHTTPResponse{ code: 200, header: map[string][]string{ - gapUpstream: {"default"}, contentType: {applicationJSON}, }, request: testHTTPRequest{ @@ -109,6 +115,7 @@ var _ = Describe("HTTP Upstream Suite", func() { RequestURI: "http://example.localhost/foo", }, }, + expectedUpstream: "default", }), Entry("request a path with encoded slashes", &httpUpstreamTableInput{ id: "encodedSlashes", @@ -120,7 +127,6 @@ var _ = Describe("HTTP Upstream Suite", func() { expectedResponse: testHTTPResponse{ code: 200, header: map[string][]string{ - gapUpstream: {"encodedSlashes"}, contentType: {applicationJSON}, }, request: testHTTPRequest{ @@ -132,6 +138,7 @@ var _ = Describe("HTTP Upstream Suite", func() { RequestURI: "http://example.localhost/foo%2fbar/?baz=1", }, }, + expectedUpstream: "encodedSlashes", }), Entry("when the request has a body", &httpUpstreamTableInput{ id: "requestWithBody", @@ -143,7 +150,6 @@ var _ = Describe("HTTP Upstream Suite", func() { expectedResponse: testHTTPResponse{ code: 200, header: map[string][]string{ - gapUpstream: {"requestWithBody"}, contentType: {applicationJSON}, }, request: testHTTPRequest{ @@ -157,6 +163,7 @@ var _ = Describe("HTTP Upstream Suite", func() { RequestURI: "http://example.localhost/withBody", }, }, + expectedUpstream: "requestWithBody", }), Entry("when the upstream is unavailable", &httpUpstreamTableInput{ id: "unavailableUpstream", @@ -166,12 +173,11 @@ var _ = Describe("HTTP Upstream Suite", func() { body: []byte{}, errorHandler: nil, expectedResponse: testHTTPResponse{ - code: 502, - header: map[string][]string{ - gapUpstream: {"unavailableUpstream"}, - }, + code: 502, + header: map[string][]string{}, request: testHTTPRequest{}, }, + expectedUpstream: "unavailableUpstream", }), Entry("when the upstream is unavailable and an error handler is set", &httpUpstreamTableInput{ id: "withErrorHandler", @@ -184,13 +190,12 @@ var _ = Describe("HTTP Upstream Suite", func() { rw.Write([]byte("error")) }, expectedResponse: testHTTPResponse{ - code: 502, - header: map[string][]string{ - gapUpstream: {"withErrorHandler"}, - }, + code: 502, + header: map[string][]string{}, raw: "error", request: testHTTPRequest{}, }, + expectedUpstream: "withErrorHandler", }), Entry("with a signature", &httpUpstreamTableInput{ id: "withSignature", @@ -207,7 +212,6 @@ var _ = Describe("HTTP Upstream Suite", func() { code: 200, header: map[string][]string{ contentType: {applicationJSON}, - gapUpstream: {"withSignature"}, }, request: testHTTPRequest{ Method: "GET", @@ -221,6 +225,7 @@ var _ = Describe("HTTP Upstream Suite", func() { RequestURI: "http://example.localhost/withSignature", }, }, + expectedUpstream: "withSignature", }), Entry("with existing headers", &httpUpstreamTableInput{ id: "existingHeaders", @@ -236,7 +241,6 @@ var _ = Describe("HTTP Upstream Suite", func() { expectedResponse: testHTTPResponse{ code: 200, header: map[string][]string{ - gapUpstream: {"existingHeaders"}, contentType: {applicationJSON}, }, request: testHTTPRequest{ @@ -251,11 +255,13 @@ var _ = Describe("HTTP Upstream Suite", func() { RequestURI: "http://example.localhost/existingHeaders", }, }, + expectedUpstream: "existingHeaders", }), ) It("ServeHTTP, when not passing a host header", func() { req := httptest.NewRequest("", "http://example.localhost/foo", nil) + req = middlewareapi.AddRequestScope(req, &middlewareapi.RequestScope{}) rw := httptest.NewRecorder() flush := options.Duration(1 * time.Second) @@ -383,7 +389,8 @@ var _ = Describe("HTTP Upstream Suite", func() { Expect(err).ToNot(HaveOccurred()) handler := newHTTPUpstreamProxy(upstream, u, nil, nil) - proxyServer = httptest.NewServer(handler) + + proxyServer = httptest.NewServer(middleware.NewScope(false)(handler)) }) AfterEach(func() { @@ -414,7 +421,6 @@ var _ = Describe("HTTP Upstream Suite", func() { response, err := http.Get(fmt.Sprintf("http://%s", proxyServer.Listener.Addr().String())) Expect(err).ToNot(HaveOccurred()) Expect(response.StatusCode).To(Equal(200)) - Expect(response.Header.Get(gapUpstream)).To(Equal("websocketProxy")) }) }) }) diff --git a/pkg/upstream/proxy_test.go b/pkg/upstream/proxy_test.go index e834bc60..44483585 100644 --- a/pkg/upstream/proxy_test.go +++ b/pkg/upstream/proxy_test.go @@ -7,6 +7,7 @@ import ( "net/http" "net/http/httptest" + middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options" . "github.com/onsi/ginkgo" . "github.com/onsi/ginkgo/extensions/table" @@ -64,17 +65,24 @@ var _ = Describe("Proxy Suite", func() { type proxyTableInput struct { target string response testHTTPResponse + upstream string } - DescribeTable("Proxy ServerHTTP", + DescribeTable("Proxy ServeHTTP", func(in *proxyTableInput) { - req := httptest.NewRequest("", in.target, nil) + req := middlewareapi.AddRequestScope( + httptest.NewRequest("", in.target, nil), + &middlewareapi.RequestScope{}, + ) rw := httptest.NewRecorder() // Don't mock the remote Address req.RemoteAddr = "" upstreamServer.ServeHTTP(rw, req) + scope := middlewareapi.GetRequestScope(req) + Expect(scope.Upstream).To(Equal(in.upstream)) + Expect(rw.Code).To(Equal(in.response.code)) // Delete extra headers that aren't relevant to tests @@ -99,7 +107,6 @@ var _ = Describe("Proxy Suite", func() { response: testHTTPResponse{ code: 200, header: map[string][]string{ - gapUpstream: {"http-backend"}, contentType: {applicationJSON}, }, request: testHTTPRequest{ @@ -114,6 +121,7 @@ var _ = Describe("Proxy Suite", func() { RequestURI: "http://example.localhost/http/1234", }, }, + upstream: "http-backend", }), Entry("with a request to the File backend", &proxyTableInput{ target: "http://example.localhost/files/foo", @@ -121,31 +129,29 @@ var _ = Describe("Proxy Suite", func() { code: 200, header: map[string][]string{ contentType: {textPlainUTF8}, - gapUpstream: {"file-backend"}, }, raw: "foo", }, + upstream: "file-backend", }), Entry("with a request to the Static backend", &proxyTableInput{ target: "http://example.localhost/static/bar", response: testHTTPResponse{ - code: 200, - header: map[string][]string{ - gapUpstream: {"static-backend"}, - }, - raw: "Authenticated", + code: 200, + header: map[string][]string{}, + raw: "Authenticated", }, + upstream: "static-backend", }), Entry("with a request to the bad HTTP backend", &proxyTableInput{ target: "http://example.localhost/bad-http/bad", response: testHTTPResponse{ - code: 502, - header: map[string][]string{ - gapUpstream: {"bad-http-backend"}, - }, + code: 502, + header: map[string][]string{}, // This tests the error handler raw: "Proxy Error", }, + upstream: "bad-http-backend", }), Entry("with a request to the to an unregistered path", &proxyTableInput{ target: "http://example.localhost/unregistered", @@ -161,12 +167,11 @@ var _ = Describe("Proxy Suite", func() { Entry("with a request to the to backend registered to a single path", &proxyTableInput{ target: "http://example.localhost/single-path", response: testHTTPResponse{ - code: 200, - header: map[string][]string{ - gapUpstream: {"single-path-backend"}, - }, - raw: "Authenticated", + code: 200, + header: map[string][]string{}, + raw: "Authenticated", }, + upstream: "single-path-backend", }), Entry("with a request to the to a subpath of a backend registered to a single path", &proxyTableInput{ target: "http://example.localhost/single-path/unregistered", diff --git a/pkg/upstream/static.go b/pkg/upstream/static.go index d53d3d09..027f3e74 100644 --- a/pkg/upstream/static.go +++ b/pkg/upstream/static.go @@ -3,6 +3,9 @@ package upstream import ( "fmt" "net/http" + + "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" + "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" ) const defaultStaticResponseCode = 200 @@ -24,9 +27,16 @@ type staticResponseHandler struct { // ServeHTTP serves a static response. func (s *staticResponseHandler) ServeHTTP(rw http.ResponseWriter, req *http.Request) { - rw.Header().Set("GAP-Upstream-Address", s.upstream) + scope := middleware.GetRequestScope(req) + // If scope is nil, this will panic. + // A scope should always be injected before this handler is called. + scope.Upstream = s.upstream + rw.WriteHeader(s.code) - fmt.Fprintf(rw, "Authenticated") + _, err := fmt.Fprintf(rw, "Authenticated") + if err != nil { + logger.Errorf("Error writing static response: %v", err) + } } // derefStaticCode returns the derefenced value, or the default if the value is nil diff --git a/pkg/upstream/static_test.go b/pkg/upstream/static_test.go index 1b7309f7..dc2f6e36 100644 --- a/pkg/upstream/static_test.go +++ b/pkg/upstream/static_test.go @@ -6,6 +6,7 @@ import ( "net/http" "net/http/httptest" + middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" . "github.com/onsi/ginkgo" . "github.com/onsi/ginkgo/extensions/table" . "github.com/onsi/gomega" @@ -40,10 +41,14 @@ var _ = Describe("Static Response Suite", func() { handler := newStaticResponseHandler(id, code) req := httptest.NewRequest("", in.requestPath, nil) + req = middlewareapi.AddRequestScope(req, &middlewareapi.RequestScope{}) + rw := httptest.NewRecorder() handler.ServeHTTP(rw, req) - Expect(rw.Header().Get("GAP-Upstream-Address")).To(Equal(id)) + scope := middlewareapi.GetRequestScope(req) + Expect(scope.Upstream).To(Equal(id)) + Expect(rw.Code).To(Equal(in.expectedCode)) Expect(rw.Body.String()).To(Equal(in.expectedBody)) }, diff --git a/pkg/upstream/upstream_suite_test.go b/pkg/upstream/upstream_suite_test.go index f585049c..5e9b9e32 100644 --- a/pkg/upstream/upstream_suite_test.go +++ b/pkg/upstream/upstream_suite_test.go @@ -59,7 +59,6 @@ const ( acceptEncoding = "Accept-Encoding" applicationJSON = "application/json" textPlainUTF8 = "text/plain; charset=utf-8" - gapUpstream = "Gap-Upstream-Address" gapAuth = "Gap-Auth" gapSignature = "Gap-Signature" )