diff --git a/CHANGELOG.md b/CHANGELOG.md index e2da49ce..bca36842 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,7 @@ ## Changes since v6.0.0 +- [#660](https://github.com/oauth2-proxy/oauth2-proxy/pull/660) Use builder pattern to simplify requests to external endpoints (@JoelSpeed) - [#591](https://github.com/oauth2-proxy/oauth2-proxy/pull/591) Introduce upstream package with new reverse proxy implementation (@JoelSpeed) - [#576](https://github.com/oauth2-proxy/oauth2-proxy/pull/576) Separate Cookie validation out of main options validation (@JoelSpeed) - [#656](https://github.com/oauth2-proxy/oauth2-proxy/pull/656) Split long session cookies more precisely (@NickMeves) diff --git a/pkg/requests/builder.go b/pkg/requests/builder.go new file mode 100644 index 00000000..95d88101 --- /dev/null +++ b/pkg/requests/builder.go @@ -0,0 +1,118 @@ +package requests + +import ( + "context" + "fmt" + "io" + "io/ioutil" + "net/http" +) + +// Builder allows users to construct a request and then execute the +// request via Do(). +// Do returns a Result which allows the user to get the body, +// unmarshal the body into an interface, or into a simplejson.Json. +type Builder interface { + WithContext(context.Context) Builder + WithBody(io.Reader) Builder + WithMethod(string) Builder + WithHeaders(http.Header) Builder + SetHeader(key, value string) Builder + Do() Result +} + +type builder struct { + context context.Context + method string + endpoint string + body io.Reader + header http.Header + result *result +} + +// New provides a new Builder for the given endpoint. +func New(endpoint string) Builder { + return &builder{ + endpoint: endpoint, + method: "GET", + } +} + +// WithContext adds a context to the request. +// If no context is provided, context.Background() is used instead. +func (r *builder) WithContext(ctx context.Context) Builder { + r.context = ctx + return r +} + +// WithBody adds a body to the request. +func (r *builder) WithBody(body io.Reader) Builder { + r.body = body + return r +} + +// WithMethod sets the request method. Defaults to "GET". +func (r *builder) WithMethod(method string) Builder { + r.method = method + return r +} + +// WithHeaders replaces the request header map with the given header map. +func (r *builder) WithHeaders(header http.Header) Builder { + r.header = header + return r +} + +// SetHeader sets a single header to the given value. +// May be used to add multiple headers. +func (r *builder) SetHeader(key, value string) Builder { + if r.header == nil { + r.header = make(http.Header) + } + r.header.Set(key, value) + return r +} + +// Do performs the request and returns the response in its raw form. +// If the request has already been performed, returns the previous result. +// This will not allow you to repeat a request. +func (r *builder) Do() Result { + if r.result != nil { + // Request has already been done + return r.result + } + + // Must provide a non-nil context to NewRequestWithContext + if r.context == nil { + r.context = context.Background() + } + + return r.do() +} + +// do creates the request, executes it with the default client and extracts the +// the body into the response +func (r *builder) do() Result { + req, err := http.NewRequestWithContext(r.context, r.method, r.endpoint, r.body) + if err != nil { + r.result = &result{err: fmt.Errorf("error creating request: %v", err)} + return r.result + } + req.Header = r.header + + resp, err := http.DefaultClient.Do(req) + if err != nil { + r.result = &result{err: fmt.Errorf("error performing request: %v", err)} + return r.result + } + + defer resp.Body.Close() + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + r.result = &result{err: fmt.Errorf("error reading response body: %v", err)} + return r.result + } + + r.result = &result{response: resp, body: body} + return r.result +} diff --git a/pkg/requests/builder_test.go b/pkg/requests/builder_test.go new file mode 100644 index 00000000..0c0f0d03 --- /dev/null +++ b/pkg/requests/builder_test.go @@ -0,0 +1,376 @@ +package requests + +import ( + "bytes" + "context" + "encoding/base64" + "encoding/json" + "fmt" + "net/http" + + "github.com/bitly/go-simplejson" + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Builder suite", func() { + var b Builder + getBuilder := func() Builder { return b } + + baseHeaders := http.Header{ + "Accept-Encoding": []string{"gzip"}, + "User-Agent": []string{"Go-http-client/1.1"}, + } + + BeforeEach(func() { + // Most tests will request the server address + b = New(serverAddr + "/json/path") + }) + + Context("with a basic request", func() { + assertSuccessfulRequest(getBuilder, testHTTPRequest{ + Method: "GET", + Header: baseHeaders, + Body: []byte{}, + RequestURI: "/json/path", + }) + }) + + Context("with a context", func() { + var ctx context.Context + var cancel context.CancelFunc + + BeforeEach(func() { + ctx, cancel = context.WithCancel(context.Background()) + b = b.WithContext(ctx) + }) + + AfterEach(func() { + cancel() + }) + + assertSuccessfulRequest(getBuilder, testHTTPRequest{ + Method: "GET", + Header: baseHeaders, + Body: []byte{}, + RequestURI: "/json/path", + }) + + Context("if the context is cancelled", func() { + BeforeEach(func() { + cancel() + }) + + assertRequestError(getBuilder, "context canceled") + }) + }) + + Context("with a body", func() { + const body = "{\"some\": \"body\"}" + header := baseHeaders.Clone() + header.Set("Content-Length", fmt.Sprintf("%d", len(body))) + + BeforeEach(func() { + buf := bytes.NewBuffer([]byte(body)) + b = b.WithBody(buf) + }) + + assertSuccessfulRequest(getBuilder, testHTTPRequest{ + Method: "GET", + Header: header, + Body: []byte(body), + RequestURI: "/json/path", + }) + }) + + Context("with a method", func() { + Context("POST with a body", func() { + const body = "{\"some\": \"body\"}" + header := baseHeaders.Clone() + header.Set("Content-Length", fmt.Sprintf("%d", len(body))) + + BeforeEach(func() { + buf := bytes.NewBuffer([]byte(body)) + b = b.WithMethod("POST").WithBody(buf) + }) + + assertSuccessfulRequest(getBuilder, testHTTPRequest{ + Method: "POST", + Header: header, + Body: []byte(body), + RequestURI: "/json/path", + }) + }) + + Context("POST without a body", func() { + header := baseHeaders.Clone() + header.Set("Content-Length", "0") + + BeforeEach(func() { + b = b.WithMethod("POST") + }) + + assertSuccessfulRequest(getBuilder, testHTTPRequest{ + Method: "POST", + Header: header, + Body: []byte{}, + RequestURI: "/json/path", + }) + }) + + Context("OPTIONS", func() { + BeforeEach(func() { + b = b.WithMethod("OPTIONS") + }) + + assertSuccessfulRequest(getBuilder, testHTTPRequest{ + Method: "OPTIONS", + Header: baseHeaders, + Body: []byte{}, + RequestURI: "/json/path", + }) + }) + + Context("INVALID-\\t-METHOD", func() { + BeforeEach(func() { + b = b.WithMethod("INVALID-\t-METHOD") + }) + + assertRequestError(getBuilder, "error creating request: net/http: invalid method \"INVALID-\\t-METHOD\"") + }) + }) + + Context("with headers", func() { + Context("setting a header", func() { + header := baseHeaders.Clone() + header.Set("header", "value") + + BeforeEach(func() { + b = b.SetHeader("header", "value") + }) + + assertSuccessfulRequest(getBuilder, testHTTPRequest{ + Method: "GET", + Header: header, + Body: []byte{}, + RequestURI: "/json/path", + }) + + Context("then replacing the headers", func() { + replacementHeaders := http.Header{ + "Accept-Encoding": []string{"*"}, + "User-Agent": []string{"test-agent"}, + "Foo": []string{"bar, baz"}, + } + + BeforeEach(func() { + b = b.WithHeaders(replacementHeaders) + }) + + assertSuccessfulRequest(getBuilder, testHTTPRequest{ + Method: "GET", + Header: replacementHeaders, + Body: []byte{}, + RequestURI: "/json/path", + }) + }) + }) + + Context("replacing the header", func() { + replacementHeaders := http.Header{ + "Accept-Encoding": []string{"*"}, + "User-Agent": []string{"test-agent"}, + "Foo": []string{"bar, baz"}, + } + + BeforeEach(func() { + b = b.WithHeaders(replacementHeaders) + }) + + assertSuccessfulRequest(getBuilder, testHTTPRequest{ + Method: "GET", + Header: replacementHeaders, + Body: []byte{}, + RequestURI: "/json/path", + }) + + Context("then setting a header", func() { + header := replacementHeaders.Clone() + header.Set("User-Agent", "different-agent") + + BeforeEach(func() { + b = b.SetHeader("User-Agent", "different-agent") + }) + + assertSuccessfulRequest(getBuilder, testHTTPRequest{ + Method: "GET", + Header: header, + Body: []byte{}, + RequestURI: "/json/path", + }) + }) + }) + }) + + Context("if the request has been completed and then modified", func() { + BeforeEach(func() { + result := b.Do() + Expect(result.Error()).ToNot(HaveOccurred()) + + b.WithMethod("POST") + }) + + Context("should not redo the request", func() { + assertSuccessfulRequest(getBuilder, testHTTPRequest{ + Method: "GET", + Header: baseHeaders, + Body: []byte{}, + RequestURI: "/json/path", + }) + }) + }) + + Context("when the requested page is not found", func() { + BeforeEach(func() { + b = New(serverAddr + "/not-found") + }) + + assertJSONError(getBuilder, "404 page not found") + }) + + Context("when the requested page is not valid JSON", func() { + BeforeEach(func() { + b = New(serverAddr + "/string/path") + }) + + assertJSONError(getBuilder, "invalid character 'O' looking for beginning of value") + }) +}) + +func assertSuccessfulRequest(builder func() Builder, expectedRequest testHTTPRequest) { + Context("Do", func() { + var result Result + + BeforeEach(func() { + result = builder().Do() + Expect(result.Error()).ToNot(HaveOccurred()) + }) + + It("returns a successful status", func() { + Expect(result.StatusCode()).To(Equal(http.StatusOK)) + }) + + It("made the expected request", func() { + actualRequest := testHTTPRequest{} + Expect(json.Unmarshal(result.Body(), &actualRequest)).To(Succeed()) + + Expect(actualRequest).To(Equal(expectedRequest)) + }) + }) + + Context("UnmarshalInto", func() { + var actualRequest testHTTPRequest + + BeforeEach(func() { + Expect(builder().Do().UnmarshalInto(&actualRequest)).To(Succeed()) + }) + + It("made the expected request", func() { + Expect(actualRequest).To(Equal(expectedRequest)) + }) + }) + + Context("UnmarshalJSON", func() { + var response *simplejson.Json + + BeforeEach(func() { + var err error + response, err = builder().Do().UnmarshalJSON() + Expect(err).ToNot(HaveOccurred()) + }) + + It("made the expected reqest", func() { + header := http.Header{} + for key, value := range response.Get("Header").MustMap() { + vs, ok := value.([]interface{}) + Expect(ok).To(BeTrue()) + svs := []string{} + for _, v := range vs { + sv, ok := v.(string) + Expect(ok).To(BeTrue()) + svs = append(svs, sv) + } + header[key] = svs + } + + // Other json unmarhsallers base64 decode byte slices automatically + body, err := base64.StdEncoding.DecodeString(response.Get("Body").MustString()) + Expect(err).ToNot(HaveOccurred()) + + actualRequest := testHTTPRequest{ + Method: response.Get("Method").MustString(), + Header: header, + Body: body, + RequestURI: response.Get("RequestURI").MustString(), + } + + Expect(actualRequest).To(Equal(expectedRequest)) + }) + }) +} + +func assertRequestError(builder func() Builder, errorMessage string) { + Context("Do", func() { + It("returns an error", func() { + result := builder().Do() + Expect(result.Error()).To(MatchError(ContainSubstring(errorMessage))) + }) + }) + + Context("UnmarshalInto", func() { + It("returns an error", func() { + var actualRequest testHTTPRequest + err := builder().Do().UnmarshalInto(&actualRequest) + Expect(err).To(MatchError(ContainSubstring(errorMessage))) + + // Should be empty + Expect(actualRequest).To(Equal(testHTTPRequest{})) + }) + }) + + Context("UnmarshalJSON", func() { + It("returns an error", func() { + resp, err := builder().Do().UnmarshalJSON() + Expect(err).To(MatchError(ContainSubstring(errorMessage))) + Expect(resp).To(BeNil()) + }) + }) +} + +func assertJSONError(builder func() Builder, errorMessage string) { + Context("Do", func() { + It("does not return an error", func() { + result := builder().Do() + Expect(result.Error()).To(BeNil()) + }) + }) + + Context("UnmarshalInto", func() { + It("returns an error", func() { + var actualRequest testHTTPRequest + err := builder().Do().UnmarshalInto(&actualRequest) + Expect(err).To(MatchError(ContainSubstring(errorMessage))) + + // Should be empty + Expect(actualRequest).To(Equal(testHTTPRequest{})) + }) + }) + + Context("UnmarshalJSON", func() { + It("returns an error", func() { + resp, err := builder().Do().UnmarshalJSON() + Expect(err).To(MatchError(ContainSubstring(errorMessage))) + Expect(resp).To(BeNil()) + }) + }) +} diff --git a/pkg/requests/requests.go b/pkg/requests/requests.go deleted file mode 100644 index 64cacaa9..00000000 --- a/pkg/requests/requests.go +++ /dev/null @@ -1,74 +0,0 @@ -package requests - -import ( - "context" - "encoding/json" - "fmt" - "io/ioutil" - "net/http" - - "github.com/bitly/go-simplejson" - "github.com/oauth2-proxy/oauth2-proxy/pkg/logger" -) - -// Request parses the request body into a simplejson.Json object -func Request(req *http.Request) (*simplejson.Json, error) { - resp, err := http.DefaultClient.Do(req) - if err != nil { - logger.Printf("%s %s %s", req.Method, req.URL, err) - return nil, err - } - body, err := ioutil.ReadAll(resp.Body) - if body != nil { - defer resp.Body.Close() - } - - logger.Printf("%d %s %s %s", resp.StatusCode, req.Method, req.URL, body) - - if err != nil { - return nil, fmt.Errorf("problem reading http request body: %w", err) - } - - if resp.StatusCode != 200 { - return nil, fmt.Errorf("got %d %s", resp.StatusCode, body) - } - - data, err := simplejson.NewJson(body) - if err != nil { - return nil, fmt.Errorf("error unmarshalling json: %w", err) - } - return data, nil -} - -// RequestJSON parses the request body into the given interface -func RequestJSON(req *http.Request, v interface{}) error { - resp, err := http.DefaultClient.Do(req) - if err != nil { - logger.Printf("%s %s %s", req.Method, req.URL, err) - return err - } - body, err := ioutil.ReadAll(resp.Body) - if body != nil { - defer resp.Body.Close() - } - - logger.Printf("%d %s %s %s", resp.StatusCode, req.Method, req.URL, body) - if err != nil { - return fmt.Errorf("error reading body from http response: %w", err) - } - if resp.StatusCode != 200 { - return fmt.Errorf("got %d %s", resp.StatusCode, body) - } - return json.Unmarshal(body, v) -} - -// RequestUnparsedResponse performs a GET and returns the raw response object -func RequestUnparsedResponse(ctx context.Context, url string, header http.Header) (resp *http.Response, err error) { - req, err := http.NewRequestWithContext(ctx, "GET", url, nil) - if err != nil { - return nil, fmt.Errorf("error performing get request: %w", err) - } - req.Header = header - - return http.DefaultClient.Do(req) -} diff --git a/pkg/requests/requests_suite_test.go b/pkg/requests/requests_suite_test.go new file mode 100644 index 00000000..83da733a --- /dev/null +++ b/pkg/requests/requests_suite_test.go @@ -0,0 +1,96 @@ +package requests + +import ( + "encoding/json" + "fmt" + "io/ioutil" + "log" + "net/http" + "net/http/httptest" + "testing" + + "github.com/oauth2-proxy/oauth2-proxy/pkg/logger" + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var ( + server *httptest.Server + serverAddr string +) + +func TestRequetsSuite(t *testing.T) { + logger.SetOutput(GinkgoWriter) + log.SetOutput(GinkgoWriter) + + RegisterFailHandler(Fail) + RunSpecs(t, "Requests Suite") +} + +var _ = BeforeSuite(func() { + // Set up a webserver that reflects requests + mux := http.NewServeMux() + mux.Handle("/json/", &testHTTPUpstream{}) + mux.HandleFunc("/string/", func(rw http.ResponseWriter, _ *http.Request) { + rw.Write([]byte("OK")) + }) + server = httptest.NewServer(mux) + serverAddr = fmt.Sprintf("http://%s", server.Listener.Addr().String()) +}) + +var _ = AfterSuite(func() { + server.Close() +}) + +// testHTTPRequest is a struct used to capture the state of a request made to +// the test server +type testHTTPRequest struct { + Method string + Header http.Header + Body []byte + RequestURI string +} + +type testHTTPUpstream struct{} + +func (t *testHTTPUpstream) ServeHTTP(rw http.ResponseWriter, req *http.Request) { + request, err := toTestHTTPRequest(req) + if err != nil { + t.writeError(rw, err) + return + } + + data, err := json.Marshal(request) + if err != nil { + t.writeError(rw, err) + return + } + + rw.Header().Set("Content-Type", "application/json") + rw.Write(data) +} + +func (t *testHTTPUpstream) writeError(rw http.ResponseWriter, err error) { + rw.WriteHeader(500) + if err != nil { + rw.Write([]byte(err.Error())) + } +} + +func toTestHTTPRequest(req *http.Request) (testHTTPRequest, error) { + requestBody := []byte{} + if req.Body != http.NoBody { + var err error + requestBody, err = ioutil.ReadAll(req.Body) + if err != nil { + return testHTTPRequest{}, err + } + } + + return testHTTPRequest{ + Method: req.Method, + Header: req.Header, + Body: requestBody, + RequestURI: req.RequestURI, + }, nil +} diff --git a/pkg/requests/requests_test.go b/pkg/requests/requests_test.go deleted file mode 100644 index 0c3e4152..00000000 --- a/pkg/requests/requests_test.go +++ /dev/null @@ -1,136 +0,0 @@ -package requests - -import ( - "context" - "io/ioutil" - "net/http" - "net/http/httptest" - "strings" - "testing" - - "github.com/bitly/go-simplejson" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func testBackend(t *testing.T, responseCode int, payload string) *httptest.Server { - return httptest.NewServer(http.HandlerFunc( - func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(responseCode) - _, err := w.Write([]byte(payload)) - require.NoError(t, err) - })) -} - -func TestRequest(t *testing.T) { - backend := testBackend(t, 200, "{\"foo\": \"bar\"}") - defer backend.Close() - - req, _ := http.NewRequest("GET", backend.URL, nil) - response, err := Request(req) - assert.Equal(t, nil, err) - result, err := response.Get("foo").String() - assert.Equal(t, nil, err) - assert.Equal(t, "bar", result) -} - -func TestRequestFailure(t *testing.T) { - // Create a backend to generate a test URL, then close it to cause a - // connection error. - backend := testBackend(t, 200, "{\"foo\": \"bar\"}") - backend.Close() - - req, err := http.NewRequest("GET", backend.URL, nil) - assert.Equal(t, nil, err) - resp, err := Request(req) - assert.Equal(t, (*simplejson.Json)(nil), resp) - assert.NotEqual(t, nil, err) - if !strings.Contains(err.Error(), "refused") { - t.Error("expected error when a connection fails: ", err) - } -} - -func TestHttpErrorCode(t *testing.T) { - backend := testBackend(t, 404, "{\"foo\": \"bar\"}") - defer backend.Close() - - req, err := http.NewRequest("GET", backend.URL, nil) - assert.Equal(t, nil, err) - resp, err := Request(req) - assert.Equal(t, (*simplejson.Json)(nil), resp) - assert.NotEqual(t, nil, err) -} - -func TestJsonParsingError(t *testing.T) { - backend := testBackend(t, 200, "not well-formed JSON") - defer backend.Close() - - req, err := http.NewRequest("GET", backend.URL, nil) - assert.Equal(t, nil, err) - resp, err := Request(req) - assert.Equal(t, (*simplejson.Json)(nil), resp) - assert.NotEqual(t, nil, err) -} - -// Parsing a URL practically never fails, so we won't cover that test case. -func TestRequestUnparsedResponseUsingAccessTokenParameter(t *testing.T) { - backend := httptest.NewServer(http.HandlerFunc( - func(w http.ResponseWriter, r *http.Request) { - token := r.FormValue("access_token") - if r.URL.Path == "/" && token == "my_token" { - w.WriteHeader(200) - _, err := w.Write([]byte("some payload")) - require.NoError(t, err) - } else { - w.WriteHeader(403) - } - })) - defer backend.Close() - - response, err := RequestUnparsedResponse( - context.Background(), backend.URL+"?access_token=my_token", nil) - assert.Equal(t, nil, err) - defer response.Body.Close() - - assert.Equal(t, 200, response.StatusCode) - body, err := ioutil.ReadAll(response.Body) - assert.Equal(t, nil, err) - assert.Equal(t, "some payload", string(body)) -} - -func TestRequestUnparsedResponseUsingAccessTokenParameterFailedResponse(t *testing.T) { - backend := testBackend(t, 200, "some payload") - // Close the backend now to force a request failure. - backend.Close() - - response, err := RequestUnparsedResponse( - context.Background(), backend.URL+"?access_token=my_token", nil) - assert.NotEqual(t, nil, err) - assert.Equal(t, (*http.Response)(nil), response) -} - -func TestRequestUnparsedResponseUsingHeaders(t *testing.T) { - backend := httptest.NewServer(http.HandlerFunc( - func(w http.ResponseWriter, r *http.Request) { - if r.URL.Path == "/" && r.Header["Auth"][0] == "my_token" { - w.WriteHeader(200) - _, err := w.Write([]byte("some payload")) - require.NoError(t, err) - } else { - w.WriteHeader(403) - } - })) - defer backend.Close() - - headers := make(http.Header) - headers.Set("Auth", "my_token") - response, err := RequestUnparsedResponse(context.Background(), backend.URL, headers) - assert.Equal(t, nil, err) - defer response.Body.Close() - - assert.Equal(t, 200, response.StatusCode) - body, err := ioutil.ReadAll(response.Body) - assert.Equal(t, nil, err) - - assert.Equal(t, "some payload", string(body)) -} diff --git a/pkg/requests/result.go b/pkg/requests/result.go new file mode 100644 index 00000000..2574aad5 --- /dev/null +++ b/pkg/requests/result.go @@ -0,0 +1,98 @@ +package requests + +import ( + "encoding/json" + "fmt" + "net/http" + + "github.com/bitly/go-simplejson" +) + +// Result is the result of a request created by a Builder +type Result interface { + Error() error + StatusCode() int + Headers() http.Header + Body() []byte + UnmarshalInto(interface{}) error + UnmarshalJSON() (*simplejson.Json, error) +} + +type result struct { + err error + response *http.Response + body []byte +} + +// Error returns an error from the result if present +func (r *result) Error() error { + return r.err +} + +// StatusCode returns the response's status code +func (r *result) StatusCode() int { + if r.response != nil { + return r.response.StatusCode + } + return 0 +} + +// Headers returns the response's headers +func (r *result) Headers() http.Header { + if r.response != nil { + return r.response.Header + } + return nil +} + +// Body returns the response's body +func (r *result) Body() []byte { + return r.body +} + +// UnmarshalInto attempts to unmarshal the response into the the given interface. +// The response body is assumed to be JSON. +// The response must have a 200 status otherwise an error will be returned. +func (r *result) UnmarshalInto(into interface{}) error { + body, err := r.getBodyForUnmarshal() + if err != nil { + return err + } + + if err := json.Unmarshal(body, into); err != nil { + return fmt.Errorf("error unmarshalling body: %v", err) + } + + return nil +} + +// UnmarshalJSON performs the request and attempts to unmarshal the response into a +// simplejson.Json. The response body is assume to be JSON. +// The response must have a 200 status otherwise an error will be returned. +func (r *result) UnmarshalJSON() (*simplejson.Json, error) { + body, err := r.getBodyForUnmarshal() + if err != nil { + return nil, err + } + + data, err := simplejson.NewJson(body) + if err != nil { + return nil, fmt.Errorf("error reading json: %v", err) + } + return data, nil +} + +// getBodyForUnmarshal returns the body if there wasn't an error and the status +// code was 200. +func (r *result) getBodyForUnmarshal() ([]byte, error) { + if r.Error() != nil { + return nil, r.Error() + } + + // Only unmarshal body if the response was successful + if r.StatusCode() != http.StatusOK { + return nil, fmt.Errorf("unexpected status \"%d\": %s", r.StatusCode(), r.Body()) + } + + return r.Body(), nil +} diff --git a/pkg/requests/result_test.go b/pkg/requests/result_test.go new file mode 100644 index 00000000..2d5c95ca --- /dev/null +++ b/pkg/requests/result_test.go @@ -0,0 +1,326 @@ +package requests + +import ( + "errors" + "net/http" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/ginkgo/extensions/table" + . "github.com/onsi/gomega" +) + +var _ = Describe("Result suite", func() { + Context("with a result", func() { + type resultTableInput struct { + result Result + expectedError error + expectedStatusCode int + expectedHeaders http.Header + expectedBody []byte + } + + DescribeTable("accessors should return expected results", + func(in resultTableInput) { + if in.expectedError != nil { + Expect(in.result.Error()).To(MatchError(in.expectedError)) + } else { + Expect(in.result.Error()).To(BeNil()) + } + + Expect(in.result.StatusCode()).To(Equal(in.expectedStatusCode)) + Expect(in.result.Headers()).To(Equal(in.expectedHeaders)) + Expect(in.result.Body()).To(Equal(in.expectedBody)) + }, + Entry("with an empty result", resultTableInput{ + result: &result{}, + expectedError: nil, + expectedStatusCode: 0, + expectedHeaders: nil, + expectedBody: nil, + }), + Entry("with an error", resultTableInput{ + result: &result{ + err: errors.New("error"), + }, + expectedError: errors.New("error"), + expectedStatusCode: 0, + expectedHeaders: nil, + expectedBody: nil, + }), + Entry("with a response with no headers", resultTableInput{ + result: &result{ + response: &http.Response{ + StatusCode: http.StatusTeapot, + }, + }, + expectedError: nil, + expectedStatusCode: http.StatusTeapot, + expectedHeaders: nil, + expectedBody: nil, + }), + Entry("with a response with no status code", resultTableInput{ + result: &result{ + response: &http.Response{ + Header: http.Header{ + "foo": []string{"bar"}, + }, + }, + }, + expectedError: nil, + expectedStatusCode: 0, + expectedHeaders: http.Header{ + "foo": []string{"bar"}, + }, + expectedBody: nil, + }), + Entry("with a response with a body", resultTableInput{ + result: &result{ + body: []byte("some body"), + }, + expectedError: nil, + expectedStatusCode: 0, + expectedHeaders: nil, + expectedBody: []byte("some body"), + }), + Entry("with all fields", resultTableInput{ + result: &result{ + err: errors.New("some error"), + response: &http.Response{ + StatusCode: http.StatusFound, + Header: http.Header{ + "header": []string{"value"}, + }, + }, + body: []byte("a body"), + }, + expectedError: errors.New("some error"), + expectedStatusCode: http.StatusFound, + expectedHeaders: http.Header{ + "header": []string{"value"}, + }, + expectedBody: []byte("a body"), + }), + ) + }) + + Context("UnmarshalInto", func() { + type testStruct struct { + A string `json:"a"` + B int `json:"b"` + } + + type unmarshalIntoTableInput struct { + result Result + expectedErr error + expectedOutput *testStruct + } + + DescribeTable("with a result", + func(in unmarshalIntoTableInput) { + input := &testStruct{} + err := in.result.UnmarshalInto(input) + if in.expectedErr != nil { + Expect(err).To(MatchError(in.expectedErr)) + } else { + Expect(err).ToNot(HaveOccurred()) + } + Expect(input).To(Equal(in.expectedOutput)) + }, + Entry("with an error", unmarshalIntoTableInput{ + result: &result{ + err: errors.New("got an error"), + response: &http.Response{ + StatusCode: http.StatusOK, + }, + body: []byte("{\"a\": \"foo\"}"), + }, + expectedErr: errors.New("got an error"), + expectedOutput: &testStruct{}, + }), + Entry("with a 409 status code", unmarshalIntoTableInput{ + result: &result{ + err: nil, + response: &http.Response{ + StatusCode: http.StatusConflict, + }, + body: []byte("{\"a\": \"foo\"}"), + }, + expectedErr: errors.New("unexpected status \"409\": {\"a\": \"foo\"}"), + expectedOutput: &testStruct{}, + }), + Entry("when the response has a valid json response", unmarshalIntoTableInput{ + result: &result{ + err: nil, + response: &http.Response{ + StatusCode: http.StatusOK, + }, + body: []byte("{\"a\": \"foo\", \"b\": 1}"), + }, + expectedErr: nil, + expectedOutput: &testStruct{A: "foo", B: 1}, + }), + Entry("when the response body is empty", unmarshalIntoTableInput{ + result: &result{ + err: nil, + response: &http.Response{ + StatusCode: http.StatusOK, + }, + body: []byte(""), + }, + expectedErr: errors.New("error unmarshalling body: unexpected end of JSON input"), + expectedOutput: &testStruct{}, + }), + Entry("when the response body is not json", unmarshalIntoTableInput{ + result: &result{ + err: nil, + response: &http.Response{ + StatusCode: http.StatusOK, + }, + body: []byte("not json"), + }, + expectedErr: errors.New("error unmarshalling body: invalid character 'o' in literal null (expecting 'u')"), + expectedOutput: &testStruct{}, + }), + ) + }) + + Context("UnmarshalJSON", func() { + type testStruct struct { + A string `json:"a"` + B int `json:"b"` + } + + type unmarshalJSONTableInput struct { + result Result + expectedErr error + expectedOutput *testStruct + } + + DescribeTable("with a result", + func(in unmarshalJSONTableInput) { + j, err := in.result.UnmarshalJSON() + if in.expectedErr != nil { + Expect(err).To(MatchError(in.expectedErr)) + Expect(j).To(BeNil()) + return + } + + // No error so j should not be nil + Expect(err).ToNot(HaveOccurred()) + + input := &testStruct{ + A: j.Get("a").MustString(), + B: j.Get("b").MustInt(), + } + Expect(input).To(Equal(in.expectedOutput)) + }, + Entry("with an error", unmarshalJSONTableInput{ + result: &result{ + err: errors.New("got an error"), + response: &http.Response{ + StatusCode: http.StatusOK, + }, + body: []byte("{\"a\": \"foo\"}"), + }, + expectedErr: errors.New("got an error"), + expectedOutput: &testStruct{}, + }), + Entry("with a 409 status code", unmarshalJSONTableInput{ + result: &result{ + err: nil, + response: &http.Response{ + StatusCode: http.StatusConflict, + }, + body: []byte("{\"a\": \"foo\"}"), + }, + expectedErr: errors.New("unexpected status \"409\": {\"a\": \"foo\"}"), + expectedOutput: &testStruct{}, + }), + Entry("when the response has a valid json response", unmarshalJSONTableInput{ + result: &result{ + err: nil, + response: &http.Response{ + StatusCode: http.StatusOK, + }, + body: []byte("{\"a\": \"foo\", \"b\": 1}"), + }, + expectedErr: nil, + expectedOutput: &testStruct{A: "foo", B: 1}, + }), + Entry("when the response body is empty", unmarshalJSONTableInput{ + result: &result{ + err: nil, + response: &http.Response{ + StatusCode: http.StatusOK, + }, + body: []byte(""), + }, + expectedErr: errors.New("error reading json: EOF"), + expectedOutput: &testStruct{}, + }), + Entry("when the response body is not json", unmarshalJSONTableInput{ + result: &result{ + err: nil, + response: &http.Response{ + StatusCode: http.StatusOK, + }, + body: []byte("not json"), + }, + expectedErr: errors.New("error reading json: invalid character 'o' in literal null (expecting 'u')"), + expectedOutput: &testStruct{}, + }), + ) + }) + + Context("getBodyForUnmarshal", func() { + type getBodyForUnmarshalTableInput struct { + result *result + expectedErr error + expectedBody []byte + } + + DescribeTable("when getting the body", func(in getBodyForUnmarshalTableInput) { + body, err := in.result.getBodyForUnmarshal() + if in.expectedErr != nil { + Expect(err).To(MatchError(in.expectedErr)) + } else { + Expect(err).ToNot(HaveOccurred()) + } + Expect(body).To(Equal(in.expectedBody)) + }, + Entry("when the result has an error", getBodyForUnmarshalTableInput{ + result: &result{ + err: errors.New("got an error"), + response: &http.Response{ + StatusCode: http.StatusOK, + }, + body: []byte("body"), + }, + expectedErr: errors.New("got an error"), + expectedBody: nil, + }), + Entry("when the response has a 409 status code", getBodyForUnmarshalTableInput{ + result: &result{ + err: nil, + response: &http.Response{ + StatusCode: http.StatusConflict, + }, + body: []byte("body"), + }, + expectedErr: errors.New("unexpected status \"409\": body"), + expectedBody: nil, + }), + Entry("when the response has a 200 status code", getBodyForUnmarshalTableInput{ + result: &result{ + err: nil, + response: &http.Response{ + StatusCode: http.StatusOK, + }, + body: []byte("body"), + }, + expectedErr: nil, + expectedBody: []byte("body"), + }), + ) + }) +}) diff --git a/pkg/validation/options.go b/pkg/validation/options.go index 50717b60..301c5e90 100644 --- a/pkg/validation/options.go +++ b/pkg/validation/options.go @@ -83,34 +83,34 @@ func Validate(o *options.Options) error { logger.Printf("Performing OIDC Discovery...") - if req, err := http.NewRequest("GET", strings.TrimSuffix(o.OIDCIssuerURL, "/")+"/.well-known/openid-configuration", nil); err == nil { - if body, err := requests.Request(req); err == nil { - - // Prefer manually configured URLs. It's a bit unclear - // why you'd be doing discovery and also providing the URLs - // explicitly though... - if o.LoginURL == "" { - o.LoginURL = body.Get("authorization_endpoint").MustString() - } - - if o.RedeemURL == "" { - o.RedeemURL = body.Get("token_endpoint").MustString() - } - - if o.OIDCJwksURL == "" { - o.OIDCJwksURL = body.Get("jwks_uri").MustString() - } - - if o.ProfileURL == "" { - o.ProfileURL = body.Get("userinfo_endpoint").MustString() - } - - o.SkipOIDCDiscovery = true - } else { - logger.Printf("error: failed to discover OIDC configuration: %v", err) - } + requestURL := strings.TrimSuffix(o.OIDCIssuerURL, "/") + "/.well-known/openid-configuration" + body, err := requests.New(requestURL). + WithContext(ctx). + Do(). + UnmarshalJSON() + if err != nil { + logger.Printf("error: failed to discover OIDC configuration: %v", err) } else { - logger.Printf("error: failed parsing OIDC discovery URL: %v", err) + // Prefer manually configured URLs. It's a bit unclear + // why you'd be doing discovery and also providing the URLs + // explicitly though... + if o.LoginURL == "" { + o.LoginURL = body.Get("authorization_endpoint").MustString() + } + + if o.RedeemURL == "" { + o.RedeemURL = body.Get("token_endpoint").MustString() + } + + if o.OIDCJwksURL == "" { + o.OIDCJwksURL = body.Get("jwks_uri").MustString() + } + + if o.ProfileURL == "" { + o.ProfileURL = body.Get("userinfo_endpoint").MustString() + } + + o.SkipOIDCDiscovery = true } } @@ -385,10 +385,10 @@ func newVerifierFromJwtIssuer(jwtIssuer jwtIssuer) (*oidc.IDTokenVerifier, error if err != nil { // Try as JWKS URI jwksURI := strings.TrimSuffix(jwtIssuer.issuerURI, "/") + "/.well-known/jwks.json" - _, err := http.NewRequest("GET", jwksURI, nil) - if err != nil { + if err := requests.New(jwksURI).Do().Error(); err != nil { return nil, err } + verifier = oidc.NewVerifier(jwtIssuer.issuerURI, oidc.NewRemoteKeySet(context.Background(), jwksURI), config) } else { verifier = provider.Verifier(config) diff --git a/providers/azure.go b/providers/azure.go index aea1b0e5..b38c1cc7 100644 --- a/providers/azure.go +++ b/providers/azure.go @@ -3,10 +3,8 @@ package providers import ( "bytes" "context" - "encoding/json" "errors" "fmt" - "io/ioutil" "net/http" "net/url" "time" @@ -91,39 +89,22 @@ func (p *AzureProvider) Redeem(ctx context.Context, redirectURL, code string) (s params.Add("resource", p.ProtectedResource.String()) } - var req *http.Request - req, err = http.NewRequestWithContext(ctx, "POST", p.RedeemURL.String(), bytes.NewBufferString(params.Encode())) - if err != nil { - return - } - req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - - var resp *http.Response - resp, err = http.DefaultClient.Do(req) - if err != nil { - return nil, err - } - var body []byte - body, err = ioutil.ReadAll(resp.Body) - resp.Body.Close() - if err != nil { - return - } - - if resp.StatusCode != 200 { - err = fmt.Errorf("got %d from %q %s", resp.StatusCode, p.RedeemURL.String(), body) - return - } - var jsonResponse struct { AccessToken string `json:"access_token"` RefreshToken string `json:"refresh_token"` ExpiresOn int64 `json:"expires_on,string"` IDToken string `json:"id_token"` } - err = json.Unmarshal(body, &jsonResponse) + + err = requests.New(p.RedeemURL.String()). + WithContext(ctx). + WithMethod("POST"). + WithBody(bytes.NewBufferString(params.Encode())). + SetHeader("Content-Type", "application/x-www-form-urlencoded"). + Do(). + UnmarshalInto(&jsonResponse) if err != nil { - return + return nil, err } created := time.Now() @@ -169,26 +150,22 @@ func (p *AzureProvider) GetEmailAddress(ctx context.Context, s *sessions.Session if s.AccessToken == "" { return "", errors.New("missing access token") } - req, err := http.NewRequestWithContext(ctx, "GET", p.ProfileURL.String(), nil) - if err != nil { - return "", err - } - req.Header = getAzureHeader(s.AccessToken) - - json, err := requests.Request(req) + json, err := requests.New(p.ProfileURL.String()). + WithContext(ctx). + WithHeaders(getAzureHeader(s.AccessToken)). + Do(). + UnmarshalJSON() if err != nil { return "", err } email, err = getEmailFromJSON(json) - if err == nil && email != "" { return email, err } email, err = json.Get("userPrincipalName").String() - if err != nil { logger.Printf("failed making request %s", err) return "", err diff --git a/providers/bitbucket.go b/providers/bitbucket.go index 2bb876cb..ffc52c79 100644 --- a/providers/bitbucket.go +++ b/providers/bitbucket.go @@ -2,7 +2,6 @@ package providers import ( "context" - "net/http" "net/url" "strings" @@ -85,15 +84,14 @@ func (p *BitbucketProvider) GetEmailAddress(ctx context.Context, s *sessions.Ses FullName string `json:"full_name"` } } - req, err := http.NewRequestWithContext(ctx, "GET", - p.ValidateURL.String()+"?access_token="+s.AccessToken, nil) + + requestURL := p.ValidateURL.String() + "?access_token=" + s.AccessToken + err := requests.New(requestURL). + WithContext(ctx). + Do(). + UnmarshalInto(&emails) if err != nil { - logger.Printf("failed building request %s", err) - return "", err - } - err = requests.RequestJSON(req, &emails) - if err != nil { - logger.Printf("failed making request %s", err) + logger.Printf("failed making request: %v", err) return "", err } @@ -101,15 +99,15 @@ func (p *BitbucketProvider) GetEmailAddress(ctx context.Context, s *sessions.Ses teamURL := &url.URL{} *teamURL = *p.ValidateURL teamURL.Path = "/2.0/teams" - req, err = http.NewRequestWithContext(ctx, "GET", - teamURL.String()+"?role=member&access_token="+s.AccessToken, nil) + + requestURL := teamURL.String() + "?role=member&access_token=" + s.AccessToken + + err := requests.New(requestURL). + WithContext(ctx). + Do(). + UnmarshalInto(&teams) if err != nil { - logger.Printf("failed building request %s", err) - return "", err - } - err = requests.RequestJSON(req, &teams) - if err != nil { - logger.Printf("failed requesting teams membership %s", err) + logger.Printf("failed requesting teams membership: %v", err) return "", err } var found = false @@ -129,20 +127,20 @@ func (p *BitbucketProvider) GetEmailAddress(ctx context.Context, s *sessions.Ses repositoriesURL := &url.URL{} *repositoriesURL = *p.ValidateURL repositoriesURL.Path = "/2.0/repositories/" + strings.Split(p.Repository, "/")[0] - req, err = http.NewRequestWithContext(ctx, "GET", - repositoriesURL.String()+"?role=contributor"+ - "&q=full_name="+url.QueryEscape("\""+p.Repository+"\"")+ - "&access_token="+s.AccessToken, - nil) + + requestURL := repositoriesURL.String() + "?role=contributor" + + "&q=full_name=" + url.QueryEscape("\""+p.Repository+"\"") + + "&access_token=" + s.AccessToken + + err := requests.New(requestURL). + WithContext(ctx). + Do(). + UnmarshalInto(&repositories) if err != nil { - logger.Printf("failed building request %s", err) - return "", err - } - err = requests.RequestJSON(req, &repositories) - if err != nil { - logger.Printf("failed checking repository access %s", err) + logger.Printf("failed checking repository access: %v", err) return "", err } + var found = false for _, repository := range repositories.Values { if p.Repository == repository.FullName { diff --git a/providers/digitalocean.go b/providers/digitalocean.go index 25d37af9..27ac60d0 100644 --- a/providers/digitalocean.go +++ b/providers/digitalocean.go @@ -60,13 +60,12 @@ func (p *DigitalOceanProvider) GetEmailAddress(ctx context.Context, s *sessions. if s.AccessToken == "" { return "", errors.New("missing access token") } - req, err := http.NewRequestWithContext(ctx, "GET", p.ProfileURL.String(), nil) - if err != nil { - return "", err - } - req.Header = getDigitalOceanHeader(s.AccessToken) - json, err := requests.Request(req) + json, err := requests.New(p.ProfileURL.String()). + WithContext(ctx). + WithHeaders(getDigitalOceanHeader(s.AccessToken)). + Do(). + UnmarshalJSON() if err != nil { return "", err } diff --git a/providers/facebook.go b/providers/facebook.go index 0f9cc624..d3d123f2 100644 --- a/providers/facebook.go +++ b/providers/facebook.go @@ -62,20 +62,22 @@ func (p *FacebookProvider) GetEmailAddress(ctx context.Context, s *sessions.Sess if s.AccessToken == "" { return "", errors.New("missing access token") } - req, err := http.NewRequestWithContext(ctx, "GET", p.ProfileURL.String()+"?fields=name,email", nil) - if err != nil { - return "", err - } - req.Header = getFacebookHeader(s.AccessToken) type result struct { Email string } var r result - err = requests.RequestJSON(req, &r) + + requestURL := p.ProfileURL.String() + "?fields=name,email" + err := requests.New(requestURL). + WithContext(ctx). + WithHeaders(getFacebookHeader(s.AccessToken)). + Do(). + UnmarshalInto(&r) if err != nil { return "", err } + if r.Email == "" { return "", errors.New("no email") } diff --git a/providers/github.go b/providers/github.go index 6d3f8b02..014ae3cb 100644 --- a/providers/github.go +++ b/providers/github.go @@ -2,10 +2,8 @@ package providers import ( "context" - "encoding/json" "errors" "fmt" - "io/ioutil" "net/http" "net/url" "path" @@ -15,6 +13,7 @@ import ( "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions" "github.com/oauth2-proxy/oauth2-proxy/pkg/logger" + "github.com/oauth2-proxy/oauth2-proxy/pkg/requests" ) // GitHubProvider represents an GitHub based Identity Provider @@ -111,27 +110,17 @@ func (p *GitHubProvider) hasOrg(ctx context.Context, accessToken string) (bool, Path: path.Join(p.ValidateURL.Path, "/user/orgs"), RawQuery: params.Encode(), } - req, _ := http.NewRequestWithContext(ctx, "GET", endpoint.String(), nil) - req.Header = getGitHubHeader(accessToken) - resp, err := http.DefaultClient.Do(req) - if err != nil { - return false, err - } - - body, err := ioutil.ReadAll(resp.Body) - resp.Body.Close() - if err != nil { - return false, err - } - if resp.StatusCode != 200 { - return false, fmt.Errorf( - "got %d from %q %s", resp.StatusCode, endpoint.String(), body) - } var op orgsPage - if err := json.Unmarshal(body, &op); err != nil { + err := requests.New(endpoint.String()). + WithContext(ctx). + WithHeaders(getGitHubHeader(accessToken)). + Do(). + UnmarshalInto(&op) + if err != nil { return false, err } + if len(op) == 0 { break } @@ -187,11 +176,15 @@ func (p *GitHubProvider) hasOrgAndTeam(ctx context.Context, accessToken string) RawQuery: params.Encode(), } - req, _ := http.NewRequestWithContext(ctx, "GET", endpoint.String(), nil) - req.Header = getGitHubHeader(accessToken) - resp, err := http.DefaultClient.Do(req) - if err != nil { - return false, err + // bodyclose cannot detect that the body is being closed later in requests.Into, + // so have to skip the linting for the next line. + // nolint:bodyclose + result := requests.New(endpoint.String()). + WithContext(ctx). + WithHeaders(getGitHubHeader(accessToken)). + Do() + if result.Error() != nil { + return false, result.Error() } if last == 0 { @@ -207,7 +200,7 @@ func (p *GitHubProvider) hasOrgAndTeam(ctx context.Context, accessToken string) // link header at last page (doesn't exist last info) // <https://api.github.com/user/teams?page=3&per_page=10>; rel="prev", <https://api.github.com/user/teams?page=1&per_page=10>; rel="first" - link := resp.Header.Get("Link") + link := result.Headers().Get("Link") rep1 := regexp.MustCompile(`(?s).*\<https://api.github.com/user/teams\?page=(.)&per_page=[0-9]+\>; rel="last".*`) i, converr := strconv.Atoi(rep1.ReplaceAllString(link, "$1")) @@ -217,21 +210,9 @@ func (p *GitHubProvider) hasOrgAndTeam(ctx context.Context, accessToken string) } } - body, err := ioutil.ReadAll(resp.Body) - if err != nil { - resp.Body.Close() - return false, err - } - resp.Body.Close() - - if resp.StatusCode != 200 { - return false, fmt.Errorf( - "got %d from %q %s", resp.StatusCode, endpoint.String(), body) - } - var tp teamsPage - if err := json.Unmarshal(body, &tp); err != nil { - return false, fmt.Errorf("%s unmarshaling %s", err, body) + if err := result.UnmarshalInto(&tp); err != nil { + return false, err } if len(tp) == 0 { break @@ -297,25 +278,13 @@ func (p *GitHubProvider) hasRepo(ctx context.Context, accessToken string) (bool, Path: path.Join(p.ValidateURL.Path, "/repo/", p.Repo), } - req, _ := http.NewRequestWithContext(ctx, "GET", endpoint.String(), nil) - req.Header = getGitHubHeader(accessToken) - resp, err := http.DefaultClient.Do(req) - if err != nil { - return false, err - } - - body, err := ioutil.ReadAll(resp.Body) - resp.Body.Close() - if err != nil { - return false, err - } - if resp.StatusCode != 200 { - return false, fmt.Errorf( - "got %d from %q %s", resp.StatusCode, endpoint.String(), body) - } - var repo repository - if err := json.Unmarshal(body, &repo); err != nil { + err := requests.New(endpoint.String()). + WithContext(ctx). + WithHeaders(getGitHubHeader(accessToken)). + Do(). + UnmarshalInto(&repo) + if err != nil { return false, err } @@ -337,26 +306,15 @@ func (p *GitHubProvider) hasUser(ctx context.Context, accessToken string) (bool, Host: p.ValidateURL.Host, Path: path.Join(p.ValidateURL.Path, "/user"), } - req, _ := http.NewRequestWithContext(ctx, "GET", endpoint.String(), nil) - req.Header = getGitHubHeader(accessToken) - resp, err := http.DefaultClient.Do(req) - if err != nil { - return false, err - } - defer resp.Body.Close() - body, err := ioutil.ReadAll(resp.Body) + err := requests.New(endpoint.String()). + WithContext(ctx). + WithHeaders(getGitHubHeader(accessToken)). + Do(). + UnmarshalInto(&user) if err != nil { return false, err } - if resp.StatusCode != 200 { - return false, fmt.Errorf("got %d from %q %s", - resp.StatusCode, stripToken(endpoint.String()), body) - } - - if err := json.Unmarshal(body, &user); err != nil { - return false, err - } if p.isVerifiedUser(user.Login) { return true, nil @@ -372,24 +330,20 @@ func (p *GitHubProvider) isCollaborator(ctx context.Context, username, accessTok Host: p.ValidateURL.Host, Path: path.Join(p.ValidateURL.Path, "/repos/", p.Repo, "/collaborators/", username), } - req, _ := http.NewRequestWithContext(ctx, "GET", endpoint.String(), nil) - req.Header = getGitHubHeader(accessToken) - resp, err := http.DefaultClient.Do(req) - if err != nil { - return false, err - } - body, err := ioutil.ReadAll(resp.Body) - resp.Body.Close() - if err != nil { - return false, err + result := requests.New(endpoint.String()). + WithContext(ctx). + WithHeaders(getGitHubHeader(accessToken)). + Do() + if result.Error() != nil { + return false, result.Error() } - if resp.StatusCode != 204 { + if result.StatusCode() != 204 { return false, fmt.Errorf("got %d from %q %s", - resp.StatusCode, endpoint.String(), body) + result.StatusCode(), endpoint.String(), result.Body()) } - logger.Printf("got %d from %q %s", resp.StatusCode, endpoint.String(), body) + logger.Printf("got %d from %q %s", result.StatusCode(), endpoint.String(), result.Body()) return true, nil } @@ -440,28 +394,14 @@ func (p *GitHubProvider) GetEmailAddress(ctx context.Context, s *sessions.Sessio Host: p.ValidateURL.Host, Path: path.Join(p.ValidateURL.Path, "/user/emails"), } - req, _ := http.NewRequestWithContext(ctx, "GET", endpoint.String(), nil) - req.Header = getGitHubHeader(s.AccessToken) - resp, err := http.DefaultClient.Do(req) + err := requests.New(endpoint.String()). + WithContext(ctx). + WithHeaders(getGitHubHeader(s.AccessToken)). + Do(). + UnmarshalInto(&emails) if err != nil { return "", err } - body, err := ioutil.ReadAll(resp.Body) - resp.Body.Close() - if err != nil { - return "", err - } - - if resp.StatusCode != 200 { - return "", fmt.Errorf("got %d from %q %s", - resp.StatusCode, endpoint.String(), body) - } - - logger.Printf("got %d from %q %s", resp.StatusCode, endpoint.String(), body) - - if err := json.Unmarshal(body, &emails); err != nil { - return "", fmt.Errorf("%s unmarshaling %s", err, body) - } returnEmail := "" for _, email := range emails { @@ -489,34 +429,15 @@ func (p *GitHubProvider) GetUserName(ctx context.Context, s *sessions.SessionSta Path: path.Join(p.ValidateURL.Path, "/user"), } - req, err := http.NewRequestWithContext(ctx, "GET", endpoint.String(), nil) - if err != nil { - return "", fmt.Errorf("could not create new GET request: %v", err) - } - - req.Header = getGitHubHeader(s.AccessToken) - resp, err := http.DefaultClient.Do(req) + err := requests.New(endpoint.String()). + WithContext(ctx). + WithHeaders(getGitHubHeader(s.AccessToken)). + Do(). + UnmarshalInto(&user) if err != nil { return "", err } - body, err := ioutil.ReadAll(resp.Body) - defer resp.Body.Close() - if err != nil { - return "", err - } - - if resp.StatusCode != 200 { - return "", fmt.Errorf("got %d from %q %s", - resp.StatusCode, endpoint.String(), body) - } - - logger.Printf("got %d from %q %s", resp.StatusCode, endpoint.String(), body) - - if err := json.Unmarshal(body, &user); err != nil { - return "", fmt.Errorf("%s unmarshaling %s", err, body) - } - // Now that we have the username we can check collaborator status if !p.isVerifiedUser(user.Login) && p.Org == "" && p.Repo != "" && p.Token != "" { if ok, err := p.isCollaborator(ctx, user.Login, p.Token); err != nil || !ok { diff --git a/providers/gitlab.go b/providers/gitlab.go index 8d959781..8c1e1534 100644 --- a/providers/gitlab.go +++ b/providers/gitlab.go @@ -2,15 +2,13 @@ package providers import ( "context" - "encoding/json" "fmt" - "io/ioutil" - "net/http" "strings" "time" oidc "github.com/coreos/go-oidc" "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions" + "github.com/oauth2-proxy/oauth2-proxy/pkg/requests" "golang.org/x/oauth2" ) @@ -131,31 +129,14 @@ func (p *GitLabProvider) getUserInfo(ctx context.Context, s *sessions.SessionSta userInfoURL := *p.LoginURL userInfoURL.Path = "/oauth/userinfo" - req, err := http.NewRequestWithContext(ctx, "GET", userInfoURL.String(), nil) - if err != nil { - return nil, fmt.Errorf("failed to create user info request: %v", err) - } - req.Header.Set("Authorization", "Bearer "+s.AccessToken) - - resp, err := http.DefaultClient.Do(req) - if err != nil { - return nil, fmt.Errorf("failed to perform user info request: %v", err) - } - var body []byte - body, err = ioutil.ReadAll(resp.Body) - resp.Body.Close() - if err != nil { - return nil, fmt.Errorf("failed to read user info response: %v", err) - } - - if resp.StatusCode != 200 { - return nil, fmt.Errorf("got %d during user info request: %s", resp.StatusCode, body) - } - var userInfo gitlabUserInfo - err = json.Unmarshal(body, &userInfo) + err := requests.New(userInfoURL.String()). + WithContext(ctx). + SetHeader("Authorization", "Bearer "+s.AccessToken). + Do(). + UnmarshalInto(&userInfo) if err != nil { - return nil, fmt.Errorf("failed to parse user info: %v", err) + return nil, fmt.Errorf("error getting user info: %v", err) } return &userInfo, nil diff --git a/providers/google.go b/providers/google.go index 5aeb6e2d..af2eebf3 100644 --- a/providers/google.go +++ b/providers/google.go @@ -9,13 +9,13 @@ import ( "fmt" "io" "io/ioutil" - "net/http" "net/url" "strings" "time" "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions" "github.com/oauth2-proxy/oauth2-proxy/pkg/logger" + "github.com/oauth2-proxy/oauth2-proxy/pkg/requests" "golang.org/x/oauth2/google" admin "google.golang.org/api/admin/directory/v1" "google.golang.org/api/googleapi" @@ -116,28 +116,6 @@ func (p *GoogleProvider) Redeem(ctx context.Context, redirectURL, code string) ( params.Add("client_secret", clientSecret) params.Add("code", code) params.Add("grant_type", "authorization_code") - var req *http.Request - req, err = http.NewRequestWithContext(ctx, "POST", p.RedeemURL.String(), bytes.NewBufferString(params.Encode())) - if err != nil { - return - } - req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - - resp, err := http.DefaultClient.Do(req) - if err != nil { - return - } - var body []byte - body, err = ioutil.ReadAll(resp.Body) - resp.Body.Close() - if err != nil { - return - } - - if resp.StatusCode != 200 { - err = fmt.Errorf("got %d from %q %s", resp.StatusCode, p.RedeemURL.String(), body) - return - } var jsonResponse struct { AccessToken string `json:"access_token"` @@ -145,10 +123,18 @@ func (p *GoogleProvider) Redeem(ctx context.Context, redirectURL, code string) ( ExpiresIn int64 `json:"expires_in"` IDToken string `json:"id_token"` } - err = json.Unmarshal(body, &jsonResponse) + + err = requests.New(p.RedeemURL.String()). + WithContext(ctx). + WithMethod("POST"). + WithBody(bytes.NewBufferString(params.Encode())). + SetHeader("Content-Type", "application/x-www-form-urlencoded"). + Do(). + UnmarshalInto(&jsonResponse) if err != nil { - return + return nil, err } + c, err := claimsFromIDToken(jsonResponse.IDToken) if err != nil { return @@ -283,38 +269,24 @@ func (p *GoogleProvider) redeemRefreshToken(ctx context.Context, refreshToken st params.Add("client_secret", clientSecret) params.Add("refresh_token", refreshToken) params.Add("grant_type", "refresh_token") - var req *http.Request - req, err = http.NewRequestWithContext(ctx, "POST", p.RedeemURL.String(), bytes.NewBufferString(params.Encode())) - if err != nil { - return - } - req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - - resp, err := http.DefaultClient.Do(req) - if err != nil { - return - } - var body []byte - body, err = ioutil.ReadAll(resp.Body) - resp.Body.Close() - if err != nil { - return - } - - if resp.StatusCode != 200 { - err = fmt.Errorf("got %d from %q %s", resp.StatusCode, p.RedeemURL.String(), body) - return - } var data struct { AccessToken string `json:"access_token"` ExpiresIn int64 `json:"expires_in"` IDToken string `json:"id_token"` } - err = json.Unmarshal(body, &data) + + err = requests.New(p.RedeemURL.String()). + WithContext(ctx). + WithMethod("POST"). + WithBody(bytes.NewBufferString(params.Encode())). + SetHeader("Content-Type", "application/x-www-form-urlencoded"). + Do(). + UnmarshalInto(&data) if err != nil { - return + return "", "", 0, err } + token = data.AccessToken idToken = data.IDToken expires = time.Duration(data.ExpiresIn) * time.Second diff --git a/providers/internal_util.go b/providers/internal_util.go index f9bdc304..42948408 100644 --- a/providers/internal_util.go +++ b/providers/internal_util.go @@ -2,7 +2,6 @@ package providers import ( "context" - "io/ioutil" "net/http" "net/url" @@ -56,20 +55,22 @@ func validateToken(ctx context.Context, p Provider, accessToken string, header h params := url.Values{"access_token": {accessToken}} endpoint = endpoint + "?" + params.Encode() } - resp, err := requests.RequestUnparsedResponse(ctx, endpoint, header) - if err != nil { + + result := requests.New(endpoint). + WithContext(ctx). + WithHeaders(header). + Do() + if result.Error() != nil { logger.Printf("GET %s", stripToken(endpoint)) - logger.Printf("token validation request failed: %s", err) + logger.Printf("token validation request failed: %s", result.Error()) return false } - body, _ := ioutil.ReadAll(resp.Body) - resp.Body.Close() - logger.Printf("%d GET %s %s", resp.StatusCode, stripToken(endpoint), body) + logger.Printf("%d GET %s %s", result.StatusCode(), stripToken(endpoint), result.Body()) - if resp.StatusCode == 200 { + if result.StatusCode() == 200 { return true } - logger.Printf("token validation request failed: status %d - %s", resp.StatusCode, body) + logger.Printf("token validation request failed: status %d - %s", result.StatusCode(), result.Body()) return false } diff --git a/providers/keycloak.go b/providers/keycloak.go index 414c4973..77efc0c7 100644 --- a/providers/keycloak.go +++ b/providers/keycloak.go @@ -2,7 +2,6 @@ package providers import ( "context" - "net/http" "net/url" "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions" @@ -51,14 +50,11 @@ func (p *KeycloakProvider) SetGroup(group string) { } func (p *KeycloakProvider) GetEmailAddress(ctx context.Context, s *sessions.SessionState) (string, error) { - - req, err := http.NewRequestWithContext(ctx, "GET", p.ValidateURL.String(), nil) - req.Header.Set("Authorization", "Bearer "+s.AccessToken) - if err != nil { - logger.Printf("failed building request %s", err) - return "", err - } - json, err := requests.Request(req) + json, err := requests.New(p.ValidateURL.String()). + WithContext(ctx). + SetHeader("Authorization", "Bearer "+s.AccessToken). + Do(). + UnmarshalJSON() if err != nil { logger.Printf("failed making request %s", err) return "", err diff --git a/providers/linkedin.go b/providers/linkedin.go index 6cc24239..7328dbbb 100644 --- a/providers/linkedin.go +++ b/providers/linkedin.go @@ -58,13 +58,13 @@ func (p *LinkedInProvider) GetEmailAddress(ctx context.Context, s *sessions.Sess if s.AccessToken == "" { return "", errors.New("missing access token") } - req, err := http.NewRequestWithContext(ctx, "GET", p.ProfileURL.String()+"?format=json", nil) - if err != nil { - return "", err - } - req.Header = getLinkedInHeader(s.AccessToken) - json, err := requests.Request(req) + requestURL := p.ProfileURL.String() + "?format=json" + json, err := requests.New(requestURL). + WithContext(ctx). + WithHeaders(getLinkedInHeader(s.AccessToken)). + Do(). + UnmarshalJSON() if err != nil { return "", err } diff --git a/providers/logingov.go b/providers/logingov.go index 46027172..79eb1827 100644 --- a/providers/logingov.go +++ b/providers/logingov.go @@ -15,6 +15,7 @@ import ( "github.com/dgrijalva/jwt-go" "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions" + "github.com/oauth2-proxy/oauth2-proxy/pkg/requests" "gopkg.in/square/go-jose.v2" ) @@ -128,51 +129,34 @@ func checkNonce(idToken string, p *LoginGovProvider) (err error) { return } -func emailFromUserInfo(ctx context.Context, accessToken string, userInfoEndpoint string) (email string, err error) { - // query the user info endpoint for user attributes - var req *http.Request - req, err = http.NewRequestWithContext(ctx, "GET", userInfoEndpoint, nil) - if err != nil { - return - } - req.Header.Set("Authorization", "Bearer "+accessToken) - - resp, err := http.DefaultClient.Do(req) - if err != nil { - return - } - var body []byte - body, err = ioutil.ReadAll(resp.Body) - resp.Body.Close() - if err != nil { - return - } - - if resp.StatusCode != 200 { - err = fmt.Errorf("got %d from %q %s", resp.StatusCode, userInfoEndpoint, body) - return - } - +func emailFromUserInfo(ctx context.Context, accessToken string, userInfoEndpoint string) (string, error) { // parse the user attributes from the data we got and make sure that // the email address has been validated. var emailData struct { Email string `json:"email"` EmailVerified bool `json:"email_verified"` } - err = json.Unmarshal(body, &emailData) + + // query the user info endpoint for user attributes + err := requests.New(userInfoEndpoint). + WithContext(ctx). + SetHeader("Authorization", "Bearer "+accessToken). + Do(). + UnmarshalInto(&emailData) if err != nil { - return + return "", err } - if emailData.Email == "" { - err = fmt.Errorf("missing email") - return + + email := emailData.Email + if email == "" { + return "", fmt.Errorf("missing email") } - email = emailData.Email + if !emailData.EmailVerified { - err = fmt.Errorf("email %s not listed as verified", email) - return + return "", fmt.Errorf("email %s not listed as verified", email) } - return + + return email, nil } // Redeem exchanges the OAuth2 authentication token for an ID token @@ -201,30 +185,6 @@ func (p *LoginGovProvider) Redeem(ctx context.Context, redirectURL, code string) params.Add("code", code) params.Add("grant_type", "authorization_code") - var req *http.Request - req, err = http.NewRequestWithContext(ctx, "POST", p.RedeemURL.String(), bytes.NewBufferString(params.Encode())) - if err != nil { - return - } - req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - - var resp *http.Response - resp, err = http.DefaultClient.Do(req) - if err != nil { - return nil, err - } - var body []byte - body, err = ioutil.ReadAll(resp.Body) - resp.Body.Close() - if err != nil { - return - } - - if resp.StatusCode != 200 { - err = fmt.Errorf("got %d from %q %s", resp.StatusCode, p.RedeemURL.String(), body) - return - } - // Get the token from the body that we got from the token endpoint. var jsonResponse struct { AccessToken string `json:"access_token"` @@ -232,9 +192,15 @@ func (p *LoginGovProvider) Redeem(ctx context.Context, redirectURL, code string) TokenType string `json:"token_type"` ExpiresIn int64 `json:"expires_in"` } - err = json.Unmarshal(body, &jsonResponse) + err = requests.New(p.RedeemURL.String()). + WithContext(ctx). + WithMethod("POST"). + WithBody(bytes.NewBufferString(params.Encode())). + SetHeader("Content-Type", "application/x-www-form-urlencoded"). + Do(). + UnmarshalInto(&jsonResponse) if err != nil { - return + return nil, err } // check nonce here diff --git a/providers/nextcloud.go b/providers/nextcloud.go index d51b7183..b70fd07c 100644 --- a/providers/nextcloud.go +++ b/providers/nextcloud.go @@ -6,7 +6,6 @@ import ( "net/http" "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions" - "github.com/oauth2-proxy/oauth2-proxy/pkg/logger" "github.com/oauth2-proxy/oauth2-proxy/pkg/requests" ) @@ -31,18 +30,15 @@ func getNextcloudHeader(accessToken string) http.Header { // GetEmailAddress returns the Account email address func (p *NextcloudProvider) GetEmailAddress(ctx context.Context, s *sessions.SessionState) (string, error) { - req, err := http.NewRequestWithContext(ctx, "GET", - p.ValidateURL.String(), nil) + json, err := requests.New(p.ValidateURL.String()). + WithContext(ctx). + WithHeaders(getNextcloudHeader(s.AccessToken)). + Do(). + UnmarshalJSON() if err != nil { - logger.Printf("failed building request %s", err) - return "", err - } - req.Header = getNextcloudHeader(s.AccessToken) - json, err := requests.Request(req) - if err != nil { - logger.Printf("failed making request %s", err) - return "", err + return "", fmt.Errorf("error making request: %v", err) } + email, err := json.Get("ocs").Get("data").Get("email").String() return email, err } diff --git a/providers/oidc.go b/providers/oidc.go index bc98dbb8..e456db76 100644 --- a/providers/oidc.go +++ b/providers/oidc.go @@ -256,13 +256,11 @@ func (p *OIDCProvider) findClaimsFromIDToken(ctx context.Context, idToken *oidc. // If the userinfo endpoint profileURL is defined, then there is a chance the userinfo // contents at the profileURL contains the email. // Make a query to the userinfo endpoint, and attempt to locate the email from there. - req, err := http.NewRequestWithContext(ctx, "GET", profileURL, nil) - if err != nil { - return nil, err - } - req.Header = getOIDCHeader(accessToken) - - respJSON, err := requests.Request(req) + respJSON, err := requests.New(profileURL). + WithContext(ctx). + WithHeaders(getOIDCHeader(accessToken)). + Do(). + UnmarshalJSON() if err != nil { return nil, err } diff --git a/providers/provider_default.go b/providers/provider_default.go index 14cec9fe..598b91e8 100644 --- a/providers/provider_default.go +++ b/providers/provider_default.go @@ -3,17 +3,15 @@ package providers import ( "bytes" "context" - "encoding/json" "errors" "fmt" - "io/ioutil" - "net/http" "net/url" "time" "github.com/coreos/go-oidc" "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions" + "github.com/oauth2-proxy/oauth2-proxy/pkg/requests" ) var _ Provider = (*ProviderData)(nil) @@ -39,35 +37,21 @@ func (p *ProviderData) Redeem(ctx context.Context, redirectURL, code string) (s params.Add("resource", p.ProtectedResource.String()) } - var req *http.Request - req, err = http.NewRequestWithContext(ctx, "POST", p.RedeemURL.String(), bytes.NewBufferString(params.Encode())) - if err != nil { - return - } - req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - - var resp *http.Response - resp, err = http.DefaultClient.Do(req) - if err != nil { - return nil, err - } - var body []byte - body, err = ioutil.ReadAll(resp.Body) - resp.Body.Close() - if err != nil { - return - } - - if resp.StatusCode != 200 { - err = fmt.Errorf("got %d from %q %s", resp.StatusCode, p.RedeemURL.String(), body) - return + result := requests.New(p.RedeemURL.String()). + WithContext(ctx). + WithMethod("POST"). + WithBody(bytes.NewBufferString(params.Encode())). + SetHeader("Content-Type", "application/x-www-form-urlencoded"). + Do() + if result.Error() != nil { + return nil, result.Error() } // blindly try json and x-www-form-urlencoded var jsonResponse struct { AccessToken string `json:"access_token"` } - err = json.Unmarshal(body, &jsonResponse) + err = result.UnmarshalInto(&jsonResponse) if err == nil { s = &sessions.SessionState{ AccessToken: jsonResponse.AccessToken, @@ -76,7 +60,7 @@ func (p *ProviderData) Redeem(ctx context.Context, redirectURL, code string) (s } var v url.Values - v, err = url.ParseQuery(string(body)) + v, err = url.ParseQuery(string(result.Body())) if err != nil { return } @@ -84,7 +68,7 @@ func (p *ProviderData) Redeem(ctx context.Context, redirectURL, code string) (s created := time.Now() s = &sessions.SessionState{AccessToken: a, CreatedAt: &created} } else { - err = fmt.Errorf("no access token found %s", body) + err = fmt.Errorf("no access token found %s", result.Body()) } return }