From d2d62bb45293fc5690355e80d55b841c041427bf Mon Sep 17 00:00:00 2001 From: Joel Speed Date: Mon, 22 Feb 2021 15:55:08 +0000 Subject: [PATCH] Replace standard serve mux with gorilla mux --- go.mod | 2 +- pkg/upstream/proxy.go | 22 +++++++++++++++++----- 2 files changed, 18 insertions(+), 6 deletions(-) diff --git a/go.mod b/go.mod index df007c82..c5da2c7d 100644 --- a/go.mod +++ b/go.mod @@ -16,7 +16,7 @@ require ( github.com/ghodss/yaml v1.0.1-0.20190212211648-25d852aebe32 github.com/go-redis/redis/v8 v8.2.3 github.com/google/uuid v1.2.0 - github.com/gorilla/mux v1.8.0 // indirect + github.com/gorilla/mux v1.8.0 github.com/justinas/alice v1.2.0 github.com/mbland/hmacauth v0.0.0-20170912233209-44256dfd4bfa github.com/mitchellh/mapstructure v1.1.2 diff --git a/pkg/upstream/proxy.go b/pkg/upstream/proxy.go index 2b0ab70e..e4b22bff 100644 --- a/pkg/upstream/proxy.go +++ b/pkg/upstream/proxy.go @@ -4,7 +4,9 @@ import ( "fmt" "net/http" "net/url" + "strings" + "github.com/gorilla/mux" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/app/pagewriter" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" @@ -18,7 +20,7 @@ type ProxyErrorHandler func(http.ResponseWriter, *http.Request, error) // multiple upstreams. func NewProxy(upstreams options.Upstreams, sigData *options.SignatureData, writer pagewriter.Writer) (http.Handler, error) { m := &multiUpstreamProxy{ - serveMux: http.NewServeMux(), + serveMux: mux.NewRouter(), } for _, upstream := range upstreams { @@ -46,7 +48,7 @@ func NewProxy(upstreams options.Upstreams, sigData *options.SignatureData, write // multiUpstreamProxy will serve requests directed to multiple upstream servers // registered in the serverMux. type multiUpstreamProxy struct { - serveMux *http.ServeMux + serveMux *mux.Router } // ServerHTTP handles HTTP requests. @@ -57,17 +59,27 @@ func (m *multiUpstreamProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request // registerStaticResponseHandler registers a static response handler with at the given path. func (m *multiUpstreamProxy) registerStaticResponseHandler(upstream options.Upstream) { logger.Printf("mapping path %q => static response %d", upstream.Path, derefStaticCode(upstream.StaticCode)) - m.serveMux.Handle(upstream.Path, newStaticResponseHandler(upstream.ID, upstream.StaticCode)) + m.registerSimpleHandler(upstream.Path, newStaticResponseHandler(upstream.ID, upstream.StaticCode)) } // registerFileServer registers a new fileServer based on the configuration given. func (m *multiUpstreamProxy) registerFileServer(upstream options.Upstream, u *url.URL) { logger.Printf("mapping path %q => file system %q", upstream.Path, u.Path) - m.serveMux.Handle(upstream.Path, newFileServer(upstream.ID, upstream.Path, u.Path)) + m.registerSimpleHandler(upstream.Path, newFileServer(upstream.ID, upstream.Path, u.Path)) } // registerHTTPUpstreamProxy registers a new httpUpstreamProxy based on the configuration given. func (m *multiUpstreamProxy) registerHTTPUpstreamProxy(upstream options.Upstream, u *url.URL, sigData *options.SignatureData, writer pagewriter.Writer) { logger.Printf("mapping path %q => upstream %q", upstream.Path, upstream.URI) - m.serveMux.Handle(upstream.Path, newHTTPUpstreamProxy(upstream, u, sigData, writer.ProxyErrorHandler)) + m.registerSimpleHandler(upstream.Path, newHTTPUpstreamProxy(upstream, u, sigData, writer.ProxyErrorHandler)) +} + +// registerSimpleHandler maintains the behaviour of the go standard serveMux +// by ensuring any path with a trailing `/` matches all paths under that prefix. +func (m *multiUpstreamProxy) registerSimpleHandler(path string, handler http.Handler) { + if strings.HasSuffix(path, "/") { + m.serveMux.PathPrefix(path).Handler(handler) + } else { + m.serveMux.Path(path).Handler(handler) + } }