1
0
mirror of https://github.com/oauth2-proxy/oauth2-proxy.git synced 2025-03-21 21:47:11 +02:00

Integrate HealthCheck middleware

This commit is contained in:
Joel Speed 2020-06-14 20:58:44 +01:00
parent ca416a2ebb
commit 9bbd6adce9
No known key found for this signature in database
GPG Key ID: 6E80578D6751DEFB
8 changed files with 47 additions and 262 deletions

46
http.go
View File

@ -9,6 +9,7 @@ import (
"strings"
"time"
"github.com/justinas/alice"
"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/options"
"github.com/oauth2-proxy/oauth2-proxy/pkg/logger"
)
@ -29,45 +30,6 @@ func (s *Server) ListenAndServe() {
}
}
// Used with gcpHealthcheck()
const userAgentHeader = "User-Agent"
const googleHealthCheckUserAgent = "GoogleHC/1.0"
const rootPath = "/"
// gcpHealthcheck handles healthcheck queries from GCP.
func gcpHealthcheck(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Check for liveness and readiness: used for Google App Engine
if r.URL.EscapedPath() == "/liveness_check" {
w.WriteHeader(http.StatusOK)
w.Write([]byte("OK"))
return
}
if r.URL.EscapedPath() == "/readiness_check" {
w.WriteHeader(http.StatusOK)
w.Write([]byte("OK"))
return
}
// Check for GKE ingress healthcheck: The ingress requires the root
// path of the target to return a 200 (OK) to indicate the service's good health. This can be quite a challenging demand
// depending on the application's path structure. This middleware filters out the requests from the health check by
//
// 1. checking that the request path is indeed the root path
// 2. ensuring that the User-Agent is "GoogleHC/1.0", the health checker
// 3. ensuring the request method is "GET"
if r.URL.Path == rootPath &&
r.Header.Get(userAgentHeader) == googleHealthCheckUserAgent &&
r.Method == http.MethodGet {
w.WriteHeader(http.StatusOK)
return
}
h.ServeHTTP(w, r)
})
}
// ServeHTTP constructs a net.Listener and starts handling HTTP requests
func (s *Server) ServeHTTP() {
HTTPAddress := s.Opts.HTTPAddress
@ -168,6 +130,12 @@ func (ln tcpKeepAliveListener) Accept() (c net.Conn, err error) {
return tc, nil
}
func newRedirectToHTTPS(opts *options.Options) alice.Constructor {
return func(next http.Handler) http.Handler {
return redirectToHTTPS(opts, next)
}
}
func redirectToHTTPS(opts *options.Options, h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
proto := r.Header.Get("X-Forwarded-Proto")

View File

@ -11,105 +11,6 @@ import (
"github.com/stretchr/testify/assert"
)
const localhost = "127.0.0.1"
const host = "test-server"
func TestGCPHealthcheckLiveness(t *testing.T) {
handler := func(w http.ResponseWriter, req *http.Request) {
w.Write([]byte("test"))
}
h := gcpHealthcheck(http.HandlerFunc(handler))
rw := httptest.NewRecorder()
r, _ := http.NewRequest("GET", "/liveness_check", nil)
r.RemoteAddr = localhost
r.Host = host
h.ServeHTTP(rw, r)
assert.Equal(t, 200, rw.Code)
assert.Equal(t, "OK", rw.Body.String())
}
func TestGCPHealthcheckReadiness(t *testing.T) {
handler := func(w http.ResponseWriter, req *http.Request) {
w.Write([]byte("test"))
}
h := gcpHealthcheck(http.HandlerFunc(handler))
rw := httptest.NewRecorder()
r, _ := http.NewRequest("GET", "/readiness_check", nil)
r.RemoteAddr = localhost
r.Host = host
h.ServeHTTP(rw, r)
assert.Equal(t, 200, rw.Code)
assert.Equal(t, "OK", rw.Body.String())
}
func TestGCPHealthcheckNotHealthcheck(t *testing.T) {
handler := func(w http.ResponseWriter, req *http.Request) {
w.Write([]byte("test"))
}
h := gcpHealthcheck(http.HandlerFunc(handler))
rw := httptest.NewRecorder()
r, _ := http.NewRequest("GET", "/not_any_check", nil)
r.RemoteAddr = localhost
r.Host = host
h.ServeHTTP(rw, r)
assert.Equal(t, "test", rw.Body.String())
}
func TestGCPHealthcheckIngress(t *testing.T) {
handler := func(w http.ResponseWriter, req *http.Request) {
w.Write([]byte("test"))
}
h := gcpHealthcheck(http.HandlerFunc(handler))
rw := httptest.NewRecorder()
r, _ := http.NewRequest("GET", "/", nil)
r.RemoteAddr = localhost
r.Host = host
r.Header.Set(userAgentHeader, googleHealthCheckUserAgent)
h.ServeHTTP(rw, r)
assert.Equal(t, 200, rw.Code)
assert.Equal(t, "", rw.Body.String())
}
func TestGCPHealthcheckNotIngress(t *testing.T) {
handler := func(w http.ResponseWriter, req *http.Request) {
w.Write([]byte("test"))
}
h := gcpHealthcheck(http.HandlerFunc(handler))
rw := httptest.NewRecorder()
r, _ := http.NewRequest("GET", "/foo", nil)
r.RemoteAddr = localhost
r.Host = host
r.Header.Set(userAgentHeader, googleHealthCheckUserAgent)
h.ServeHTTP(rw, r)
assert.Equal(t, "test", rw.Body.String())
}
func TestGCPHealthcheckNotIngressPut(t *testing.T) {
handler := func(w http.ResponseWriter, req *http.Request) {
w.Write([]byte("test"))
}
h := gcpHealthcheck(http.HandlerFunc(handler))
rw := httptest.NewRecorder()
r, _ := http.NewRequest("PUT", "/", nil)
r.RemoteAddr = localhost
r.Host = host
r.Header.Set(userAgentHeader, googleHealthCheckUserAgent)
h.ServeHTTP(rw, r)
assert.Equal(t, "test", rw.Body.String())
}
func TestRedirectToHTTPSTrue(t *testing.T) {
opts := options.NewOptions()
opts.ForceHTTPS = true

View File

@ -21,7 +21,6 @@ type responseLogger struct {
size int
upstream string
authInfo string
silent bool
}
// Header returns the ResponseWriter's Header
@ -105,7 +104,5 @@ func (h loggingHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
url := *req.URL
responseLogger := &responseLogger{w: w}
h.handler.ServeHTTP(responseLogger, req)
if !responseLogger.silent {
logger.PrintReq(responseLogger.authInfo, responseLogger.upstream, req, url, t, responseLogger.Status(), responseLogger.Size())
}
logger.PrintReq(responseLogger.authInfo, responseLogger.upstream, req, url, t, responseLogger.Status(), responseLogger.Size())
}

View File

@ -5,14 +5,11 @@ import (
"fmt"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"time"
"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/options"
"github.com/oauth2-proxy/oauth2-proxy/pkg/logger"
"github.com/oauth2-proxy/oauth2-proxy/pkg/validation"
)
func TestLoggingHandler_ServeHTTP(t *testing.T) {
@ -22,21 +19,20 @@ func TestLoggingHandler_ServeHTTP(t *testing.T) {
Format,
ExpectedLogMessage,
Path string
ExcludePaths []string
SilencePingLogging bool
ExcludePaths []string
}{
{logger.DefaultRequestLoggingFormat, fmt.Sprintf("127.0.0.1 - - [%s] test-server GET - \"/foo/bar\" HTTP/1.1 \"\" 200 4 0.000\n", logger.FormatTimestamp(ts)), "/foo/bar", []string{}, false},
{logger.DefaultRequestLoggingFormat, fmt.Sprintf("127.0.0.1 - - [%s] test-server GET - \"/foo/bar\" HTTP/1.1 \"\" 200 4 0.000\n", logger.FormatTimestamp(ts)), "/foo/bar", []string{}, true},
{logger.DefaultRequestLoggingFormat, fmt.Sprintf("127.0.0.1 - - [%s] test-server GET - \"/foo/bar\" HTTP/1.1 \"\" 200 4 0.000\n", logger.FormatTimestamp(ts)), "/foo/bar", []string{"/ping"}, false},
{logger.DefaultRequestLoggingFormat, "", "/foo/bar", []string{"/foo/bar"}, false},
{logger.DefaultRequestLoggingFormat, "", "/ping", []string{}, true},
{logger.DefaultRequestLoggingFormat, "", "/ping", []string{"/ping"}, false},
{logger.DefaultRequestLoggingFormat, "", "/ping", []string{"/ping"}, true},
{logger.DefaultRequestLoggingFormat, "", "/ping", []string{"/foo/bar", "/ping"}, false},
{"{{.RequestMethod}}", "GET\n", "/foo/bar", []string{}, true},
{"{{.RequestMethod}}", "GET\n", "/foo/bar", []string{"/ping"}, false},
{"{{.RequestMethod}}", "GET\n", "/ping", []string{}, false},
{"{{.RequestMethod}}", "", "/ping", []string{"/ping"}, true},
{logger.DefaultRequestLoggingFormat, fmt.Sprintf("127.0.0.1 - - [%s] test-server GET - \"/foo/bar\" HTTP/1.1 \"\" 200 4 0.000\n", logger.FormatTimestamp(ts)), "/foo/bar", []string{}},
{logger.DefaultRequestLoggingFormat, fmt.Sprintf("127.0.0.1 - - [%s] test-server GET - \"/foo/bar\" HTTP/1.1 \"\" 200 4 0.000\n", logger.FormatTimestamp(ts)), "/foo/bar", []string{}},
{logger.DefaultRequestLoggingFormat, fmt.Sprintf("127.0.0.1 - - [%s] test-server GET - \"/foo/bar\" HTTP/1.1 \"\" 200 4 0.000\n", logger.FormatTimestamp(ts)), "/foo/bar", []string{"/ping"}},
{logger.DefaultRequestLoggingFormat, "", "/foo/bar", []string{"/foo/bar"}},
{logger.DefaultRequestLoggingFormat, "", "/ping", []string{}},
{logger.DefaultRequestLoggingFormat, "", "/ping", []string{"/ping"}},
{logger.DefaultRequestLoggingFormat, "", "/ping", []string{"/ping"}},
{logger.DefaultRequestLoggingFormat, "", "/ping", []string{"/foo/bar", "/ping"}},
{"{{.RequestMethod}}", "GET\n", "/foo/bar", []string{}},
{"{{.RequestMethod}}", "GET\n", "/foo/bar", []string{"/ping"}},
{"{{.RequestMethod}}", "GET\n", "/ping", []string{}},
{"{{.RequestMethod}}", "", "/ping", []string{"/ping"}},
}
for _, test := range tests {
@ -52,9 +48,6 @@ func TestLoggingHandler_ServeHTTP(t *testing.T) {
logger.SetOutput(buf)
logger.SetReqTemplate(test.Format)
if test.SilencePingLogging {
test.ExcludePaths = append(test.ExcludePaths, "/ping")
}
logger.SetExcludePaths(test.ExcludePaths)
h := LoggingHandler(http.HandlerFunc(handler))
@ -70,59 +63,3 @@ func TestLoggingHandler_ServeHTTP(t *testing.T) {
}
}
}
func TestLoggingHandler_PingUserAgent(t *testing.T) {
tests := []struct {
ExpectedLogMessage string
Path string
SilencePingLogging bool
WithUserAgent string
}{
{"444\n", "/foo", true, "Blah"},
{"444\n", "/foo", false, "Blah"},
{"", "/ping", true, "Blah"},
{"200\n", "/ping", false, "Blah"},
{"", "/ping", true, "PingMe!"},
{"", "/ping", false, "PingMe!"},
{"", "/foo", true, "PingMe!"},
{"", "/foo", false, "PingMe!"},
}
for idx, test := range tests {
t.Run(fmt.Sprintf("%d", idx), func(t *testing.T) {
opts := options.NewOptions()
opts.PingUserAgent = "PingMe!"
opts.SkipAuthRegex = []string{"/foo"}
opts.Upstreams = []string{"static://444/foo"}
opts.Logging.SilencePing = test.SilencePingLogging
if test.SilencePingLogging {
opts.Logging.ExcludePaths = []string{"/ping"}
}
opts.RawRedirectURL = "localhost"
validation.Validate(opts)
p := NewOAuthProxy(opts, func(email string) bool {
return true
})
p.provider = NewTestProvider(&url.URL{Host: "localhost"}, "")
buf := bytes.NewBuffer(nil)
logger.SetOutput(buf)
logger.SetReqEnabled(true)
logger.SetReqTemplate("{{.StatusCode}}")
r, _ := http.NewRequest("GET", test.Path, nil)
if test.WithUserAgent != "" {
r.Header.Set("User-Agent", test.WithUserAgent)
}
h := LoggingHandler(p)
h.ServeHTTP(httptest.NewRecorder(), r)
actual := buf.String()
if !strings.Contains(actual, test.ExpectedLogMessage) {
t.Errorf("Log message was\n%s\ninstead of matching \n%s", actual, test.ExpectedLogMessage)
}
})
}
}

30
main.go
View File

@ -3,7 +3,6 @@ package main
import (
"fmt"
"math/rand"
"net/http"
"os"
"os/signal"
"runtime"
@ -11,8 +10,10 @@ import (
"syscall"
"time"
"github.com/justinas/alice"
"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/options"
"github.com/oauth2-proxy/oauth2-proxy/pkg/logger"
"github.com/oauth2-proxy/oauth2-proxy/pkg/middleware"
"github.com/oauth2-proxy/oauth2-proxy/pkg/validation"
)
@ -71,14 +72,29 @@ func main() {
rand.Seed(time.Now().UnixNano())
var handler http.Handler
if opts.GCPHealthChecks {
handler = redirectToHTTPS(opts, gcpHealthcheck(LoggingHandler(oauthproxy)))
} else {
handler = redirectToHTTPS(opts, LoggingHandler(oauthproxy))
chain := alice.New()
if opts.ForceHTTPS {
chain = chain.Append(newRedirectToHTTPS(opts))
}
healthCheckPaths := []string{opts.PingPath}
healthCheckUserAgents := []string{opts.PingUserAgent}
if opts.GCPHealthChecks {
healthCheckPaths = append(healthCheckPaths, "/liveness_check", "/readiness_check")
healthCheckUserAgents = append(healthCheckUserAgents, "GoogleHC/1.0")
}
// To silence logging of health checks, register the health check handler before
// the logging handler
if opts.Logging.SilencePing {
chain = chain.Append(middleware.NewHealthCheck(healthCheckPaths, healthCheckUserAgents), LoggingHandler)
} else {
chain = chain.Append(LoggingHandler, middleware.NewHealthCheck(healthCheckPaths, healthCheckUserAgents))
}
s := &Server{
Handler: handler,
Handler: chain.Then(oauthproxy),
Opts: opts,
stop: make(chan struct{}, 1),
}

View File

@ -81,9 +81,6 @@ type OAuthProxy struct {
Validator func(string) bool
RobotsPath string
PingPath string
PingUserAgent string
SilencePings bool
SignInPath string
SignOutPath string
OAuthStartPath string
@ -313,9 +310,6 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) *OAuthPro
Validator: validator,
RobotsPath: "/robots.txt",
PingPath: opts.PingPath,
PingUserAgent: opts.PingUserAgent,
SilencePings: opts.Logging.SilencePing,
SignInPath: fmt.Sprintf("%s/sign_in", opts.ProxyPrefix),
SignOutPath: fmt.Sprintf("%s/sign_out", opts.ProxyPrefix),
OAuthStartPath: fmt.Sprintf("%s/start", opts.ProxyPrefix),
@ -468,17 +462,6 @@ func (p *OAuthProxy) RobotsTxt(rw http.ResponseWriter) {
fmt.Fprintf(rw, "User-agent: *\nDisallow: /")
}
// PingPage responds 200 OK to requests
func (p *OAuthProxy) PingPage(rw http.ResponseWriter) {
if p.SilencePings {
if rl, ok := rw.(*responseLogger); ok {
rl.silent = true
}
}
rw.WriteHeader(http.StatusOK)
fmt.Fprintf(rw, "OK")
}
// ErrorPage writes an error response
func (p *OAuthProxy) ErrorPage(rw http.ResponseWriter, code int, title string, message string) {
rw.WriteHeader(code)
@ -684,17 +667,6 @@ func prepareNoCache(w http.ResponseWriter) {
}
}
// IsPingRequest will check if the request appears to be performing a health check
// either via the path it's requesting or by a special User-Agent configuration.
func (p *OAuthProxy) IsPingRequest(req *http.Request) bool {
if req.URL.EscapedPath() == p.PingPath {
return true
}
return p.PingUserAgent != "" && req.Header.Get("User-Agent") == p.PingUserAgent
}
func (p *OAuthProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
if strings.HasPrefix(req.URL.Path, p.ProxyPrefix) {
prepareNoCache(rw)
@ -703,8 +675,6 @@ func (p *OAuthProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
switch path := req.URL.Path; {
case path == p.RobotsPath:
p.RobotsTxt(rw)
case p.IsPingRequest(req):
p.PingPage(rw)
case p.IsWhitelistedRequest(req):
p.serveMux.ServeHTTP(rw, req)
case path == p.SignInPath:

View File

@ -9,7 +9,7 @@ import (
)
// configureLogger is responsible for configuring the logger based on the options given
func configureLogger(o options.Logging, pingPath string, msgs []string) []string {
func configureLogger(o options.Logging, msgs []string) []string {
// Setup the log file
if len(o.File.Filename) > 0 {
// Validate that the file/dir can be written
@ -48,11 +48,7 @@ func configureLogger(o options.Logging, pingPath string, msgs []string) []string
logger.SetAuthTemplate(o.AuthFormat)
logger.SetReqTemplate(o.RequestFormat)
excludePaths := o.ExcludePaths
if o.SilencePing {
excludePaths = append(excludePaths, pingPath)
}
logger.SetExcludePaths(excludePaths)
logger.SetExcludePaths(o.ExcludePaths)
if !o.LocalTime {
logger.SetFlags(logger.Flags() | logger.LUTC)

View File

@ -264,7 +264,7 @@ func Validate(o *options.Options) error {
msgs = parseSignatureKey(o, msgs)
msgs = validateCookieName(o, msgs)
msgs = configureLogger(o.Logging, o.PingPath, msgs)
msgs = configureLogger(o.Logging, msgs)
if o.ReverseProxy {
parser, err := ip.GetRealClientIPParser(o.RealClientIPHeader)