From ca416a2ebbacad164dcea3243f4e3d6e24ecf3f8 Mon Sep 17 00:00:00 2001 From: Joel Speed Date: Sun, 14 Jun 2020 16:42:05 +0100 Subject: [PATCH 1/3] Add HealthCheck middleware --- go.mod | 1 + go.sum | 3 + pkg/middleware/healthcheck.go | 48 ++++++++++ pkg/middleware/healthcheck_test.go | 112 ++++++++++++++++++++++++ pkg/middleware/middleware_suite_test.go | 16 ++++ 5 files changed, 180 insertions(+) create mode 100644 pkg/middleware/healthcheck.go create mode 100644 pkg/middleware/healthcheck_test.go create mode 100644 pkg/middleware/middleware_suite_test.go diff --git a/go.mod b/go.mod index 62f41227..00d57333 100644 --- a/go.mod +++ b/go.mod @@ -12,6 +12,7 @@ require ( github.com/dgrijalva/jwt-go v3.2.0+incompatible github.com/fsnotify/fsnotify v1.4.9 github.com/go-redis/redis/v7 v7.2.0 + github.com/justinas/alice v1.2.0 github.com/kr/pretty v0.2.0 // indirect github.com/mbland/hmacauth v0.0.0-20170912233209-44256dfd4bfa github.com/mitchellh/mapstructure v1.1.2 diff --git a/go.sum b/go.sum index 098d3828..dd62bce2 100644 --- a/go.sum +++ b/go.sum @@ -102,6 +102,8 @@ github.com/jstemmer/go-junit-report v0.0.0-20190106144839-af01ea7f8024/go.mod h1 github.com/jtolds/gls v4.20.0+incompatible h1:xdiiI2gbIgH/gLH7ADydsJ1uDOEzR8yvV7C0MuV77Wo= github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU= github.com/julienschmidt/httprouter v1.2.0/go.mod h1:SYymIcj16QtmaHHD7aYtjjsJG7VTCxuUUipMqKk8s4w= +github.com/justinas/alice v1.2.0 h1:+MHSA/vccVCF4Uq37S42jwlkvI2Xzl7zTPCN5BnZNVo= +github.com/justinas/alice v1.2.0/go.mod h1:fN5HRH/reO/zrUflLfTN43t3vXvKzvZIENsNEe7i7qA= github.com/kisielk/errcheck v1.1.0/go.mod h1:EZBBE59ingxPouuu3KfxchcWSUPOHkagtvWXihfKN4Q= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= @@ -129,6 +131,7 @@ github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+W github.com/onsi/ginkgo v1.10.1/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= github.com/onsi/ginkgo v1.12.0 h1:Iw5WCbBcaAAd0fpRb1c9r5YCylv4XDoCSigm1zLevwU= github.com/onsi/ginkgo v1.12.0/go.mod h1:oUhWkIvk5aDxtKvDDuw8gItl8pKl42LzjC9KZE0HfGg= +github.com/onsi/ginkgo v1.12.3 h1:+RYp9QczoWz9zfUyLP/5SLXQVhfr6gZOoKGfQqHuLZQ= github.com/onsi/gomega v1.7.0/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY= github.com/onsi/gomega v1.7.1/go.mod h1:XdKZgCCFLUoM/7CFJVPcG8C1xQ1AJ0vpAezJrB7JYyY= github.com/onsi/gomega v1.9.0 h1:R1uwffexN6Pr340GtYRIdZmAiN4J+iw6WG4wog1DUXg= diff --git a/pkg/middleware/healthcheck.go b/pkg/middleware/healthcheck.go new file mode 100644 index 00000000..ea1b533a --- /dev/null +++ b/pkg/middleware/healthcheck.go @@ -0,0 +1,48 @@ +package middleware + +import ( + "fmt" + "net/http" + + "github.com/justinas/alice" +) + +func NewHealthCheck(paths, userAgents []string) alice.Constructor { + return func(next http.Handler) http.Handler { + return healthCheck(paths, userAgents, next) + } +} + +func healthCheck(paths, userAgents []string, next http.Handler) http.Handler { + // Use a map as a set to check health check paths + pathSet := make(map[string]struct{}) + for _, path := range paths { + pathSet[path] = struct{}{} + } + + // Use a map as a set to check health check paths + userAgentSet := make(map[string]struct{}) + for _, userAgent := range userAgents { + userAgentSet[userAgent] = struct{}{} + } + + return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + if isHealthCheckRequest(pathSet, userAgentSet, req) { + rw.WriteHeader(http.StatusOK) + fmt.Fprintf(rw, "OK") + return + } + + next.ServeHTTP(rw, req) + }) +} + +func isHealthCheckRequest(paths, userAgents map[string]struct{}, req *http.Request) bool { + if _, ok := paths[req.URL.EscapedPath()]; ok { + return true + } + if _, ok := userAgents[req.Header.Get("User-Agent")]; ok { + return true + } + return false +} diff --git a/pkg/middleware/healthcheck_test.go b/pkg/middleware/healthcheck_test.go new file mode 100644 index 00000000..8db4d57b --- /dev/null +++ b/pkg/middleware/healthcheck_test.go @@ -0,0 +1,112 @@ +package middleware + +import ( + "net/http" + "net/http/httptest" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/ginkgo/extensions/table" + . "github.com/onsi/gomega" +) + +var _ = Describe("HealthCheck suite", func() { + type requestTableInput struct { + healthCheckPaths []string + healthCheckUserAgents []string + requestString string + headers map[string]string + expectedStatus int + expectedBody string + } + + DescribeTable("when serving a request", + func(in *requestTableInput) { + req := httptest.NewRequest("", in.requestString, nil) + for k, v := range in.headers { + req.Header.Add(k, v) + } + + rw := httptest.NewRecorder() + + handler := NewHealthCheck(in.healthCheckPaths, in.healthCheckUserAgents)(http.NotFoundHandler()) + handler.ServeHTTP(rw, req) + + Expect(rw.Code).To(Equal(in.expectedStatus)) + Expect(rw.Body.String()).To(Equal(in.expectedBody)) + }, + Entry("when requesting the healthcheck path", &requestTableInput{ + healthCheckPaths: []string{"/ping"}, + healthCheckUserAgents: []string{"hc/1.0"}, + requestString: "http://example.com/ping", + headers: map[string]string{}, + expectedStatus: 200, + expectedBody: "OK", + }), + Entry("when requesting a different path", &requestTableInput{ + healthCheckPaths: []string{"/ping"}, + healthCheckUserAgents: []string{"hc/1.0"}, + requestString: "http://example.com/different", + headers: map[string]string{}, + expectedStatus: 404, + expectedBody: "404 page not found\n", + }), + Entry("with a request from the health check user agent", &requestTableInput{ + healthCheckPaths: []string{"/ping"}, + healthCheckUserAgents: []string{"hc/1.0"}, + requestString: "http://example.com/abc", + headers: map[string]string{ + "User-Agent": "hc/1.0", + }, + expectedStatus: 200, + expectedBody: "OK", + }), + Entry("with a request from a different user agent", &requestTableInput{ + healthCheckPaths: []string{"/ping"}, + healthCheckUserAgents: []string{"hc/1.0"}, + requestString: "http://example.com/abc", + headers: map[string]string{ + "User-Agent": "different", + }, + expectedStatus: 404, + expectedBody: "404 page not found\n", + }), + Entry("with multiple paths, request one of the healthcheck paths", &requestTableInput{ + healthCheckPaths: []string{"/ping", "/liveness_check", "/readiness_check"}, + healthCheckUserAgents: []string{"hc/1.0"}, + requestString: "http://example.com/readiness_check", + headers: map[string]string{}, + expectedStatus: 200, + expectedBody: "OK", + }), + Entry("with multiple paths, request none of the healthcheck paths", &requestTableInput{ + healthCheckPaths: []string{"/ping", "/liveness_check", "/readiness_check"}, + healthCheckUserAgents: []string{"hc/1.0"}, + requestString: "http://example.com/readiness", + headers: map[string]string{ + "User-Agent": "user", + }, + expectedStatus: 404, + expectedBody: "404 page not found\n", + }), + Entry("with multiple user agents, request from a health check user agent", &requestTableInput{ + healthCheckPaths: []string{"/ping"}, + healthCheckUserAgents: []string{"hc/1.0", "GoogleHC/1.0"}, + requestString: "http://example.com/abc", + headers: map[string]string{ + "User-Agent": "GoogleHC/1.0", + }, + expectedStatus: 200, + expectedBody: "OK", + }), + Entry("with multiple user agents, request from none of the health check user agents", &requestTableInput{ + healthCheckPaths: []string{"/ping"}, + healthCheckUserAgents: []string{"hc/1.0", "GoogleHC/1.0"}, + requestString: "http://example.com/abc", + headers: map[string]string{ + "User-Agent": "user", + }, + expectedStatus: 404, + expectedBody: "404 page not found\n", + }), + ) +}) diff --git a/pkg/middleware/middleware_suite_test.go b/pkg/middleware/middleware_suite_test.go new file mode 100644 index 00000000..9972ce8f --- /dev/null +++ b/pkg/middleware/middleware_suite_test.go @@ -0,0 +1,16 @@ +package middleware + +import ( + "testing" + + "github.com/oauth2-proxy/oauth2-proxy/pkg/logger" + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +func TestMiddlewareSuite(t *testing.T) { + logger.SetOutput(GinkgoWriter) + + RegisterFailHandler(Fail) + RunSpecs(t, "Middleware") +} From 9bbd6adce97057d2107d60b2a9494974bbc8a601 Mon Sep 17 00:00:00 2001 From: Joel Speed Date: Sun, 14 Jun 2020 20:58:44 +0100 Subject: [PATCH 2/3] Integrate HealthCheck middleware --- http.go | 46 +++--------------- http_test.go | 99 --------------------------------------- logging_handler.go | 5 +- logging_handler_test.go | 89 +++++------------------------------ main.go | 30 +++++++++--- oauthproxy.go | 30 ------------ pkg/validation/logging.go | 8 +--- pkg/validation/options.go | 2 +- 8 files changed, 47 insertions(+), 262 deletions(-) diff --git a/http.go b/http.go index a2287cb7..a8694187 100644 --- a/http.go +++ b/http.go @@ -9,6 +9,7 @@ import ( "strings" "time" + "github.com/justinas/alice" "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/options" "github.com/oauth2-proxy/oauth2-proxy/pkg/logger" ) @@ -29,45 +30,6 @@ func (s *Server) ListenAndServe() { } } -// Used with gcpHealthcheck() -const userAgentHeader = "User-Agent" -const googleHealthCheckUserAgent = "GoogleHC/1.0" -const rootPath = "/" - -// gcpHealthcheck handles healthcheck queries from GCP. -func gcpHealthcheck(h http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Check for liveness and readiness: used for Google App Engine - if r.URL.EscapedPath() == "/liveness_check" { - w.WriteHeader(http.StatusOK) - w.Write([]byte("OK")) - return - } - if r.URL.EscapedPath() == "/readiness_check" { - w.WriteHeader(http.StatusOK) - w.Write([]byte("OK")) - return - } - - // Check for GKE ingress healthcheck: The ingress requires the root - // path of the target to return a 200 (OK) to indicate the service's good health. This can be quite a challenging demand - // depending on the application's path structure. This middleware filters out the requests from the health check by - // - // 1. checking that the request path is indeed the root path - // 2. ensuring that the User-Agent is "GoogleHC/1.0", the health checker - // 3. ensuring the request method is "GET" - if r.URL.Path == rootPath && - r.Header.Get(userAgentHeader) == googleHealthCheckUserAgent && - r.Method == http.MethodGet { - - w.WriteHeader(http.StatusOK) - return - } - - h.ServeHTTP(w, r) - }) -} - // ServeHTTP constructs a net.Listener and starts handling HTTP requests func (s *Server) ServeHTTP() { HTTPAddress := s.Opts.HTTPAddress @@ -168,6 +130,12 @@ func (ln tcpKeepAliveListener) Accept() (c net.Conn, err error) { return tc, nil } +func newRedirectToHTTPS(opts *options.Options) alice.Constructor { + return func(next http.Handler) http.Handler { + return redirectToHTTPS(opts, next) + } +} + func redirectToHTTPS(opts *options.Options, h http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { proto := r.Header.Get("X-Forwarded-Proto") diff --git a/http_test.go b/http_test.go index bd9c02fa..5bd58172 100644 --- a/http_test.go +++ b/http_test.go @@ -11,105 +11,6 @@ import ( "github.com/stretchr/testify/assert" ) -const localhost = "127.0.0.1" -const host = "test-server" - -func TestGCPHealthcheckLiveness(t *testing.T) { - handler := func(w http.ResponseWriter, req *http.Request) { - w.Write([]byte("test")) - } - - h := gcpHealthcheck(http.HandlerFunc(handler)) - rw := httptest.NewRecorder() - r, _ := http.NewRequest("GET", "/liveness_check", nil) - r.RemoteAddr = localhost - r.Host = host - h.ServeHTTP(rw, r) - - assert.Equal(t, 200, rw.Code) - assert.Equal(t, "OK", rw.Body.String()) -} - -func TestGCPHealthcheckReadiness(t *testing.T) { - handler := func(w http.ResponseWriter, req *http.Request) { - w.Write([]byte("test")) - } - - h := gcpHealthcheck(http.HandlerFunc(handler)) - rw := httptest.NewRecorder() - r, _ := http.NewRequest("GET", "/readiness_check", nil) - r.RemoteAddr = localhost - r.Host = host - h.ServeHTTP(rw, r) - - assert.Equal(t, 200, rw.Code) - assert.Equal(t, "OK", rw.Body.String()) -} - -func TestGCPHealthcheckNotHealthcheck(t *testing.T) { - handler := func(w http.ResponseWriter, req *http.Request) { - w.Write([]byte("test")) - } - - h := gcpHealthcheck(http.HandlerFunc(handler)) - rw := httptest.NewRecorder() - r, _ := http.NewRequest("GET", "/not_any_check", nil) - r.RemoteAddr = localhost - r.Host = host - h.ServeHTTP(rw, r) - - assert.Equal(t, "test", rw.Body.String()) -} - -func TestGCPHealthcheckIngress(t *testing.T) { - handler := func(w http.ResponseWriter, req *http.Request) { - w.Write([]byte("test")) - } - - h := gcpHealthcheck(http.HandlerFunc(handler)) - rw := httptest.NewRecorder() - r, _ := http.NewRequest("GET", "/", nil) - r.RemoteAddr = localhost - r.Host = host - r.Header.Set(userAgentHeader, googleHealthCheckUserAgent) - h.ServeHTTP(rw, r) - - assert.Equal(t, 200, rw.Code) - assert.Equal(t, "", rw.Body.String()) -} - -func TestGCPHealthcheckNotIngress(t *testing.T) { - handler := func(w http.ResponseWriter, req *http.Request) { - w.Write([]byte("test")) - } - - h := gcpHealthcheck(http.HandlerFunc(handler)) - rw := httptest.NewRecorder() - r, _ := http.NewRequest("GET", "/foo", nil) - r.RemoteAddr = localhost - r.Host = host - r.Header.Set(userAgentHeader, googleHealthCheckUserAgent) - h.ServeHTTP(rw, r) - - assert.Equal(t, "test", rw.Body.String()) -} - -func TestGCPHealthcheckNotIngressPut(t *testing.T) { - handler := func(w http.ResponseWriter, req *http.Request) { - w.Write([]byte("test")) - } - - h := gcpHealthcheck(http.HandlerFunc(handler)) - rw := httptest.NewRecorder() - r, _ := http.NewRequest("PUT", "/", nil) - r.RemoteAddr = localhost - r.Host = host - r.Header.Set(userAgentHeader, googleHealthCheckUserAgent) - h.ServeHTTP(rw, r) - - assert.Equal(t, "test", rw.Body.String()) -} - func TestRedirectToHTTPSTrue(t *testing.T) { opts := options.NewOptions() opts.ForceHTTPS = true diff --git a/logging_handler.go b/logging_handler.go index 414c3ee2..1c857413 100644 --- a/logging_handler.go +++ b/logging_handler.go @@ -21,7 +21,6 @@ type responseLogger struct { size int upstream string authInfo string - silent bool } // Header returns the ResponseWriter's Header @@ -105,7 +104,5 @@ func (h loggingHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { url := *req.URL responseLogger := &responseLogger{w: w} h.handler.ServeHTTP(responseLogger, req) - if !responseLogger.silent { - logger.PrintReq(responseLogger.authInfo, responseLogger.upstream, req, url, t, responseLogger.Status(), responseLogger.Size()) - } + logger.PrintReq(responseLogger.authInfo, responseLogger.upstream, req, url, t, responseLogger.Status(), responseLogger.Size()) } diff --git a/logging_handler_test.go b/logging_handler_test.go index 16819bcf..b522dd73 100644 --- a/logging_handler_test.go +++ b/logging_handler_test.go @@ -5,14 +5,11 @@ import ( "fmt" "net/http" "net/http/httptest" - "net/url" "strings" "testing" "time" - "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/options" "github.com/oauth2-proxy/oauth2-proxy/pkg/logger" - "github.com/oauth2-proxy/oauth2-proxy/pkg/validation" ) func TestLoggingHandler_ServeHTTP(t *testing.T) { @@ -22,21 +19,20 @@ func TestLoggingHandler_ServeHTTP(t *testing.T) { Format, ExpectedLogMessage, Path string - ExcludePaths []string - SilencePingLogging bool + ExcludePaths []string }{ - {logger.DefaultRequestLoggingFormat, fmt.Sprintf("127.0.0.1 - - [%s] test-server GET - \"/foo/bar\" HTTP/1.1 \"\" 200 4 0.000\n", logger.FormatTimestamp(ts)), "/foo/bar", []string{}, false}, - {logger.DefaultRequestLoggingFormat, fmt.Sprintf("127.0.0.1 - - [%s] test-server GET - \"/foo/bar\" HTTP/1.1 \"\" 200 4 0.000\n", logger.FormatTimestamp(ts)), "/foo/bar", []string{}, true}, - {logger.DefaultRequestLoggingFormat, fmt.Sprintf("127.0.0.1 - - [%s] test-server GET - \"/foo/bar\" HTTP/1.1 \"\" 200 4 0.000\n", logger.FormatTimestamp(ts)), "/foo/bar", []string{"/ping"}, false}, - {logger.DefaultRequestLoggingFormat, "", "/foo/bar", []string{"/foo/bar"}, false}, - {logger.DefaultRequestLoggingFormat, "", "/ping", []string{}, true}, - {logger.DefaultRequestLoggingFormat, "", "/ping", []string{"/ping"}, false}, - {logger.DefaultRequestLoggingFormat, "", "/ping", []string{"/ping"}, true}, - {logger.DefaultRequestLoggingFormat, "", "/ping", []string{"/foo/bar", "/ping"}, false}, - {"{{.RequestMethod}}", "GET\n", "/foo/bar", []string{}, true}, - {"{{.RequestMethod}}", "GET\n", "/foo/bar", []string{"/ping"}, false}, - {"{{.RequestMethod}}", "GET\n", "/ping", []string{}, false}, - {"{{.RequestMethod}}", "", "/ping", []string{"/ping"}, true}, + {logger.DefaultRequestLoggingFormat, fmt.Sprintf("127.0.0.1 - - [%s] test-server GET - \"/foo/bar\" HTTP/1.1 \"\" 200 4 0.000\n", logger.FormatTimestamp(ts)), "/foo/bar", []string{}}, + {logger.DefaultRequestLoggingFormat, fmt.Sprintf("127.0.0.1 - - [%s] test-server GET - \"/foo/bar\" HTTP/1.1 \"\" 200 4 0.000\n", logger.FormatTimestamp(ts)), "/foo/bar", []string{}}, + {logger.DefaultRequestLoggingFormat, fmt.Sprintf("127.0.0.1 - - [%s] test-server GET - \"/foo/bar\" HTTP/1.1 \"\" 200 4 0.000\n", logger.FormatTimestamp(ts)), "/foo/bar", []string{"/ping"}}, + {logger.DefaultRequestLoggingFormat, "", "/foo/bar", []string{"/foo/bar"}}, + {logger.DefaultRequestLoggingFormat, "", "/ping", []string{}}, + {logger.DefaultRequestLoggingFormat, "", "/ping", []string{"/ping"}}, + {logger.DefaultRequestLoggingFormat, "", "/ping", []string{"/ping"}}, + {logger.DefaultRequestLoggingFormat, "", "/ping", []string{"/foo/bar", "/ping"}}, + {"{{.RequestMethod}}", "GET\n", "/foo/bar", []string{}}, + {"{{.RequestMethod}}", "GET\n", "/foo/bar", []string{"/ping"}}, + {"{{.RequestMethod}}", "GET\n", "/ping", []string{}}, + {"{{.RequestMethod}}", "", "/ping", []string{"/ping"}}, } for _, test := range tests { @@ -52,9 +48,6 @@ func TestLoggingHandler_ServeHTTP(t *testing.T) { logger.SetOutput(buf) logger.SetReqTemplate(test.Format) - if test.SilencePingLogging { - test.ExcludePaths = append(test.ExcludePaths, "/ping") - } logger.SetExcludePaths(test.ExcludePaths) h := LoggingHandler(http.HandlerFunc(handler)) @@ -70,59 +63,3 @@ func TestLoggingHandler_ServeHTTP(t *testing.T) { } } } - -func TestLoggingHandler_PingUserAgent(t *testing.T) { - tests := []struct { - ExpectedLogMessage string - Path string - SilencePingLogging bool - WithUserAgent string - }{ - {"444\n", "/foo", true, "Blah"}, - {"444\n", "/foo", false, "Blah"}, - {"", "/ping", true, "Blah"}, - {"200\n", "/ping", false, "Blah"}, - {"", "/ping", true, "PingMe!"}, - {"", "/ping", false, "PingMe!"}, - {"", "/foo", true, "PingMe!"}, - {"", "/foo", false, "PingMe!"}, - } - - for idx, test := range tests { - t.Run(fmt.Sprintf("%d", idx), func(t *testing.T) { - opts := options.NewOptions() - opts.PingUserAgent = "PingMe!" - opts.SkipAuthRegex = []string{"/foo"} - opts.Upstreams = []string{"static://444/foo"} - opts.Logging.SilencePing = test.SilencePingLogging - if test.SilencePingLogging { - opts.Logging.ExcludePaths = []string{"/ping"} - } - opts.RawRedirectURL = "localhost" - validation.Validate(opts) - - p := NewOAuthProxy(opts, func(email string) bool { - return true - }) - p.provider = NewTestProvider(&url.URL{Host: "localhost"}, "") - - buf := bytes.NewBuffer(nil) - logger.SetOutput(buf) - logger.SetReqEnabled(true) - logger.SetReqTemplate("{{.StatusCode}}") - - r, _ := http.NewRequest("GET", test.Path, nil) - if test.WithUserAgent != "" { - r.Header.Set("User-Agent", test.WithUserAgent) - } - - h := LoggingHandler(p) - h.ServeHTTP(httptest.NewRecorder(), r) - - actual := buf.String() - if !strings.Contains(actual, test.ExpectedLogMessage) { - t.Errorf("Log message was\n%s\ninstead of matching \n%s", actual, test.ExpectedLogMessage) - } - }) - } -} diff --git a/main.go b/main.go index 16a34cd7..9b04e24d 100644 --- a/main.go +++ b/main.go @@ -3,7 +3,6 @@ package main import ( "fmt" "math/rand" - "net/http" "os" "os/signal" "runtime" @@ -11,8 +10,10 @@ import ( "syscall" "time" + "github.com/justinas/alice" "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/options" "github.com/oauth2-proxy/oauth2-proxy/pkg/logger" + "github.com/oauth2-proxy/oauth2-proxy/pkg/middleware" "github.com/oauth2-proxy/oauth2-proxy/pkg/validation" ) @@ -71,14 +72,29 @@ func main() { rand.Seed(time.Now().UnixNano()) - var handler http.Handler - if opts.GCPHealthChecks { - handler = redirectToHTTPS(opts, gcpHealthcheck(LoggingHandler(oauthproxy))) - } else { - handler = redirectToHTTPS(opts, LoggingHandler(oauthproxy)) + chain := alice.New() + + if opts.ForceHTTPS { + chain = chain.Append(newRedirectToHTTPS(opts)) } + + healthCheckPaths := []string{opts.PingPath} + healthCheckUserAgents := []string{opts.PingUserAgent} + if opts.GCPHealthChecks { + healthCheckPaths = append(healthCheckPaths, "/liveness_check", "/readiness_check") + healthCheckUserAgents = append(healthCheckUserAgents, "GoogleHC/1.0") + } + + // To silence logging of health checks, register the health check handler before + // the logging handler + if opts.Logging.SilencePing { + chain = chain.Append(middleware.NewHealthCheck(healthCheckPaths, healthCheckUserAgents), LoggingHandler) + } else { + chain = chain.Append(LoggingHandler, middleware.NewHealthCheck(healthCheckPaths, healthCheckUserAgents)) + } + s := &Server{ - Handler: handler, + Handler: chain.Then(oauthproxy), Opts: opts, stop: make(chan struct{}, 1), } diff --git a/oauthproxy.go b/oauthproxy.go index 234451f4..b4119918 100644 --- a/oauthproxy.go +++ b/oauthproxy.go @@ -81,9 +81,6 @@ type OAuthProxy struct { Validator func(string) bool RobotsPath string - PingPath string - PingUserAgent string - SilencePings bool SignInPath string SignOutPath string OAuthStartPath string @@ -313,9 +310,6 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) *OAuthPro Validator: validator, RobotsPath: "/robots.txt", - PingPath: opts.PingPath, - PingUserAgent: opts.PingUserAgent, - SilencePings: opts.Logging.SilencePing, SignInPath: fmt.Sprintf("%s/sign_in", opts.ProxyPrefix), SignOutPath: fmt.Sprintf("%s/sign_out", opts.ProxyPrefix), OAuthStartPath: fmt.Sprintf("%s/start", opts.ProxyPrefix), @@ -468,17 +462,6 @@ func (p *OAuthProxy) RobotsTxt(rw http.ResponseWriter) { fmt.Fprintf(rw, "User-agent: *\nDisallow: /") } -// PingPage responds 200 OK to requests -func (p *OAuthProxy) PingPage(rw http.ResponseWriter) { - if p.SilencePings { - if rl, ok := rw.(*responseLogger); ok { - rl.silent = true - } - } - rw.WriteHeader(http.StatusOK) - fmt.Fprintf(rw, "OK") -} - // ErrorPage writes an error response func (p *OAuthProxy) ErrorPage(rw http.ResponseWriter, code int, title string, message string) { rw.WriteHeader(code) @@ -684,17 +667,6 @@ func prepareNoCache(w http.ResponseWriter) { } } -// IsPingRequest will check if the request appears to be performing a health check -// either via the path it's requesting or by a special User-Agent configuration. -func (p *OAuthProxy) IsPingRequest(req *http.Request) bool { - - if req.URL.EscapedPath() == p.PingPath { - return true - } - - return p.PingUserAgent != "" && req.Header.Get("User-Agent") == p.PingUserAgent -} - func (p *OAuthProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { if strings.HasPrefix(req.URL.Path, p.ProxyPrefix) { prepareNoCache(rw) @@ -703,8 +675,6 @@ func (p *OAuthProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { switch path := req.URL.Path; { case path == p.RobotsPath: p.RobotsTxt(rw) - case p.IsPingRequest(req): - p.PingPage(rw) case p.IsWhitelistedRequest(req): p.serveMux.ServeHTTP(rw, req) case path == p.SignInPath: diff --git a/pkg/validation/logging.go b/pkg/validation/logging.go index 1c8ab2a3..2c4754aa 100644 --- a/pkg/validation/logging.go +++ b/pkg/validation/logging.go @@ -9,7 +9,7 @@ import ( ) // configureLogger is responsible for configuring the logger based on the options given -func configureLogger(o options.Logging, pingPath string, msgs []string) []string { +func configureLogger(o options.Logging, msgs []string) []string { // Setup the log file if len(o.File.Filename) > 0 { // Validate that the file/dir can be written @@ -48,11 +48,7 @@ func configureLogger(o options.Logging, pingPath string, msgs []string) []string logger.SetAuthTemplate(o.AuthFormat) logger.SetReqTemplate(o.RequestFormat) - excludePaths := o.ExcludePaths - if o.SilencePing { - excludePaths = append(excludePaths, pingPath) - } - logger.SetExcludePaths(excludePaths) + logger.SetExcludePaths(o.ExcludePaths) if !o.LocalTime { logger.SetFlags(logger.Flags() | logger.LUTC) diff --git a/pkg/validation/options.go b/pkg/validation/options.go index b22882d0..d3ae4ece 100644 --- a/pkg/validation/options.go +++ b/pkg/validation/options.go @@ -264,7 +264,7 @@ func Validate(o *options.Options) error { msgs = parseSignatureKey(o, msgs) msgs = validateCookieName(o, msgs) - msgs = configureLogger(o.Logging, o.PingPath, msgs) + msgs = configureLogger(o.Logging, msgs) if o.ReverseProxy { parser, err := ip.GetRealClientIPParser(o.RealClientIPHeader) From ba3e40ab1c86382ec39bed889ebe455c3f3a9b13 Mon Sep 17 00:00:00 2001 From: Joel Speed Date: Sun, 14 Jun 2020 21:06:14 +0100 Subject: [PATCH 3/3] Add changelog entry for healthcheck middleware --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index c051533d..22e489a1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -55,6 +55,7 @@ ## Changes since v5.1.1 +- [#620](https://github.com/oauth2-proxy/oauth2-proxy/pull/620) Add HealthCheck middleware (@JoelSpeed) - [#604](https://github.com/oauth2-proxy/oauth2-proxy/pull/604) Add Keycloak local testing environment (@EvgeniGordeev) - [#539](https://github.com/oauth2-proxy/oauth2-proxy/pull/539) Refactor encryption ciphers and add AES-GCM support (@NickMeves) - [#601](https://github.com/oauth2-proxy/oauth2-proxy/pull/601) Ensure decrypted user/email are valid UTF8 (@JoelSpeed)