You've already forked oauth2-proxy
							
							
				mirror of
				https://github.com/oauth2-proxy/oauth2-proxy.git
				synced 2025-10-30 23:47:52 +02:00 
			
		
		
		
	Move Logging to Middleware Package (#1070)
* Use a specialized ResponseWriter in middleware * Track User & Upstream in RequestScope * Wrap responses in our custom ResponseWriter * Add tests for logging middleware * Inject upstream metadata into request scope * Use custom ResponseWriter only in logging middleware * Assume RequestScope is never nil
This commit is contained in:
		| @@ -8,6 +8,7 @@ | ||||
|  | ||||
| ## Changes since v7.0.1 | ||||
|  | ||||
| - [#1070](https://github.com/oauth2-proxy/oauth2-proxy/pull/1070) Refactor logging middleware to middleware package (@NickMeves) | ||||
| - [#1064](https://github.com/oauth2-proxy/oauth2-proxy/pull/1064) Add support for setting groups on session when using basic auth (@stefansedich) | ||||
| - [#1056](https://github.com/oauth2-proxy/oauth2-proxy/pull/1056) Add option for custom logos on the sign in page (@JoelSpeed) | ||||
| - [#1054](https://github.com/oauth2-proxy/oauth2-proxy/pull/1054) Update to Go 1.16 (@JoelSpeed) | ||||
|   | ||||
| @@ -1,108 +0,0 @@ | ||||
| // largely adapted from https://github.com/gorilla/handlers/blob/master/handlers.go | ||||
| // to add logging of request duration as last value (and drop referrer) | ||||
|  | ||||
| package main | ||||
|  | ||||
| import ( | ||||
| 	"bufio" | ||||
| 	"errors" | ||||
| 	"net" | ||||
| 	"net/http" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" | ||||
| ) | ||||
|  | ||||
| // responseLogger is wrapper of http.ResponseWriter that keeps track of its HTTP status | ||||
| // code and body size | ||||
| type responseLogger struct { | ||||
| 	w        http.ResponseWriter | ||||
| 	status   int | ||||
| 	size     int | ||||
| 	upstream string | ||||
| 	authInfo string | ||||
| } | ||||
|  | ||||
| // Header returns the ResponseWriter's Header | ||||
| func (l *responseLogger) Header() http.Header { | ||||
| 	return l.w.Header() | ||||
| } | ||||
|  | ||||
| // Support Websocket | ||||
| func (l *responseLogger) Hijack() (rwc net.Conn, buf *bufio.ReadWriter, err error) { | ||||
| 	if hj, ok := l.w.(http.Hijacker); ok { | ||||
| 		return hj.Hijack() | ||||
| 	} | ||||
| 	return nil, nil, errors.New("http.Hijacker is not available on writer") | ||||
| } | ||||
|  | ||||
| // ExtractGAPMetadata extracts and removes GAP headers from the ResponseWriter's | ||||
| // Header | ||||
| func (l *responseLogger) ExtractGAPMetadata() { | ||||
| 	upstream := l.w.Header().Get("GAP-Upstream-Address") | ||||
| 	if upstream != "" { | ||||
| 		l.upstream = upstream | ||||
| 		l.w.Header().Del("GAP-Upstream-Address") | ||||
| 	} | ||||
| 	authInfo := l.w.Header().Get("GAP-Auth") | ||||
| 	if authInfo != "" { | ||||
| 		l.authInfo = authInfo | ||||
| 		l.w.Header().Del("GAP-Auth") | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // Write writes the response using the ResponseWriter | ||||
| func (l *responseLogger) Write(b []byte) (int, error) { | ||||
| 	if l.status == 0 { | ||||
| 		// The status will be StatusOK if WriteHeader has not been called yet | ||||
| 		l.status = http.StatusOK | ||||
| 	} | ||||
| 	l.ExtractGAPMetadata() | ||||
| 	size, err := l.w.Write(b) | ||||
| 	l.size += size | ||||
| 	return size, err | ||||
| } | ||||
|  | ||||
| // WriteHeader writes the status code for the Response | ||||
| func (l *responseLogger) WriteHeader(s int) { | ||||
| 	l.ExtractGAPMetadata() | ||||
| 	l.w.WriteHeader(s) | ||||
| 	l.status = s | ||||
| } | ||||
|  | ||||
| // Status returns the response status code | ||||
| func (l *responseLogger) Status() int { | ||||
| 	return l.status | ||||
| } | ||||
|  | ||||
| // Size returns the response size | ||||
| func (l *responseLogger) Size() int { | ||||
| 	return l.size | ||||
| } | ||||
|  | ||||
| // Flush sends any buffered data to the client | ||||
| func (l *responseLogger) Flush() { | ||||
| 	if flusher, ok := l.w.(http.Flusher); ok { | ||||
| 		flusher.Flush() | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // loggingHandler is the http.Handler implementation for LoggingHandler | ||||
| type loggingHandler struct { | ||||
| 	handler http.Handler | ||||
| } | ||||
|  | ||||
| // LoggingHandler provides an http.Handler which logs requests to the HTTP server | ||||
| func LoggingHandler(h http.Handler) http.Handler { | ||||
| 	return loggingHandler{ | ||||
| 		handler: h, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (h loggingHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { | ||||
| 	t := time.Now() | ||||
| 	url := *req.URL | ||||
| 	responseLogger := &responseLogger{w: w} | ||||
| 	h.handler.ServeHTTP(responseLogger, req) | ||||
| 	logger.PrintReq(responseLogger.authInfo, responseLogger.upstream, req, url, t, responseLogger.Status(), responseLogger.Size()) | ||||
| } | ||||
| @@ -1,116 +0,0 @@ | ||||
| package main | ||||
|  | ||||
| import ( | ||||
| 	"bytes" | ||||
| 	"net/http" | ||||
| 	"net/http/httptest" | ||||
| 	"testing" | ||||
|  | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" | ||||
| 	"github.com/stretchr/testify/assert" | ||||
| ) | ||||
|  | ||||
| const RequestLoggingFormatWithoutTime = "{{.Client}} - {{.Username}} [TIMELESS] {{.Host}} {{.RequestMethod}} {{.Upstream}} {{.RequestURI}} {{.Protocol}} {{.UserAgent}} {{.StatusCode}} {{.ResponseSize}} {{.RequestDuration}}" | ||||
|  | ||||
| func TestLoggingHandler_ServeHTTP(t *testing.T) { | ||||
| 	tests := []struct { | ||||
| 		Format             string | ||||
| 		ExpectedLogMessage string | ||||
| 		Path               string | ||||
| 		ExcludePaths       []string | ||||
| 	}{ | ||||
| 		{ | ||||
| 			Format:             RequestLoggingFormatWithoutTime, | ||||
| 			ExpectedLogMessage: "127.0.0.1 - - [TIMELESS] test-server GET - \"/foo/bar\" HTTP/1.1 \"\" 200 4 0.000\n", | ||||
| 			Path:               "/foo/bar", | ||||
| 			ExcludePaths:       []string{}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			Format:             RequestLoggingFormatWithoutTime, | ||||
| 			ExpectedLogMessage: "127.0.0.1 - - [TIMELESS] test-server GET - \"/foo/bar\" HTTP/1.1 \"\" 200 4 0.000\n", | ||||
| 			Path:               "/foo/bar", | ||||
| 			ExcludePaths:       []string{}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			Format:             RequestLoggingFormatWithoutTime, | ||||
| 			ExpectedLogMessage: "127.0.0.1 - - [TIMELESS] test-server GET - \"/foo/bar\" HTTP/1.1 \"\" 200 4 0.000\n", | ||||
| 			Path:               "/foo/bar", | ||||
| 			ExcludePaths:       []string{"/ping"}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			Format:             RequestLoggingFormatWithoutTime, | ||||
| 			ExpectedLogMessage: "", | ||||
| 			Path:               "/foo/bar", | ||||
| 			ExcludePaths:       []string{"/foo/bar"}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			Format:             RequestLoggingFormatWithoutTime, | ||||
| 			ExpectedLogMessage: "127.0.0.1 - - [TIMELESS] test-server GET - \"/ping\" HTTP/1.1 \"\" 200 4 0.000\n", | ||||
| 			Path:               "/ping", | ||||
| 			ExcludePaths:       []string{}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			Format:             RequestLoggingFormatWithoutTime, | ||||
| 			ExpectedLogMessage: "", | ||||
| 			Path:               "/ping", | ||||
| 			ExcludePaths:       []string{"/ping"}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			Format:             RequestLoggingFormatWithoutTime, | ||||
| 			ExpectedLogMessage: "", | ||||
| 			Path:               "/ping", | ||||
| 			ExcludePaths:       []string{"/foo/bar", "/ping"}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			Format:             "{{.RequestMethod}}", | ||||
| 			ExpectedLogMessage: "GET\n", | ||||
| 			Path:               "/foo/bar", | ||||
| 			ExcludePaths:       []string{""}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			Format:             "{{.RequestMethod}}", | ||||
| 			ExpectedLogMessage: "GET\n", | ||||
| 			Path:               "/foo/bar", | ||||
| 			ExcludePaths:       []string{"/ping"}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			Format:             "{{.RequestMethod}}", | ||||
| 			ExpectedLogMessage: "GET\n", | ||||
| 			Path:               "/ping", | ||||
| 			ExcludePaths:       []string{""}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			Format:             "{{.RequestMethod}}", | ||||
| 			ExpectedLogMessage: "", | ||||
| 			Path:               "/ping", | ||||
| 			ExcludePaths:       []string{"/ping"}, | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	for _, test := range tests { | ||||
| 		buf := bytes.NewBuffer(nil) | ||||
| 		handler := func(w http.ResponseWriter, req *http.Request) { | ||||
| 			_, ok := w.(http.Hijacker) | ||||
| 			if !ok { | ||||
| 				t.Error("http.Hijacker is not available") | ||||
| 			} | ||||
|  | ||||
| 			_, err := w.Write([]byte("test")) | ||||
| 			assert.NoError(t, err) | ||||
| 		} | ||||
|  | ||||
| 		logger.SetOutput(buf) | ||||
| 		logger.SetReqTemplate(test.Format) | ||||
| 		logger.SetExcludePaths(test.ExcludePaths) | ||||
| 		h := LoggingHandler(http.HandlerFunc(handler)) | ||||
|  | ||||
| 		r, _ := http.NewRequest("GET", test.Path, nil) | ||||
| 		r.RemoteAddr = "127.0.0.1" | ||||
| 		r.Host = "test-server" | ||||
|  | ||||
| 		h.ServeHTTP(httptest.NewRecorder(), r) | ||||
|  | ||||
| 		actual := buf.String() | ||||
| 		assert.Equal(t, test.ExpectedLogMessage, actual) | ||||
| 	} | ||||
| } | ||||
| @@ -250,9 +250,15 @@ func buildPreAuthChain(opts *options.Options) (alice.Chain, error) { | ||||
| 	// To silence logging of health checks, register the health check handler before | ||||
| 	// the logging handler | ||||
| 	if opts.Logging.SilencePing { | ||||
| 		chain = chain.Append(middleware.NewHealthCheck(healthCheckPaths, healthCheckUserAgents), LoggingHandler) | ||||
| 		chain = chain.Append( | ||||
| 			middleware.NewHealthCheck(healthCheckPaths, healthCheckUserAgents), | ||||
| 			middleware.NewRequestLogger(), | ||||
| 		) | ||||
| 	} else { | ||||
| 		chain = chain.Append(LoggingHandler, middleware.NewHealthCheck(healthCheckPaths, healthCheckUserAgents)) | ||||
| 		chain = chain.Append( | ||||
| 			middleware.NewRequestLogger(), | ||||
| 			middleware.NewHealthCheck(healthCheckPaths, healthCheckUserAgents), | ||||
| 		) | ||||
| 	} | ||||
|  | ||||
| 	chain = chain.Append(middleware.NewRequestMetricsWithDefaultRegistry()) | ||||
|   | ||||
| @@ -34,6 +34,9 @@ type RequestScope struct { | ||||
| 	// SessionRevalidated indicates whether the session has been revalidated since | ||||
| 	// it was loaded or not. | ||||
| 	SessionRevalidated bool | ||||
|  | ||||
| 	// Upstream tracks which upstream was used for this request | ||||
| 	Upstream string | ||||
| } | ||||
|  | ||||
| // GetRequestScope returns the current request scope from the given request | ||||
|   | ||||
| @@ -4,6 +4,7 @@ import ( | ||||
| 	"net/http" | ||||
| 	"testing" | ||||
|  | ||||
| 	middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" | ||||
| 	. "github.com/onsi/ginkgo" | ||||
| 	. "github.com/onsi/gomega" | ||||
| @@ -19,6 +20,17 @@ func TestMiddlewareSuite(t *testing.T) { | ||||
|  | ||||
| func testHandler() http.Handler { | ||||
| 	return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { | ||||
| 		rw.WriteHeader(200) | ||||
| 		rw.Write([]byte("test")) | ||||
| 	}) | ||||
| } | ||||
|  | ||||
| func testUpstreamHandler(upstream string) http.Handler { | ||||
| 	return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { | ||||
| 		scope := middlewareapi.GetRequestScope(req) | ||||
| 		scope.Upstream = upstream | ||||
|  | ||||
| 		rw.WriteHeader(200) | ||||
| 		rw.Write([]byte("test")) | ||||
| 	}) | ||||
| } | ||||
|   | ||||
							
								
								
									
										110
									
								
								pkg/middleware/request_logger.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										110
									
								
								pkg/middleware/request_logger.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,110 @@ | ||||
| package middleware | ||||
|  | ||||
| import ( | ||||
| 	"bufio" | ||||
| 	"errors" | ||||
| 	"net" | ||||
| 	"net/http" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/justinas/alice" | ||||
| 	middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" | ||||
| ) | ||||
|  | ||||
| // NewRequestLogger returns middleware which logs requests | ||||
| // It uses a custom ResponseWriter to track status code & response size details | ||||
| func NewRequestLogger() alice.Constructor { | ||||
| 	return requestLogger | ||||
| } | ||||
|  | ||||
| func requestLogger(next http.Handler) http.Handler { | ||||
| 	return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { | ||||
| 		startTime := time.Now() | ||||
| 		url := *req.URL | ||||
|  | ||||
| 		responseLogger := &loggingResponse{ResponseWriter: rw} | ||||
| 		next.ServeHTTP(responseLogger, req) | ||||
|  | ||||
| 		scope := middlewareapi.GetRequestScope(req) | ||||
| 		// If scope is nil, this will panic. | ||||
| 		// A scope should always be injected before this handler is called. | ||||
| 		logger.PrintReq( | ||||
| 			getUser(scope), | ||||
| 			scope.Upstream, | ||||
| 			req, | ||||
| 			url, | ||||
| 			startTime, | ||||
| 			responseLogger.Status(), | ||||
| 			responseLogger.Size(), | ||||
| 		) | ||||
| 	}) | ||||
| } | ||||
|  | ||||
| func getUser(scope *middlewareapi.RequestScope) string { | ||||
| 	session := scope.Session | ||||
| 	if session != nil { | ||||
| 		if session.Email != "" { | ||||
| 			return session.Email | ||||
| 		} | ||||
| 		return session.User | ||||
| 	} | ||||
| 	return "" | ||||
| } | ||||
|  | ||||
| // loggingResponse is a custom http.ResponseWriter that allows tracking certain | ||||
| // details for request logging. | ||||
| type loggingResponse struct { | ||||
| 	http.ResponseWriter | ||||
|  | ||||
| 	status int | ||||
| 	size   int | ||||
| } | ||||
|  | ||||
| // Write writes the response using the ResponseWriter | ||||
| func (r *loggingResponse) Write(b []byte) (int, error) { | ||||
| 	if r.status == 0 { | ||||
| 		// The status will be StatusOK if WriteHeader has not been called yet | ||||
| 		r.status = http.StatusOK | ||||
| 	} | ||||
| 	size, err := r.ResponseWriter.Write(b) | ||||
| 	r.size += size | ||||
| 	return size, err | ||||
| } | ||||
|  | ||||
| // WriteHeader writes the status code for the Response | ||||
| func (r *loggingResponse) WriteHeader(s int) { | ||||
| 	r.ResponseWriter.WriteHeader(s) | ||||
| 	r.status = s | ||||
| } | ||||
|  | ||||
| // Hijack implements the `http.Hijacker` interface that actual ResponseWriters | ||||
| // implement to support websockets | ||||
| func (r *loggingResponse) Hijack() (net.Conn, *bufio.ReadWriter, error) { | ||||
| 	if hj, ok := r.ResponseWriter.(http.Hijacker); ok { | ||||
| 		return hj.Hijack() | ||||
| 	} | ||||
| 	return nil, nil, errors.New("http.Hijacker is not available on writer") | ||||
| } | ||||
|  | ||||
| // Flush sends any buffered data to the client. Implements the `http.Flusher` | ||||
| // interface | ||||
| func (r *loggingResponse) Flush() { | ||||
| 	if flusher, ok := r.ResponseWriter.(http.Flusher); ok { | ||||
| 		if r.status == 0 { | ||||
| 			// The status will be StatusOK if WriteHeader has not been called yet | ||||
| 			r.status = http.StatusOK | ||||
| 		} | ||||
| 		flusher.Flush() | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // Status returns the response status code | ||||
| func (r *loggingResponse) Status() int { | ||||
| 	return r.status | ||||
| } | ||||
|  | ||||
| // Size returns the response size | ||||
| func (r *loggingResponse) Size() int { | ||||
| 	return r.size | ||||
| } | ||||
							
								
								
									
										121
									
								
								pkg/middleware/request_logger_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										121
									
								
								pkg/middleware/request_logger_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,121 @@ | ||||
| package middleware | ||||
|  | ||||
| import ( | ||||
| 	"bytes" | ||||
| 	"net/http" | ||||
| 	"net/http/httptest" | ||||
|  | ||||
| 	middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" | ||||
| 	. "github.com/onsi/ginkgo" | ||||
| 	. "github.com/onsi/ginkgo/extensions/table" | ||||
| 	. "github.com/onsi/gomega" | ||||
| ) | ||||
|  | ||||
| const RequestLoggingFormatWithoutTime = "{{.Client}} - {{.Username}} [TIMELESS] {{.Host}} {{.RequestMethod}} {{.Upstream}} {{.RequestURI}} {{.Protocol}} {{.UserAgent}} {{.StatusCode}} {{.ResponseSize}} {{.RequestDuration}}" | ||||
|  | ||||
| var _ = Describe("Request logger suite", func() { | ||||
| 	type requestLoggerTableInput struct { | ||||
| 		Format             string | ||||
| 		ExpectedLogMessage string | ||||
| 		Path               string | ||||
| 		ExcludePaths       []string | ||||
| 		Upstream           string | ||||
| 		Session            *sessions.SessionState | ||||
| 	} | ||||
|  | ||||
| 	DescribeTable("when service a request", | ||||
| 		func(in *requestLoggerTableInput) { | ||||
| 			buf := bytes.NewBuffer(nil) | ||||
| 			logger.SetOutput(buf) | ||||
| 			logger.SetReqTemplate(in.Format) | ||||
| 			logger.SetExcludePaths(in.ExcludePaths) | ||||
|  | ||||
| 			req, err := http.NewRequest("GET", in.Path, nil) | ||||
| 			Expect(err).ToNot(HaveOccurred()) | ||||
| 			req.RemoteAddr = "127.0.0.1" | ||||
| 			req.Host = "test-server" | ||||
|  | ||||
| 			scope := &middlewareapi.RequestScope{Session: in.Session} | ||||
| 			req = middlewareapi.AddRequestScope(req, scope) | ||||
|  | ||||
| 			handler := NewRequestLogger()(testUpstreamHandler(in.Upstream)) | ||||
| 			handler.ServeHTTP(httptest.NewRecorder(), req) | ||||
|  | ||||
| 			Expect(buf.String()).To(Equal(in.ExpectedLogMessage)) | ||||
| 		}, | ||||
| 		Entry("standard request", &requestLoggerTableInput{ | ||||
| 			Format:             RequestLoggingFormatWithoutTime, | ||||
| 			ExpectedLogMessage: "127.0.0.1 - standard.user [TIMELESS] test-server GET standard \"/foo/bar\" HTTP/1.1 \"\" 200 4 0.000\n", | ||||
| 			Path:               "/foo/bar", | ||||
| 			ExcludePaths:       []string{}, | ||||
| 			Upstream:           "standard", | ||||
| 			Session:            &sessions.SessionState{User: "standard.user"}, | ||||
| 		}), | ||||
| 		Entry("with unrelated path excluded", &requestLoggerTableInput{ | ||||
| 			Format:             RequestLoggingFormatWithoutTime, | ||||
| 			ExpectedLogMessage: "127.0.0.1 - unrelated.exclusion [TIMELESS] test-server GET unrelated \"/foo/bar\" HTTP/1.1 \"\" 200 4 0.000\n", | ||||
| 			Path:               "/foo/bar", | ||||
| 			ExcludePaths:       []string{"/ping"}, | ||||
| 			Upstream:           "unrelated", | ||||
| 			Session:            &sessions.SessionState{User: "unrelated.exclusion"}, | ||||
| 		}), | ||||
| 		Entry("with path as the sole exclusion", &requestLoggerTableInput{ | ||||
| 			Format:             RequestLoggingFormatWithoutTime, | ||||
| 			ExpectedLogMessage: "", | ||||
| 			Path:               "/foo/bar", | ||||
| 			ExcludePaths:       []string{"/foo/bar"}, | ||||
| 		}), | ||||
| 		Entry("ping path", &requestLoggerTableInput{ | ||||
| 			Format:             RequestLoggingFormatWithoutTime, | ||||
| 			ExpectedLogMessage: "127.0.0.1 - mr.ping [TIMELESS] test-server GET - \"/ping\" HTTP/1.1 \"\" 200 4 0.000\n", | ||||
| 			Path:               "/ping", | ||||
| 			ExcludePaths:       []string{}, | ||||
| 			Upstream:           "", | ||||
| 			Session:            &sessions.SessionState{User: "mr.ping"}, | ||||
| 		}), | ||||
| 		Entry("ping path but excluded", &requestLoggerTableInput{ | ||||
| 			Format:             RequestLoggingFormatWithoutTime, | ||||
| 			ExpectedLogMessage: "", | ||||
| 			Path:               "/ping", | ||||
| 			ExcludePaths:       []string{"/ping"}, | ||||
| 			Upstream:           "", | ||||
| 			Session:            &sessions.SessionState{User: "mr.ping"}, | ||||
| 		}), | ||||
| 		Entry("ping path and excluded in list", &requestLoggerTableInput{ | ||||
| 			Format:             RequestLoggingFormatWithoutTime, | ||||
| 			ExpectedLogMessage: "", | ||||
| 			Path:               "/ping", | ||||
| 			ExcludePaths:       []string{"/foo/bar", "/ping"}, | ||||
| 		}), | ||||
| 		Entry("custom format", &requestLoggerTableInput{ | ||||
| 			Format:             "{{.RequestMethod}} {{.Username}} {{.Upstream}}", | ||||
| 			ExpectedLogMessage: "GET custom.format custom\n", | ||||
| 			Path:               "/foo/bar", | ||||
| 			ExcludePaths:       []string{""}, | ||||
| 			Upstream:           "custom", | ||||
| 			Session:            &sessions.SessionState{User: "custom.format"}, | ||||
| 		}), | ||||
| 		Entry("custom format with unrelated exclusion", &requestLoggerTableInput{ | ||||
| 			Format:             "{{.RequestMethod}} {{.Username}} {{.Upstream}}", | ||||
| 			ExpectedLogMessage: "GET custom.format custom\n", | ||||
| 			Path:               "/foo/bar", | ||||
| 			ExcludePaths:       []string{"/ping"}, | ||||
| 			Upstream:           "custom", | ||||
| 			Session:            &sessions.SessionState{User: "custom.format"}, | ||||
| 		}), | ||||
| 		Entry("custom format ping path", &requestLoggerTableInput{ | ||||
| 			Format:             "{{.RequestMethod}}", | ||||
| 			ExpectedLogMessage: "GET\n", | ||||
| 			Path:               "/ping", | ||||
| 			ExcludePaths:       []string{""}, | ||||
| 		}), | ||||
| 		Entry("custom format ping path excluded", &requestLoggerTableInput{ | ||||
| 			Format:             "{{.RequestMethod}}", | ||||
| 			ExpectedLogMessage: "", | ||||
| 			Path:               "/ping", | ||||
| 			ExcludePaths:       []string{"/ping"}, | ||||
| 		}), | ||||
| 	) | ||||
| }) | ||||
| @@ -4,6 +4,8 @@ import ( | ||||
| 	"net/http" | ||||
| 	"runtime" | ||||
| 	"strings" | ||||
|  | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" | ||||
| ) | ||||
|  | ||||
| const fileScheme = "file" | ||||
| @@ -37,6 +39,10 @@ type fileServer struct { | ||||
| // ServeHTTP proxies requests to the upstream provider while signing the | ||||
| // request headers | ||||
| func (u *fileServer) ServeHTTP(rw http.ResponseWriter, req *http.Request) { | ||||
| 	rw.Header().Set("GAP-Upstream-Address", u.upstream) | ||||
| 	scope := middleware.GetRequestScope(req) | ||||
| 	// If scope is nil, this will panic. | ||||
| 	// A scope should always be injected before this handler is called. | ||||
| 	scope.Upstream = u.upstream | ||||
|  | ||||
| 	u.handler.ServeHTTP(rw, req) | ||||
| } | ||||
|   | ||||
| @@ -7,6 +7,7 @@ import ( | ||||
| 	"net/http/httptest" | ||||
| 	"os" | ||||
|  | ||||
| 	middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" | ||||
| 	. "github.com/onsi/ginkgo" | ||||
| 	. "github.com/onsi/ginkgo/extensions/table" | ||||
| 	. "github.com/onsi/gomega" | ||||
| @@ -42,10 +43,14 @@ var _ = Describe("File Server Suite", func() { | ||||
| 	DescribeTable("fileServer ServeHTTP", | ||||
| 		func(requestPath string, expectedResponseCode int, expectedBody string) { | ||||
| 			req := httptest.NewRequest("", requestPath, nil) | ||||
| 			req = middlewareapi.AddRequestScope(req, &middlewareapi.RequestScope{}) | ||||
|  | ||||
| 			rw := httptest.NewRecorder() | ||||
| 			handler.ServeHTTP(rw, req) | ||||
|  | ||||
| 			Expect(rw.Header().Get("GAP-Upstream-Address")).To(Equal(id)) | ||||
| 			scope := middlewareapi.GetRequestScope(req) | ||||
| 			Expect(scope.Upstream).To(Equal(id)) | ||||
|  | ||||
| 			Expect(rw.Code).To(Equal(expectedResponseCode)) | ||||
| 			Expect(rw.Body.String()).To(Equal(expectedBody)) | ||||
| 		}, | ||||
|   | ||||
| @@ -8,6 +8,7 @@ import ( | ||||
| 	"strings" | ||||
|  | ||||
| 	"github.com/mbland/hmacauth" | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options" | ||||
| 	"github.com/yhat/wsutil" | ||||
| ) | ||||
| @@ -76,7 +77,12 @@ type httpUpstreamProxy struct { | ||||
| // ServeHTTP proxies requests to the upstream provider while signing the | ||||
| // request headers | ||||
| func (h *httpUpstreamProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { | ||||
| 	rw.Header().Set("GAP-Upstream-Address", h.upstream) | ||||
| 	scope := middleware.GetRequestScope(req) | ||||
| 	// If scope is nil, this will panic. | ||||
| 	// A scope should always be injected before this handler is called. | ||||
| 	scope.Upstream = h.upstream | ||||
|  | ||||
| 	// TODO (@NickMeves) - Deprecate GAP-Signature & remove GAP-Auth | ||||
| 	if h.auth != nil { | ||||
| 		req.Header.Set("GAP-Auth", rw.Header().Get("GAP-Auth")) | ||||
| 		h.auth.SignRequest(req) | ||||
|   | ||||
| @@ -13,7 +13,9 @@ import ( | ||||
| 	"strings" | ||||
| 	"time" | ||||
|  | ||||
| 	middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options" | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/middleware" | ||||
| 	. "github.com/onsi/ginkgo" | ||||
| 	. "github.com/onsi/ginkgo/extensions/table" | ||||
| 	. "github.com/onsi/gomega" | ||||
| @@ -36,6 +38,7 @@ var _ = Describe("HTTP Upstream Suite", func() { | ||||
| 		signatureData    *options.SignatureData | ||||
| 		existingHeaders  map[string]string | ||||
| 		expectedResponse testHTTPResponse | ||||
| 		expectedUpstream string | ||||
| 		errorHandler     ProxyErrorHandler | ||||
| 	} | ||||
|  | ||||
| @@ -50,6 +53,7 @@ var _ = Describe("HTTP Upstream Suite", func() { | ||||
| 				req.Header.Add(key, value) | ||||
| 			} | ||||
|  | ||||
| 			req = middlewareapi.AddRequestScope(req, &middlewareapi.RequestScope{}) | ||||
| 			rw := httptest.NewRecorder() | ||||
|  | ||||
| 			flush := options.Duration(1 * time.Second) | ||||
| @@ -71,6 +75,9 @@ var _ = Describe("HTTP Upstream Suite", func() { | ||||
|  | ||||
| 			Expect(rw.Code).To(Equal(in.expectedResponse.code)) | ||||
|  | ||||
| 			scope := middlewareapi.GetRequestScope(req) | ||||
| 			Expect(scope.Upstream).To(Equal(in.expectedUpstream)) | ||||
|  | ||||
| 			// Delete extra headers that aren't relevant to tests | ||||
| 			testSanitizeResponseHeader(rw.Header()) | ||||
| 			Expect(rw.Header()).To(Equal(in.expectedResponse.header)) | ||||
| @@ -97,7 +104,6 @@ var _ = Describe("HTTP Upstream Suite", func() { | ||||
| 			expectedResponse: testHTTPResponse{ | ||||
| 				code: 200, | ||||
| 				header: map[string][]string{ | ||||
| 					gapUpstream: {"default"}, | ||||
| 					contentType: {applicationJSON}, | ||||
| 				}, | ||||
| 				request: testHTTPRequest{ | ||||
| @@ -109,6 +115,7 @@ var _ = Describe("HTTP Upstream Suite", func() { | ||||
| 					RequestURI: "http://example.localhost/foo", | ||||
| 				}, | ||||
| 			}, | ||||
| 			expectedUpstream: "default", | ||||
| 		}), | ||||
| 		Entry("request a path with encoded slashes", &httpUpstreamTableInput{ | ||||
| 			id:           "encodedSlashes", | ||||
| @@ -120,7 +127,6 @@ var _ = Describe("HTTP Upstream Suite", func() { | ||||
| 			expectedResponse: testHTTPResponse{ | ||||
| 				code: 200, | ||||
| 				header: map[string][]string{ | ||||
| 					gapUpstream: {"encodedSlashes"}, | ||||
| 					contentType: {applicationJSON}, | ||||
| 				}, | ||||
| 				request: testHTTPRequest{ | ||||
| @@ -132,6 +138,7 @@ var _ = Describe("HTTP Upstream Suite", func() { | ||||
| 					RequestURI: "http://example.localhost/foo%2fbar/?baz=1", | ||||
| 				}, | ||||
| 			}, | ||||
| 			expectedUpstream: "encodedSlashes", | ||||
| 		}), | ||||
| 		Entry("when the request has a body", &httpUpstreamTableInput{ | ||||
| 			id:           "requestWithBody", | ||||
| @@ -143,7 +150,6 @@ var _ = Describe("HTTP Upstream Suite", func() { | ||||
| 			expectedResponse: testHTTPResponse{ | ||||
| 				code: 200, | ||||
| 				header: map[string][]string{ | ||||
| 					gapUpstream: {"requestWithBody"}, | ||||
| 					contentType: {applicationJSON}, | ||||
| 				}, | ||||
| 				request: testHTTPRequest{ | ||||
| @@ -157,6 +163,7 @@ var _ = Describe("HTTP Upstream Suite", func() { | ||||
| 					RequestURI: "http://example.localhost/withBody", | ||||
| 				}, | ||||
| 			}, | ||||
| 			expectedUpstream: "requestWithBody", | ||||
| 		}), | ||||
| 		Entry("when the upstream is unavailable", &httpUpstreamTableInput{ | ||||
| 			id:           "unavailableUpstream", | ||||
| @@ -166,12 +173,11 @@ var _ = Describe("HTTP Upstream Suite", func() { | ||||
| 			body:         []byte{}, | ||||
| 			errorHandler: nil, | ||||
| 			expectedResponse: testHTTPResponse{ | ||||
| 				code: 502, | ||||
| 				header: map[string][]string{ | ||||
| 					gapUpstream: {"unavailableUpstream"}, | ||||
| 				}, | ||||
| 				code:    502, | ||||
| 				header:  map[string][]string{}, | ||||
| 				request: testHTTPRequest{}, | ||||
| 			}, | ||||
| 			expectedUpstream: "unavailableUpstream", | ||||
| 		}), | ||||
| 		Entry("when the upstream is unavailable and an error handler is set", &httpUpstreamTableInput{ | ||||
| 			id:         "withErrorHandler", | ||||
| @@ -184,13 +190,12 @@ var _ = Describe("HTTP Upstream Suite", func() { | ||||
| 				rw.Write([]byte("error")) | ||||
| 			}, | ||||
| 			expectedResponse: testHTTPResponse{ | ||||
| 				code: 502, | ||||
| 				header: map[string][]string{ | ||||
| 					gapUpstream: {"withErrorHandler"}, | ||||
| 				}, | ||||
| 				code:    502, | ||||
| 				header:  map[string][]string{}, | ||||
| 				raw:     "error", | ||||
| 				request: testHTTPRequest{}, | ||||
| 			}, | ||||
| 			expectedUpstream: "withErrorHandler", | ||||
| 		}), | ||||
| 		Entry("with a signature", &httpUpstreamTableInput{ | ||||
| 			id:         "withSignature", | ||||
| @@ -207,7 +212,6 @@ var _ = Describe("HTTP Upstream Suite", func() { | ||||
| 				code: 200, | ||||
| 				header: map[string][]string{ | ||||
| 					contentType: {applicationJSON}, | ||||
| 					gapUpstream: {"withSignature"}, | ||||
| 				}, | ||||
| 				request: testHTTPRequest{ | ||||
| 					Method: "GET", | ||||
| @@ -221,6 +225,7 @@ var _ = Describe("HTTP Upstream Suite", func() { | ||||
| 					RequestURI: "http://example.localhost/withSignature", | ||||
| 				}, | ||||
| 			}, | ||||
| 			expectedUpstream: "withSignature", | ||||
| 		}), | ||||
| 		Entry("with existing headers", &httpUpstreamTableInput{ | ||||
| 			id:           "existingHeaders", | ||||
| @@ -236,7 +241,6 @@ var _ = Describe("HTTP Upstream Suite", func() { | ||||
| 			expectedResponse: testHTTPResponse{ | ||||
| 				code: 200, | ||||
| 				header: map[string][]string{ | ||||
| 					gapUpstream: {"existingHeaders"}, | ||||
| 					contentType: {applicationJSON}, | ||||
| 				}, | ||||
| 				request: testHTTPRequest{ | ||||
| @@ -251,11 +255,13 @@ var _ = Describe("HTTP Upstream Suite", func() { | ||||
| 					RequestURI: "http://example.localhost/existingHeaders", | ||||
| 				}, | ||||
| 			}, | ||||
| 			expectedUpstream: "existingHeaders", | ||||
| 		}), | ||||
| 	) | ||||
|  | ||||
| 	It("ServeHTTP, when not passing a host header", func() { | ||||
| 		req := httptest.NewRequest("", "http://example.localhost/foo", nil) | ||||
| 		req = middlewareapi.AddRequestScope(req, &middlewareapi.RequestScope{}) | ||||
| 		rw := httptest.NewRecorder() | ||||
|  | ||||
| 		flush := options.Duration(1 * time.Second) | ||||
| @@ -383,7 +389,8 @@ var _ = Describe("HTTP Upstream Suite", func() { | ||||
| 			Expect(err).ToNot(HaveOccurred()) | ||||
|  | ||||
| 			handler := newHTTPUpstreamProxy(upstream, u, nil, nil) | ||||
| 			proxyServer = httptest.NewServer(handler) | ||||
|  | ||||
| 			proxyServer = httptest.NewServer(middleware.NewScope(false)(handler)) | ||||
| 		}) | ||||
|  | ||||
| 		AfterEach(func() { | ||||
| @@ -414,7 +421,6 @@ var _ = Describe("HTTP Upstream Suite", func() { | ||||
| 			response, err := http.Get(fmt.Sprintf("http://%s", proxyServer.Listener.Addr().String())) | ||||
| 			Expect(err).ToNot(HaveOccurred()) | ||||
| 			Expect(response.StatusCode).To(Equal(200)) | ||||
| 			Expect(response.Header.Get(gapUpstream)).To(Equal("websocketProxy")) | ||||
| 		}) | ||||
| 	}) | ||||
| }) | ||||
|   | ||||
| @@ -7,6 +7,7 @@ import ( | ||||
| 	"net/http" | ||||
| 	"net/http/httptest" | ||||
|  | ||||
| 	middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options" | ||||
| 	. "github.com/onsi/ginkgo" | ||||
| 	. "github.com/onsi/ginkgo/extensions/table" | ||||
| @@ -64,17 +65,24 @@ var _ = Describe("Proxy Suite", func() { | ||||
| 	type proxyTableInput struct { | ||||
| 		target   string | ||||
| 		response testHTTPResponse | ||||
| 		upstream string | ||||
| 	} | ||||
|  | ||||
| 	DescribeTable("Proxy ServerHTTP", | ||||
| 	DescribeTable("Proxy ServeHTTP", | ||||
| 		func(in *proxyTableInput) { | ||||
| 			req := httptest.NewRequest("", in.target, nil) | ||||
| 			req := middlewareapi.AddRequestScope( | ||||
| 				httptest.NewRequest("", in.target, nil), | ||||
| 				&middlewareapi.RequestScope{}, | ||||
| 			) | ||||
| 			rw := httptest.NewRecorder() | ||||
| 			// Don't mock the remote Address | ||||
| 			req.RemoteAddr = "" | ||||
|  | ||||
| 			upstreamServer.ServeHTTP(rw, req) | ||||
|  | ||||
| 			scope := middlewareapi.GetRequestScope(req) | ||||
| 			Expect(scope.Upstream).To(Equal(in.upstream)) | ||||
|  | ||||
| 			Expect(rw.Code).To(Equal(in.response.code)) | ||||
|  | ||||
| 			// Delete extra headers that aren't relevant to tests | ||||
| @@ -99,7 +107,6 @@ var _ = Describe("Proxy Suite", func() { | ||||
| 			response: testHTTPResponse{ | ||||
| 				code: 200, | ||||
| 				header: map[string][]string{ | ||||
| 					gapUpstream: {"http-backend"}, | ||||
| 					contentType: {applicationJSON}, | ||||
| 				}, | ||||
| 				request: testHTTPRequest{ | ||||
| @@ -114,6 +121,7 @@ var _ = Describe("Proxy Suite", func() { | ||||
| 					RequestURI: "http://example.localhost/http/1234", | ||||
| 				}, | ||||
| 			}, | ||||
| 			upstream: "http-backend", | ||||
| 		}), | ||||
| 		Entry("with a request to the File backend", &proxyTableInput{ | ||||
| 			target: "http://example.localhost/files/foo", | ||||
| @@ -121,31 +129,29 @@ var _ = Describe("Proxy Suite", func() { | ||||
| 				code: 200, | ||||
| 				header: map[string][]string{ | ||||
| 					contentType: {textPlainUTF8}, | ||||
| 					gapUpstream: {"file-backend"}, | ||||
| 				}, | ||||
| 				raw: "foo", | ||||
| 			}, | ||||
| 			upstream: "file-backend", | ||||
| 		}), | ||||
| 		Entry("with a request to the Static backend", &proxyTableInput{ | ||||
| 			target: "http://example.localhost/static/bar", | ||||
| 			response: testHTTPResponse{ | ||||
| 				code: 200, | ||||
| 				header: map[string][]string{ | ||||
| 					gapUpstream: {"static-backend"}, | ||||
| 				}, | ||||
| 				raw: "Authenticated", | ||||
| 				code:   200, | ||||
| 				header: map[string][]string{}, | ||||
| 				raw:    "Authenticated", | ||||
| 			}, | ||||
| 			upstream: "static-backend", | ||||
| 		}), | ||||
| 		Entry("with a request to the bad HTTP backend", &proxyTableInput{ | ||||
| 			target: "http://example.localhost/bad-http/bad", | ||||
| 			response: testHTTPResponse{ | ||||
| 				code: 502, | ||||
| 				header: map[string][]string{ | ||||
| 					gapUpstream: {"bad-http-backend"}, | ||||
| 				}, | ||||
| 				code:   502, | ||||
| 				header: map[string][]string{}, | ||||
| 				// This tests the error handler | ||||
| 				raw: "Proxy Error", | ||||
| 			}, | ||||
| 			upstream: "bad-http-backend", | ||||
| 		}), | ||||
| 		Entry("with a request to the to an unregistered path", &proxyTableInput{ | ||||
| 			target: "http://example.localhost/unregistered", | ||||
| @@ -161,12 +167,11 @@ var _ = Describe("Proxy Suite", func() { | ||||
| 		Entry("with a request to the to backend registered to a single path", &proxyTableInput{ | ||||
| 			target: "http://example.localhost/single-path", | ||||
| 			response: testHTTPResponse{ | ||||
| 				code: 200, | ||||
| 				header: map[string][]string{ | ||||
| 					gapUpstream: {"single-path-backend"}, | ||||
| 				}, | ||||
| 				raw: "Authenticated", | ||||
| 				code:   200, | ||||
| 				header: map[string][]string{}, | ||||
| 				raw:    "Authenticated", | ||||
| 			}, | ||||
| 			upstream: "single-path-backend", | ||||
| 		}), | ||||
| 		Entry("with a request to the to a subpath of a backend registered to a single path", &proxyTableInput{ | ||||
| 			target: "http://example.localhost/single-path/unregistered", | ||||
|   | ||||
| @@ -3,6 +3,9 @@ package upstream | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"net/http" | ||||
|  | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" | ||||
| ) | ||||
|  | ||||
| const defaultStaticResponseCode = 200 | ||||
| @@ -24,9 +27,16 @@ type staticResponseHandler struct { | ||||
|  | ||||
| // ServeHTTP serves a static response. | ||||
| func (s *staticResponseHandler) ServeHTTP(rw http.ResponseWriter, req *http.Request) { | ||||
| 	rw.Header().Set("GAP-Upstream-Address", s.upstream) | ||||
| 	scope := middleware.GetRequestScope(req) | ||||
| 	// If scope is nil, this will panic. | ||||
| 	// A scope should always be injected before this handler is called. | ||||
| 	scope.Upstream = s.upstream | ||||
|  | ||||
| 	rw.WriteHeader(s.code) | ||||
| 	fmt.Fprintf(rw, "Authenticated") | ||||
| 	_, err := fmt.Fprintf(rw, "Authenticated") | ||||
| 	if err != nil { | ||||
| 		logger.Errorf("Error writing static response: %v", err) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // derefStaticCode returns the derefenced value, or the default if the value is nil | ||||
|   | ||||
| @@ -6,6 +6,7 @@ import ( | ||||
| 	"net/http" | ||||
| 	"net/http/httptest" | ||||
|  | ||||
| 	middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" | ||||
| 	. "github.com/onsi/ginkgo" | ||||
| 	. "github.com/onsi/ginkgo/extensions/table" | ||||
| 	. "github.com/onsi/gomega" | ||||
| @@ -40,10 +41,14 @@ var _ = Describe("Static Response Suite", func() { | ||||
| 			handler := newStaticResponseHandler(id, code) | ||||
|  | ||||
| 			req := httptest.NewRequest("", in.requestPath, nil) | ||||
| 			req = middlewareapi.AddRequestScope(req, &middlewareapi.RequestScope{}) | ||||
|  | ||||
| 			rw := httptest.NewRecorder() | ||||
| 			handler.ServeHTTP(rw, req) | ||||
|  | ||||
| 			Expect(rw.Header().Get("GAP-Upstream-Address")).To(Equal(id)) | ||||
| 			scope := middlewareapi.GetRequestScope(req) | ||||
| 			Expect(scope.Upstream).To(Equal(id)) | ||||
|  | ||||
| 			Expect(rw.Code).To(Equal(in.expectedCode)) | ||||
| 			Expect(rw.Body.String()).To(Equal(in.expectedBody)) | ||||
| 		}, | ||||
|   | ||||
| @@ -59,7 +59,6 @@ const ( | ||||
| 	acceptEncoding  = "Accept-Encoding" | ||||
| 	applicationJSON = "application/json" | ||||
| 	textPlainUTF8   = "text/plain; charset=utf-8" | ||||
| 	gapUpstream     = "Gap-Upstream-Address" | ||||
| 	gapAuth         = "Gap-Auth" | ||||
| 	gapSignature    = "Gap-Signature" | ||||
| ) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user