1
0
mirror of https://github.com/oauth2-proxy/oauth2-proxy.git synced 2025-08-06 22:42:56 +02:00

Add pagewriter to upstream proxy

This commit is contained in:
Joel Speed
2021-03-31 10:30:42 +01:00
committed by Joel Speed
parent 725ae543d5
commit befcdd9d04
6 changed files with 227 additions and 9 deletions

View File

@ -101,3 +101,73 @@ func NewWriter(opts Opts) (Writer, error) {
staticPageWriter: staticPages,
}, nil
}
// WriterFuncs is an implementation of the PageWriter interface based
// on override functions.
// If any of the funcs are not provided, a default implementation will be used.
// This is primarily for us in testing.
type WriterFuncs struct {
SignInPageFunc func(rw http.ResponseWriter, req *http.Request, redirectURL string)
ErrorPageFunc func(rw http.ResponseWriter, opts ErrorPageOpts)
ProxyErrorFunc func(rw http.ResponseWriter, req *http.Request, proxyErr error)
RobotsTxtfunc func(rw http.ResponseWriter, req *http.Request)
}
// WriteSignInPage implements the Writer interface.
// If the SignInPageFunc is provided, this will be used, else a default
// implementation will be used.
func (w *WriterFuncs) WriteSignInPage(rw http.ResponseWriter, req *http.Request, redirectURL string) {
if w.SignInPageFunc != nil {
w.SignInPageFunc(rw, req, redirectURL)
return
}
if _, err := rw.Write([]byte("Sign In")); err != nil {
rw.WriteHeader(http.StatusInternalServerError)
}
}
// WriteErrorPage implements the Writer interface.
// If the ErrorPageFunc is provided, this will be used, else a default
// implementation will be used.
func (w *WriterFuncs) WriteErrorPage(rw http.ResponseWriter, opts ErrorPageOpts) {
if w.ErrorPageFunc != nil {
w.ErrorPageFunc(rw, opts)
return
}
rw.WriteHeader(opts.Status)
errMsg := fmt.Sprintf("%d - %v", opts.Status, opts.AppError)
if _, err := rw.Write([]byte(errMsg)); err != nil {
rw.WriteHeader(http.StatusInternalServerError)
}
}
// ProxyErrorHandler implements the Writer interface.
// If the ProxyErrorFunc is provided, this will be used, else a default
// implementation will be used.
func (w *WriterFuncs) ProxyErrorHandler(rw http.ResponseWriter, req *http.Request, proxyErr error) {
if w.ProxyErrorFunc != nil {
w.ProxyErrorFunc(rw, req, proxyErr)
return
}
w.WriteErrorPage(rw, ErrorPageOpts{
Status: http.StatusBadGateway,
AppError: proxyErr.Error(),
})
}
// WriteRobotsTxt implements the Writer interface.
// If the RobotsTxtfunc is provided, this will be used, else a default
// implementation will be used.
func (w *WriterFuncs) WriteRobotsTxt(rw http.ResponseWriter, req *http.Request) {
if w.RobotsTxtfunc != nil {
w.RobotsTxtfunc(rw, req)
return
}
if _, err := rw.Write([]byte("Allow: *")); err != nil {
rw.WriteHeader(http.StatusInternalServerError)
}
}

View File

@ -1,6 +1,8 @@
package pagewriter
import (
"errors"
"fmt"
"io/ioutil"
"net/http"
"net/http/httptest"
@ -8,6 +10,7 @@ import (
"path/filepath"
. "github.com/onsi/ginkgo"
. "github.com/onsi/ginkgo/extensions/table"
. "github.com/onsi/gomega"
)
@ -135,4 +138,144 @@ var _ = Describe("Writer", func() {
})
})
})
Context("WriterFuncs", func() {
type writerFuncsTableInput struct {
writer Writer
expectedStatus int
expectedBody string
}
DescribeTable("WriteSignInPage",
func(in writerFuncsTableInput) {
rw := httptest.NewRecorder()
req := httptest.NewRequest("", "/sign-in", nil)
redirectURL := "<redirectURL>"
in.writer.WriteSignInPage(rw, req, redirectURL)
Expect(rw.Result().StatusCode).To(Equal(in.expectedStatus))
body, err := ioutil.ReadAll(rw.Result().Body)
Expect(err).ToNot(HaveOccurred())
Expect(string(body)).To(Equal(in.expectedBody))
},
Entry("With no override", writerFuncsTableInput{
writer: &WriterFuncs{},
expectedStatus: 200,
expectedBody: "Sign In",
}),
Entry("With an override function", writerFuncsTableInput{
writer: &WriterFuncs{
SignInPageFunc: func(rw http.ResponseWriter, req *http.Request, redirectURL string) {
rw.WriteHeader(202)
rw.Write([]byte(fmt.Sprintf("%s %s", req.URL.Path, redirectURL)))
},
},
expectedStatus: 202,
expectedBody: "/sign-in <redirectURL>",
}),
)
DescribeTable("WriteErrorPage",
func(in writerFuncsTableInput) {
rw := httptest.NewRecorder()
in.writer.WriteErrorPage(rw, ErrorPageOpts{
Status: http.StatusInternalServerError,
RedirectURL: "<redirectURL>",
RequestID: "12345",
AppError: "application error",
})
Expect(rw.Result().StatusCode).To(Equal(in.expectedStatus))
body, err := ioutil.ReadAll(rw.Result().Body)
Expect(err).ToNot(HaveOccurred())
Expect(string(body)).To(Equal(in.expectedBody))
},
Entry("With no override", writerFuncsTableInput{
writer: &WriterFuncs{},
expectedStatus: 500,
expectedBody: "500 - application error",
}),
Entry("With an override function", writerFuncsTableInput{
writer: &WriterFuncs{
ErrorPageFunc: func(rw http.ResponseWriter, opts ErrorPageOpts) {
rw.WriteHeader(503)
rw.Write([]byte(fmt.Sprintf("%s %s", opts.RequestID, opts.RedirectURL)))
},
},
expectedStatus: 503,
expectedBody: "12345 <redirectURL>",
}),
)
DescribeTable("ProxyErrorHandler",
func(in writerFuncsTableInput) {
rw := httptest.NewRecorder()
req := httptest.NewRequest("", "/proxy", nil)
err := errors.New("proxy error")
in.writer.ProxyErrorHandler(rw, req, err)
Expect(rw.Result().StatusCode).To(Equal(in.expectedStatus))
body, err := ioutil.ReadAll(rw.Result().Body)
Expect(err).ToNot(HaveOccurred())
Expect(string(body)).To(Equal(in.expectedBody))
},
Entry("With no override", writerFuncsTableInput{
writer: &WriterFuncs{},
expectedStatus: 502,
expectedBody: "502 - proxy error",
}),
Entry("With an override function for the proxy handler", writerFuncsTableInput{
writer: &WriterFuncs{
ProxyErrorFunc: func(rw http.ResponseWriter, req *http.Request, proxyErr error) {
rw.WriteHeader(503)
rw.Write([]byte(fmt.Sprintf("%s %v", req.URL.Path, proxyErr)))
},
},
expectedStatus: 503,
expectedBody: "/proxy proxy error",
}),
Entry("With an override function for the error page", writerFuncsTableInput{
writer: &WriterFuncs{
ErrorPageFunc: func(rw http.ResponseWriter, opts ErrorPageOpts) {
rw.WriteHeader(500)
rw.Write([]byte("Internal Server Error"))
},
},
expectedStatus: 500,
expectedBody: "Internal Server Error",
}),
)
DescribeTable("WriteRobotsTxt",
func(in writerFuncsTableInput) {
rw := httptest.NewRecorder()
req := httptest.NewRequest("", "/robots.txt", nil)
in.writer.WriteRobotsTxt(rw, req)
Expect(rw.Result().StatusCode).To(Equal(in.expectedStatus))
body, err := ioutil.ReadAll(rw.Result().Body)
Expect(err).ToNot(HaveOccurred())
Expect(string(body)).To(Equal(in.expectedBody))
},
Entry("With no override", writerFuncsTableInput{
writer: &WriterFuncs{},
expectedStatus: 200,
expectedBody: "Allow: *",
}),
Entry("With an override function", writerFuncsTableInput{
writer: &WriterFuncs{
RobotsTxtfunc: func(rw http.ResponseWriter, req *http.Request) {
rw.WriteHeader(202)
rw.Write([]byte("Disallow: *"))
},
},
expectedStatus: 202,
expectedBody: "Disallow: *",
}),
)
})
})