From 21ef86b5942cba6d3d708394b3d485bde4685886 Mon Sep 17 00:00:00 2001 From: Joel Speed Date: Sat, 4 Jul 2020 08:09:35 +0100 Subject: [PATCH] Add tests for the request builder --- pkg/requests/builder_test.go | 384 ++++++++++++++++++++++++++++ pkg/requests/requests_suite_test.go | 96 +++++++ 2 files changed, 480 insertions(+) create mode 100644 pkg/requests/builder_test.go create mode 100644 pkg/requests/requests_suite_test.go diff --git a/pkg/requests/builder_test.go b/pkg/requests/builder_test.go new file mode 100644 index 00000000..b859fa23 --- /dev/null +++ b/pkg/requests/builder_test.go @@ -0,0 +1,384 @@ +package requests + +import ( + "bytes" + "context" + "encoding/base64" + "encoding/json" + "fmt" + "io/ioutil" + "net/http" + + "github.com/bitly/go-simplejson" + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Builder suite", func() { + var b Builder + getBuilder := func() Builder { return b } + + baseHeaders := http.Header{ + "Accept-Encoding": []string{"gzip"}, + "User-Agent": []string{"Go-http-client/1.1"}, + } + + BeforeEach(func() { + // Most tests will request the server address + b = New(serverAddr + "/json/path") + }) + + Context("with a basic request", func() { + assertSuccessfulRequest(getBuilder, testHTTPRequest{ + Method: "GET", + Header: baseHeaders, + Body: []byte{}, + RequestURI: "/json/path", + }) + }) + + Context("with a context", func() { + var ctx context.Context + var cancel context.CancelFunc + + BeforeEach(func() { + ctx, cancel = context.WithCancel(context.Background()) + b = b.WithContext(ctx) + }) + + AfterEach(func() { + cancel() + }) + + assertSuccessfulRequest(getBuilder, testHTTPRequest{ + Method: "GET", + Header: baseHeaders, + Body: []byte{}, + RequestURI: "/json/path", + }) + + Context("if the context is cancelled", func() { + BeforeEach(func() { + cancel() + }) + + assertRequestError(getBuilder, "context canceled") + }) + }) + + Context("with a body", func() { + const body = "{\"some\": \"body\"}" + header := baseHeaders.Clone() + header.Set("Content-Length", fmt.Sprintf("%d", len(body))) + + BeforeEach(func() { + buf := bytes.NewBuffer([]byte(body)) + b = b.WithBody(buf) + }) + + assertSuccessfulRequest(getBuilder, testHTTPRequest{ + Method: "GET", + Header: header, + Body: []byte(body), + RequestURI: "/json/path", + }) + }) + + Context("with a method", func() { + Context("POST with a body", func() { + const body = "{\"some\": \"body\"}" + header := baseHeaders.Clone() + header.Set("Content-Length", fmt.Sprintf("%d", len(body))) + + BeforeEach(func() { + buf := bytes.NewBuffer([]byte(body)) + b = b.WithMethod("POST").WithBody(buf) + }) + + assertSuccessfulRequest(getBuilder, testHTTPRequest{ + Method: "POST", + Header: header, + Body: []byte(body), + RequestURI: "/json/path", + }) + }) + + Context("POST without a body", func() { + header := baseHeaders.Clone() + header.Set("Content-Length", "0") + + BeforeEach(func() { + b = b.WithMethod("POST") + }) + + assertSuccessfulRequest(getBuilder, testHTTPRequest{ + Method: "POST", + Header: header, + Body: []byte{}, + RequestURI: "/json/path", + }) + }) + + Context("OPTIONS", func() { + BeforeEach(func() { + b = b.WithMethod("OPTIONS") + }) + + assertSuccessfulRequest(getBuilder, testHTTPRequest{ + Method: "OPTIONS", + Header: baseHeaders, + Body: []byte{}, + RequestURI: "/json/path", + }) + }) + + Context("INVALID-\\t-METHOD", func() { + BeforeEach(func() { + b = b.WithMethod("INVALID-\t-METHOD") + }) + + assertRequestError(getBuilder, "error creating request: net/http: invalid method \"INVALID-\\t-METHOD\"") + }) + }) + + Context("with headers", func() { + Context("setting a header", func() { + header := baseHeaders.Clone() + header.Set("header", "value") + + BeforeEach(func() { + b = b.SetHeader("header", "value") + }) + + assertSuccessfulRequest(getBuilder, testHTTPRequest{ + Method: "GET", + Header: header, + Body: []byte{}, + RequestURI: "/json/path", + }) + + Context("then replacing the headers", func() { + replacementHeaders := http.Header{ + "Accept-Encoding": []string{"*"}, + "User-Agent": []string{"test-agent"}, + "Foo": []string{"bar, baz"}, + } + + BeforeEach(func() { + b = b.WithHeaders(replacementHeaders) + }) + + assertSuccessfulRequest(getBuilder, testHTTPRequest{ + Method: "GET", + Header: replacementHeaders, + Body: []byte{}, + RequestURI: "/json/path", + }) + }) + }) + + Context("replacing the header", func() { + replacementHeaders := http.Header{ + "Accept-Encoding": []string{"*"}, + "User-Agent": []string{"test-agent"}, + "Foo": []string{"bar, baz"}, + } + + BeforeEach(func() { + b = b.WithHeaders(replacementHeaders) + }) + + assertSuccessfulRequest(getBuilder, testHTTPRequest{ + Method: "GET", + Header: replacementHeaders, + Body: []byte{}, + RequestURI: "/json/path", + }) + + Context("then setting a header", func() { + header := replacementHeaders.Clone() + header.Set("User-Agent", "different-agent") + + BeforeEach(func() { + b = b.SetHeader("User-Agent", "different-agent") + }) + + assertSuccessfulRequest(getBuilder, testHTTPRequest{ + Method: "GET", + Header: header, + Body: []byte{}, + RequestURI: "/json/path", + }) + }) + }) + }) + + Context("if the request has been completed and then modified", func() { + BeforeEach(func() { + _, err := b.Do() + Expect(err).ToNot(HaveOccurred()) + + b.WithMethod("POST") + }) + + Context("should not redo the request", func() { + assertSuccessfulRequest(getBuilder, testHTTPRequest{ + Method: "GET", + Header: baseHeaders, + Body: []byte{}, + RequestURI: "/json/path", + }) + }) + }) + + Context("when the requested page is not found", func() { + BeforeEach(func() { + b = New(serverAddr + "/not-found") + }) + + assertJSONError(getBuilder, "404 page not found") + }) + + Context("when the requested page is not valid JSON", func() { + BeforeEach(func() { + b = New(serverAddr + "/string/path") + }) + + assertJSONError(getBuilder, "invalid character 'O' looking for beginning of value") + }) +}) + +func assertSuccessfulRequest(builder func() Builder, expectedRequest testHTTPRequest) { + Context("Do", func() { + var resp *http.Response + + BeforeEach(func() { + var err error + resp, err = builder().Do() + Expect(err).ToNot(HaveOccurred()) + }) + + It("returns a successful status", func() { + Expect(resp.StatusCode).To(Equal(http.StatusOK)) + }) + + It("made the expected request", func() { + body, err := ioutil.ReadAll(resp.Body) + Expect(err).ToNot(HaveOccurred()) + resp.Body.Close() + + actualRequest := testHTTPRequest{} + Expect(json.Unmarshal(body, &actualRequest)).To(Succeed()) + + Expect(actualRequest).To(Equal(expectedRequest)) + }) + }) + + Context("UnmarshalInto", func() { + var actualRequest testHTTPRequest + + BeforeEach(func() { + Expect(builder().UnmarshalInto(&actualRequest)).To(Succeed()) + }) + + It("made the expected request", func() { + Expect(actualRequest).To(Equal(expectedRequest)) + }) + }) + + Context("UnmarshalJSON", func() { + var response *simplejson.Json + + BeforeEach(func() { + var err error + response, err = builder().UnmarshalJSON() + Expect(err).ToNot(HaveOccurred()) + }) + + It("made the expected reqest", func() { + header := http.Header{} + for key, value := range response.Get("Header").MustMap() { + vs, ok := value.([]interface{}) + Expect(ok).To(BeTrue()) + svs := []string{} + for _, v := range vs { + sv, ok := v.(string) + Expect(ok).To(BeTrue()) + svs = append(svs, sv) + } + header[key] = svs + } + + // Other json unmarhsallers base64 decode byte slices automatically + body, err := base64.StdEncoding.DecodeString(response.Get("Body").MustString()) + Expect(err).ToNot(HaveOccurred()) + + actualRequest := testHTTPRequest{ + Method: response.Get("Method").MustString(), + Header: header, + Body: body, + RequestURI: response.Get("RequestURI").MustString(), + } + + Expect(actualRequest).To(Equal(expectedRequest)) + }) + }) +} + +func assertRequestError(builder func() Builder, errorMessage string) { + Context("Do", func() { + It("returns an error", func() { + resp, err := builder().Do() + Expect(err).To(MatchError(ContainSubstring(errorMessage))) + Expect(resp).To(BeNil()) + }) + }) + + Context("UnmarshalInto", func() { + It("returns an error", func() { + var actualRequest testHTTPRequest + err := builder().UnmarshalInto(&actualRequest) + Expect(err).To(MatchError(ContainSubstring(errorMessage))) + + // Should be empty + Expect(actualRequest).To(Equal(testHTTPRequest{})) + }) + }) + + Context("UnmarshalJSON", func() { + It("returns an error", func() { + resp, err := builder().UnmarshalJSON() + Expect(err).To(MatchError(ContainSubstring(errorMessage))) + Expect(resp).To(BeNil()) + }) + }) +} + +func assertJSONError(builder func() Builder, errorMessage string) { + Context("Do", func() { + It("does not return an error", func() { + resp, err := builder().Do() + Expect(err).To(BeNil()) + Expect(resp).ToNot(BeNil()) + }) + }) + + Context("UnmarshalInto", func() { + It("returns an error", func() { + var actualRequest testHTTPRequest + err := builder().UnmarshalInto(&actualRequest) + Expect(err).To(MatchError(ContainSubstring(errorMessage))) + + // Should be empty + Expect(actualRequest).To(Equal(testHTTPRequest{})) + }) + }) + + Context("UnmarshalJSON", func() { + It("returns an error", func() { + resp, err := builder().UnmarshalJSON() + Expect(err).To(MatchError(ContainSubstring(errorMessage))) + Expect(resp).To(BeNil()) + }) + }) +} diff --git a/pkg/requests/requests_suite_test.go b/pkg/requests/requests_suite_test.go new file mode 100644 index 00000000..83da733a --- /dev/null +++ b/pkg/requests/requests_suite_test.go @@ -0,0 +1,96 @@ +package requests + +import ( + "encoding/json" + "fmt" + "io/ioutil" + "log" + "net/http" + "net/http/httptest" + "testing" + + "github.com/oauth2-proxy/oauth2-proxy/pkg/logger" + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var ( + server *httptest.Server + serverAddr string +) + +func TestRequetsSuite(t *testing.T) { + logger.SetOutput(GinkgoWriter) + log.SetOutput(GinkgoWriter) + + RegisterFailHandler(Fail) + RunSpecs(t, "Requests Suite") +} + +var _ = BeforeSuite(func() { + // Set up a webserver that reflects requests + mux := http.NewServeMux() + mux.Handle("/json/", &testHTTPUpstream{}) + mux.HandleFunc("/string/", func(rw http.ResponseWriter, _ *http.Request) { + rw.Write([]byte("OK")) + }) + server = httptest.NewServer(mux) + serverAddr = fmt.Sprintf("http://%s", server.Listener.Addr().String()) +}) + +var _ = AfterSuite(func() { + server.Close() +}) + +// testHTTPRequest is a struct used to capture the state of a request made to +// the test server +type testHTTPRequest struct { + Method string + Header http.Header + Body []byte + RequestURI string +} + +type testHTTPUpstream struct{} + +func (t *testHTTPUpstream) ServeHTTP(rw http.ResponseWriter, req *http.Request) { + request, err := toTestHTTPRequest(req) + if err != nil { + t.writeError(rw, err) + return + } + + data, err := json.Marshal(request) + if err != nil { + t.writeError(rw, err) + return + } + + rw.Header().Set("Content-Type", "application/json") + rw.Write(data) +} + +func (t *testHTTPUpstream) writeError(rw http.ResponseWriter, err error) { + rw.WriteHeader(500) + if err != nil { + rw.Write([]byte(err.Error())) + } +} + +func toTestHTTPRequest(req *http.Request) (testHTTPRequest, error) { + requestBody := []byte{} + if req.Body != http.NoBody { + var err error + requestBody, err = ioutil.ReadAll(req.Body) + if err != nil { + return testHTTPRequest{}, err + } + } + + return testHTTPRequest{ + Method: req.Method, + Header: req.Header, + Body: requestBody, + RequestURI: req.RequestURI, + }, nil +}