1
0
mirror of https://github.com/oauth2-proxy/oauth2-proxy.git synced 2025-04-21 12:17:22 +02:00

Migrate all requests to result pattern

This commit is contained in:
Joel Speed 2020-07-06 17:42:26 +01:00
parent d0b6c04960
commit de9e65a63a
No known key found for this signature in database
GPG Key ID: 6E80578D6751DEFB
15 changed files with 46 additions and 50 deletions

View File

@ -86,6 +86,7 @@ func Validate(o *options.Options) error {
requestURL := strings.TrimSuffix(o.OIDCIssuerURL, "/") + "/.well-known/openid-configuration" requestURL := strings.TrimSuffix(o.OIDCIssuerURL, "/") + "/.well-known/openid-configuration"
body, err := requests.New(requestURL). body, err := requests.New(requestURL).
WithContext(ctx). WithContext(ctx).
Do().
UnmarshalJSON() UnmarshalJSON()
if err != nil { if err != nil {
logger.Printf("error: failed to discover OIDC configuration: %v", err) logger.Printf("error: failed to discover OIDC configuration: %v", err)
@ -384,11 +385,9 @@ func newVerifierFromJwtIssuer(jwtIssuer jwtIssuer) (*oidc.IDTokenVerifier, error
if err != nil { if err != nil {
// Try as JWKS URI // Try as JWKS URI
jwksURI := strings.TrimSuffix(jwtIssuer.issuerURI, "/") + "/.well-known/jwks.json" jwksURI := strings.TrimSuffix(jwtIssuer.issuerURI, "/") + "/.well-known/jwks.json"
resp, err := requests.New(jwksURI).Do() if err := requests.New(jwksURI).Do().Error(); err != nil {
if err != nil {
return nil, err return nil, err
} }
resp.Body.Close()
verifier = oidc.NewVerifier(jwtIssuer.issuerURI, oidc.NewRemoteKeySet(context.Background(), jwksURI), config) verifier = oidc.NewVerifier(jwtIssuer.issuerURI, oidc.NewRemoteKeySet(context.Background(), jwksURI), config)
} else { } else {

View File

@ -101,6 +101,7 @@ func (p *AzureProvider) Redeem(ctx context.Context, redirectURL, code string) (s
WithMethod("POST"). WithMethod("POST").
WithBody(bytes.NewBufferString(params.Encode())). WithBody(bytes.NewBufferString(params.Encode())).
SetHeader("Content-Type", "application/x-www-form-urlencoded"). SetHeader("Content-Type", "application/x-www-form-urlencoded").
Do().
UnmarshalInto(&jsonResponse) UnmarshalInto(&jsonResponse)
if err != nil { if err != nil {
return nil, err return nil, err
@ -153,6 +154,7 @@ func (p *AzureProvider) GetEmailAddress(ctx context.Context, s *sessions.Session
json, err := requests.New(p.ProfileURL.String()). json, err := requests.New(p.ProfileURL.String()).
WithContext(ctx). WithContext(ctx).
WithHeaders(getAzureHeader(s.AccessToken)). WithHeaders(getAzureHeader(s.AccessToken)).
Do().
UnmarshalJSON() UnmarshalJSON()
if err != nil { if err != nil {
return "", err return "", err

View File

@ -88,6 +88,7 @@ func (p *BitbucketProvider) GetEmailAddress(ctx context.Context, s *sessions.Ses
requestURL := p.ValidateURL.String() + "?access_token=" + s.AccessToken requestURL := p.ValidateURL.String() + "?access_token=" + s.AccessToken
err := requests.New(requestURL). err := requests.New(requestURL).
WithContext(ctx). WithContext(ctx).
Do().
UnmarshalInto(&emails) UnmarshalInto(&emails)
if err != nil { if err != nil {
logger.Printf("failed making request: %v", err) logger.Printf("failed making request: %v", err)
@ -103,6 +104,7 @@ func (p *BitbucketProvider) GetEmailAddress(ctx context.Context, s *sessions.Ses
err := requests.New(requestURL). err := requests.New(requestURL).
WithContext(ctx). WithContext(ctx).
Do().
UnmarshalInto(&teams) UnmarshalInto(&teams)
if err != nil { if err != nil {
logger.Printf("failed requesting teams membership: %v", err) logger.Printf("failed requesting teams membership: %v", err)
@ -132,6 +134,7 @@ func (p *BitbucketProvider) GetEmailAddress(ctx context.Context, s *sessions.Ses
err := requests.New(requestURL). err := requests.New(requestURL).
WithContext(ctx). WithContext(ctx).
Do().
UnmarshalInto(&repositories) UnmarshalInto(&repositories)
if err != nil { if err != nil {
logger.Printf("failed checking repository access: %v", err) logger.Printf("failed checking repository access: %v", err)

View File

@ -64,6 +64,7 @@ func (p *DigitalOceanProvider) GetEmailAddress(ctx context.Context, s *sessions.
json, err := requests.New(p.ProfileURL.String()). json, err := requests.New(p.ProfileURL.String()).
WithContext(ctx). WithContext(ctx).
WithHeaders(getDigitalOceanHeader(s.AccessToken)). WithHeaders(getDigitalOceanHeader(s.AccessToken)).
Do().
UnmarshalJSON() UnmarshalJSON()
if err != nil { if err != nil {
return "", err return "", err

View File

@ -72,6 +72,7 @@ func (p *FacebookProvider) GetEmailAddress(ctx context.Context, s *sessions.Sess
err := requests.New(requestURL). err := requests.New(requestURL).
WithContext(ctx). WithContext(ctx).
WithHeaders(getFacebookHeader(s.AccessToken)). WithHeaders(getFacebookHeader(s.AccessToken)).
Do().
UnmarshalInto(&r) UnmarshalInto(&r)
if err != nil { if err != nil {
return "", err return "", err

View File

@ -4,7 +4,6 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"io/ioutil"
"net/http" "net/http"
"net/url" "net/url"
"path" "path"
@ -116,6 +115,7 @@ func (p *GitHubProvider) hasOrg(ctx context.Context, accessToken string) (bool,
err := requests.New(endpoint.String()). err := requests.New(endpoint.String()).
WithContext(ctx). WithContext(ctx).
WithHeaders(getGitHubHeader(accessToken)). WithHeaders(getGitHubHeader(accessToken)).
Do().
UnmarshalInto(&op) UnmarshalInto(&op)
if err != nil { if err != nil {
return false, err return false, err
@ -179,12 +179,12 @@ func (p *GitHubProvider) hasOrgAndTeam(ctx context.Context, accessToken string)
// bodyclose cannot detect that the body is being closed later in requests.Into, // bodyclose cannot detect that the body is being closed later in requests.Into,
// so have to skip the linting for the next line. // so have to skip the linting for the next line.
// nolint:bodyclose // nolint:bodyclose
resp, err := requests.New(endpoint.String()). result := requests.New(endpoint.String()).
WithContext(ctx). WithContext(ctx).
WithHeaders(getGitHubHeader(accessToken)). WithHeaders(getGitHubHeader(accessToken)).
Do() Do()
if err != nil { if result.Error() != nil {
return false, err return false, result.Error()
} }
if last == 0 { if last == 0 {
@ -200,7 +200,7 @@ func (p *GitHubProvider) hasOrgAndTeam(ctx context.Context, accessToken string)
// link header at last page (doesn't exist last info) // link header at last page (doesn't exist last info)
// <https://api.github.com/user/teams?page=3&per_page=10>; rel="prev", <https://api.github.com/user/teams?page=1&per_page=10>; rel="first" // <https://api.github.com/user/teams?page=3&per_page=10>; rel="prev", <https://api.github.com/user/teams?page=1&per_page=10>; rel="first"
link := resp.Header.Get("Link") link := result.Headers().Get("Link")
rep1 := regexp.MustCompile(`(?s).*\<https://api.github.com/user/teams\?page=(.)&per_page=[0-9]+\>; rel="last".*`) rep1 := regexp.MustCompile(`(?s).*\<https://api.github.com/user/teams\?page=(.)&per_page=[0-9]+\>; rel="last".*`)
i, converr := strconv.Atoi(rep1.ReplaceAllString(link, "$1")) i, converr := strconv.Atoi(rep1.ReplaceAllString(link, "$1"))
@ -211,7 +211,7 @@ func (p *GitHubProvider) hasOrgAndTeam(ctx context.Context, accessToken string)
} }
var tp teamsPage var tp teamsPage
if err := requests.UnmarshalInto(resp, &tp); err != nil { if err := result.UnmarshalInto(&tp); err != nil {
return false, err return false, err
} }
if len(tp) == 0 { if len(tp) == 0 {
@ -282,6 +282,7 @@ func (p *GitHubProvider) hasRepo(ctx context.Context, accessToken string) (bool,
err := requests.New(endpoint.String()). err := requests.New(endpoint.String()).
WithContext(ctx). WithContext(ctx).
WithHeaders(getGitHubHeader(accessToken)). WithHeaders(getGitHubHeader(accessToken)).
Do().
UnmarshalInto(&repo) UnmarshalInto(&repo)
if err != nil { if err != nil {
return false, err return false, err
@ -309,6 +310,7 @@ func (p *GitHubProvider) hasUser(ctx context.Context, accessToken string) (bool,
err := requests.New(endpoint.String()). err := requests.New(endpoint.String()).
WithContext(ctx). WithContext(ctx).
WithHeaders(getGitHubHeader(accessToken)). WithHeaders(getGitHubHeader(accessToken)).
Do().
UnmarshalInto(&user) UnmarshalInto(&user)
if err != nil { if err != nil {
return false, err return false, err
@ -328,26 +330,20 @@ func (p *GitHubProvider) isCollaborator(ctx context.Context, username, accessTok
Host: p.ValidateURL.Host, Host: p.ValidateURL.Host,
Path: path.Join(p.ValidateURL.Path, "/repos/", p.Repo, "/collaborators/", username), Path: path.Join(p.ValidateURL.Path, "/repos/", p.Repo, "/collaborators/", username),
} }
resp, err := requests.New(endpoint.String()). result := requests.New(endpoint.String()).
WithContext(ctx). WithContext(ctx).
WithHeaders(getGitHubHeader(accessToken)). WithHeaders(getGitHubHeader(accessToken)).
Do() Do()
if err != nil { if result.Error() != nil {
return false, err return false, result.Error()
} }
body, err := ioutil.ReadAll(resp.Body) if result.StatusCode() != 204 {
resp.Body.Close()
if err != nil {
return false, err
}
if resp.StatusCode != 204 {
return false, fmt.Errorf("got %d from %q %s", return false, fmt.Errorf("got %d from %q %s",
resp.StatusCode, endpoint.String(), body) result.StatusCode(), endpoint.String(), result.Body())
} }
logger.Printf("got %d from %q %s", resp.StatusCode, endpoint.String(), body) logger.Printf("got %d from %q %s", result.StatusCode(), endpoint.String(), result.Body())
return true, nil return true, nil
} }
@ -401,6 +397,7 @@ func (p *GitHubProvider) GetEmailAddress(ctx context.Context, s *sessions.Sessio
err := requests.New(endpoint.String()). err := requests.New(endpoint.String()).
WithContext(ctx). WithContext(ctx).
WithHeaders(getGitHubHeader(s.AccessToken)). WithHeaders(getGitHubHeader(s.AccessToken)).
Do().
UnmarshalInto(&emails) UnmarshalInto(&emails)
if err != nil { if err != nil {
return "", err return "", err
@ -435,6 +432,7 @@ func (p *GitHubProvider) GetUserName(ctx context.Context, s *sessions.SessionSta
err := requests.New(endpoint.String()). err := requests.New(endpoint.String()).
WithContext(ctx). WithContext(ctx).
WithHeaders(getGitHubHeader(s.AccessToken)). WithHeaders(getGitHubHeader(s.AccessToken)).
Do().
UnmarshalInto(&user) UnmarshalInto(&user)
if err != nil { if err != nil {
return "", err return "", err

View File

@ -133,6 +133,7 @@ func (p *GitLabProvider) getUserInfo(ctx context.Context, s *sessions.SessionSta
err := requests.New(userInfoURL.String()). err := requests.New(userInfoURL.String()).
WithContext(ctx). WithContext(ctx).
SetHeader("Authorization", "Bearer "+s.AccessToken). SetHeader("Authorization", "Bearer "+s.AccessToken).
Do().
UnmarshalInto(&userInfo) UnmarshalInto(&userInfo)
if err != nil { if err != nil {
return nil, fmt.Errorf("error getting user info: %v", err) return nil, fmt.Errorf("error getting user info: %v", err)

View File

@ -129,6 +129,7 @@ func (p *GoogleProvider) Redeem(ctx context.Context, redirectURL, code string) (
WithMethod("POST"). WithMethod("POST").
WithBody(bytes.NewBufferString(params.Encode())). WithBody(bytes.NewBufferString(params.Encode())).
SetHeader("Content-Type", "application/x-www-form-urlencoded"). SetHeader("Content-Type", "application/x-www-form-urlencoded").
Do().
UnmarshalInto(&jsonResponse) UnmarshalInto(&jsonResponse)
if err != nil { if err != nil {
return nil, err return nil, err
@ -280,6 +281,7 @@ func (p *GoogleProvider) redeemRefreshToken(ctx context.Context, refreshToken st
WithMethod("POST"). WithMethod("POST").
WithBody(bytes.NewBufferString(params.Encode())). WithBody(bytes.NewBufferString(params.Encode())).
SetHeader("Content-Type", "application/x-www-form-urlencoded"). SetHeader("Content-Type", "application/x-www-form-urlencoded").
Do().
UnmarshalInto(&data) UnmarshalInto(&data)
if err != nil { if err != nil {
return "", "", 0, err return "", "", 0, err

View File

@ -2,7 +2,6 @@ package providers
import ( import (
"context" "context"
"io/ioutil"
"net/http" "net/http"
"net/url" "net/url"
@ -57,23 +56,21 @@ func validateToken(ctx context.Context, p Provider, accessToken string, header h
endpoint = endpoint + "?" + params.Encode() endpoint = endpoint + "?" + params.Encode()
} }
resp, err := requests.New(endpoint). result := requests.New(endpoint).
WithContext(ctx). WithContext(ctx).
WithHeaders(header). WithHeaders(header).
Do() Do()
if err != nil { if result.Error() != nil {
logger.Printf("GET %s", stripToken(endpoint)) logger.Printf("GET %s", stripToken(endpoint))
logger.Printf("token validation request failed: %s", err) logger.Printf("token validation request failed: %s", result.Error())
return false return false
} }
body, _ := ioutil.ReadAll(resp.Body) logger.Printf("%d GET %s %s", result.StatusCode(), stripToken(endpoint), result.Body())
resp.Body.Close()
logger.Printf("%d GET %s %s", resp.StatusCode, stripToken(endpoint), body)
if resp.StatusCode == 200 { if result.StatusCode() == 200 {
return true return true
} }
logger.Printf("token validation request failed: status %d - %s", resp.StatusCode, body) logger.Printf("token validation request failed: status %d - %s", result.StatusCode(), result.Body())
return false return false
} }

View File

@ -53,6 +53,7 @@ func (p *KeycloakProvider) GetEmailAddress(ctx context.Context, s *sessions.Sess
json, err := requests.New(p.ValidateURL.String()). json, err := requests.New(p.ValidateURL.String()).
WithContext(ctx). WithContext(ctx).
SetHeader("Authorization", "Bearer "+s.AccessToken). SetHeader("Authorization", "Bearer "+s.AccessToken).
Do().
UnmarshalJSON() UnmarshalJSON()
if err != nil { if err != nil {
logger.Printf("failed making request %s", err) logger.Printf("failed making request %s", err)

View File

@ -63,6 +63,7 @@ func (p *LinkedInProvider) GetEmailAddress(ctx context.Context, s *sessions.Sess
json, err := requests.New(requestURL). json, err := requests.New(requestURL).
WithContext(ctx). WithContext(ctx).
WithHeaders(getLinkedInHeader(s.AccessToken)). WithHeaders(getLinkedInHeader(s.AccessToken)).
Do().
UnmarshalJSON() UnmarshalJSON()
if err != nil { if err != nil {
return "", err return "", err

View File

@ -141,6 +141,7 @@ func emailFromUserInfo(ctx context.Context, accessToken string, userInfoEndpoint
err := requests.New(userInfoEndpoint). err := requests.New(userInfoEndpoint).
WithContext(ctx). WithContext(ctx).
SetHeader("Authorization", "Bearer "+accessToken). SetHeader("Authorization", "Bearer "+accessToken).
Do().
UnmarshalInto(&emailData) UnmarshalInto(&emailData)
if err != nil { if err != nil {
return "", err return "", err
@ -196,6 +197,7 @@ func (p *LoginGovProvider) Redeem(ctx context.Context, redirectURL, code string)
WithMethod("POST"). WithMethod("POST").
WithBody(bytes.NewBufferString(params.Encode())). WithBody(bytes.NewBufferString(params.Encode())).
SetHeader("Content-Type", "application/x-www-form-urlencoded"). SetHeader("Content-Type", "application/x-www-form-urlencoded").
Do().
UnmarshalInto(&jsonResponse) UnmarshalInto(&jsonResponse)
if err != nil { if err != nil {
return nil, err return nil, err

View File

@ -33,6 +33,7 @@ func (p *NextcloudProvider) GetEmailAddress(ctx context.Context, s *sessions.Ses
json, err := requests.New(p.ValidateURL.String()). json, err := requests.New(p.ValidateURL.String()).
WithContext(ctx). WithContext(ctx).
WithHeaders(getNextcloudHeader(s.AccessToken)). WithHeaders(getNextcloudHeader(s.AccessToken)).
Do().
UnmarshalJSON() UnmarshalJSON()
if err != nil { if err != nil {
return "", fmt.Errorf("error making request: %v", err) return "", fmt.Errorf("error making request: %v", err)

View File

@ -259,6 +259,7 @@ func (p *OIDCProvider) findClaimsFromIDToken(ctx context.Context, idToken *oidc.
respJSON, err := requests.New(profileURL). respJSON, err := requests.New(profileURL).
WithContext(ctx). WithContext(ctx).
WithHeaders(getOIDCHeader(accessToken)). WithHeaders(getOIDCHeader(accessToken)).
Do().
UnmarshalJSON() UnmarshalJSON()
if err != nil { if err != nil {
return nil, err return nil, err

View File

@ -3,10 +3,8 @@ package providers
import ( import (
"bytes" "bytes"
"context" "context"
"encoding/json"
"errors" "errors"
"fmt" "fmt"
"io/ioutil"
"net/url" "net/url"
"time" "time"
@ -39,33 +37,21 @@ func (p *ProviderData) Redeem(ctx context.Context, redirectURL, code string) (s
params.Add("resource", p.ProtectedResource.String()) params.Add("resource", p.ProtectedResource.String())
} }
resp, err := requests.New(p.RedeemURL.String()). result := requests.New(p.RedeemURL.String()).
WithContext(ctx). WithContext(ctx).
WithMethod("POST"). WithMethod("POST").
WithBody(bytes.NewBufferString(params.Encode())). WithBody(bytes.NewBufferString(params.Encode())).
SetHeader("Content-Type", "application/x-www-form-urlencoded"). SetHeader("Content-Type", "application/x-www-form-urlencoded").
Do() Do()
if err != nil { if result.Error() != nil {
return nil, err return nil, result.Error()
}
var body []byte
body, err = ioutil.ReadAll(resp.Body)
resp.Body.Close()
if err != nil {
return
}
if resp.StatusCode != 200 {
err = fmt.Errorf("got %d from %q %s", resp.StatusCode, p.RedeemURL.String(), body)
return
} }
// blindly try json and x-www-form-urlencoded // blindly try json and x-www-form-urlencoded
var jsonResponse struct { var jsonResponse struct {
AccessToken string `json:"access_token"` AccessToken string `json:"access_token"`
} }
err = json.Unmarshal(body, &jsonResponse) err = result.UnmarshalInto(&jsonResponse)
if err == nil { if err == nil {
s = &sessions.SessionState{ s = &sessions.SessionState{
AccessToken: jsonResponse.AccessToken, AccessToken: jsonResponse.AccessToken,
@ -74,7 +60,7 @@ func (p *ProviderData) Redeem(ctx context.Context, redirectURL, code string) (s
} }
var v url.Values var v url.Values
v, err = url.ParseQuery(string(body)) v, err = url.ParseQuery(string(result.Body()))
if err != nil { if err != nil {
return return
} }
@ -82,7 +68,7 @@ func (p *ProviderData) Redeem(ctx context.Context, redirectURL, code string) (s
created := time.Now() created := time.Now()
s = &sessions.SessionState{AccessToken: a, CreatedAt: &created} s = &sessions.SessionState{AccessToken: a, CreatedAt: &created}
} else { } else {
err = fmt.Errorf("no access token found %s", body) err = fmt.Errorf("no access token found %s", result.Body())
} }
return return
} }