From befcdd9d042fb745d8c22e6ca67e653b29492daf Mon Sep 17 00:00:00 2001 From: Joel Speed Date: Wed, 31 Mar 2021 10:30:42 +0100 Subject: [PATCH] Add pagewriter to upstream proxy --- CHANGELOG.md | 1 + oauthproxy.go | 2 +- pkg/app/pagewriter/pagewriter.go | 70 +++++++++++++ pkg/app/pagewriter/pagewriter_test.go | 143 ++++++++++++++++++++++++++ pkg/upstream/proxy.go | 9 +- pkg/upstream/proxy_test.go | 11 +- 6 files changed, 227 insertions(+), 9 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0d9e7e62..41bf97d5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,7 @@ ## Changes since v7.1.3 +- [#1142](https://github.com/oauth2-proxy/oauth2-proxy/pull/1142) Add pagewriter to upstream proxy (@JoelSpeed) - [#1181](https://github.com/oauth2-proxy/oauth2-proxy/pull/1181) Fix incorrect `cfg` name in show-debug-on-error flag (@iTaybb) # V7.1.3 diff --git a/oauthproxy.go b/oauthproxy.go index 581aad98..5902b9af 100644 --- a/oauthproxy.go +++ b/oauthproxy.go @@ -124,7 +124,7 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr return nil, fmt.Errorf("error initialising page writer: %v", err) } - upstreamProxy, err := upstream.NewProxy(opts.UpstreamServers, opts.GetSignatureData(), pageWriter.ProxyErrorHandler) + upstreamProxy, err := upstream.NewProxy(opts.UpstreamServers, opts.GetSignatureData(), pageWriter) if err != nil { return nil, fmt.Errorf("error initialising upstream proxy: %v", err) } diff --git a/pkg/app/pagewriter/pagewriter.go b/pkg/app/pagewriter/pagewriter.go index 5991e625..c72400f0 100644 --- a/pkg/app/pagewriter/pagewriter.go +++ b/pkg/app/pagewriter/pagewriter.go @@ -101,3 +101,73 @@ func NewWriter(opts Opts) (Writer, error) { staticPageWriter: staticPages, }, nil } + +// WriterFuncs is an implementation of the PageWriter interface based +// on override functions. +// If any of the funcs are not provided, a default implementation will be used. +// This is primarily for us in testing. +type WriterFuncs struct { + SignInPageFunc func(rw http.ResponseWriter, req *http.Request, redirectURL string) + ErrorPageFunc func(rw http.ResponseWriter, opts ErrorPageOpts) + ProxyErrorFunc func(rw http.ResponseWriter, req *http.Request, proxyErr error) + RobotsTxtfunc func(rw http.ResponseWriter, req *http.Request) +} + +// WriteSignInPage implements the Writer interface. +// If the SignInPageFunc is provided, this will be used, else a default +// implementation will be used. +func (w *WriterFuncs) WriteSignInPage(rw http.ResponseWriter, req *http.Request, redirectURL string) { + if w.SignInPageFunc != nil { + w.SignInPageFunc(rw, req, redirectURL) + return + } + + if _, err := rw.Write([]byte("Sign In")); err != nil { + rw.WriteHeader(http.StatusInternalServerError) + } +} + +// WriteErrorPage implements the Writer interface. +// If the ErrorPageFunc is provided, this will be used, else a default +// implementation will be used. +func (w *WriterFuncs) WriteErrorPage(rw http.ResponseWriter, opts ErrorPageOpts) { + if w.ErrorPageFunc != nil { + w.ErrorPageFunc(rw, opts) + return + } + + rw.WriteHeader(opts.Status) + errMsg := fmt.Sprintf("%d - %v", opts.Status, opts.AppError) + if _, err := rw.Write([]byte(errMsg)); err != nil { + rw.WriteHeader(http.StatusInternalServerError) + } +} + +// ProxyErrorHandler implements the Writer interface. +// If the ProxyErrorFunc is provided, this will be used, else a default +// implementation will be used. +func (w *WriterFuncs) ProxyErrorHandler(rw http.ResponseWriter, req *http.Request, proxyErr error) { + if w.ProxyErrorFunc != nil { + w.ProxyErrorFunc(rw, req, proxyErr) + return + } + + w.WriteErrorPage(rw, ErrorPageOpts{ + Status: http.StatusBadGateway, + AppError: proxyErr.Error(), + }) +} + +// WriteRobotsTxt implements the Writer interface. +// If the RobotsTxtfunc is provided, this will be used, else a default +// implementation will be used. +func (w *WriterFuncs) WriteRobotsTxt(rw http.ResponseWriter, req *http.Request) { + if w.RobotsTxtfunc != nil { + w.RobotsTxtfunc(rw, req) + return + } + + if _, err := rw.Write([]byte("Allow: *")); err != nil { + rw.WriteHeader(http.StatusInternalServerError) + } +} diff --git a/pkg/app/pagewriter/pagewriter_test.go b/pkg/app/pagewriter/pagewriter_test.go index eefd2437..2adedd19 100644 --- a/pkg/app/pagewriter/pagewriter_test.go +++ b/pkg/app/pagewriter/pagewriter_test.go @@ -1,6 +1,8 @@ package pagewriter import ( + "errors" + "fmt" "io/ioutil" "net/http" "net/http/httptest" @@ -8,6 +10,7 @@ import ( "path/filepath" . "github.com/onsi/ginkgo" + . "github.com/onsi/ginkgo/extensions/table" . "github.com/onsi/gomega" ) @@ -135,4 +138,144 @@ var _ = Describe("Writer", func() { }) }) }) + + Context("WriterFuncs", func() { + type writerFuncsTableInput struct { + writer Writer + expectedStatus int + expectedBody string + } + + DescribeTable("WriteSignInPage", + func(in writerFuncsTableInput) { + rw := httptest.NewRecorder() + req := httptest.NewRequest("", "/sign-in", nil) + redirectURL := "" + in.writer.WriteSignInPage(rw, req, redirectURL) + + Expect(rw.Result().StatusCode).To(Equal(in.expectedStatus)) + + body, err := ioutil.ReadAll(rw.Result().Body) + Expect(err).ToNot(HaveOccurred()) + Expect(string(body)).To(Equal(in.expectedBody)) + }, + Entry("With no override", writerFuncsTableInput{ + writer: &WriterFuncs{}, + expectedStatus: 200, + expectedBody: "Sign In", + }), + Entry("With an override function", writerFuncsTableInput{ + writer: &WriterFuncs{ + SignInPageFunc: func(rw http.ResponseWriter, req *http.Request, redirectURL string) { + rw.WriteHeader(202) + rw.Write([]byte(fmt.Sprintf("%s %s", req.URL.Path, redirectURL))) + }, + }, + expectedStatus: 202, + expectedBody: "/sign-in ", + }), + ) + + DescribeTable("WriteErrorPage", + func(in writerFuncsTableInput) { + rw := httptest.NewRecorder() + in.writer.WriteErrorPage(rw, ErrorPageOpts{ + Status: http.StatusInternalServerError, + RedirectURL: "", + RequestID: "12345", + AppError: "application error", + }) + + Expect(rw.Result().StatusCode).To(Equal(in.expectedStatus)) + + body, err := ioutil.ReadAll(rw.Result().Body) + Expect(err).ToNot(HaveOccurred()) + Expect(string(body)).To(Equal(in.expectedBody)) + }, + Entry("With no override", writerFuncsTableInput{ + writer: &WriterFuncs{}, + expectedStatus: 500, + expectedBody: "500 - application error", + }), + Entry("With an override function", writerFuncsTableInput{ + writer: &WriterFuncs{ + ErrorPageFunc: func(rw http.ResponseWriter, opts ErrorPageOpts) { + rw.WriteHeader(503) + rw.Write([]byte(fmt.Sprintf("%s %s", opts.RequestID, opts.RedirectURL))) + }, + }, + expectedStatus: 503, + expectedBody: "12345 ", + }), + ) + + DescribeTable("ProxyErrorHandler", + func(in writerFuncsTableInput) { + rw := httptest.NewRecorder() + req := httptest.NewRequest("", "/proxy", nil) + err := errors.New("proxy error") + in.writer.ProxyErrorHandler(rw, req, err) + + Expect(rw.Result().StatusCode).To(Equal(in.expectedStatus)) + + body, err := ioutil.ReadAll(rw.Result().Body) + Expect(err).ToNot(HaveOccurred()) + Expect(string(body)).To(Equal(in.expectedBody)) + }, + Entry("With no override", writerFuncsTableInput{ + writer: &WriterFuncs{}, + expectedStatus: 502, + expectedBody: "502 - proxy error", + }), + Entry("With an override function for the proxy handler", writerFuncsTableInput{ + writer: &WriterFuncs{ + ProxyErrorFunc: func(rw http.ResponseWriter, req *http.Request, proxyErr error) { + rw.WriteHeader(503) + rw.Write([]byte(fmt.Sprintf("%s %v", req.URL.Path, proxyErr))) + }, + }, + expectedStatus: 503, + expectedBody: "/proxy proxy error", + }), + Entry("With an override function for the error page", writerFuncsTableInput{ + writer: &WriterFuncs{ + ErrorPageFunc: func(rw http.ResponseWriter, opts ErrorPageOpts) { + rw.WriteHeader(500) + rw.Write([]byte("Internal Server Error")) + }, + }, + expectedStatus: 500, + expectedBody: "Internal Server Error", + }), + ) + + DescribeTable("WriteRobotsTxt", + func(in writerFuncsTableInput) { + rw := httptest.NewRecorder() + req := httptest.NewRequest("", "/robots.txt", nil) + in.writer.WriteRobotsTxt(rw, req) + + Expect(rw.Result().StatusCode).To(Equal(in.expectedStatus)) + + body, err := ioutil.ReadAll(rw.Result().Body) + Expect(err).ToNot(HaveOccurred()) + Expect(string(body)).To(Equal(in.expectedBody)) + }, + Entry("With no override", writerFuncsTableInput{ + writer: &WriterFuncs{}, + expectedStatus: 200, + expectedBody: "Allow: *", + }), + Entry("With an override function", writerFuncsTableInput{ + writer: &WriterFuncs{ + RobotsTxtfunc: func(rw http.ResponseWriter, req *http.Request) { + rw.WriteHeader(202) + rw.Write([]byte("Disallow: *")) + }, + }, + expectedStatus: 202, + expectedBody: "Disallow: *", + }), + ) + }) }) diff --git a/pkg/upstream/proxy.go b/pkg/upstream/proxy.go index f345158b..2b0ab70e 100644 --- a/pkg/upstream/proxy.go +++ b/pkg/upstream/proxy.go @@ -6,6 +6,7 @@ import ( "net/url" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options" + "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/app/pagewriter" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" ) @@ -15,7 +16,7 @@ type ProxyErrorHandler func(http.ResponseWriter, *http.Request, error) // NewProxy creates a new multiUpstreamProxy that can serve requests directed to // multiple upstreams. -func NewProxy(upstreams options.Upstreams, sigData *options.SignatureData, errorHandler ProxyErrorHandler) (http.Handler, error) { +func NewProxy(upstreams options.Upstreams, sigData *options.SignatureData, writer pagewriter.Writer) (http.Handler, error) { m := &multiUpstreamProxy{ serveMux: http.NewServeMux(), } @@ -34,7 +35,7 @@ func NewProxy(upstreams options.Upstreams, sigData *options.SignatureData, error case fileScheme: m.registerFileServer(upstream, u) case httpScheme, httpsScheme: - m.registerHTTPUpstreamProxy(upstream, u, sigData, errorHandler) + m.registerHTTPUpstreamProxy(upstream, u, sigData, writer) default: return nil, fmt.Errorf("unknown scheme for upstream %q: %q", upstream.ID, u.Scheme) } @@ -66,7 +67,7 @@ func (m *multiUpstreamProxy) registerFileServer(upstream options.Upstream, u *ur } // registerHTTPUpstreamProxy registers a new httpUpstreamProxy based on the configuration given. -func (m *multiUpstreamProxy) registerHTTPUpstreamProxy(upstream options.Upstream, u *url.URL, sigData *options.SignatureData, errorHandler ProxyErrorHandler) { +func (m *multiUpstreamProxy) registerHTTPUpstreamProxy(upstream options.Upstream, u *url.URL, sigData *options.SignatureData, writer pagewriter.Writer) { logger.Printf("mapping path %q => upstream %q", upstream.Path, upstream.URI) - m.serveMux.Handle(upstream.Path, newHTTPUpstreamProxy(upstream, u, sigData, errorHandler)) + m.serveMux.Handle(upstream.Path, newHTTPUpstreamProxy(upstream, u, sigData, writer.ProxyErrorHandler)) } diff --git a/pkg/upstream/proxy_test.go b/pkg/upstream/proxy_test.go index 44483585..96b7a0dd 100644 --- a/pkg/upstream/proxy_test.go +++ b/pkg/upstream/proxy_test.go @@ -9,6 +9,7 @@ import ( middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options" + "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/app/pagewriter" . "github.com/onsi/ginkgo" . "github.com/onsi/ginkgo/extensions/table" . "github.com/onsi/gomega" @@ -20,9 +21,11 @@ var _ = Describe("Proxy Suite", func() { BeforeEach(func() { sigData := &options.SignatureData{Hash: crypto.SHA256, Key: "secret"} - errorHandler := func(rw http.ResponseWriter, _ *http.Request, _ error) { - rw.WriteHeader(502) - rw.Write([]byte("Proxy Error")) + writer := &pagewriter.WriterFuncs{ + ProxyErrorFunc: func(rw http.ResponseWriter, _ *http.Request, _ error) { + rw.WriteHeader(502) + rw.Write([]byte("Proxy Error")) + }, } ok := http.StatusOK @@ -58,7 +61,7 @@ var _ = Describe("Proxy Suite", func() { } var err error - upstreamServer, err = NewProxy(upstreams, sigData, errorHandler) + upstreamServer, err = NewProxy(upstreams, sigData, writer) Expect(err).ToNot(HaveOccurred()) })