1
0
mirror of https://github.com/oauth2-proxy/oauth2-proxy.git synced 2024-11-28 09:08:44 +02:00

Migrate all requests to result pattern

This commit is contained in:
Joel Speed 2020-07-06 17:42:26 +01:00
parent d0b6c04960
commit de9e65a63a
No known key found for this signature in database
GPG Key ID: 6E80578D6751DEFB
15 changed files with 46 additions and 50 deletions

View File

@ -86,6 +86,7 @@ func Validate(o *options.Options) error {
requestURL := strings.TrimSuffix(o.OIDCIssuerURL, "/") + "/.well-known/openid-configuration"
body, err := requests.New(requestURL).
WithContext(ctx).
Do().
UnmarshalJSON()
if err != nil {
logger.Printf("error: failed to discover OIDC configuration: %v", err)
@ -384,11 +385,9 @@ func newVerifierFromJwtIssuer(jwtIssuer jwtIssuer) (*oidc.IDTokenVerifier, error
if err != nil {
// Try as JWKS URI
jwksURI := strings.TrimSuffix(jwtIssuer.issuerURI, "/") + "/.well-known/jwks.json"
resp, err := requests.New(jwksURI).Do()
if err != nil {
if err := requests.New(jwksURI).Do().Error(); err != nil {
return nil, err
}
resp.Body.Close()
verifier = oidc.NewVerifier(jwtIssuer.issuerURI, oidc.NewRemoteKeySet(context.Background(), jwksURI), config)
} else {

View File

@ -101,6 +101,7 @@ func (p *AzureProvider) Redeem(ctx context.Context, redirectURL, code string) (s
WithMethod("POST").
WithBody(bytes.NewBufferString(params.Encode())).
SetHeader("Content-Type", "application/x-www-form-urlencoded").
Do().
UnmarshalInto(&jsonResponse)
if err != nil {
return nil, err
@ -153,6 +154,7 @@ func (p *AzureProvider) GetEmailAddress(ctx context.Context, s *sessions.Session
json, err := requests.New(p.ProfileURL.String()).
WithContext(ctx).
WithHeaders(getAzureHeader(s.AccessToken)).
Do().
UnmarshalJSON()
if err != nil {
return "", err

View File

@ -88,6 +88,7 @@ func (p *BitbucketProvider) GetEmailAddress(ctx context.Context, s *sessions.Ses
requestURL := p.ValidateURL.String() + "?access_token=" + s.AccessToken
err := requests.New(requestURL).
WithContext(ctx).
Do().
UnmarshalInto(&emails)
if err != nil {
logger.Printf("failed making request: %v", err)
@ -103,6 +104,7 @@ func (p *BitbucketProvider) GetEmailAddress(ctx context.Context, s *sessions.Ses
err := requests.New(requestURL).
WithContext(ctx).
Do().
UnmarshalInto(&teams)
if err != nil {
logger.Printf("failed requesting teams membership: %v", err)
@ -132,6 +134,7 @@ func (p *BitbucketProvider) GetEmailAddress(ctx context.Context, s *sessions.Ses
err := requests.New(requestURL).
WithContext(ctx).
Do().
UnmarshalInto(&repositories)
if err != nil {
logger.Printf("failed checking repository access: %v", err)

View File

@ -64,6 +64,7 @@ func (p *DigitalOceanProvider) GetEmailAddress(ctx context.Context, s *sessions.
json, err := requests.New(p.ProfileURL.String()).
WithContext(ctx).
WithHeaders(getDigitalOceanHeader(s.AccessToken)).
Do().
UnmarshalJSON()
if err != nil {
return "", err

View File

@ -72,6 +72,7 @@ func (p *FacebookProvider) GetEmailAddress(ctx context.Context, s *sessions.Sess
err := requests.New(requestURL).
WithContext(ctx).
WithHeaders(getFacebookHeader(s.AccessToken)).
Do().
UnmarshalInto(&r)
if err != nil {
return "", err

View File

@ -4,7 +4,6 @@ import (
"context"
"errors"
"fmt"
"io/ioutil"
"net/http"
"net/url"
"path"
@ -116,6 +115,7 @@ func (p *GitHubProvider) hasOrg(ctx context.Context, accessToken string) (bool,
err := requests.New(endpoint.String()).
WithContext(ctx).
WithHeaders(getGitHubHeader(accessToken)).
Do().
UnmarshalInto(&op)
if err != nil {
return false, err
@ -179,12 +179,12 @@ func (p *GitHubProvider) hasOrgAndTeam(ctx context.Context, accessToken string)
// bodyclose cannot detect that the body is being closed later in requests.Into,
// so have to skip the linting for the next line.
// nolint:bodyclose
resp, err := requests.New(endpoint.String()).
result := requests.New(endpoint.String()).
WithContext(ctx).
WithHeaders(getGitHubHeader(accessToken)).
Do()
if err != nil {
return false, err
if result.Error() != nil {
return false, result.Error()
}
if last == 0 {
@ -200,7 +200,7 @@ func (p *GitHubProvider) hasOrgAndTeam(ctx context.Context, accessToken string)
// link header at last page (doesn't exist last info)
// <https://api.github.com/user/teams?page=3&per_page=10>; rel="prev", <https://api.github.com/user/teams?page=1&per_page=10>; rel="first"
link := resp.Header.Get("Link")
link := result.Headers().Get("Link")
rep1 := regexp.MustCompile(`(?s).*\<https://api.github.com/user/teams\?page=(.)&per_page=[0-9]+\>; rel="last".*`)
i, converr := strconv.Atoi(rep1.ReplaceAllString(link, "$1"))
@ -211,7 +211,7 @@ func (p *GitHubProvider) hasOrgAndTeam(ctx context.Context, accessToken string)
}
var tp teamsPage
if err := requests.UnmarshalInto(resp, &tp); err != nil {
if err := result.UnmarshalInto(&tp); err != nil {
return false, err
}
if len(tp) == 0 {
@ -282,6 +282,7 @@ func (p *GitHubProvider) hasRepo(ctx context.Context, accessToken string) (bool,
err := requests.New(endpoint.String()).
WithContext(ctx).
WithHeaders(getGitHubHeader(accessToken)).
Do().
UnmarshalInto(&repo)
if err != nil {
return false, err
@ -309,6 +310,7 @@ func (p *GitHubProvider) hasUser(ctx context.Context, accessToken string) (bool,
err := requests.New(endpoint.String()).
WithContext(ctx).
WithHeaders(getGitHubHeader(accessToken)).
Do().
UnmarshalInto(&user)
if err != nil {
return false, err
@ -328,26 +330,20 @@ func (p *GitHubProvider) isCollaborator(ctx context.Context, username, accessTok
Host: p.ValidateURL.Host,
Path: path.Join(p.ValidateURL.Path, "/repos/", p.Repo, "/collaborators/", username),
}
resp, err := requests.New(endpoint.String()).
result := requests.New(endpoint.String()).
WithContext(ctx).
WithHeaders(getGitHubHeader(accessToken)).
Do()
if err != nil {
return false, err
if result.Error() != nil {
return false, result.Error()
}
body, err := ioutil.ReadAll(resp.Body)
resp.Body.Close()
if err != nil {
return false, err
}
if resp.StatusCode != 204 {
if result.StatusCode() != 204 {
return false, fmt.Errorf("got %d from %q %s",
resp.StatusCode, endpoint.String(), body)
result.StatusCode(), endpoint.String(), result.Body())
}
logger.Printf("got %d from %q %s", resp.StatusCode, endpoint.String(), body)
logger.Printf("got %d from %q %s", result.StatusCode(), endpoint.String(), result.Body())
return true, nil
}
@ -401,6 +397,7 @@ func (p *GitHubProvider) GetEmailAddress(ctx context.Context, s *sessions.Sessio
err := requests.New(endpoint.String()).
WithContext(ctx).
WithHeaders(getGitHubHeader(s.AccessToken)).
Do().
UnmarshalInto(&emails)
if err != nil {
return "", err
@ -435,6 +432,7 @@ func (p *GitHubProvider) GetUserName(ctx context.Context, s *sessions.SessionSta
err := requests.New(endpoint.String()).
WithContext(ctx).
WithHeaders(getGitHubHeader(s.AccessToken)).
Do().
UnmarshalInto(&user)
if err != nil {
return "", err

View File

@ -133,6 +133,7 @@ func (p *GitLabProvider) getUserInfo(ctx context.Context, s *sessions.SessionSta
err := requests.New(userInfoURL.String()).
WithContext(ctx).
SetHeader("Authorization", "Bearer "+s.AccessToken).
Do().
UnmarshalInto(&userInfo)
if err != nil {
return nil, fmt.Errorf("error getting user info: %v", err)

View File

@ -129,6 +129,7 @@ func (p *GoogleProvider) Redeem(ctx context.Context, redirectURL, code string) (
WithMethod("POST").
WithBody(bytes.NewBufferString(params.Encode())).
SetHeader("Content-Type", "application/x-www-form-urlencoded").
Do().
UnmarshalInto(&jsonResponse)
if err != nil {
return nil, err
@ -280,6 +281,7 @@ func (p *GoogleProvider) redeemRefreshToken(ctx context.Context, refreshToken st
WithMethod("POST").
WithBody(bytes.NewBufferString(params.Encode())).
SetHeader("Content-Type", "application/x-www-form-urlencoded").
Do().
UnmarshalInto(&data)
if err != nil {
return "", "", 0, err

View File

@ -2,7 +2,6 @@ package providers
import (
"context"
"io/ioutil"
"net/http"
"net/url"
@ -57,23 +56,21 @@ func validateToken(ctx context.Context, p Provider, accessToken string, header h
endpoint = endpoint + "?" + params.Encode()
}
resp, err := requests.New(endpoint).
result := requests.New(endpoint).
WithContext(ctx).
WithHeaders(header).
Do()
if err != nil {
if result.Error() != nil {
logger.Printf("GET %s", stripToken(endpoint))
logger.Printf("token validation request failed: %s", err)
logger.Printf("token validation request failed: %s", result.Error())
return false
}
body, _ := ioutil.ReadAll(resp.Body)
resp.Body.Close()
logger.Printf("%d GET %s %s", resp.StatusCode, stripToken(endpoint), body)
logger.Printf("%d GET %s %s", result.StatusCode(), stripToken(endpoint), result.Body())
if resp.StatusCode == 200 {
if result.StatusCode() == 200 {
return true
}
logger.Printf("token validation request failed: status %d - %s", resp.StatusCode, body)
logger.Printf("token validation request failed: status %d - %s", result.StatusCode(), result.Body())
return false
}

View File

@ -53,6 +53,7 @@ func (p *KeycloakProvider) GetEmailAddress(ctx context.Context, s *sessions.Sess
json, err := requests.New(p.ValidateURL.String()).
WithContext(ctx).
SetHeader("Authorization", "Bearer "+s.AccessToken).
Do().
UnmarshalJSON()
if err != nil {
logger.Printf("failed making request %s", err)

View File

@ -63,6 +63,7 @@ func (p *LinkedInProvider) GetEmailAddress(ctx context.Context, s *sessions.Sess
json, err := requests.New(requestURL).
WithContext(ctx).
WithHeaders(getLinkedInHeader(s.AccessToken)).
Do().
UnmarshalJSON()
if err != nil {
return "", err

View File

@ -141,6 +141,7 @@ func emailFromUserInfo(ctx context.Context, accessToken string, userInfoEndpoint
err := requests.New(userInfoEndpoint).
WithContext(ctx).
SetHeader("Authorization", "Bearer "+accessToken).
Do().
UnmarshalInto(&emailData)
if err != nil {
return "", err
@ -196,6 +197,7 @@ func (p *LoginGovProvider) Redeem(ctx context.Context, redirectURL, code string)
WithMethod("POST").
WithBody(bytes.NewBufferString(params.Encode())).
SetHeader("Content-Type", "application/x-www-form-urlencoded").
Do().
UnmarshalInto(&jsonResponse)
if err != nil {
return nil, err

View File

@ -33,6 +33,7 @@ func (p *NextcloudProvider) GetEmailAddress(ctx context.Context, s *sessions.Ses
json, err := requests.New(p.ValidateURL.String()).
WithContext(ctx).
WithHeaders(getNextcloudHeader(s.AccessToken)).
Do().
UnmarshalJSON()
if err != nil {
return "", fmt.Errorf("error making request: %v", err)

View File

@ -259,6 +259,7 @@ func (p *OIDCProvider) findClaimsFromIDToken(ctx context.Context, idToken *oidc.
respJSON, err := requests.New(profileURL).
WithContext(ctx).
WithHeaders(getOIDCHeader(accessToken)).
Do().
UnmarshalJSON()
if err != nil {
return nil, err

View File

@ -3,10 +3,8 @@ package providers
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io/ioutil"
"net/url"
"time"
@ -39,33 +37,21 @@ func (p *ProviderData) Redeem(ctx context.Context, redirectURL, code string) (s
params.Add("resource", p.ProtectedResource.String())
}
resp, err := requests.New(p.RedeemURL.String()).
result := requests.New(p.RedeemURL.String()).
WithContext(ctx).
WithMethod("POST").
WithBody(bytes.NewBufferString(params.Encode())).
SetHeader("Content-Type", "application/x-www-form-urlencoded").
Do()
if err != nil {
return nil, err
}
var body []byte
body, err = ioutil.ReadAll(resp.Body)
resp.Body.Close()
if err != nil {
return
}
if resp.StatusCode != 200 {
err = fmt.Errorf("got %d from %q %s", resp.StatusCode, p.RedeemURL.String(), body)
return
if result.Error() != nil {
return nil, result.Error()
}
// blindly try json and x-www-form-urlencoded
var jsonResponse struct {
AccessToken string `json:"access_token"`
}
err = json.Unmarshal(body, &jsonResponse)
err = result.UnmarshalInto(&jsonResponse)
if err == nil {
s = &sessions.SessionState{
AccessToken: jsonResponse.AccessToken,
@ -74,7 +60,7 @@ func (p *ProviderData) Redeem(ctx context.Context, redirectURL, code string) (s
}
var v url.Values
v, err = url.ParseQuery(string(body))
v, err = url.ParseQuery(string(result.Body()))
if err != nil {
return
}
@ -82,7 +68,7 @@ func (p *ProviderData) Redeem(ctx context.Context, redirectURL, code string) (s
created := time.Now()
s = &sessions.SessionState{AccessToken: a, CreatedAt: &created}
} else {
err = fmt.Errorf("no access token found %s", body)
err = fmt.Errorf("no access token found %s", result.Body())
}
return
}