From 5b95ed30332ca8d6d9ced4863ed2ab0e3612eb8d Mon Sep 17 00:00:00 2001 From: Joel Speed Date: Wed, 27 May 2020 15:13:57 +0100 Subject: [PATCH] Add tests for upstream package --- pkg/upstream/file_test.go | 58 ++++ pkg/upstream/http.go | 8 +- pkg/upstream/http_test.go | 417 ++++++++++++++++++++++++++++ pkg/upstream/proxy_test.go | 182 ++++++++++++ pkg/upstream/static_test.go | 81 ++++++ pkg/upstream/upstream_suite_test.go | 180 ++++++++++++ 6 files changed, 925 insertions(+), 1 deletion(-) create mode 100644 pkg/upstream/file_test.go create mode 100644 pkg/upstream/http_test.go create mode 100644 pkg/upstream/proxy_test.go create mode 100644 pkg/upstream/static_test.go create mode 100644 pkg/upstream/upstream_suite_test.go diff --git a/pkg/upstream/file_test.go b/pkg/upstream/file_test.go new file mode 100644 index 00000000..2da1f078 --- /dev/null +++ b/pkg/upstream/file_test.go @@ -0,0 +1,58 @@ +package upstream + +import ( + "crypto/rand" + "io" + "net/http" + "net/http/httptest" + "os" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/ginkgo/extensions/table" + . "github.com/onsi/gomega" +) + +var _ = Describe("File Server Suite", func() { + var dir string + var handler http.Handler + var id string + + const ( + foo = "foo" + bar = "bar" + baz = "baz" + pageNotFound = "404 page not found\n" + ) + + BeforeEach(func() { + // Generate a random id before each test to check the GAP-Upstream-Address + // is being set correctly + idBytes := make([]byte, 16) + _, err := io.ReadFull(rand.Reader, idBytes) + Expect(err).ToNot(HaveOccurred()) + id = string(idBytes) + + handler = newFileServer(id, "/files", filesDir) + }) + + AfterEach(func() { + Expect(os.RemoveAll(dir)).To(Succeed()) + }) + + DescribeTable("fileServer ServeHTTP", + func(requestPath string, expectedResponseCode int, expectedBody string) { + req := httptest.NewRequest("", requestPath, nil) + rw := httptest.NewRecorder() + handler.ServeHTTP(rw, req) + + Expect(rw.Header().Get("GAP-Upstream-Address")).To(Equal(id)) + Expect(rw.Code).To(Equal(expectedResponseCode)) + Expect(rw.Body.String()).To(Equal(expectedBody)) + }, + Entry("for file foo", "/files/foo", 200, foo), + Entry("for file bar", "/files/bar", 200, bar), + Entry("for file foo/baz", "/files/subdir/baz", 200, baz), + Entry("for a non-existent file inside the path", "/files/baz", 404, pageNotFound), + Entry("for a non-existent file oustide the path", "/baz", 404, pageNotFound), + ) +}) diff --git a/pkg/upstream/http.go b/pkg/upstream/http.go index 0c209103..fa7b2a8a 100644 --- a/pkg/upstream/http.go +++ b/pkg/upstream/http.go @@ -6,6 +6,7 @@ import ( "net/http/httputil" "net/url" "strings" + "time" "github.com/mbland/hmacauth" "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/options" @@ -96,7 +97,12 @@ func newReverseProxy(target *url.URL, upstream options.Upstream, errorHandler Pr proxy := httputil.NewSingleHostReverseProxy(target) // Configure options on the SingleHostReverseProxy - proxy.FlushInterval = *upstream.FlushInterval + if upstream.FlushInterval != nil { + proxy.FlushInterval = *upstream.FlushInterval + } else { + proxy.FlushInterval = 1 * time.Second + } + if upstream.InsecureSkipTLSVerify { proxy.Transport = &http.Transport{ TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, diff --git a/pkg/upstream/http_test.go b/pkg/upstream/http_test.go new file mode 100644 index 00000000..8c601880 --- /dev/null +++ b/pkg/upstream/http_test.go @@ -0,0 +1,417 @@ +package upstream + +import ( + "bytes" + "crypto" + "crypto/tls" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "net/http/httputil" + "net/url" + "strings" + "time" + + "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/options" + . "github.com/onsi/ginkgo" + . "github.com/onsi/ginkgo/extensions/table" + . "github.com/onsi/gomega" + "golang.org/x/net/websocket" +) + +var _ = Describe("HTTP Upstream Suite", func() { + + const flushInterval5s = 5 * time.Second + const flushInterval1s = 1 * time.Second + + type httpUpstreamTableInput struct { + id string + serverAddr *string + target string + method string + body []byte + signatureData *options.SignatureData + existingHeaders map[string]string + expectedResponse testHTTPResponse + errorHandler ProxyErrorHandler + } + + DescribeTable("HTTP Upstream ServeHTTP", + func(in *httpUpstreamTableInput) { + buf := bytes.NewBuffer(in.body) + req := httptest.NewRequest(in.method, in.target, buf) + // Don't mock the remote Address + req.RemoteAddr = "" + + for key, value := range in.existingHeaders { + req.Header.Add(key, value) + } + + rw := httptest.NewRecorder() + + flush := 1 * time.Second + upstream := options.Upstream{ + ID: in.id, + PassHostHeader: true, + ProxyWebSockets: false, + InsecureSkipTLSVerify: false, + FlushInterval: &flush, + } + + Expect(in.serverAddr).ToNot(BeNil()) + u, err := url.Parse(*in.serverAddr) + Expect(err).ToNot(HaveOccurred()) + + handler := newHTTPUpstreamProxy(upstream, u, in.signatureData, in.errorHandler) + handler.ServeHTTP(rw, req) + + Expect(rw.Code).To(Equal(in.expectedResponse.code)) + + // Delete extra headers that aren't relevant to tests + testSanitizeResponseHeader(rw.Header()) + Expect(rw.Header()).To(Equal(in.expectedResponse.header)) + + body := rw.Body.Bytes() + if in.expectedResponse.raw != "" || rw.Code != http.StatusOK { + Expect(string(body)).To(Equal(in.expectedResponse.raw)) + return + } + + // Compare the reflected request to the upstream + request := testHTTPRequest{} + Expect(json.Unmarshal(body, &request)).To(Succeed()) + testSanitizeRequestHeader(request.Header) + Expect(request).To(Equal(in.expectedResponse.request)) + }, + Entry("request a path on the server", &httpUpstreamTableInput{ + id: "default", + serverAddr: &serverAddr, + target: "http://example.localhost/foo", + method: "GET", + body: []byte{}, + errorHandler: nil, + expectedResponse: testHTTPResponse{ + code: 200, + header: map[string][]string{ + gapUpstream: {"default"}, + contentType: {applicationJSON}, + }, + request: testHTTPRequest{ + Method: "GET", + URL: "http://example.localhost/foo", + Header: map[string][]string{}, + Body: []byte{}, + Host: "example.localhost", + RequestURI: "http://example.localhost/foo", + }, + }, + }), + Entry("request a path with encoded slashes", &httpUpstreamTableInput{ + id: "encodedSlashes", + serverAddr: &serverAddr, + target: "http://example.localhost/foo%2fbar/?baz=1", + method: "GET", + body: []byte{}, + errorHandler: nil, + expectedResponse: testHTTPResponse{ + code: 200, + header: map[string][]string{ + gapUpstream: {"encodedSlashes"}, + contentType: {applicationJSON}, + }, + request: testHTTPRequest{ + Method: "GET", + URL: "http://example.localhost/foo%2fbar/?baz=1", + Header: map[string][]string{}, + Body: []byte{}, + Host: "example.localhost", + RequestURI: "http://example.localhost/foo%2fbar/?baz=1", + }, + }, + }), + Entry("when the request has a body", &httpUpstreamTableInput{ + id: "requestWithBody", + serverAddr: &serverAddr, + target: "http://example.localhost/withBody", + method: "POST", + body: []byte("body"), + errorHandler: nil, + expectedResponse: testHTTPResponse{ + code: 200, + header: map[string][]string{ + gapUpstream: {"requestWithBody"}, + contentType: {applicationJSON}, + }, + request: testHTTPRequest{ + Method: "POST", + URL: "http://example.localhost/withBody", + Header: map[string][]string{ + contentLength: {"4"}, + }, + Body: []byte("body"), + Host: "example.localhost", + RequestURI: "http://example.localhost/withBody", + }, + }, + }), + Entry("when the upstream is unavailable", &httpUpstreamTableInput{ + id: "unavailableUpstream", + serverAddr: &invalidServer, + target: "http://example.localhost/unavailableUpstream", + method: "GET", + body: []byte{}, + errorHandler: nil, + expectedResponse: testHTTPResponse{ + code: 502, + header: map[string][]string{ + gapUpstream: {"unavailableUpstream"}, + }, + request: testHTTPRequest{}, + }, + }), + Entry("when the upstream is unavailable and an error handler is set", &httpUpstreamTableInput{ + id: "withErrorHandler", + serverAddr: &invalidServer, + target: "http://example.localhost/withErrorHandler", + method: "GET", + body: []byte{}, + errorHandler: func(rw http.ResponseWriter, _ *http.Request, _ error) { + rw.WriteHeader(502) + rw.Write([]byte("error")) + }, + expectedResponse: testHTTPResponse{ + code: 502, + header: map[string][]string{ + gapUpstream: {"withErrorHandler"}, + }, + raw: "error", + request: testHTTPRequest{}, + }, + }), + Entry("with a signature", &httpUpstreamTableInput{ + id: "withSignature", + serverAddr: &serverAddr, + target: "http://example.localhost/withSignature", + method: "GET", + body: []byte{}, + signatureData: &options.SignatureData{ + Hash: crypto.SHA256, + Key: "key", + }, + errorHandler: nil, + expectedResponse: testHTTPResponse{ + code: 200, + header: map[string][]string{ + contentType: {applicationJSON}, + gapUpstream: {"withSignature"}, + }, + request: testHTTPRequest{ + Method: "GET", + URL: "http://example.localhost/withSignature", + Header: map[string][]string{ + gapAuth: {""}, + gapSignature: {"sha256 osMWI8Rr0Zr5HgNq6wakrgJITVJQMmFN1fXCesrqrmM="}, + }, + Body: []byte{}, + Host: "example.localhost", + RequestURI: "http://example.localhost/withSignature", + }, + }, + }), + Entry("with existing headers", &httpUpstreamTableInput{ + id: "existingHeaders", + serverAddr: &serverAddr, + target: "http://example.localhost/existingHeaders", + method: "GET", + body: []byte{}, + errorHandler: nil, + existingHeaders: map[string]string{ + "Header1": "value1", + "Header2": "value2", + }, + expectedResponse: testHTTPResponse{ + code: 200, + header: map[string][]string{ + gapUpstream: {"existingHeaders"}, + contentType: {applicationJSON}, + }, + request: testHTTPRequest{ + Method: "GET", + URL: "http://example.localhost/existingHeaders", + Header: map[string][]string{ + "Header1": {"value1"}, + "Header2": {"value2"}, + }, + Body: []byte{}, + Host: "example.localhost", + RequestURI: "http://example.localhost/existingHeaders", + }, + }, + }), + ) + + It("ServeHTTP, when not passing a host header", func() { + req := httptest.NewRequest("", "http://example.localhost/foo", nil) + rw := httptest.NewRecorder() + + flush := 1 * time.Second + upstream := options.Upstream{ + ID: "noPassHost", + PassHostHeader: false, + ProxyWebSockets: false, + InsecureSkipTLSVerify: false, + FlushInterval: &flush, + } + + u, err := url.Parse(serverAddr) + Expect(err).ToNot(HaveOccurred()) + + handler := newHTTPUpstreamProxy(upstream, u, nil, nil) + httpUpstream, ok := handler.(*httpUpstreamProxy) + Expect(ok).To(BeTrue()) + + // Override the handler to just run the director and not actually send the request + requestInterceptor := func(h http.Handler) http.Handler { + return http.HandlerFunc(func(_ http.ResponseWriter, req *http.Request) { + proxy, ok := h.(*httputil.ReverseProxy) + Expect(ok).To(BeTrue()) + proxy.Director(req) + }) + } + httpUpstream.handler = requestInterceptor(httpUpstream.handler) + + httpUpstream.ServeHTTP(rw, req) + Expect(req.Host).To(Equal(strings.TrimPrefix(serverAddr, "http://"))) + }) + + type newUpstreamTableInput struct { + proxyWebSockets bool + flushInterval time.Duration + skipVerify bool + sigData *options.SignatureData + errorHandler func(http.ResponseWriter, *http.Request, error) + } + + DescribeTable("newHTTPUpstreamProxy", + func(in *newUpstreamTableInput) { + u, err := url.Parse("http://upstream:1234") + Expect(err).ToNot(HaveOccurred()) + + upstream := options.Upstream{ + ID: "foo123", + FlushInterval: &in.flushInterval, + InsecureSkipTLSVerify: in.skipVerify, + ProxyWebSockets: in.proxyWebSockets, + } + + handler := newHTTPUpstreamProxy(upstream, u, in.sigData, in.errorHandler) + upstreamProxy, ok := handler.(*httpUpstreamProxy) + Expect(ok).To(BeTrue()) + + Expect(upstreamProxy.auth != nil).To(Equal(in.sigData != nil)) + Expect(upstreamProxy.wsHandler != nil).To(Equal(in.proxyWebSockets)) + Expect(upstreamProxy.upstream).To(Equal(upstream.ID)) + Expect(upstreamProxy.handler).ToNot(BeNil()) + + proxy, ok := upstreamProxy.handler.(*httputil.ReverseProxy) + Expect(ok).To(BeTrue()) + Expect(proxy.FlushInterval).To(Equal(in.flushInterval)) + Expect(proxy.ErrorHandler != nil).To(Equal(in.errorHandler != nil)) + if in.skipVerify { + Expect(proxy.Transport).To(Equal(&http.Transport{ + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, + })) + } + }, + Entry("with proxy websockets", &newUpstreamTableInput{ + proxyWebSockets: true, + flushInterval: flushInterval1s, + skipVerify: false, + sigData: nil, + errorHandler: nil, + }), + Entry("with a non standard flush interval", &newUpstreamTableInput{ + proxyWebSockets: false, + flushInterval: flushInterval5s, + skipVerify: false, + sigData: nil, + errorHandler: nil, + }), + Entry("with a InsecureSkipTLSVerify", &newUpstreamTableInput{ + proxyWebSockets: false, + flushInterval: flushInterval1s, + skipVerify: true, + sigData: nil, + errorHandler: nil, + }), + Entry("with a SignatureData", &newUpstreamTableInput{ + proxyWebSockets: false, + flushInterval: flushInterval1s, + skipVerify: false, + sigData: &options.SignatureData{Hash: crypto.SHA256, Key: "secret"}, + errorHandler: nil, + }), + Entry("with an error handler", &newUpstreamTableInput{ + proxyWebSockets: false, + flushInterval: flushInterval1s, + skipVerify: false, + sigData: nil, + errorHandler: func(rw http.ResponseWriter, req *http.Request, arg3 error) { + rw.WriteHeader(502) + }, + }), + ) + + Context("with a websocket proxy", func() { + var proxyServer *httptest.Server + + BeforeEach(func() { + flush := 1 * time.Second + upstream := options.Upstream{ + ID: "websocketProxy", + PassHostHeader: true, + ProxyWebSockets: true, + InsecureSkipTLSVerify: false, + FlushInterval: &flush, + } + + u, err := url.Parse(serverAddr) + Expect(err).ToNot(HaveOccurred()) + + handler := newHTTPUpstreamProxy(upstream, u, nil, nil) + proxyServer = httptest.NewServer(handler) + }) + + AfterEach(func() { + proxyServer.Close() + }) + + It("will proxy websockets", func() { + origin := "http://example.localhost" + message := "Hello, world!" + + proxyURL, err := url.Parse(fmt.Sprintf("http://%s", proxyServer.Listener.Addr().String())) + Expect(err).ToNot(HaveOccurred()) + + wsAddr := fmt.Sprintf("ws://%s/", proxyURL.Host) + ws, err := websocket.Dial(wsAddr, "", origin) + Expect(err).ToNot(HaveOccurred()) + + Expect(websocket.Message.Send(ws, []byte(message))).To(Succeed()) + var response testWebSocketResponse + Expect(websocket.JSON.Receive(ws, &response)).To(Succeed()) + Expect(response).To(Equal(testWebSocketResponse{ + Message: message, + Origin: origin, + })) + }) + + It("will proxy HTTP requests", func() { + response, err := http.Get(fmt.Sprintf("http://%s", proxyServer.Listener.Addr().String())) + Expect(err).ToNot(HaveOccurred()) + Expect(response.StatusCode).To(Equal(200)) + Expect(response.Header.Get(gapUpstream)).To(Equal("websocketProxy")) + }) + }) +}) diff --git a/pkg/upstream/proxy_test.go b/pkg/upstream/proxy_test.go new file mode 100644 index 00000000..945fb665 --- /dev/null +++ b/pkg/upstream/proxy_test.go @@ -0,0 +1,182 @@ +package upstream + +import ( + "crypto" + "encoding/json" + "fmt" + "html/template" + "net/http" + "net/http/httptest" + + "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/options" + . "github.com/onsi/ginkgo" + . "github.com/onsi/ginkgo/extensions/table" + . "github.com/onsi/gomega" +) + +var _ = Describe("Proxy Suite", func() { + var upstreamServer http.Handler + + BeforeEach(func() { + sigData := &options.SignatureData{Hash: crypto.SHA256, Key: "secret"} + + tmpl, err := template.New("").Parse("{{ .Title }}\n{{ .Message }}\n{{ .ProxyPrefix }}") + Expect(err).ToNot(HaveOccurred()) + errorHandler := NewProxyErrorHandler(tmpl, "prefix") + + ok := http.StatusOK + + upstreams := options.Upstreams{ + { + ID: "http-backend", + Path: "/http/", + URI: serverAddr, + }, + { + ID: "file-backend", + Path: "/files/", + URI: fmt.Sprintf("file:///%s", filesDir), + }, + { + ID: "static-backend", + Path: "/static/", + Static: true, + StaticCode: &ok, + }, + { + ID: "bad-http-backend", + Path: "/bad-http/", + URI: "http://::1", + }, + { + ID: "single-path-backend", + Path: "/single-path", + Static: true, + StaticCode: &ok, + }, + } + + upstreamServer, err = NewProxy(upstreams, sigData, errorHandler) + Expect(err).ToNot(HaveOccurred()) + }) + + type proxyTableInput struct { + target string + response testHTTPResponse + } + + DescribeTable("Proxy ServerHTTP", + func(in *proxyTableInput) { + req := httptest.NewRequest("", in.target, nil) + rw := httptest.NewRecorder() + // Don't mock the remote Address + req.RemoteAddr = "" + + upstreamServer.ServeHTTP(rw, req) + + Expect(rw.Code).To(Equal(in.response.code)) + + // Delete extra headers that aren't relevant to tests + testSanitizeResponseHeader(rw.Header()) + Expect(rw.Header()).To(Equal(in.response.header)) + + body := rw.Body.Bytes() + // If the raw body is set, check that, else check the Request object + if in.response.raw != "" { + Expect(string(body)).To(Equal(in.response.raw)) + return + } + + // Compare the reflected request to the upstream + request := testHTTPRequest{} + Expect(json.Unmarshal(body, &request)).To(Succeed()) + testSanitizeRequestHeader(request.Header) + Expect(request).To(Equal(in.response.request)) + }, + Entry("with a request to the HTTP service", &proxyTableInput{ + target: "http://example.localhost/http/1234", + response: testHTTPResponse{ + code: 200, + header: map[string][]string{ + gapUpstream: {"http-backend"}, + contentType: {applicationJSON}, + }, + request: testHTTPRequest{ + Method: "GET", + URL: "http://example.localhost/http/1234", + Header: map[string][]string{ + "Gap-Auth": {""}, + "Gap-Signature": {"sha256 ofB1u6+FhEUbFLc3/uGbJVkl7GaN4egFqVvyO3+2I1w="}, + }, + Body: []byte{}, + Host: "example.localhost", + RequestURI: "http://example.localhost/http/1234", + }, + }, + }), + Entry("with a request to the File backend", &proxyTableInput{ + target: "http://example.localhost/files/foo", + response: testHTTPResponse{ + code: 200, + header: map[string][]string{ + contentType: {textPlainUTF8}, + gapUpstream: {"file-backend"}, + }, + raw: "foo", + }, + }), + Entry("with a request to the Static backend", &proxyTableInput{ + target: "http://example.localhost/static/bar", + response: testHTTPResponse{ + code: 200, + header: map[string][]string{ + gapUpstream: {"static-backend"}, + }, + raw: "Authenticated", + }, + }), + Entry("with a request to the bad HTTP backend", &proxyTableInput{ + target: "http://example.localhost/bad-http/bad", + response: testHTTPResponse{ + code: 502, + header: map[string][]string{ + gapUpstream: {"bad-http-backend"}, + }, + // This tests the error handler + raw: "Bad Gateway\nError proxying to upstream server\nprefix", + }, + }), + Entry("with a request to the to an unregistered path", &proxyTableInput{ + target: "http://example.localhost/unregistered", + response: testHTTPResponse{ + code: 404, + header: map[string][]string{ + "X-Content-Type-Options": {"nosniff"}, + contentType: {textPlainUTF8}, + }, + raw: "404 page not found\n", + }, + }), + Entry("with a request to the to backend registered to a single path", &proxyTableInput{ + target: "http://example.localhost/single-path", + response: testHTTPResponse{ + code: 200, + header: map[string][]string{ + gapUpstream: {"single-path-backend"}, + }, + raw: "Authenticated", + }, + }), + Entry("with a request to the to a subpath of a backend registered to a single path", &proxyTableInput{ + target: "http://example.localhost/single-path/unregistered", + response: testHTTPResponse{ + code: 404, + header: map[string][]string{ + "X-Content-Type-Options": {"nosniff"}, + contentType: {textPlainUTF8}, + }, + raw: "404 page not found\n", + }, + }), + ) +}) diff --git a/pkg/upstream/static_test.go b/pkg/upstream/static_test.go new file mode 100644 index 00000000..1b7309f7 --- /dev/null +++ b/pkg/upstream/static_test.go @@ -0,0 +1,81 @@ +package upstream + +import ( + "crypto/rand" + "io" + "net/http" + "net/http/httptest" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/ginkgo/extensions/table" + . "github.com/onsi/gomega" +) + +var _ = Describe("Static Response Suite", func() { + const authenticated = "Authenticated" + var id string + + BeforeEach(func() { + // Generate a random id before each test to check the GAP-Upstream-Address + // is being set correctly + idBytes := make([]byte, 16) + _, err := io.ReadFull(rand.Reader, idBytes) + Expect(err).ToNot(HaveOccurred()) + id = string(idBytes) + }) + + type serveHTTPTableInput struct { + requestPath string + staticCode int + expectedBody string + expectedCode int + } + + DescribeTable("staticResponse ServeHTTP", + func(in *serveHTTPTableInput) { + var code *int + if in.staticCode != 0 { + code = &in.staticCode + } + handler := newStaticResponseHandler(id, code) + + req := httptest.NewRequest("", in.requestPath, nil) + rw := httptest.NewRecorder() + handler.ServeHTTP(rw, req) + + Expect(rw.Header().Get("GAP-Upstream-Address")).To(Equal(id)) + Expect(rw.Code).To(Equal(in.expectedCode)) + Expect(rw.Body.String()).To(Equal(in.expectedBody)) + }, + Entry("with no given code", &serveHTTPTableInput{ + requestPath: "/", + staticCode: 0, // Placeholder for nil + expectedBody: authenticated, + expectedCode: http.StatusOK, + }), + Entry("with status OK", &serveHTTPTableInput{ + requestPath: "/abc", + staticCode: http.StatusOK, + expectedBody: authenticated, + expectedCode: http.StatusOK, + }), + Entry("with status NoContent", &serveHTTPTableInput{ + requestPath: "/def", + staticCode: http.StatusNoContent, + expectedBody: authenticated, + expectedCode: http.StatusNoContent, + }), + Entry("with status NotFound", &serveHTTPTableInput{ + requestPath: "/ghi", + staticCode: http.StatusNotFound, + expectedBody: authenticated, + expectedCode: http.StatusNotFound, + }), + Entry("with status Teapot", &serveHTTPTableInput{ + requestPath: "/jkl", + staticCode: http.StatusTeapot, + expectedBody: authenticated, + expectedCode: http.StatusTeapot, + }), + ) +}) diff --git a/pkg/upstream/upstream_suite_test.go b/pkg/upstream/upstream_suite_test.go new file mode 100644 index 00000000..7d8c2ba4 --- /dev/null +++ b/pkg/upstream/upstream_suite_test.go @@ -0,0 +1,180 @@ +package upstream + +import ( + "encoding/json" + "fmt" + "io/ioutil" + "log" + "net/http" + "net/http/httptest" + "os" + "path" + "testing" + + "github.com/oauth2-proxy/oauth2-proxy/pkg/logger" + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" + "golang.org/x/net/websocket" +) + +var ( + filesDir string + server *httptest.Server + serverAddr string + invalidServer = "http://::1" +) + +func TestUpstreamSuite(t *testing.T) { + logger.SetOutput(GinkgoWriter) + log.SetOutput(GinkgoWriter) + + RegisterFailHandler(Fail) + RunSpecs(t, "Upstream Suite") +} + +var _ = BeforeSuite(func() { + // Set up files for serving via file servers + dir, err := ioutil.TempDir("", "oauth2-proxy-upstream-suite") + Expect(err).ToNot(HaveOccurred()) + Expect(ioutil.WriteFile(path.Join(dir, "foo"), []byte("foo"), 0644)).To(Succeed()) + Expect(ioutil.WriteFile(path.Join(dir, "bar"), []byte("bar"), 0644)).To(Succeed()) + Expect(os.Mkdir(path.Join(dir, "subdir"), os.ModePerm)).To(Succeed()) + Expect(ioutil.WriteFile(path.Join(dir, "subdir", "baz"), []byte("baz"), 0644)).To(Succeed()) + filesDir = dir + + // Set up a webserver that reflects requests + server = httptest.NewServer(&testHTTPUpstream{}) + serverAddr = fmt.Sprintf("http://%s", server.Listener.Addr().String()) +}) + +var _ = AfterSuite(func() { + server.Close() + Expect(os.RemoveAll(filesDir)).To(Succeed()) +}) + +const ( + contentType = "Content-Type" + contentLength = "Content-Length" + acceptEncoding = "Accept-Encoding" + applicationJSON = "application/json" + textPlainUTF8 = "text/plain; charset=utf-8" + gapUpstream = "Gap-Upstream-Address" + gapAuth = "Gap-Auth" + gapSignature = "Gap-Signature" +) + +// testHTTPResponse is a struct used for checking responses in table tests +type testHTTPResponse struct { + code int + header http.Header + raw string + request testHTTPRequest +} + +// testHTTPRequest is a struct used to capture the state of a request made to +// an upstream during a test +type testHTTPRequest struct { + Method string + URL string + Header http.Header + Body []byte + Host string + RequestURI string +} + +type testWebSocketResponse struct { + Message string + Origin string +} + +type testHTTPUpstream struct{} + +func (t *testHTTPUpstream) ServeHTTP(rw http.ResponseWriter, req *http.Request) { + if req.Header.Get("Upgrade") == "websocket" { + t.websocketHandler().ServeHTTP(rw, req) + } else { + t.serveHTTP(rw, req) + } +} + +func (t *testHTTPUpstream) serveHTTP(rw http.ResponseWriter, req *http.Request) { + request, err := toTestHTTPRequest(req) + if err != nil { + t.writeError(rw, err) + return + } + + data, err := json.Marshal(request) + if err != nil { + t.writeError(rw, err) + return + } + + rw.Header().Set("Content-Type", "application/json") + rw.Write(data) +} + +func (t *testHTTPUpstream) websocketHandler() http.Handler { + return websocket.Handler(func(ws *websocket.Conn) { + defer ws.Close() + var data []byte + err := websocket.Message.Receive(ws, &data) + if err != nil { + websocket.Message.Send(ws, []byte(err.Error())) + return + } + + wsResponse := testWebSocketResponse{ + Message: string(data), + Origin: ws.Request().Header.Get("Origin"), + } + err = websocket.JSON.Send(ws, wsResponse) + if err != nil { + websocket.Message.Send(ws, []byte(err.Error())) + return + } + }) +} + +func (t *testHTTPUpstream) writeError(rw http.ResponseWriter, err error) { + rw.WriteHeader(500) + if err != nil { + rw.Write([]byte(err.Error())) + } +} + +func toTestHTTPRequest(req *http.Request) (testHTTPRequest, error) { + requestBody := []byte{} + if req.Body != http.NoBody { + var err error + requestBody, err = ioutil.ReadAll(req.Body) + if err != nil { + return testHTTPRequest{}, err + } + } + + return testHTTPRequest{ + Method: req.Method, + URL: req.URL.String(), + Header: req.Header, + Body: requestBody, + Host: req.Host, + RequestURI: req.RequestURI, + }, nil +} + +// String headers added to the response that we do not want to test +func testSanitizeResponseHeader(h http.Header) { + // From HTTP responses + h.Del("Date") + h.Del(contentLength) + + // From File responses + h.Del("Accept-Ranges") + h.Del("Last-Modified") +} + +// Strip the accept header that is added by the HTTP Transport +func testSanitizeRequestHeader(h http.Header) { + h.Del(acceptEncoding) +}