From 0bc0feb4bb20c6d79b06cd8450f41b90bba8f987 Mon Sep 17 00:00:00 2001 From: Joel Speed Date: Fri, 3 Jul 2020 18:38:26 +0100 Subject: [PATCH 1/8] Add request builder to simplify request handling --- pkg/requests/builder.go | 173 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 173 insertions(+) create mode 100644 pkg/requests/builder.go diff --git a/pkg/requests/builder.go b/pkg/requests/builder.go new file mode 100644 index 00000000..2e1a358e --- /dev/null +++ b/pkg/requests/builder.go @@ -0,0 +1,173 @@ +package requests + +import ( + "context" + "encoding/json" + "fmt" + "io" + "io/ioutil" + "net/http" + + "github.com/bitly/go-simplejson" +) + +// Builder allows users to construct a request and then either get the requests +// response via Do(), parse the response into a simplejson.Json via JSON(), +// or to parse the json response into an object via UnmarshalInto(). +type Builder interface { + WithContext(context.Context) Builder + WithBody(io.Reader) Builder + WithMethod(string) Builder + WithHeaders(http.Header) Builder + SetHeader(key, value string) Builder + Do() (*http.Response, error) + UnmarshalInto(interface{}) error + UnmarshalJSON() (*simplejson.Json, error) +} + +type builder struct { + context context.Context + method string + endpoint string + body io.Reader + header http.Header + response *http.Response +} + +// New provides a new Builder for the given endpoint. +func New(endpoint string) Builder { + return &builder{ + endpoint: endpoint, + method: "GET", + } +} + +// WithContext adds a context to the request. +// If no context is provided, context.Background() is used instead. +func (r *builder) WithContext(ctx context.Context) Builder { + r.context = ctx + return r +} + +// WithBody adds a body to the request. +func (r *builder) WithBody(body io.Reader) Builder { + r.body = body + return r +} + +// WithMethod sets the request method. Defaults to "GET". +func (r *builder) WithMethod(method string) Builder { + r.method = method + return r +} + +// WithHeaders replaces the request header map with the given header map. +func (r *builder) WithHeaders(header http.Header) Builder { + r.header = header + return r +} + +// SetHeader sets a single header to the given value. +// May be used to add multiple headers. +func (r *builder) SetHeader(key, value string) Builder { + if r.header == nil { + r.header = make(http.Header) + } + r.header.Set(key, value) + return r +} + +// Do performs the request and returns the response in its raw form. +// If the request has already been performed, returns the previous result. +// This will not allow you to repeat a request. +func (r *builder) Do() (*http.Response, error) { + if r.response != nil { + // Request has already been done + return r.response, nil + } + + // Must provide a non-nil context to NewRequestWithContext + if r.context == nil { + r.context = context.Background() + } + + req, err := http.NewRequestWithContext(r.context, r.method, r.endpoint, r.body) + if err != nil { + return nil, fmt.Errorf("error creating request: %v", err) + } + req.Header = r.header + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, fmt.Errorf("error performing request: %v", err) + } + + r.response = resp + return resp, nil +} + +// UnmarshalInto performs the request and attempts to unmarshal the response into the +// the given interface. The response body is assumed to be JSON. +// The response must have a 200 status otherwise an error will be returned. +func (r *builder) UnmarshalInto(into interface{}) error { + resp, err := r.Do() + if err != nil { + return err + } + + return UnmarshalInto(resp, into) +} + +// UnmarshalJSON performs the request and attempts to unmarshal the response into a +// simplejson.Json. The response body is assume to be JSON. +// The response must have a 200 status otherwise an error will be returned. +func (r *builder) UnmarshalJSON() (*simplejson.Json, error) { + resp, err := r.Do() + if err != nil { + return nil, err + } + + body, err := getResponseBody(resp) + if err != nil { + return nil, err + } + + data, err := simplejson.NewJson(body) + if err != nil { + return nil, fmt.Errorf("error reading json: %v", err) + } + return data, nil +} + +// UnmarshalInto attempts to unmarshal the response into the the given interface. +// The response body is assumed to be JSON. +// The response must have a 200 status otherwise an error will be returned. +func UnmarshalInto(resp *http.Response, into interface{}) error { + body, err := getResponseBody(resp) + if err != nil { + return err + } + + if err := json.Unmarshal(body, into); err != nil { + return fmt.Errorf("error unmarshalling body: %v", err) + } + + return nil +} + +// getResponseBody extracts the response body, but will only return the body +// if the response was successful. +func getResponseBody(resp *http.Response) ([]byte, error) { + defer resp.Body.Close() + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("error reading response body: %v", err) + } + + // Only unmarshal body if the response was successful + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("unexpected status \"%d\": %s", resp.StatusCode, body) + } + + return body, nil +} From 21ef86b5942cba6d3d708394b3d485bde4685886 Mon Sep 17 00:00:00 2001 From: Joel Speed Date: Sat, 4 Jul 2020 08:09:35 +0100 Subject: [PATCH 2/8] Add tests for the request builder --- pkg/requests/builder_test.go | 384 ++++++++++++++++++++++++++++ pkg/requests/requests_suite_test.go | 96 +++++++ 2 files changed, 480 insertions(+) create mode 100644 pkg/requests/builder_test.go create mode 100644 pkg/requests/requests_suite_test.go diff --git a/pkg/requests/builder_test.go b/pkg/requests/builder_test.go new file mode 100644 index 00000000..b859fa23 --- /dev/null +++ b/pkg/requests/builder_test.go @@ -0,0 +1,384 @@ +package requests + +import ( + "bytes" + "context" + "encoding/base64" + "encoding/json" + "fmt" + "io/ioutil" + "net/http" + + "github.com/bitly/go-simplejson" + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Builder suite", func() { + var b Builder + getBuilder := func() Builder { return b } + + baseHeaders := http.Header{ + "Accept-Encoding": []string{"gzip"}, + "User-Agent": []string{"Go-http-client/1.1"}, + } + + BeforeEach(func() { + // Most tests will request the server address + b = New(serverAddr + "/json/path") + }) + + Context("with a basic request", func() { + assertSuccessfulRequest(getBuilder, testHTTPRequest{ + Method: "GET", + Header: baseHeaders, + Body: []byte{}, + RequestURI: "/json/path", + }) + }) + + Context("with a context", func() { + var ctx context.Context + var cancel context.CancelFunc + + BeforeEach(func() { + ctx, cancel = context.WithCancel(context.Background()) + b = b.WithContext(ctx) + }) + + AfterEach(func() { + cancel() + }) + + assertSuccessfulRequest(getBuilder, testHTTPRequest{ + Method: "GET", + Header: baseHeaders, + Body: []byte{}, + RequestURI: "/json/path", + }) + + Context("if the context is cancelled", func() { + BeforeEach(func() { + cancel() + }) + + assertRequestError(getBuilder, "context canceled") + }) + }) + + Context("with a body", func() { + const body = "{\"some\": \"body\"}" + header := baseHeaders.Clone() + header.Set("Content-Length", fmt.Sprintf("%d", len(body))) + + BeforeEach(func() { + buf := bytes.NewBuffer([]byte(body)) + b = b.WithBody(buf) + }) + + assertSuccessfulRequest(getBuilder, testHTTPRequest{ + Method: "GET", + Header: header, + Body: []byte(body), + RequestURI: "/json/path", + }) + }) + + Context("with a method", func() { + Context("POST with a body", func() { + const body = "{\"some\": \"body\"}" + header := baseHeaders.Clone() + header.Set("Content-Length", fmt.Sprintf("%d", len(body))) + + BeforeEach(func() { + buf := bytes.NewBuffer([]byte(body)) + b = b.WithMethod("POST").WithBody(buf) + }) + + assertSuccessfulRequest(getBuilder, testHTTPRequest{ + Method: "POST", + Header: header, + Body: []byte(body), + RequestURI: "/json/path", + }) + }) + + Context("POST without a body", func() { + header := baseHeaders.Clone() + header.Set("Content-Length", "0") + + BeforeEach(func() { + b = b.WithMethod("POST") + }) + + assertSuccessfulRequest(getBuilder, testHTTPRequest{ + Method: "POST", + Header: header, + Body: []byte{}, + RequestURI: "/json/path", + }) + }) + + Context("OPTIONS", func() { + BeforeEach(func() { + b = b.WithMethod("OPTIONS") + }) + + assertSuccessfulRequest(getBuilder, testHTTPRequest{ + Method: "OPTIONS", + Header: baseHeaders, + Body: []byte{}, + RequestURI: "/json/path", + }) + }) + + Context("INVALID-\\t-METHOD", func() { + BeforeEach(func() { + b = b.WithMethod("INVALID-\t-METHOD") + }) + + assertRequestError(getBuilder, "error creating request: net/http: invalid method \"INVALID-\\t-METHOD\"") + }) + }) + + Context("with headers", func() { + Context("setting a header", func() { + header := baseHeaders.Clone() + header.Set("header", "value") + + BeforeEach(func() { + b = b.SetHeader("header", "value") + }) + + assertSuccessfulRequest(getBuilder, testHTTPRequest{ + Method: "GET", + Header: header, + Body: []byte{}, + RequestURI: "/json/path", + }) + + Context("then replacing the headers", func() { + replacementHeaders := http.Header{ + "Accept-Encoding": []string{"*"}, + "User-Agent": []string{"test-agent"}, + "Foo": []string{"bar, baz"}, + } + + BeforeEach(func() { + b = b.WithHeaders(replacementHeaders) + }) + + assertSuccessfulRequest(getBuilder, testHTTPRequest{ + Method: "GET", + Header: replacementHeaders, + Body: []byte{}, + RequestURI: "/json/path", + }) + }) + }) + + Context("replacing the header", func() { + replacementHeaders := http.Header{ + "Accept-Encoding": []string{"*"}, + "User-Agent": []string{"test-agent"}, + "Foo": []string{"bar, baz"}, + } + + BeforeEach(func() { + b = b.WithHeaders(replacementHeaders) + }) + + assertSuccessfulRequest(getBuilder, testHTTPRequest{ + Method: "GET", + Header: replacementHeaders, + Body: []byte{}, + RequestURI: "/json/path", + }) + + Context("then setting a header", func() { + header := replacementHeaders.Clone() + header.Set("User-Agent", "different-agent") + + BeforeEach(func() { + b = b.SetHeader("User-Agent", "different-agent") + }) + + assertSuccessfulRequest(getBuilder, testHTTPRequest{ + Method: "GET", + Header: header, + Body: []byte{}, + RequestURI: "/json/path", + }) + }) + }) + }) + + Context("if the request has been completed and then modified", func() { + BeforeEach(func() { + _, err := b.Do() + Expect(err).ToNot(HaveOccurred()) + + b.WithMethod("POST") + }) + + Context("should not redo the request", func() { + assertSuccessfulRequest(getBuilder, testHTTPRequest{ + Method: "GET", + Header: baseHeaders, + Body: []byte{}, + RequestURI: "/json/path", + }) + }) + }) + + Context("when the requested page is not found", func() { + BeforeEach(func() { + b = New(serverAddr + "/not-found") + }) + + assertJSONError(getBuilder, "404 page not found") + }) + + Context("when the requested page is not valid JSON", func() { + BeforeEach(func() { + b = New(serverAddr + "/string/path") + }) + + assertJSONError(getBuilder, "invalid character 'O' looking for beginning of value") + }) +}) + +func assertSuccessfulRequest(builder func() Builder, expectedRequest testHTTPRequest) { + Context("Do", func() { + var resp *http.Response + + BeforeEach(func() { + var err error + resp, err = builder().Do() + Expect(err).ToNot(HaveOccurred()) + }) + + It("returns a successful status", func() { + Expect(resp.StatusCode).To(Equal(http.StatusOK)) + }) + + It("made the expected request", func() { + body, err := ioutil.ReadAll(resp.Body) + Expect(err).ToNot(HaveOccurred()) + resp.Body.Close() + + actualRequest := testHTTPRequest{} + Expect(json.Unmarshal(body, &actualRequest)).To(Succeed()) + + Expect(actualRequest).To(Equal(expectedRequest)) + }) + }) + + Context("UnmarshalInto", func() { + var actualRequest testHTTPRequest + + BeforeEach(func() { + Expect(builder().UnmarshalInto(&actualRequest)).To(Succeed()) + }) + + It("made the expected request", func() { + Expect(actualRequest).To(Equal(expectedRequest)) + }) + }) + + Context("UnmarshalJSON", func() { + var response *simplejson.Json + + BeforeEach(func() { + var err error + response, err = builder().UnmarshalJSON() + Expect(err).ToNot(HaveOccurred()) + }) + + It("made the expected reqest", func() { + header := http.Header{} + for key, value := range response.Get("Header").MustMap() { + vs, ok := value.([]interface{}) + Expect(ok).To(BeTrue()) + svs := []string{} + for _, v := range vs { + sv, ok := v.(string) + Expect(ok).To(BeTrue()) + svs = append(svs, sv) + } + header[key] = svs + } + + // Other json unmarhsallers base64 decode byte slices automatically + body, err := base64.StdEncoding.DecodeString(response.Get("Body").MustString()) + Expect(err).ToNot(HaveOccurred()) + + actualRequest := testHTTPRequest{ + Method: response.Get("Method").MustString(), + Header: header, + Body: body, + RequestURI: response.Get("RequestURI").MustString(), + } + + Expect(actualRequest).To(Equal(expectedRequest)) + }) + }) +} + +func assertRequestError(builder func() Builder, errorMessage string) { + Context("Do", func() { + It("returns an error", func() { + resp, err := builder().Do() + Expect(err).To(MatchError(ContainSubstring(errorMessage))) + Expect(resp).To(BeNil()) + }) + }) + + Context("UnmarshalInto", func() { + It("returns an error", func() { + var actualRequest testHTTPRequest + err := builder().UnmarshalInto(&actualRequest) + Expect(err).To(MatchError(ContainSubstring(errorMessage))) + + // Should be empty + Expect(actualRequest).To(Equal(testHTTPRequest{})) + }) + }) + + Context("UnmarshalJSON", func() { + It("returns an error", func() { + resp, err := builder().UnmarshalJSON() + Expect(err).To(MatchError(ContainSubstring(errorMessage))) + Expect(resp).To(BeNil()) + }) + }) +} + +func assertJSONError(builder func() Builder, errorMessage string) { + Context("Do", func() { + It("does not return an error", func() { + resp, err := builder().Do() + Expect(err).To(BeNil()) + Expect(resp).ToNot(BeNil()) + }) + }) + + Context("UnmarshalInto", func() { + It("returns an error", func() { + var actualRequest testHTTPRequest + err := builder().UnmarshalInto(&actualRequest) + Expect(err).To(MatchError(ContainSubstring(errorMessage))) + + // Should be empty + Expect(actualRequest).To(Equal(testHTTPRequest{})) + }) + }) + + Context("UnmarshalJSON", func() { + It("returns an error", func() { + resp, err := builder().UnmarshalJSON() + Expect(err).To(MatchError(ContainSubstring(errorMessage))) + Expect(resp).To(BeNil()) + }) + }) +} diff --git a/pkg/requests/requests_suite_test.go b/pkg/requests/requests_suite_test.go new file mode 100644 index 00000000..83da733a --- /dev/null +++ b/pkg/requests/requests_suite_test.go @@ -0,0 +1,96 @@ +package requests + +import ( + "encoding/json" + "fmt" + "io/ioutil" + "log" + "net/http" + "net/http/httptest" + "testing" + + "github.com/oauth2-proxy/oauth2-proxy/pkg/logger" + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var ( + server *httptest.Server + serverAddr string +) + +func TestRequetsSuite(t *testing.T) { + logger.SetOutput(GinkgoWriter) + log.SetOutput(GinkgoWriter) + + RegisterFailHandler(Fail) + RunSpecs(t, "Requests Suite") +} + +var _ = BeforeSuite(func() { + // Set up a webserver that reflects requests + mux := http.NewServeMux() + mux.Handle("/json/", &testHTTPUpstream{}) + mux.HandleFunc("/string/", func(rw http.ResponseWriter, _ *http.Request) { + rw.Write([]byte("OK")) + }) + server = httptest.NewServer(mux) + serverAddr = fmt.Sprintf("http://%s", server.Listener.Addr().String()) +}) + +var _ = AfterSuite(func() { + server.Close() +}) + +// testHTTPRequest is a struct used to capture the state of a request made to +// the test server +type testHTTPRequest struct { + Method string + Header http.Header + Body []byte + RequestURI string +} + +type testHTTPUpstream struct{} + +func (t *testHTTPUpstream) ServeHTTP(rw http.ResponseWriter, req *http.Request) { + request, err := toTestHTTPRequest(req) + if err != nil { + t.writeError(rw, err) + return + } + + data, err := json.Marshal(request) + if err != nil { + t.writeError(rw, err) + return + } + + rw.Header().Set("Content-Type", "application/json") + rw.Write(data) +} + +func (t *testHTTPUpstream) writeError(rw http.ResponseWriter, err error) { + rw.WriteHeader(500) + if err != nil { + rw.Write([]byte(err.Error())) + } +} + +func toTestHTTPRequest(req *http.Request) (testHTTPRequest, error) { + requestBody := []byte{} + if req.Body != http.NoBody { + var err error + requestBody, err = ioutil.ReadAll(req.Body) + if err != nil { + return testHTTPRequest{}, err + } + } + + return testHTTPRequest{ + Method: req.Method, + Header: req.Header, + Body: requestBody, + RequestURI: req.RequestURI, + }, nil +} From 53142455b6ed997ee10f45e05db63cb52c9f5311 Mon Sep 17 00:00:00 2001 From: Joel Speed Date: Fri, 3 Jul 2020 19:27:25 +0100 Subject: [PATCH 3/8] Migrate all requests to new builder pattern --- pkg/validation/options.go | 57 ++++++------- providers/azure.go | 49 +++-------- providers/bitbucket.go | 51 +++++------- providers/digitalocean.go | 10 +-- providers/facebook.go | 13 +-- providers/github.go | 153 +++++++++------------------------- providers/gitlab.go | 32 ++----- providers/google.go | 68 +++++---------- providers/internal_util.go | 6 +- providers/keycloak.go | 13 +-- providers/linkedin.go | 11 ++- providers/logingov.go | 84 ++++++------------- providers/nextcloud.go | 17 ++-- providers/oidc.go | 11 +-- providers/provider_default.go | 18 ++-- 15 files changed, 194 insertions(+), 399 deletions(-) diff --git a/pkg/validation/options.go b/pkg/validation/options.go index 50717b60..ae2ed065 100644 --- a/pkg/validation/options.go +++ b/pkg/validation/options.go @@ -83,34 +83,33 @@ func Validate(o *options.Options) error { logger.Printf("Performing OIDC Discovery...") - if req, err := http.NewRequest("GET", strings.TrimSuffix(o.OIDCIssuerURL, "/")+"/.well-known/openid-configuration", nil); err == nil { - if body, err := requests.Request(req); err == nil { - - // Prefer manually configured URLs. It's a bit unclear - // why you'd be doing discovery and also providing the URLs - // explicitly though... - if o.LoginURL == "" { - o.LoginURL = body.Get("authorization_endpoint").MustString() - } - - if o.RedeemURL == "" { - o.RedeemURL = body.Get("token_endpoint").MustString() - } - - if o.OIDCJwksURL == "" { - o.OIDCJwksURL = body.Get("jwks_uri").MustString() - } - - if o.ProfileURL == "" { - o.ProfileURL = body.Get("userinfo_endpoint").MustString() - } - - o.SkipOIDCDiscovery = true - } else { - logger.Printf("error: failed to discover OIDC configuration: %v", err) - } + requestURL := strings.TrimSuffix(o.OIDCIssuerURL, "/") + "/.well-known/openid-configuration" + body, err := requests.New(requestURL). + WithContext(ctx). + UnmarshalJSON() + if err != nil { + logger.Printf("error: failed to discover OIDC configuration: %v", err) } else { - logger.Printf("error: failed parsing OIDC discovery URL: %v", err) + // Prefer manually configured URLs. It's a bit unclear + // why you'd be doing discovery and also providing the URLs + // explicitly though... + if o.LoginURL == "" { + o.LoginURL = body.Get("authorization_endpoint").MustString() + } + + if o.RedeemURL == "" { + o.RedeemURL = body.Get("token_endpoint").MustString() + } + + if o.OIDCJwksURL == "" { + o.OIDCJwksURL = body.Get("jwks_uri").MustString() + } + + if o.ProfileURL == "" { + o.ProfileURL = body.Get("userinfo_endpoint").MustString() + } + + o.SkipOIDCDiscovery = true } } @@ -385,10 +384,12 @@ func newVerifierFromJwtIssuer(jwtIssuer jwtIssuer) (*oidc.IDTokenVerifier, error if err != nil { // Try as JWKS URI jwksURI := strings.TrimSuffix(jwtIssuer.issuerURI, "/") + "/.well-known/jwks.json" - _, err := http.NewRequest("GET", jwksURI, nil) + resp, err := requests.New(jwksURI).Do() if err != nil { return nil, err } + resp.Body.Close() + verifier = oidc.NewVerifier(jwtIssuer.issuerURI, oidc.NewRemoteKeySet(context.Background(), jwksURI), config) } else { verifier = provider.Verifier(config) diff --git a/providers/azure.go b/providers/azure.go index aea1b0e5..5e8df0a0 100644 --- a/providers/azure.go +++ b/providers/azure.go @@ -3,10 +3,8 @@ package providers import ( "bytes" "context" - "encoding/json" "errors" "fmt" - "io/ioutil" "net/http" "net/url" "time" @@ -91,39 +89,21 @@ func (p *AzureProvider) Redeem(ctx context.Context, redirectURL, code string) (s params.Add("resource", p.ProtectedResource.String()) } - var req *http.Request - req, err = http.NewRequestWithContext(ctx, "POST", p.RedeemURL.String(), bytes.NewBufferString(params.Encode())) - if err != nil { - return - } - req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - - var resp *http.Response - resp, err = http.DefaultClient.Do(req) - if err != nil { - return nil, err - } - var body []byte - body, err = ioutil.ReadAll(resp.Body) - resp.Body.Close() - if err != nil { - return - } - - if resp.StatusCode != 200 { - err = fmt.Errorf("got %d from %q %s", resp.StatusCode, p.RedeemURL.String(), body) - return - } - var jsonResponse struct { AccessToken string `json:"access_token"` RefreshToken string `json:"refresh_token"` ExpiresOn int64 `json:"expires_on,string"` IDToken string `json:"id_token"` } - err = json.Unmarshal(body, &jsonResponse) + + err = requests.New(p.RedeemURL.String()). + WithContext(ctx). + WithMethod("POST"). + WithBody(bytes.NewBufferString(params.Encode())). + SetHeader("Content-Type", "application/x-www-form-urlencoded"). + UnmarshalInto(&jsonResponse) if err != nil { - return + return nil, err } created := time.Now() @@ -169,26 +149,21 @@ func (p *AzureProvider) GetEmailAddress(ctx context.Context, s *sessions.Session if s.AccessToken == "" { return "", errors.New("missing access token") } - req, err := http.NewRequestWithContext(ctx, "GET", p.ProfileURL.String(), nil) - if err != nil { - return "", err - } - req.Header = getAzureHeader(s.AccessToken) - - json, err := requests.Request(req) + json, err := requests.New(p.ProfileURL.String()). + WithContext(ctx). + WithHeaders(getAzureHeader(s.AccessToken)). + UnmarshalJSON() if err != nil { return "", err } email, err = getEmailFromJSON(json) - if err == nil && email != "" { return email, err } email, err = json.Get("userPrincipalName").String() - if err != nil { logger.Printf("failed making request %s", err) return "", err diff --git a/providers/bitbucket.go b/providers/bitbucket.go index 2bb876cb..726d3562 100644 --- a/providers/bitbucket.go +++ b/providers/bitbucket.go @@ -2,7 +2,6 @@ package providers import ( "context" - "net/http" "net/url" "strings" @@ -85,15 +84,13 @@ func (p *BitbucketProvider) GetEmailAddress(ctx context.Context, s *sessions.Ses FullName string `json:"full_name"` } } - req, err := http.NewRequestWithContext(ctx, "GET", - p.ValidateURL.String()+"?access_token="+s.AccessToken, nil) + + requestURL := p.ValidateURL.String() + "?access_token=" + s.AccessToken + err := requests.New(requestURL). + WithContext(ctx). + UnmarshalInto(&emails) if err != nil { - logger.Printf("failed building request %s", err) - return "", err - } - err = requests.RequestJSON(req, &emails) - if err != nil { - logger.Printf("failed making request %s", err) + logger.Printf("failed making request: %v", err) return "", err } @@ -101,15 +98,14 @@ func (p *BitbucketProvider) GetEmailAddress(ctx context.Context, s *sessions.Ses teamURL := &url.URL{} *teamURL = *p.ValidateURL teamURL.Path = "/2.0/teams" - req, err = http.NewRequestWithContext(ctx, "GET", - teamURL.String()+"?role=member&access_token="+s.AccessToken, nil) + + requestURL := teamURL.String() + "?role=member&access_token=" + s.AccessToken + + err := requests.New(requestURL). + WithContext(ctx). + UnmarshalInto(&teams) if err != nil { - logger.Printf("failed building request %s", err) - return "", err - } - err = requests.RequestJSON(req, &teams) - if err != nil { - logger.Printf("failed requesting teams membership %s", err) + logger.Printf("failed requesting teams membership: %v", err) return "", err } var found = false @@ -129,20 +125,19 @@ func (p *BitbucketProvider) GetEmailAddress(ctx context.Context, s *sessions.Ses repositoriesURL := &url.URL{} *repositoriesURL = *p.ValidateURL repositoriesURL.Path = "/2.0/repositories/" + strings.Split(p.Repository, "/")[0] - req, err = http.NewRequestWithContext(ctx, "GET", - repositoriesURL.String()+"?role=contributor"+ - "&q=full_name="+url.QueryEscape("\""+p.Repository+"\"")+ - "&access_token="+s.AccessToken, - nil) + + requestURL := repositoriesURL.String() + "?role=contributor" + + "&q=full_name=" + url.QueryEscape("\""+p.Repository+"\"") + + "&access_token=" + s.AccessToken + + err := requests.New(requestURL). + WithContext(ctx). + UnmarshalInto(&repositories) if err != nil { - logger.Printf("failed building request %s", err) - return "", err - } - err = requests.RequestJSON(req, &repositories) - if err != nil { - logger.Printf("failed checking repository access %s", err) + logger.Printf("failed checking repository access: %v", err) return "", err } + var found = false for _, repository := range repositories.Values { if p.Repository == repository.FullName { diff --git a/providers/digitalocean.go b/providers/digitalocean.go index 25d37af9..306646d5 100644 --- a/providers/digitalocean.go +++ b/providers/digitalocean.go @@ -60,13 +60,11 @@ func (p *DigitalOceanProvider) GetEmailAddress(ctx context.Context, s *sessions. if s.AccessToken == "" { return "", errors.New("missing access token") } - req, err := http.NewRequestWithContext(ctx, "GET", p.ProfileURL.String(), nil) - if err != nil { - return "", err - } - req.Header = getDigitalOceanHeader(s.AccessToken) - json, err := requests.Request(req) + json, err := requests.New(p.ProfileURL.String()). + WithContext(ctx). + WithHeaders(getDigitalOceanHeader(s.AccessToken)). + UnmarshalJSON() if err != nil { return "", err } diff --git a/providers/facebook.go b/providers/facebook.go index 0f9cc624..81973416 100644 --- a/providers/facebook.go +++ b/providers/facebook.go @@ -62,20 +62,21 @@ func (p *FacebookProvider) GetEmailAddress(ctx context.Context, s *sessions.Sess if s.AccessToken == "" { return "", errors.New("missing access token") } - req, err := http.NewRequestWithContext(ctx, "GET", p.ProfileURL.String()+"?fields=name,email", nil) - if err != nil { - return "", err - } - req.Header = getFacebookHeader(s.AccessToken) type result struct { Email string } var r result - err = requests.RequestJSON(req, &r) + + requestURL := p.ProfileURL.String() + "?fields=name,email" + err := requests.New(requestURL). + WithContext(ctx). + WithHeaders(getFacebookHeader(s.AccessToken)). + UnmarshalInto(&r) if err != nil { return "", err } + if r.Email == "" { return "", errors.New("no email") } diff --git a/providers/github.go b/providers/github.go index 6d3f8b02..e93a23e1 100644 --- a/providers/github.go +++ b/providers/github.go @@ -2,7 +2,6 @@ package providers import ( "context" - "encoding/json" "errors" "fmt" "io/ioutil" @@ -15,6 +14,7 @@ import ( "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions" "github.com/oauth2-proxy/oauth2-proxy/pkg/logger" + "github.com/oauth2-proxy/oauth2-proxy/pkg/requests" ) // GitHubProvider represents an GitHub based Identity Provider @@ -111,27 +111,16 @@ func (p *GitHubProvider) hasOrg(ctx context.Context, accessToken string) (bool, Path: path.Join(p.ValidateURL.Path, "/user/orgs"), RawQuery: params.Encode(), } - req, _ := http.NewRequestWithContext(ctx, "GET", endpoint.String(), nil) - req.Header = getGitHubHeader(accessToken) - resp, err := http.DefaultClient.Do(req) - if err != nil { - return false, err - } - - body, err := ioutil.ReadAll(resp.Body) - resp.Body.Close() - if err != nil { - return false, err - } - if resp.StatusCode != 200 { - return false, fmt.Errorf( - "got %d from %q %s", resp.StatusCode, endpoint.String(), body) - } var op orgsPage - if err := json.Unmarshal(body, &op); err != nil { + err := requests.New(endpoint.String()). + WithContext(ctx). + WithHeaders(getGitHubHeader(accessToken)). + UnmarshalInto(&op) + if err != nil { return false, err } + if len(op) == 0 { break } @@ -187,9 +176,13 @@ func (p *GitHubProvider) hasOrgAndTeam(ctx context.Context, accessToken string) RawQuery: params.Encode(), } - req, _ := http.NewRequestWithContext(ctx, "GET", endpoint.String(), nil) - req.Header = getGitHubHeader(accessToken) - resp, err := http.DefaultClient.Do(req) + // bodyclose cannot detect that the body is being closed later in requests.Into, + // so have to skip the linting for the next line. + // nolint:bodyclose + resp, err := requests.New(endpoint.String()). + WithContext(ctx). + WithHeaders(getGitHubHeader(accessToken)). + Do() if err != nil { return false, err } @@ -217,21 +210,9 @@ func (p *GitHubProvider) hasOrgAndTeam(ctx context.Context, accessToken string) } } - body, err := ioutil.ReadAll(resp.Body) - if err != nil { - resp.Body.Close() - return false, err - } - resp.Body.Close() - - if resp.StatusCode != 200 { - return false, fmt.Errorf( - "got %d from %q %s", resp.StatusCode, endpoint.String(), body) - } - var tp teamsPage - if err := json.Unmarshal(body, &tp); err != nil { - return false, fmt.Errorf("%s unmarshaling %s", err, body) + if err := requests.UnmarshalInto(resp, &tp); err != nil { + return false, err } if len(tp) == 0 { break @@ -297,25 +278,12 @@ func (p *GitHubProvider) hasRepo(ctx context.Context, accessToken string) (bool, Path: path.Join(p.ValidateURL.Path, "/repo/", p.Repo), } - req, _ := http.NewRequestWithContext(ctx, "GET", endpoint.String(), nil) - req.Header = getGitHubHeader(accessToken) - resp, err := http.DefaultClient.Do(req) - if err != nil { - return false, err - } - - body, err := ioutil.ReadAll(resp.Body) - resp.Body.Close() - if err != nil { - return false, err - } - if resp.StatusCode != 200 { - return false, fmt.Errorf( - "got %d from %q %s", resp.StatusCode, endpoint.String(), body) - } - var repo repository - if err := json.Unmarshal(body, &repo); err != nil { + err := requests.New(endpoint.String()). + WithContext(ctx). + WithHeaders(getGitHubHeader(accessToken)). + UnmarshalInto(&repo) + if err != nil { return false, err } @@ -337,26 +305,14 @@ func (p *GitHubProvider) hasUser(ctx context.Context, accessToken string) (bool, Host: p.ValidateURL.Host, Path: path.Join(p.ValidateURL.Path, "/user"), } - req, _ := http.NewRequestWithContext(ctx, "GET", endpoint.String(), nil) - req.Header = getGitHubHeader(accessToken) - resp, err := http.DefaultClient.Do(req) - if err != nil { - return false, err - } - defer resp.Body.Close() - body, err := ioutil.ReadAll(resp.Body) + err := requests.New(endpoint.String()). + WithContext(ctx). + WithHeaders(getGitHubHeader(accessToken)). + UnmarshalInto(&user) if err != nil { return false, err } - if resp.StatusCode != 200 { - return false, fmt.Errorf("got %d from %q %s", - resp.StatusCode, stripToken(endpoint.String()), body) - } - - if err := json.Unmarshal(body, &user); err != nil { - return false, err - } if p.isVerifiedUser(user.Login) { return true, nil @@ -372,12 +328,14 @@ func (p *GitHubProvider) isCollaborator(ctx context.Context, username, accessTok Host: p.ValidateURL.Host, Path: path.Join(p.ValidateURL.Path, "/repos/", p.Repo, "/collaborators/", username), } - req, _ := http.NewRequestWithContext(ctx, "GET", endpoint.String(), nil) - req.Header = getGitHubHeader(accessToken) - resp, err := http.DefaultClient.Do(req) + resp, err := requests.New(endpoint.String()). + WithContext(ctx). + WithHeaders(getGitHubHeader(accessToken)). + Do() if err != nil { return false, err } + body, err := ioutil.ReadAll(resp.Body) resp.Body.Close() if err != nil { @@ -440,28 +398,13 @@ func (p *GitHubProvider) GetEmailAddress(ctx context.Context, s *sessions.Sessio Host: p.ValidateURL.Host, Path: path.Join(p.ValidateURL.Path, "/user/emails"), } - req, _ := http.NewRequestWithContext(ctx, "GET", endpoint.String(), nil) - req.Header = getGitHubHeader(s.AccessToken) - resp, err := http.DefaultClient.Do(req) + err := requests.New(endpoint.String()). + WithContext(ctx). + WithHeaders(getGitHubHeader(s.AccessToken)). + UnmarshalInto(&emails) if err != nil { return "", err } - body, err := ioutil.ReadAll(resp.Body) - resp.Body.Close() - if err != nil { - return "", err - } - - if resp.StatusCode != 200 { - return "", fmt.Errorf("got %d from %q %s", - resp.StatusCode, endpoint.String(), body) - } - - logger.Printf("got %d from %q %s", resp.StatusCode, endpoint.String(), body) - - if err := json.Unmarshal(body, &emails); err != nil { - return "", fmt.Errorf("%s unmarshaling %s", err, body) - } returnEmail := "" for _, email := range emails { @@ -489,34 +432,14 @@ func (p *GitHubProvider) GetUserName(ctx context.Context, s *sessions.SessionSta Path: path.Join(p.ValidateURL.Path, "/user"), } - req, err := http.NewRequestWithContext(ctx, "GET", endpoint.String(), nil) - if err != nil { - return "", fmt.Errorf("could not create new GET request: %v", err) - } - - req.Header = getGitHubHeader(s.AccessToken) - resp, err := http.DefaultClient.Do(req) + err := requests.New(endpoint.String()). + WithContext(ctx). + WithHeaders(getGitHubHeader(s.AccessToken)). + UnmarshalInto(&user) if err != nil { return "", err } - body, err := ioutil.ReadAll(resp.Body) - defer resp.Body.Close() - if err != nil { - return "", err - } - - if resp.StatusCode != 200 { - return "", fmt.Errorf("got %d from %q %s", - resp.StatusCode, endpoint.String(), body) - } - - logger.Printf("got %d from %q %s", resp.StatusCode, endpoint.String(), body) - - if err := json.Unmarshal(body, &user); err != nil { - return "", fmt.Errorf("%s unmarshaling %s", err, body) - } - // Now that we have the username we can check collaborator status if !p.isVerifiedUser(user.Login) && p.Org == "" && p.Repo != "" && p.Token != "" { if ok, err := p.isCollaborator(ctx, user.Login, p.Token); err != nil || !ok { diff --git a/providers/gitlab.go b/providers/gitlab.go index 8d959781..17c06f5e 100644 --- a/providers/gitlab.go +++ b/providers/gitlab.go @@ -2,15 +2,13 @@ package providers import ( "context" - "encoding/json" "fmt" - "io/ioutil" - "net/http" "strings" "time" oidc "github.com/coreos/go-oidc" "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions" + "github.com/oauth2-proxy/oauth2-proxy/pkg/requests" "golang.org/x/oauth2" ) @@ -131,31 +129,13 @@ func (p *GitLabProvider) getUserInfo(ctx context.Context, s *sessions.SessionSta userInfoURL := *p.LoginURL userInfoURL.Path = "/oauth/userinfo" - req, err := http.NewRequestWithContext(ctx, "GET", userInfoURL.String(), nil) - if err != nil { - return nil, fmt.Errorf("failed to create user info request: %v", err) - } - req.Header.Set("Authorization", "Bearer "+s.AccessToken) - - resp, err := http.DefaultClient.Do(req) - if err != nil { - return nil, fmt.Errorf("failed to perform user info request: %v", err) - } - var body []byte - body, err = ioutil.ReadAll(resp.Body) - resp.Body.Close() - if err != nil { - return nil, fmt.Errorf("failed to read user info response: %v", err) - } - - if resp.StatusCode != 200 { - return nil, fmt.Errorf("got %d during user info request: %s", resp.StatusCode, body) - } - var userInfo gitlabUserInfo - err = json.Unmarshal(body, &userInfo) + err := requests.New(userInfoURL.String()). + WithContext(ctx). + SetHeader("Authorization", "Bearer "+s.AccessToken). + UnmarshalInto(&userInfo) if err != nil { - return nil, fmt.Errorf("failed to parse user info: %v", err) + return nil, fmt.Errorf("error getting user info: %v", err) } return &userInfo, nil diff --git a/providers/google.go b/providers/google.go index 5aeb6e2d..fbed12f1 100644 --- a/providers/google.go +++ b/providers/google.go @@ -9,13 +9,13 @@ import ( "fmt" "io" "io/ioutil" - "net/http" "net/url" "strings" "time" "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions" "github.com/oauth2-proxy/oauth2-proxy/pkg/logger" + "github.com/oauth2-proxy/oauth2-proxy/pkg/requests" "golang.org/x/oauth2/google" admin "google.golang.org/api/admin/directory/v1" "google.golang.org/api/googleapi" @@ -116,28 +116,6 @@ func (p *GoogleProvider) Redeem(ctx context.Context, redirectURL, code string) ( params.Add("client_secret", clientSecret) params.Add("code", code) params.Add("grant_type", "authorization_code") - var req *http.Request - req, err = http.NewRequestWithContext(ctx, "POST", p.RedeemURL.String(), bytes.NewBufferString(params.Encode())) - if err != nil { - return - } - req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - - resp, err := http.DefaultClient.Do(req) - if err != nil { - return - } - var body []byte - body, err = ioutil.ReadAll(resp.Body) - resp.Body.Close() - if err != nil { - return - } - - if resp.StatusCode != 200 { - err = fmt.Errorf("got %d from %q %s", resp.StatusCode, p.RedeemURL.String(), body) - return - } var jsonResponse struct { AccessToken string `json:"access_token"` @@ -145,10 +123,17 @@ func (p *GoogleProvider) Redeem(ctx context.Context, redirectURL, code string) ( ExpiresIn int64 `json:"expires_in"` IDToken string `json:"id_token"` } - err = json.Unmarshal(body, &jsonResponse) + + err = requests.New(p.RedeemURL.String()). + WithContext(ctx). + WithMethod("POST"). + WithBody(bytes.NewBufferString(params.Encode())). + SetHeader("Content-Type", "application/x-www-form-urlencoded"). + UnmarshalInto(&jsonResponse) if err != nil { - return + return nil, err } + c, err := claimsFromIDToken(jsonResponse.IDToken) if err != nil { return @@ -283,38 +268,23 @@ func (p *GoogleProvider) redeemRefreshToken(ctx context.Context, refreshToken st params.Add("client_secret", clientSecret) params.Add("refresh_token", refreshToken) params.Add("grant_type", "refresh_token") - var req *http.Request - req, err = http.NewRequestWithContext(ctx, "POST", p.RedeemURL.String(), bytes.NewBufferString(params.Encode())) - if err != nil { - return - } - req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - - resp, err := http.DefaultClient.Do(req) - if err != nil { - return - } - var body []byte - body, err = ioutil.ReadAll(resp.Body) - resp.Body.Close() - if err != nil { - return - } - - if resp.StatusCode != 200 { - err = fmt.Errorf("got %d from %q %s", resp.StatusCode, p.RedeemURL.String(), body) - return - } var data struct { AccessToken string `json:"access_token"` ExpiresIn int64 `json:"expires_in"` IDToken string `json:"id_token"` } - err = json.Unmarshal(body, &data) + + err = requests.New(p.RedeemURL.String()). + WithContext(ctx). + WithMethod("POST"). + WithBody(bytes.NewBufferString(params.Encode())). + SetHeader("Content-Type", "application/x-www-form-urlencoded"). + UnmarshalInto(&data) if err != nil { - return + return "", "", 0, err } + token = data.AccessToken idToken = data.IDToken expires = time.Duration(data.ExpiresIn) * time.Second diff --git a/providers/internal_util.go b/providers/internal_util.go index f9bdc304..361e27c7 100644 --- a/providers/internal_util.go +++ b/providers/internal_util.go @@ -56,7 +56,11 @@ func validateToken(ctx context.Context, p Provider, accessToken string, header h params := url.Values{"access_token": {accessToken}} endpoint = endpoint + "?" + params.Encode() } - resp, err := requests.RequestUnparsedResponse(ctx, endpoint, header) + + resp, err := requests.New(endpoint). + WithContext(ctx). + WithHeaders(header). + Do() if err != nil { logger.Printf("GET %s", stripToken(endpoint)) logger.Printf("token validation request failed: %s", err) diff --git a/providers/keycloak.go b/providers/keycloak.go index 414c4973..78206de8 100644 --- a/providers/keycloak.go +++ b/providers/keycloak.go @@ -2,7 +2,6 @@ package providers import ( "context" - "net/http" "net/url" "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions" @@ -51,14 +50,10 @@ func (p *KeycloakProvider) SetGroup(group string) { } func (p *KeycloakProvider) GetEmailAddress(ctx context.Context, s *sessions.SessionState) (string, error) { - - req, err := http.NewRequestWithContext(ctx, "GET", p.ValidateURL.String(), nil) - req.Header.Set("Authorization", "Bearer "+s.AccessToken) - if err != nil { - logger.Printf("failed building request %s", err) - return "", err - } - json, err := requests.Request(req) + json, err := requests.New(p.ValidateURL.String()). + WithContext(ctx). + SetHeader("Authorization", "Bearer "+s.AccessToken). + UnmarshalJSON() if err != nil { logger.Printf("failed making request %s", err) return "", err diff --git a/providers/linkedin.go b/providers/linkedin.go index 6cc24239..8dcd3c9d 100644 --- a/providers/linkedin.go +++ b/providers/linkedin.go @@ -58,13 +58,12 @@ func (p *LinkedInProvider) GetEmailAddress(ctx context.Context, s *sessions.Sess if s.AccessToken == "" { return "", errors.New("missing access token") } - req, err := http.NewRequestWithContext(ctx, "GET", p.ProfileURL.String()+"?format=json", nil) - if err != nil { - return "", err - } - req.Header = getLinkedInHeader(s.AccessToken) - json, err := requests.Request(req) + requestURL := p.ProfileURL.String() + "?format=json" + json, err := requests.New(requestURL). + WithContext(ctx). + WithHeaders(getLinkedInHeader(s.AccessToken)). + UnmarshalJSON() if err != nil { return "", err } diff --git a/providers/logingov.go b/providers/logingov.go index 46027172..8846a8f2 100644 --- a/providers/logingov.go +++ b/providers/logingov.go @@ -15,6 +15,7 @@ import ( "github.com/dgrijalva/jwt-go" "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions" + "github.com/oauth2-proxy/oauth2-proxy/pkg/requests" "gopkg.in/square/go-jose.v2" ) @@ -128,51 +129,33 @@ func checkNonce(idToken string, p *LoginGovProvider) (err error) { return } -func emailFromUserInfo(ctx context.Context, accessToken string, userInfoEndpoint string) (email string, err error) { - // query the user info endpoint for user attributes - var req *http.Request - req, err = http.NewRequestWithContext(ctx, "GET", userInfoEndpoint, nil) - if err != nil { - return - } - req.Header.Set("Authorization", "Bearer "+accessToken) - - resp, err := http.DefaultClient.Do(req) - if err != nil { - return - } - var body []byte - body, err = ioutil.ReadAll(resp.Body) - resp.Body.Close() - if err != nil { - return - } - - if resp.StatusCode != 200 { - err = fmt.Errorf("got %d from %q %s", resp.StatusCode, userInfoEndpoint, body) - return - } - +func emailFromUserInfo(ctx context.Context, accessToken string, userInfoEndpoint string) (string, error) { // parse the user attributes from the data we got and make sure that // the email address has been validated. var emailData struct { Email string `json:"email"` EmailVerified bool `json:"email_verified"` } - err = json.Unmarshal(body, &emailData) + + // query the user info endpoint for user attributes + err := requests.New(userInfoEndpoint). + WithContext(ctx). + SetHeader("Authorization", "Bearer "+accessToken). + UnmarshalInto(&emailData) if err != nil { - return + return "", err } - if emailData.Email == "" { - err = fmt.Errorf("missing email") - return + + email := emailData.Email + if email == "" { + return "", fmt.Errorf("missing email") } - email = emailData.Email + if !emailData.EmailVerified { - err = fmt.Errorf("email %s not listed as verified", email) - return + return "", fmt.Errorf("email %s not listed as verified", email) } - return + + return email, nil } // Redeem exchanges the OAuth2 authentication token for an ID token @@ -201,30 +184,6 @@ func (p *LoginGovProvider) Redeem(ctx context.Context, redirectURL, code string) params.Add("code", code) params.Add("grant_type", "authorization_code") - var req *http.Request - req, err = http.NewRequestWithContext(ctx, "POST", p.RedeemURL.String(), bytes.NewBufferString(params.Encode())) - if err != nil { - return - } - req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - - var resp *http.Response - resp, err = http.DefaultClient.Do(req) - if err != nil { - return nil, err - } - var body []byte - body, err = ioutil.ReadAll(resp.Body) - resp.Body.Close() - if err != nil { - return - } - - if resp.StatusCode != 200 { - err = fmt.Errorf("got %d from %q %s", resp.StatusCode, p.RedeemURL.String(), body) - return - } - // Get the token from the body that we got from the token endpoint. var jsonResponse struct { AccessToken string `json:"access_token"` @@ -232,9 +191,14 @@ func (p *LoginGovProvider) Redeem(ctx context.Context, redirectURL, code string) TokenType string `json:"token_type"` ExpiresIn int64 `json:"expires_in"` } - err = json.Unmarshal(body, &jsonResponse) + err = requests.New(p.RedeemURL.String()). + WithContext(ctx). + WithMethod("POST"). + WithBody(bytes.NewBufferString(params.Encode())). + SetHeader("Content-Type", "application/x-www-form-urlencoded"). + UnmarshalInto(&jsonResponse) if err != nil { - return + return nil, err } // check nonce here diff --git a/providers/nextcloud.go b/providers/nextcloud.go index d51b7183..844bbff3 100644 --- a/providers/nextcloud.go +++ b/providers/nextcloud.go @@ -6,7 +6,6 @@ import ( "net/http" "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions" - "github.com/oauth2-proxy/oauth2-proxy/pkg/logger" "github.com/oauth2-proxy/oauth2-proxy/pkg/requests" ) @@ -31,18 +30,14 @@ func getNextcloudHeader(accessToken string) http.Header { // GetEmailAddress returns the Account email address func (p *NextcloudProvider) GetEmailAddress(ctx context.Context, s *sessions.SessionState) (string, error) { - req, err := http.NewRequestWithContext(ctx, "GET", - p.ValidateURL.String(), nil) + json, err := requests.New(p.ValidateURL.String()). + WithContext(ctx). + WithHeaders(getNextcloudHeader(s.AccessToken)). + UnmarshalJSON() if err != nil { - logger.Printf("failed building request %s", err) - return "", err - } - req.Header = getNextcloudHeader(s.AccessToken) - json, err := requests.Request(req) - if err != nil { - logger.Printf("failed making request %s", err) - return "", err + return "", fmt.Errorf("error making request: %v", err) } + email, err := json.Get("ocs").Get("data").Get("email").String() return email, err } diff --git a/providers/oidc.go b/providers/oidc.go index bc98dbb8..08e0c3e8 100644 --- a/providers/oidc.go +++ b/providers/oidc.go @@ -256,13 +256,10 @@ func (p *OIDCProvider) findClaimsFromIDToken(ctx context.Context, idToken *oidc. // If the userinfo endpoint profileURL is defined, then there is a chance the userinfo // contents at the profileURL contains the email. // Make a query to the userinfo endpoint, and attempt to locate the email from there. - req, err := http.NewRequestWithContext(ctx, "GET", profileURL, nil) - if err != nil { - return nil, err - } - req.Header = getOIDCHeader(accessToken) - - respJSON, err := requests.Request(req) + respJSON, err := requests.New(profileURL). + WithContext(ctx). + WithHeaders(getOIDCHeader(accessToken)). + UnmarshalJSON() if err != nil { return nil, err } diff --git a/providers/provider_default.go b/providers/provider_default.go index 14cec9fe..f29e8c4a 100644 --- a/providers/provider_default.go +++ b/providers/provider_default.go @@ -7,13 +7,13 @@ import ( "errors" "fmt" "io/ioutil" - "net/http" "net/url" "time" "github.com/coreos/go-oidc" "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions" + "github.com/oauth2-proxy/oauth2-proxy/pkg/requests" ) var _ Provider = (*ProviderData)(nil) @@ -39,18 +39,16 @@ func (p *ProviderData) Redeem(ctx context.Context, redirectURL, code string) (s params.Add("resource", p.ProtectedResource.String()) } - var req *http.Request - req, err = http.NewRequestWithContext(ctx, "POST", p.RedeemURL.String(), bytes.NewBufferString(params.Encode())) - if err != nil { - return - } - req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - - var resp *http.Response - resp, err = http.DefaultClient.Do(req) + resp, err := requests.New(p.RedeemURL.String()). + WithContext(ctx). + WithMethod("POST"). + WithBody(bytes.NewBufferString(params.Encode())). + SetHeader("Content-Type", "application/x-www-form-urlencoded"). + Do() if err != nil { return nil, err } + var body []byte body, err = ioutil.ReadAll(resp.Body) resp.Body.Close() From 028a0ed62e8c039caf5fae8a642468bb42b7eff7 Mon Sep 17 00:00:00 2001 From: Joel Speed Date: Sat, 4 Jul 2020 06:29:11 +0100 Subject: [PATCH 4/8] Remove old requests code --- pkg/requests/requests.go | 74 ------------------ pkg/requests/requests_test.go | 136 ---------------------------------- 2 files changed, 210 deletions(-) delete mode 100644 pkg/requests/requests.go delete mode 100644 pkg/requests/requests_test.go diff --git a/pkg/requests/requests.go b/pkg/requests/requests.go deleted file mode 100644 index 64cacaa9..00000000 --- a/pkg/requests/requests.go +++ /dev/null @@ -1,74 +0,0 @@ -package requests - -import ( - "context" - "encoding/json" - "fmt" - "io/ioutil" - "net/http" - - "github.com/bitly/go-simplejson" - "github.com/oauth2-proxy/oauth2-proxy/pkg/logger" -) - -// Request parses the request body into a simplejson.Json object -func Request(req *http.Request) (*simplejson.Json, error) { - resp, err := http.DefaultClient.Do(req) - if err != nil { - logger.Printf("%s %s %s", req.Method, req.URL, err) - return nil, err - } - body, err := ioutil.ReadAll(resp.Body) - if body != nil { - defer resp.Body.Close() - } - - logger.Printf("%d %s %s %s", resp.StatusCode, req.Method, req.URL, body) - - if err != nil { - return nil, fmt.Errorf("problem reading http request body: %w", err) - } - - if resp.StatusCode != 200 { - return nil, fmt.Errorf("got %d %s", resp.StatusCode, body) - } - - data, err := simplejson.NewJson(body) - if err != nil { - return nil, fmt.Errorf("error unmarshalling json: %w", err) - } - return data, nil -} - -// RequestJSON parses the request body into the given interface -func RequestJSON(req *http.Request, v interface{}) error { - resp, err := http.DefaultClient.Do(req) - if err != nil { - logger.Printf("%s %s %s", req.Method, req.URL, err) - return err - } - body, err := ioutil.ReadAll(resp.Body) - if body != nil { - defer resp.Body.Close() - } - - logger.Printf("%d %s %s %s", resp.StatusCode, req.Method, req.URL, body) - if err != nil { - return fmt.Errorf("error reading body from http response: %w", err) - } - if resp.StatusCode != 200 { - return fmt.Errorf("got %d %s", resp.StatusCode, body) - } - return json.Unmarshal(body, v) -} - -// RequestUnparsedResponse performs a GET and returns the raw response object -func RequestUnparsedResponse(ctx context.Context, url string, header http.Header) (resp *http.Response, err error) { - req, err := http.NewRequestWithContext(ctx, "GET", url, nil) - if err != nil { - return nil, fmt.Errorf("error performing get request: %w", err) - } - req.Header = header - - return http.DefaultClient.Do(req) -} diff --git a/pkg/requests/requests_test.go b/pkg/requests/requests_test.go deleted file mode 100644 index 0c3e4152..00000000 --- a/pkg/requests/requests_test.go +++ /dev/null @@ -1,136 +0,0 @@ -package requests - -import ( - "context" - "io/ioutil" - "net/http" - "net/http/httptest" - "strings" - "testing" - - "github.com/bitly/go-simplejson" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func testBackend(t *testing.T, responseCode int, payload string) *httptest.Server { - return httptest.NewServer(http.HandlerFunc( - func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(responseCode) - _, err := w.Write([]byte(payload)) - require.NoError(t, err) - })) -} - -func TestRequest(t *testing.T) { - backend := testBackend(t, 200, "{\"foo\": \"bar\"}") - defer backend.Close() - - req, _ := http.NewRequest("GET", backend.URL, nil) - response, err := Request(req) - assert.Equal(t, nil, err) - result, err := response.Get("foo").String() - assert.Equal(t, nil, err) - assert.Equal(t, "bar", result) -} - -func TestRequestFailure(t *testing.T) { - // Create a backend to generate a test URL, then close it to cause a - // connection error. - backend := testBackend(t, 200, "{\"foo\": \"bar\"}") - backend.Close() - - req, err := http.NewRequest("GET", backend.URL, nil) - assert.Equal(t, nil, err) - resp, err := Request(req) - assert.Equal(t, (*simplejson.Json)(nil), resp) - assert.NotEqual(t, nil, err) - if !strings.Contains(err.Error(), "refused") { - t.Error("expected error when a connection fails: ", err) - } -} - -func TestHttpErrorCode(t *testing.T) { - backend := testBackend(t, 404, "{\"foo\": \"bar\"}") - defer backend.Close() - - req, err := http.NewRequest("GET", backend.URL, nil) - assert.Equal(t, nil, err) - resp, err := Request(req) - assert.Equal(t, (*simplejson.Json)(nil), resp) - assert.NotEqual(t, nil, err) -} - -func TestJsonParsingError(t *testing.T) { - backend := testBackend(t, 200, "not well-formed JSON") - defer backend.Close() - - req, err := http.NewRequest("GET", backend.URL, nil) - assert.Equal(t, nil, err) - resp, err := Request(req) - assert.Equal(t, (*simplejson.Json)(nil), resp) - assert.NotEqual(t, nil, err) -} - -// Parsing a URL practically never fails, so we won't cover that test case. -func TestRequestUnparsedResponseUsingAccessTokenParameter(t *testing.T) { - backend := httptest.NewServer(http.HandlerFunc( - func(w http.ResponseWriter, r *http.Request) { - token := r.FormValue("access_token") - if r.URL.Path == "/" && token == "my_token" { - w.WriteHeader(200) - _, err := w.Write([]byte("some payload")) - require.NoError(t, err) - } else { - w.WriteHeader(403) - } - })) - defer backend.Close() - - response, err := RequestUnparsedResponse( - context.Background(), backend.URL+"?access_token=my_token", nil) - assert.Equal(t, nil, err) - defer response.Body.Close() - - assert.Equal(t, 200, response.StatusCode) - body, err := ioutil.ReadAll(response.Body) - assert.Equal(t, nil, err) - assert.Equal(t, "some payload", string(body)) -} - -func TestRequestUnparsedResponseUsingAccessTokenParameterFailedResponse(t *testing.T) { - backend := testBackend(t, 200, "some payload") - // Close the backend now to force a request failure. - backend.Close() - - response, err := RequestUnparsedResponse( - context.Background(), backend.URL+"?access_token=my_token", nil) - assert.NotEqual(t, nil, err) - assert.Equal(t, (*http.Response)(nil), response) -} - -func TestRequestUnparsedResponseUsingHeaders(t *testing.T) { - backend := httptest.NewServer(http.HandlerFunc( - func(w http.ResponseWriter, r *http.Request) { - if r.URL.Path == "/" && r.Header["Auth"][0] == "my_token" { - w.WriteHeader(200) - _, err := w.Write([]byte("some payload")) - require.NoError(t, err) - } else { - w.WriteHeader(403) - } - })) - defer backend.Close() - - headers := make(http.Header) - headers.Set("Auth", "my_token") - response, err := RequestUnparsedResponse(context.Background(), backend.URL, headers) - assert.Equal(t, nil, err) - defer response.Body.Close() - - assert.Equal(t, 200, response.StatusCode) - body, err := ioutil.ReadAll(response.Body) - assert.Equal(t, nil, err) - - assert.Equal(t, "some payload", string(body)) -} From 02410d3919f35e9968c94484cb752dae9a9a6357 Mon Sep 17 00:00:00 2001 From: Joel Speed Date: Sun, 5 Jul 2020 09:52:36 +0100 Subject: [PATCH 5/8] Update changelog to add request builder entry --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index e2da49ce..bca36842 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,7 @@ ## Changes since v6.0.0 +- [#660](https://github.com/oauth2-proxy/oauth2-proxy/pull/660) Use builder pattern to simplify requests to external endpoints (@JoelSpeed) - [#591](https://github.com/oauth2-proxy/oauth2-proxy/pull/591) Introduce upstream package with new reverse proxy implementation (@JoelSpeed) - [#576](https://github.com/oauth2-proxy/oauth2-proxy/pull/576) Separate Cookie validation out of main options validation (@JoelSpeed) - [#656](https://github.com/oauth2-proxy/oauth2-proxy/pull/656) Split long session cookies more precisely (@NickMeves) From fbf406324523601e0f233e9ac75e6f1c5250ae89 Mon Sep 17 00:00:00 2001 From: Joel Speed Date: Mon, 6 Jul 2020 17:32:15 +0100 Subject: [PATCH 6/8] Switch Builder.Do() to return a Result --- pkg/requests/builder.go | 101 ++++++++--------------------------- pkg/requests/builder_test.go | 42 ++++++--------- pkg/requests/result.go | 98 +++++++++++++++++++++++++++++++++ 3 files changed, 138 insertions(+), 103 deletions(-) create mode 100644 pkg/requests/result.go diff --git a/pkg/requests/builder.go b/pkg/requests/builder.go index 2e1a358e..95d88101 100644 --- a/pkg/requests/builder.go +++ b/pkg/requests/builder.go @@ -2,27 +2,23 @@ package requests import ( "context" - "encoding/json" "fmt" "io" "io/ioutil" "net/http" - - "github.com/bitly/go-simplejson" ) -// Builder allows users to construct a request and then either get the requests -// response via Do(), parse the response into a simplejson.Json via JSON(), -// or to parse the json response into an object via UnmarshalInto(). +// Builder allows users to construct a request and then execute the +// request via Do(). +// Do returns a Result which allows the user to get the body, +// unmarshal the body into an interface, or into a simplejson.Json. type Builder interface { WithContext(context.Context) Builder WithBody(io.Reader) Builder WithMethod(string) Builder WithHeaders(http.Header) Builder SetHeader(key, value string) Builder - Do() (*http.Response, error) - UnmarshalInto(interface{}) error - UnmarshalJSON() (*simplejson.Json, error) + Do() Result } type builder struct { @@ -31,7 +27,7 @@ type builder struct { endpoint string body io.Reader header http.Header - response *http.Response + result *result } // New provides a new Builder for the given endpoint. @@ -80,10 +76,10 @@ func (r *builder) SetHeader(key, value string) Builder { // Do performs the request and returns the response in its raw form. // If the request has already been performed, returns the previous result. // This will not allow you to repeat a request. -func (r *builder) Do() (*http.Response, error) { - if r.response != nil { +func (r *builder) Do() Result { + if r.result != nil { // Request has already been done - return r.response, nil + return r.result } // Must provide a non-nil context to NewRequestWithContext @@ -91,83 +87,32 @@ func (r *builder) Do() (*http.Response, error) { r.context = context.Background() } + return r.do() +} + +// do creates the request, executes it with the default client and extracts the +// the body into the response +func (r *builder) do() Result { req, err := http.NewRequestWithContext(r.context, r.method, r.endpoint, r.body) if err != nil { - return nil, fmt.Errorf("error creating request: %v", err) + r.result = &result{err: fmt.Errorf("error creating request: %v", err)} + return r.result } req.Header = r.header resp, err := http.DefaultClient.Do(req) if err != nil { - return nil, fmt.Errorf("error performing request: %v", err) + r.result = &result{err: fmt.Errorf("error performing request: %v", err)} + return r.result } - r.response = resp - return resp, nil -} - -// UnmarshalInto performs the request and attempts to unmarshal the response into the -// the given interface. The response body is assumed to be JSON. -// The response must have a 200 status otherwise an error will be returned. -func (r *builder) UnmarshalInto(into interface{}) error { - resp, err := r.Do() - if err != nil { - return err - } - - return UnmarshalInto(resp, into) -} - -// UnmarshalJSON performs the request and attempts to unmarshal the response into a -// simplejson.Json. The response body is assume to be JSON. -// The response must have a 200 status otherwise an error will be returned. -func (r *builder) UnmarshalJSON() (*simplejson.Json, error) { - resp, err := r.Do() - if err != nil { - return nil, err - } - - body, err := getResponseBody(resp) - if err != nil { - return nil, err - } - - data, err := simplejson.NewJson(body) - if err != nil { - return nil, fmt.Errorf("error reading json: %v", err) - } - return data, nil -} - -// UnmarshalInto attempts to unmarshal the response into the the given interface. -// The response body is assumed to be JSON. -// The response must have a 200 status otherwise an error will be returned. -func UnmarshalInto(resp *http.Response, into interface{}) error { - body, err := getResponseBody(resp) - if err != nil { - return err - } - - if err := json.Unmarshal(body, into); err != nil { - return fmt.Errorf("error unmarshalling body: %v", err) - } - - return nil -} - -// getResponseBody extracts the response body, but will only return the body -// if the response was successful. -func getResponseBody(resp *http.Response) ([]byte, error) { defer resp.Body.Close() body, err := ioutil.ReadAll(resp.Body) if err != nil { - return nil, fmt.Errorf("error reading response body: %v", err) + r.result = &result{err: fmt.Errorf("error reading response body: %v", err)} + return r.result } - // Only unmarshal body if the response was successful - if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("unexpected status \"%d\": %s", resp.StatusCode, body) - } - - return body, nil + r.result = &result{response: resp, body: body} + return r.result } diff --git a/pkg/requests/builder_test.go b/pkg/requests/builder_test.go index b859fa23..0c0f0d03 100644 --- a/pkg/requests/builder_test.go +++ b/pkg/requests/builder_test.go @@ -6,7 +6,6 @@ import ( "encoding/base64" "encoding/json" "fmt" - "io/ioutil" "net/http" "github.com/bitly/go-simplejson" @@ -215,8 +214,8 @@ var _ = Describe("Builder suite", func() { Context("if the request has been completed and then modified", func() { BeforeEach(func() { - _, err := b.Do() - Expect(err).ToNot(HaveOccurred()) + result := b.Do() + Expect(result.Error()).ToNot(HaveOccurred()) b.WithMethod("POST") }) @@ -250,25 +249,20 @@ var _ = Describe("Builder suite", func() { func assertSuccessfulRequest(builder func() Builder, expectedRequest testHTTPRequest) { Context("Do", func() { - var resp *http.Response + var result Result BeforeEach(func() { - var err error - resp, err = builder().Do() - Expect(err).ToNot(HaveOccurred()) + result = builder().Do() + Expect(result.Error()).ToNot(HaveOccurred()) }) It("returns a successful status", func() { - Expect(resp.StatusCode).To(Equal(http.StatusOK)) + Expect(result.StatusCode()).To(Equal(http.StatusOK)) }) It("made the expected request", func() { - body, err := ioutil.ReadAll(resp.Body) - Expect(err).ToNot(HaveOccurred()) - resp.Body.Close() - actualRequest := testHTTPRequest{} - Expect(json.Unmarshal(body, &actualRequest)).To(Succeed()) + Expect(json.Unmarshal(result.Body(), &actualRequest)).To(Succeed()) Expect(actualRequest).To(Equal(expectedRequest)) }) @@ -278,7 +272,7 @@ func assertSuccessfulRequest(builder func() Builder, expectedRequest testHTTPReq var actualRequest testHTTPRequest BeforeEach(func() { - Expect(builder().UnmarshalInto(&actualRequest)).To(Succeed()) + Expect(builder().Do().UnmarshalInto(&actualRequest)).To(Succeed()) }) It("made the expected request", func() { @@ -291,7 +285,7 @@ func assertSuccessfulRequest(builder func() Builder, expectedRequest testHTTPReq BeforeEach(func() { var err error - response, err = builder().UnmarshalJSON() + response, err = builder().Do().UnmarshalJSON() Expect(err).ToNot(HaveOccurred()) }) @@ -328,16 +322,15 @@ func assertSuccessfulRequest(builder func() Builder, expectedRequest testHTTPReq func assertRequestError(builder func() Builder, errorMessage string) { Context("Do", func() { It("returns an error", func() { - resp, err := builder().Do() - Expect(err).To(MatchError(ContainSubstring(errorMessage))) - Expect(resp).To(BeNil()) + result := builder().Do() + Expect(result.Error()).To(MatchError(ContainSubstring(errorMessage))) }) }) Context("UnmarshalInto", func() { It("returns an error", func() { var actualRequest testHTTPRequest - err := builder().UnmarshalInto(&actualRequest) + err := builder().Do().UnmarshalInto(&actualRequest) Expect(err).To(MatchError(ContainSubstring(errorMessage))) // Should be empty @@ -347,7 +340,7 @@ func assertRequestError(builder func() Builder, errorMessage string) { Context("UnmarshalJSON", func() { It("returns an error", func() { - resp, err := builder().UnmarshalJSON() + resp, err := builder().Do().UnmarshalJSON() Expect(err).To(MatchError(ContainSubstring(errorMessage))) Expect(resp).To(BeNil()) }) @@ -357,16 +350,15 @@ func assertRequestError(builder func() Builder, errorMessage string) { func assertJSONError(builder func() Builder, errorMessage string) { Context("Do", func() { It("does not return an error", func() { - resp, err := builder().Do() - Expect(err).To(BeNil()) - Expect(resp).ToNot(BeNil()) + result := builder().Do() + Expect(result.Error()).To(BeNil()) }) }) Context("UnmarshalInto", func() { It("returns an error", func() { var actualRequest testHTTPRequest - err := builder().UnmarshalInto(&actualRequest) + err := builder().Do().UnmarshalInto(&actualRequest) Expect(err).To(MatchError(ContainSubstring(errorMessage))) // Should be empty @@ -376,7 +368,7 @@ func assertJSONError(builder func() Builder, errorMessage string) { Context("UnmarshalJSON", func() { It("returns an error", func() { - resp, err := builder().UnmarshalJSON() + resp, err := builder().Do().UnmarshalJSON() Expect(err).To(MatchError(ContainSubstring(errorMessage))) Expect(resp).To(BeNil()) }) diff --git a/pkg/requests/result.go b/pkg/requests/result.go new file mode 100644 index 00000000..2574aad5 --- /dev/null +++ b/pkg/requests/result.go @@ -0,0 +1,98 @@ +package requests + +import ( + "encoding/json" + "fmt" + "net/http" + + "github.com/bitly/go-simplejson" +) + +// Result is the result of a request created by a Builder +type Result interface { + Error() error + StatusCode() int + Headers() http.Header + Body() []byte + UnmarshalInto(interface{}) error + UnmarshalJSON() (*simplejson.Json, error) +} + +type result struct { + err error + response *http.Response + body []byte +} + +// Error returns an error from the result if present +func (r *result) Error() error { + return r.err +} + +// StatusCode returns the response's status code +func (r *result) StatusCode() int { + if r.response != nil { + return r.response.StatusCode + } + return 0 +} + +// Headers returns the response's headers +func (r *result) Headers() http.Header { + if r.response != nil { + return r.response.Header + } + return nil +} + +// Body returns the response's body +func (r *result) Body() []byte { + return r.body +} + +// UnmarshalInto attempts to unmarshal the response into the the given interface. +// The response body is assumed to be JSON. +// The response must have a 200 status otherwise an error will be returned. +func (r *result) UnmarshalInto(into interface{}) error { + body, err := r.getBodyForUnmarshal() + if err != nil { + return err + } + + if err := json.Unmarshal(body, into); err != nil { + return fmt.Errorf("error unmarshalling body: %v", err) + } + + return nil +} + +// UnmarshalJSON performs the request and attempts to unmarshal the response into a +// simplejson.Json. The response body is assume to be JSON. +// The response must have a 200 status otherwise an error will be returned. +func (r *result) UnmarshalJSON() (*simplejson.Json, error) { + body, err := r.getBodyForUnmarshal() + if err != nil { + return nil, err + } + + data, err := simplejson.NewJson(body) + if err != nil { + return nil, fmt.Errorf("error reading json: %v", err) + } + return data, nil +} + +// getBodyForUnmarshal returns the body if there wasn't an error and the status +// code was 200. +func (r *result) getBodyForUnmarshal() ([]byte, error) { + if r.Error() != nil { + return nil, r.Error() + } + + // Only unmarshal body if the response was successful + if r.StatusCode() != http.StatusOK { + return nil, fmt.Errorf("unexpected status \"%d\": %s", r.StatusCode(), r.Body()) + } + + return r.Body(), nil +} From d0b6c0496084d10bf7a667937598f1db8f2405e4 Mon Sep 17 00:00:00 2001 From: Joel Speed Date: Mon, 6 Jul 2020 18:26:35 +0100 Subject: [PATCH 7/8] Add tests for request result --- pkg/requests/result_test.go | 326 ++++++++++++++++++++++++++++++++++++ 1 file changed, 326 insertions(+) create mode 100644 pkg/requests/result_test.go diff --git a/pkg/requests/result_test.go b/pkg/requests/result_test.go new file mode 100644 index 00000000..2d5c95ca --- /dev/null +++ b/pkg/requests/result_test.go @@ -0,0 +1,326 @@ +package requests + +import ( + "errors" + "net/http" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/ginkgo/extensions/table" + . "github.com/onsi/gomega" +) + +var _ = Describe("Result suite", func() { + Context("with a result", func() { + type resultTableInput struct { + result Result + expectedError error + expectedStatusCode int + expectedHeaders http.Header + expectedBody []byte + } + + DescribeTable("accessors should return expected results", + func(in resultTableInput) { + if in.expectedError != nil { + Expect(in.result.Error()).To(MatchError(in.expectedError)) + } else { + Expect(in.result.Error()).To(BeNil()) + } + + Expect(in.result.StatusCode()).To(Equal(in.expectedStatusCode)) + Expect(in.result.Headers()).To(Equal(in.expectedHeaders)) + Expect(in.result.Body()).To(Equal(in.expectedBody)) + }, + Entry("with an empty result", resultTableInput{ + result: &result{}, + expectedError: nil, + expectedStatusCode: 0, + expectedHeaders: nil, + expectedBody: nil, + }), + Entry("with an error", resultTableInput{ + result: &result{ + err: errors.New("error"), + }, + expectedError: errors.New("error"), + expectedStatusCode: 0, + expectedHeaders: nil, + expectedBody: nil, + }), + Entry("with a response with no headers", resultTableInput{ + result: &result{ + response: &http.Response{ + StatusCode: http.StatusTeapot, + }, + }, + expectedError: nil, + expectedStatusCode: http.StatusTeapot, + expectedHeaders: nil, + expectedBody: nil, + }), + Entry("with a response with no status code", resultTableInput{ + result: &result{ + response: &http.Response{ + Header: http.Header{ + "foo": []string{"bar"}, + }, + }, + }, + expectedError: nil, + expectedStatusCode: 0, + expectedHeaders: http.Header{ + "foo": []string{"bar"}, + }, + expectedBody: nil, + }), + Entry("with a response with a body", resultTableInput{ + result: &result{ + body: []byte("some body"), + }, + expectedError: nil, + expectedStatusCode: 0, + expectedHeaders: nil, + expectedBody: []byte("some body"), + }), + Entry("with all fields", resultTableInput{ + result: &result{ + err: errors.New("some error"), + response: &http.Response{ + StatusCode: http.StatusFound, + Header: http.Header{ + "header": []string{"value"}, + }, + }, + body: []byte("a body"), + }, + expectedError: errors.New("some error"), + expectedStatusCode: http.StatusFound, + expectedHeaders: http.Header{ + "header": []string{"value"}, + }, + expectedBody: []byte("a body"), + }), + ) + }) + + Context("UnmarshalInto", func() { + type testStruct struct { + A string `json:"a"` + B int `json:"b"` + } + + type unmarshalIntoTableInput struct { + result Result + expectedErr error + expectedOutput *testStruct + } + + DescribeTable("with a result", + func(in unmarshalIntoTableInput) { + input := &testStruct{} + err := in.result.UnmarshalInto(input) + if in.expectedErr != nil { + Expect(err).To(MatchError(in.expectedErr)) + } else { + Expect(err).ToNot(HaveOccurred()) + } + Expect(input).To(Equal(in.expectedOutput)) + }, + Entry("with an error", unmarshalIntoTableInput{ + result: &result{ + err: errors.New("got an error"), + response: &http.Response{ + StatusCode: http.StatusOK, + }, + body: []byte("{\"a\": \"foo\"}"), + }, + expectedErr: errors.New("got an error"), + expectedOutput: &testStruct{}, + }), + Entry("with a 409 status code", unmarshalIntoTableInput{ + result: &result{ + err: nil, + response: &http.Response{ + StatusCode: http.StatusConflict, + }, + body: []byte("{\"a\": \"foo\"}"), + }, + expectedErr: errors.New("unexpected status \"409\": {\"a\": \"foo\"}"), + expectedOutput: &testStruct{}, + }), + Entry("when the response has a valid json response", unmarshalIntoTableInput{ + result: &result{ + err: nil, + response: &http.Response{ + StatusCode: http.StatusOK, + }, + body: []byte("{\"a\": \"foo\", \"b\": 1}"), + }, + expectedErr: nil, + expectedOutput: &testStruct{A: "foo", B: 1}, + }), + Entry("when the response body is empty", unmarshalIntoTableInput{ + result: &result{ + err: nil, + response: &http.Response{ + StatusCode: http.StatusOK, + }, + body: []byte(""), + }, + expectedErr: errors.New("error unmarshalling body: unexpected end of JSON input"), + expectedOutput: &testStruct{}, + }), + Entry("when the response body is not json", unmarshalIntoTableInput{ + result: &result{ + err: nil, + response: &http.Response{ + StatusCode: http.StatusOK, + }, + body: []byte("not json"), + }, + expectedErr: errors.New("error unmarshalling body: invalid character 'o' in literal null (expecting 'u')"), + expectedOutput: &testStruct{}, + }), + ) + }) + + Context("UnmarshalJSON", func() { + type testStruct struct { + A string `json:"a"` + B int `json:"b"` + } + + type unmarshalJSONTableInput struct { + result Result + expectedErr error + expectedOutput *testStruct + } + + DescribeTable("with a result", + func(in unmarshalJSONTableInput) { + j, err := in.result.UnmarshalJSON() + if in.expectedErr != nil { + Expect(err).To(MatchError(in.expectedErr)) + Expect(j).To(BeNil()) + return + } + + // No error so j should not be nil + Expect(err).ToNot(HaveOccurred()) + + input := &testStruct{ + A: j.Get("a").MustString(), + B: j.Get("b").MustInt(), + } + Expect(input).To(Equal(in.expectedOutput)) + }, + Entry("with an error", unmarshalJSONTableInput{ + result: &result{ + err: errors.New("got an error"), + response: &http.Response{ + StatusCode: http.StatusOK, + }, + body: []byte("{\"a\": \"foo\"}"), + }, + expectedErr: errors.New("got an error"), + expectedOutput: &testStruct{}, + }), + Entry("with a 409 status code", unmarshalJSONTableInput{ + result: &result{ + err: nil, + response: &http.Response{ + StatusCode: http.StatusConflict, + }, + body: []byte("{\"a\": \"foo\"}"), + }, + expectedErr: errors.New("unexpected status \"409\": {\"a\": \"foo\"}"), + expectedOutput: &testStruct{}, + }), + Entry("when the response has a valid json response", unmarshalJSONTableInput{ + result: &result{ + err: nil, + response: &http.Response{ + StatusCode: http.StatusOK, + }, + body: []byte("{\"a\": \"foo\", \"b\": 1}"), + }, + expectedErr: nil, + expectedOutput: &testStruct{A: "foo", B: 1}, + }), + Entry("when the response body is empty", unmarshalJSONTableInput{ + result: &result{ + err: nil, + response: &http.Response{ + StatusCode: http.StatusOK, + }, + body: []byte(""), + }, + expectedErr: errors.New("error reading json: EOF"), + expectedOutput: &testStruct{}, + }), + Entry("when the response body is not json", unmarshalJSONTableInput{ + result: &result{ + err: nil, + response: &http.Response{ + StatusCode: http.StatusOK, + }, + body: []byte("not json"), + }, + expectedErr: errors.New("error reading json: invalid character 'o' in literal null (expecting 'u')"), + expectedOutput: &testStruct{}, + }), + ) + }) + + Context("getBodyForUnmarshal", func() { + type getBodyForUnmarshalTableInput struct { + result *result + expectedErr error + expectedBody []byte + } + + DescribeTable("when getting the body", func(in getBodyForUnmarshalTableInput) { + body, err := in.result.getBodyForUnmarshal() + if in.expectedErr != nil { + Expect(err).To(MatchError(in.expectedErr)) + } else { + Expect(err).ToNot(HaveOccurred()) + } + Expect(body).To(Equal(in.expectedBody)) + }, + Entry("when the result has an error", getBodyForUnmarshalTableInput{ + result: &result{ + err: errors.New("got an error"), + response: &http.Response{ + StatusCode: http.StatusOK, + }, + body: []byte("body"), + }, + expectedErr: errors.New("got an error"), + expectedBody: nil, + }), + Entry("when the response has a 409 status code", getBodyForUnmarshalTableInput{ + result: &result{ + err: nil, + response: &http.Response{ + StatusCode: http.StatusConflict, + }, + body: []byte("body"), + }, + expectedErr: errors.New("unexpected status \"409\": body"), + expectedBody: nil, + }), + Entry("when the response has a 200 status code", getBodyForUnmarshalTableInput{ + result: &result{ + err: nil, + response: &http.Response{ + StatusCode: http.StatusOK, + }, + body: []byte("body"), + }, + expectedErr: nil, + expectedBody: []byte("body"), + }), + ) + }) +}) From de9e65a63adf43778a5478c9e331a8fcd71b7f56 Mon Sep 17 00:00:00 2001 From: Joel Speed Date: Mon, 6 Jul 2020 17:42:26 +0100 Subject: [PATCH 8/8] Migrate all requests to result pattern --- pkg/validation/options.go | 5 ++--- providers/azure.go | 2 ++ providers/bitbucket.go | 3 +++ providers/digitalocean.go | 1 + providers/facebook.go | 1 + providers/github.go | 34 ++++++++++++++++------------------ providers/gitlab.go | 1 + providers/google.go | 2 ++ providers/internal_util.go | 15 ++++++--------- providers/keycloak.go | 1 + providers/linkedin.go | 1 + providers/logingov.go | 2 ++ providers/nextcloud.go | 1 + providers/oidc.go | 1 + providers/provider_default.go | 26 ++++++-------------------- 15 files changed, 46 insertions(+), 50 deletions(-) diff --git a/pkg/validation/options.go b/pkg/validation/options.go index ae2ed065..301c5e90 100644 --- a/pkg/validation/options.go +++ b/pkg/validation/options.go @@ -86,6 +86,7 @@ func Validate(o *options.Options) error { requestURL := strings.TrimSuffix(o.OIDCIssuerURL, "/") + "/.well-known/openid-configuration" body, err := requests.New(requestURL). WithContext(ctx). + Do(). UnmarshalJSON() if err != nil { logger.Printf("error: failed to discover OIDC configuration: %v", err) @@ -384,11 +385,9 @@ func newVerifierFromJwtIssuer(jwtIssuer jwtIssuer) (*oidc.IDTokenVerifier, error if err != nil { // Try as JWKS URI jwksURI := strings.TrimSuffix(jwtIssuer.issuerURI, "/") + "/.well-known/jwks.json" - resp, err := requests.New(jwksURI).Do() - if err != nil { + if err := requests.New(jwksURI).Do().Error(); err != nil { return nil, err } - resp.Body.Close() verifier = oidc.NewVerifier(jwtIssuer.issuerURI, oidc.NewRemoteKeySet(context.Background(), jwksURI), config) } else { diff --git a/providers/azure.go b/providers/azure.go index 5e8df0a0..b38c1cc7 100644 --- a/providers/azure.go +++ b/providers/azure.go @@ -101,6 +101,7 @@ func (p *AzureProvider) Redeem(ctx context.Context, redirectURL, code string) (s WithMethod("POST"). WithBody(bytes.NewBufferString(params.Encode())). SetHeader("Content-Type", "application/x-www-form-urlencoded"). + Do(). UnmarshalInto(&jsonResponse) if err != nil { return nil, err @@ -153,6 +154,7 @@ func (p *AzureProvider) GetEmailAddress(ctx context.Context, s *sessions.Session json, err := requests.New(p.ProfileURL.String()). WithContext(ctx). WithHeaders(getAzureHeader(s.AccessToken)). + Do(). UnmarshalJSON() if err != nil { return "", err diff --git a/providers/bitbucket.go b/providers/bitbucket.go index 726d3562..ffc52c79 100644 --- a/providers/bitbucket.go +++ b/providers/bitbucket.go @@ -88,6 +88,7 @@ func (p *BitbucketProvider) GetEmailAddress(ctx context.Context, s *sessions.Ses requestURL := p.ValidateURL.String() + "?access_token=" + s.AccessToken err := requests.New(requestURL). WithContext(ctx). + Do(). UnmarshalInto(&emails) if err != nil { logger.Printf("failed making request: %v", err) @@ -103,6 +104,7 @@ func (p *BitbucketProvider) GetEmailAddress(ctx context.Context, s *sessions.Ses err := requests.New(requestURL). WithContext(ctx). + Do(). UnmarshalInto(&teams) if err != nil { logger.Printf("failed requesting teams membership: %v", err) @@ -132,6 +134,7 @@ func (p *BitbucketProvider) GetEmailAddress(ctx context.Context, s *sessions.Ses err := requests.New(requestURL). WithContext(ctx). + Do(). UnmarshalInto(&repositories) if err != nil { logger.Printf("failed checking repository access: %v", err) diff --git a/providers/digitalocean.go b/providers/digitalocean.go index 306646d5..27ac60d0 100644 --- a/providers/digitalocean.go +++ b/providers/digitalocean.go @@ -64,6 +64,7 @@ func (p *DigitalOceanProvider) GetEmailAddress(ctx context.Context, s *sessions. json, err := requests.New(p.ProfileURL.String()). WithContext(ctx). WithHeaders(getDigitalOceanHeader(s.AccessToken)). + Do(). UnmarshalJSON() if err != nil { return "", err diff --git a/providers/facebook.go b/providers/facebook.go index 81973416..d3d123f2 100644 --- a/providers/facebook.go +++ b/providers/facebook.go @@ -72,6 +72,7 @@ func (p *FacebookProvider) GetEmailAddress(ctx context.Context, s *sessions.Sess err := requests.New(requestURL). WithContext(ctx). WithHeaders(getFacebookHeader(s.AccessToken)). + Do(). UnmarshalInto(&r) if err != nil { return "", err diff --git a/providers/github.go b/providers/github.go index e93a23e1..014ae3cb 100644 --- a/providers/github.go +++ b/providers/github.go @@ -4,7 +4,6 @@ import ( "context" "errors" "fmt" - "io/ioutil" "net/http" "net/url" "path" @@ -116,6 +115,7 @@ func (p *GitHubProvider) hasOrg(ctx context.Context, accessToken string) (bool, err := requests.New(endpoint.String()). WithContext(ctx). WithHeaders(getGitHubHeader(accessToken)). + Do(). UnmarshalInto(&op) if err != nil { return false, err @@ -179,12 +179,12 @@ func (p *GitHubProvider) hasOrgAndTeam(ctx context.Context, accessToken string) // bodyclose cannot detect that the body is being closed later in requests.Into, // so have to skip the linting for the next line. // nolint:bodyclose - resp, err := requests.New(endpoint.String()). + result := requests.New(endpoint.String()). WithContext(ctx). WithHeaders(getGitHubHeader(accessToken)). Do() - if err != nil { - return false, err + if result.Error() != nil { + return false, result.Error() } if last == 0 { @@ -200,7 +200,7 @@ func (p *GitHubProvider) hasOrgAndTeam(ctx context.Context, accessToken string) // link header at last page (doesn't exist last info) // ; rel="prev", ; rel="first" - link := resp.Header.Get("Link") + link := result.Headers().Get("Link") rep1 := regexp.MustCompile(`(?s).*\; rel="last".*`) i, converr := strconv.Atoi(rep1.ReplaceAllString(link, "$1")) @@ -211,7 +211,7 @@ func (p *GitHubProvider) hasOrgAndTeam(ctx context.Context, accessToken string) } var tp teamsPage - if err := requests.UnmarshalInto(resp, &tp); err != nil { + if err := result.UnmarshalInto(&tp); err != nil { return false, err } if len(tp) == 0 { @@ -282,6 +282,7 @@ func (p *GitHubProvider) hasRepo(ctx context.Context, accessToken string) (bool, err := requests.New(endpoint.String()). WithContext(ctx). WithHeaders(getGitHubHeader(accessToken)). + Do(). UnmarshalInto(&repo) if err != nil { return false, err @@ -309,6 +310,7 @@ func (p *GitHubProvider) hasUser(ctx context.Context, accessToken string) (bool, err := requests.New(endpoint.String()). WithContext(ctx). WithHeaders(getGitHubHeader(accessToken)). + Do(). UnmarshalInto(&user) if err != nil { return false, err @@ -328,26 +330,20 @@ func (p *GitHubProvider) isCollaborator(ctx context.Context, username, accessTok Host: p.ValidateURL.Host, Path: path.Join(p.ValidateURL.Path, "/repos/", p.Repo, "/collaborators/", username), } - resp, err := requests.New(endpoint.String()). + result := requests.New(endpoint.String()). WithContext(ctx). WithHeaders(getGitHubHeader(accessToken)). Do() - if err != nil { - return false, err + if result.Error() != nil { + return false, result.Error() } - body, err := ioutil.ReadAll(resp.Body) - resp.Body.Close() - if err != nil { - return false, err - } - - if resp.StatusCode != 204 { + if result.StatusCode() != 204 { return false, fmt.Errorf("got %d from %q %s", - resp.StatusCode, endpoint.String(), body) + result.StatusCode(), endpoint.String(), result.Body()) } - logger.Printf("got %d from %q %s", resp.StatusCode, endpoint.String(), body) + logger.Printf("got %d from %q %s", result.StatusCode(), endpoint.String(), result.Body()) return true, nil } @@ -401,6 +397,7 @@ func (p *GitHubProvider) GetEmailAddress(ctx context.Context, s *sessions.Sessio err := requests.New(endpoint.String()). WithContext(ctx). WithHeaders(getGitHubHeader(s.AccessToken)). + Do(). UnmarshalInto(&emails) if err != nil { return "", err @@ -435,6 +432,7 @@ func (p *GitHubProvider) GetUserName(ctx context.Context, s *sessions.SessionSta err := requests.New(endpoint.String()). WithContext(ctx). WithHeaders(getGitHubHeader(s.AccessToken)). + Do(). UnmarshalInto(&user) if err != nil { return "", err diff --git a/providers/gitlab.go b/providers/gitlab.go index 17c06f5e..8c1e1534 100644 --- a/providers/gitlab.go +++ b/providers/gitlab.go @@ -133,6 +133,7 @@ func (p *GitLabProvider) getUserInfo(ctx context.Context, s *sessions.SessionSta err := requests.New(userInfoURL.String()). WithContext(ctx). SetHeader("Authorization", "Bearer "+s.AccessToken). + Do(). UnmarshalInto(&userInfo) if err != nil { return nil, fmt.Errorf("error getting user info: %v", err) diff --git a/providers/google.go b/providers/google.go index fbed12f1..af2eebf3 100644 --- a/providers/google.go +++ b/providers/google.go @@ -129,6 +129,7 @@ func (p *GoogleProvider) Redeem(ctx context.Context, redirectURL, code string) ( WithMethod("POST"). WithBody(bytes.NewBufferString(params.Encode())). SetHeader("Content-Type", "application/x-www-form-urlencoded"). + Do(). UnmarshalInto(&jsonResponse) if err != nil { return nil, err @@ -280,6 +281,7 @@ func (p *GoogleProvider) redeemRefreshToken(ctx context.Context, refreshToken st WithMethod("POST"). WithBody(bytes.NewBufferString(params.Encode())). SetHeader("Content-Type", "application/x-www-form-urlencoded"). + Do(). UnmarshalInto(&data) if err != nil { return "", "", 0, err diff --git a/providers/internal_util.go b/providers/internal_util.go index 361e27c7..42948408 100644 --- a/providers/internal_util.go +++ b/providers/internal_util.go @@ -2,7 +2,6 @@ package providers import ( "context" - "io/ioutil" "net/http" "net/url" @@ -57,23 +56,21 @@ func validateToken(ctx context.Context, p Provider, accessToken string, header h endpoint = endpoint + "?" + params.Encode() } - resp, err := requests.New(endpoint). + result := requests.New(endpoint). WithContext(ctx). WithHeaders(header). Do() - if err != nil { + if result.Error() != nil { logger.Printf("GET %s", stripToken(endpoint)) - logger.Printf("token validation request failed: %s", err) + logger.Printf("token validation request failed: %s", result.Error()) return false } - body, _ := ioutil.ReadAll(resp.Body) - resp.Body.Close() - logger.Printf("%d GET %s %s", resp.StatusCode, stripToken(endpoint), body) + logger.Printf("%d GET %s %s", result.StatusCode(), stripToken(endpoint), result.Body()) - if resp.StatusCode == 200 { + if result.StatusCode() == 200 { return true } - logger.Printf("token validation request failed: status %d - %s", resp.StatusCode, body) + logger.Printf("token validation request failed: status %d - %s", result.StatusCode(), result.Body()) return false } diff --git a/providers/keycloak.go b/providers/keycloak.go index 78206de8..77efc0c7 100644 --- a/providers/keycloak.go +++ b/providers/keycloak.go @@ -53,6 +53,7 @@ func (p *KeycloakProvider) GetEmailAddress(ctx context.Context, s *sessions.Sess json, err := requests.New(p.ValidateURL.String()). WithContext(ctx). SetHeader("Authorization", "Bearer "+s.AccessToken). + Do(). UnmarshalJSON() if err != nil { logger.Printf("failed making request %s", err) diff --git a/providers/linkedin.go b/providers/linkedin.go index 8dcd3c9d..7328dbbb 100644 --- a/providers/linkedin.go +++ b/providers/linkedin.go @@ -63,6 +63,7 @@ func (p *LinkedInProvider) GetEmailAddress(ctx context.Context, s *sessions.Sess json, err := requests.New(requestURL). WithContext(ctx). WithHeaders(getLinkedInHeader(s.AccessToken)). + Do(). UnmarshalJSON() if err != nil { return "", err diff --git a/providers/logingov.go b/providers/logingov.go index 8846a8f2..79eb1827 100644 --- a/providers/logingov.go +++ b/providers/logingov.go @@ -141,6 +141,7 @@ func emailFromUserInfo(ctx context.Context, accessToken string, userInfoEndpoint err := requests.New(userInfoEndpoint). WithContext(ctx). SetHeader("Authorization", "Bearer "+accessToken). + Do(). UnmarshalInto(&emailData) if err != nil { return "", err @@ -196,6 +197,7 @@ func (p *LoginGovProvider) Redeem(ctx context.Context, redirectURL, code string) WithMethod("POST"). WithBody(bytes.NewBufferString(params.Encode())). SetHeader("Content-Type", "application/x-www-form-urlencoded"). + Do(). UnmarshalInto(&jsonResponse) if err != nil { return nil, err diff --git a/providers/nextcloud.go b/providers/nextcloud.go index 844bbff3..b70fd07c 100644 --- a/providers/nextcloud.go +++ b/providers/nextcloud.go @@ -33,6 +33,7 @@ func (p *NextcloudProvider) GetEmailAddress(ctx context.Context, s *sessions.Ses json, err := requests.New(p.ValidateURL.String()). WithContext(ctx). WithHeaders(getNextcloudHeader(s.AccessToken)). + Do(). UnmarshalJSON() if err != nil { return "", fmt.Errorf("error making request: %v", err) diff --git a/providers/oidc.go b/providers/oidc.go index 08e0c3e8..e456db76 100644 --- a/providers/oidc.go +++ b/providers/oidc.go @@ -259,6 +259,7 @@ func (p *OIDCProvider) findClaimsFromIDToken(ctx context.Context, idToken *oidc. respJSON, err := requests.New(profileURL). WithContext(ctx). WithHeaders(getOIDCHeader(accessToken)). + Do(). UnmarshalJSON() if err != nil { return nil, err diff --git a/providers/provider_default.go b/providers/provider_default.go index f29e8c4a..598b91e8 100644 --- a/providers/provider_default.go +++ b/providers/provider_default.go @@ -3,10 +3,8 @@ package providers import ( "bytes" "context" - "encoding/json" "errors" "fmt" - "io/ioutil" "net/url" "time" @@ -39,33 +37,21 @@ func (p *ProviderData) Redeem(ctx context.Context, redirectURL, code string) (s params.Add("resource", p.ProtectedResource.String()) } - resp, err := requests.New(p.RedeemURL.String()). + result := requests.New(p.RedeemURL.String()). WithContext(ctx). WithMethod("POST"). WithBody(bytes.NewBufferString(params.Encode())). SetHeader("Content-Type", "application/x-www-form-urlencoded"). Do() - if err != nil { - return nil, err - } - - var body []byte - body, err = ioutil.ReadAll(resp.Body) - resp.Body.Close() - if err != nil { - return - } - - if resp.StatusCode != 200 { - err = fmt.Errorf("got %d from %q %s", resp.StatusCode, p.RedeemURL.String(), body) - return + if result.Error() != nil { + return nil, result.Error() } // blindly try json and x-www-form-urlencoded var jsonResponse struct { AccessToken string `json:"access_token"` } - err = json.Unmarshal(body, &jsonResponse) + err = result.UnmarshalInto(&jsonResponse) if err == nil { s = &sessions.SessionState{ AccessToken: jsonResponse.AccessToken, @@ -74,7 +60,7 @@ func (p *ProviderData) Redeem(ctx context.Context, redirectURL, code string) (s } var v url.Values - v, err = url.ParseQuery(string(body)) + v, err = url.ParseQuery(string(result.Body())) if err != nil { return } @@ -82,7 +68,7 @@ func (p *ProviderData) Redeem(ctx context.Context, redirectURL, code string) (s created := time.Now() s = &sessions.SessionState{AccessToken: a, CreatedAt: &created} } else { - err = fmt.Errorf("no access token found %s", body) + err = fmt.Errorf("no access token found %s", result.Body()) } return }