mirror of
https://github.com/oauth2-proxy/oauth2-proxy.git
synced 2024-11-28 09:08:44 +02:00
Migrate all requests to result pattern
This commit is contained in:
parent
d0b6c04960
commit
de9e65a63a
@ -86,6 +86,7 @@ func Validate(o *options.Options) error {
|
||||
requestURL := strings.TrimSuffix(o.OIDCIssuerURL, "/") + "/.well-known/openid-configuration"
|
||||
body, err := requests.New(requestURL).
|
||||
WithContext(ctx).
|
||||
Do().
|
||||
UnmarshalJSON()
|
||||
if err != nil {
|
||||
logger.Printf("error: failed to discover OIDC configuration: %v", err)
|
||||
@ -384,11 +385,9 @@ func newVerifierFromJwtIssuer(jwtIssuer jwtIssuer) (*oidc.IDTokenVerifier, error
|
||||
if err != nil {
|
||||
// Try as JWKS URI
|
||||
jwksURI := strings.TrimSuffix(jwtIssuer.issuerURI, "/") + "/.well-known/jwks.json"
|
||||
resp, err := requests.New(jwksURI).Do()
|
||||
if err != nil {
|
||||
if err := requests.New(jwksURI).Do().Error(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp.Body.Close()
|
||||
|
||||
verifier = oidc.NewVerifier(jwtIssuer.issuerURI, oidc.NewRemoteKeySet(context.Background(), jwksURI), config)
|
||||
} else {
|
||||
|
@ -101,6 +101,7 @@ func (p *AzureProvider) Redeem(ctx context.Context, redirectURL, code string) (s
|
||||
WithMethod("POST").
|
||||
WithBody(bytes.NewBufferString(params.Encode())).
|
||||
SetHeader("Content-Type", "application/x-www-form-urlencoded").
|
||||
Do().
|
||||
UnmarshalInto(&jsonResponse)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -153,6 +154,7 @@ func (p *AzureProvider) GetEmailAddress(ctx context.Context, s *sessions.Session
|
||||
json, err := requests.New(p.ProfileURL.String()).
|
||||
WithContext(ctx).
|
||||
WithHeaders(getAzureHeader(s.AccessToken)).
|
||||
Do().
|
||||
UnmarshalJSON()
|
||||
if err != nil {
|
||||
return "", err
|
||||
|
@ -88,6 +88,7 @@ func (p *BitbucketProvider) GetEmailAddress(ctx context.Context, s *sessions.Ses
|
||||
requestURL := p.ValidateURL.String() + "?access_token=" + s.AccessToken
|
||||
err := requests.New(requestURL).
|
||||
WithContext(ctx).
|
||||
Do().
|
||||
UnmarshalInto(&emails)
|
||||
if err != nil {
|
||||
logger.Printf("failed making request: %v", err)
|
||||
@ -103,6 +104,7 @@ func (p *BitbucketProvider) GetEmailAddress(ctx context.Context, s *sessions.Ses
|
||||
|
||||
err := requests.New(requestURL).
|
||||
WithContext(ctx).
|
||||
Do().
|
||||
UnmarshalInto(&teams)
|
||||
if err != nil {
|
||||
logger.Printf("failed requesting teams membership: %v", err)
|
||||
@ -132,6 +134,7 @@ func (p *BitbucketProvider) GetEmailAddress(ctx context.Context, s *sessions.Ses
|
||||
|
||||
err := requests.New(requestURL).
|
||||
WithContext(ctx).
|
||||
Do().
|
||||
UnmarshalInto(&repositories)
|
||||
if err != nil {
|
||||
logger.Printf("failed checking repository access: %v", err)
|
||||
|
@ -64,6 +64,7 @@ func (p *DigitalOceanProvider) GetEmailAddress(ctx context.Context, s *sessions.
|
||||
json, err := requests.New(p.ProfileURL.String()).
|
||||
WithContext(ctx).
|
||||
WithHeaders(getDigitalOceanHeader(s.AccessToken)).
|
||||
Do().
|
||||
UnmarshalJSON()
|
||||
if err != nil {
|
||||
return "", err
|
||||
|
@ -72,6 +72,7 @@ func (p *FacebookProvider) GetEmailAddress(ctx context.Context, s *sessions.Sess
|
||||
err := requests.New(requestURL).
|
||||
WithContext(ctx).
|
||||
WithHeaders(getFacebookHeader(s.AccessToken)).
|
||||
Do().
|
||||
UnmarshalInto(&r)
|
||||
if err != nil {
|
||||
return "", err
|
||||
|
@ -4,7 +4,6 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"path"
|
||||
@ -116,6 +115,7 @@ func (p *GitHubProvider) hasOrg(ctx context.Context, accessToken string) (bool,
|
||||
err := requests.New(endpoint.String()).
|
||||
WithContext(ctx).
|
||||
WithHeaders(getGitHubHeader(accessToken)).
|
||||
Do().
|
||||
UnmarshalInto(&op)
|
||||
if err != nil {
|
||||
return false, err
|
||||
@ -179,12 +179,12 @@ func (p *GitHubProvider) hasOrgAndTeam(ctx context.Context, accessToken string)
|
||||
// bodyclose cannot detect that the body is being closed later in requests.Into,
|
||||
// so have to skip the linting for the next line.
|
||||
// nolint:bodyclose
|
||||
resp, err := requests.New(endpoint.String()).
|
||||
result := requests.New(endpoint.String()).
|
||||
WithContext(ctx).
|
||||
WithHeaders(getGitHubHeader(accessToken)).
|
||||
Do()
|
||||
if err != nil {
|
||||
return false, err
|
||||
if result.Error() != nil {
|
||||
return false, result.Error()
|
||||
}
|
||||
|
||||
if last == 0 {
|
||||
@ -200,7 +200,7 @@ func (p *GitHubProvider) hasOrgAndTeam(ctx context.Context, accessToken string)
|
||||
// link header at last page (doesn't exist last info)
|
||||
// <https://api.github.com/user/teams?page=3&per_page=10>; rel="prev", <https://api.github.com/user/teams?page=1&per_page=10>; rel="first"
|
||||
|
||||
link := resp.Header.Get("Link")
|
||||
link := result.Headers().Get("Link")
|
||||
rep1 := regexp.MustCompile(`(?s).*\<https://api.github.com/user/teams\?page=(.)&per_page=[0-9]+\>; rel="last".*`)
|
||||
i, converr := strconv.Atoi(rep1.ReplaceAllString(link, "$1"))
|
||||
|
||||
@ -211,7 +211,7 @@ func (p *GitHubProvider) hasOrgAndTeam(ctx context.Context, accessToken string)
|
||||
}
|
||||
|
||||
var tp teamsPage
|
||||
if err := requests.UnmarshalInto(resp, &tp); err != nil {
|
||||
if err := result.UnmarshalInto(&tp); err != nil {
|
||||
return false, err
|
||||
}
|
||||
if len(tp) == 0 {
|
||||
@ -282,6 +282,7 @@ func (p *GitHubProvider) hasRepo(ctx context.Context, accessToken string) (bool,
|
||||
err := requests.New(endpoint.String()).
|
||||
WithContext(ctx).
|
||||
WithHeaders(getGitHubHeader(accessToken)).
|
||||
Do().
|
||||
UnmarshalInto(&repo)
|
||||
if err != nil {
|
||||
return false, err
|
||||
@ -309,6 +310,7 @@ func (p *GitHubProvider) hasUser(ctx context.Context, accessToken string) (bool,
|
||||
err := requests.New(endpoint.String()).
|
||||
WithContext(ctx).
|
||||
WithHeaders(getGitHubHeader(accessToken)).
|
||||
Do().
|
||||
UnmarshalInto(&user)
|
||||
if err != nil {
|
||||
return false, err
|
||||
@ -328,26 +330,20 @@ func (p *GitHubProvider) isCollaborator(ctx context.Context, username, accessTok
|
||||
Host: p.ValidateURL.Host,
|
||||
Path: path.Join(p.ValidateURL.Path, "/repos/", p.Repo, "/collaborators/", username),
|
||||
}
|
||||
resp, err := requests.New(endpoint.String()).
|
||||
result := requests.New(endpoint.String()).
|
||||
WithContext(ctx).
|
||||
WithHeaders(getGitHubHeader(accessToken)).
|
||||
Do()
|
||||
if err != nil {
|
||||
return false, err
|
||||
if result.Error() != nil {
|
||||
return false, result.Error()
|
||||
}
|
||||
|
||||
body, err := ioutil.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if resp.StatusCode != 204 {
|
||||
if result.StatusCode() != 204 {
|
||||
return false, fmt.Errorf("got %d from %q %s",
|
||||
resp.StatusCode, endpoint.String(), body)
|
||||
result.StatusCode(), endpoint.String(), result.Body())
|
||||
}
|
||||
|
||||
logger.Printf("got %d from %q %s", resp.StatusCode, endpoint.String(), body)
|
||||
logger.Printf("got %d from %q %s", result.StatusCode(), endpoint.String(), result.Body())
|
||||
|
||||
return true, nil
|
||||
}
|
||||
@ -401,6 +397,7 @@ func (p *GitHubProvider) GetEmailAddress(ctx context.Context, s *sessions.Sessio
|
||||
err := requests.New(endpoint.String()).
|
||||
WithContext(ctx).
|
||||
WithHeaders(getGitHubHeader(s.AccessToken)).
|
||||
Do().
|
||||
UnmarshalInto(&emails)
|
||||
if err != nil {
|
||||
return "", err
|
||||
@ -435,6 +432,7 @@ func (p *GitHubProvider) GetUserName(ctx context.Context, s *sessions.SessionSta
|
||||
err := requests.New(endpoint.String()).
|
||||
WithContext(ctx).
|
||||
WithHeaders(getGitHubHeader(s.AccessToken)).
|
||||
Do().
|
||||
UnmarshalInto(&user)
|
||||
if err != nil {
|
||||
return "", err
|
||||
|
@ -133,6 +133,7 @@ func (p *GitLabProvider) getUserInfo(ctx context.Context, s *sessions.SessionSta
|
||||
err := requests.New(userInfoURL.String()).
|
||||
WithContext(ctx).
|
||||
SetHeader("Authorization", "Bearer "+s.AccessToken).
|
||||
Do().
|
||||
UnmarshalInto(&userInfo)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting user info: %v", err)
|
||||
|
@ -129,6 +129,7 @@ func (p *GoogleProvider) Redeem(ctx context.Context, redirectURL, code string) (
|
||||
WithMethod("POST").
|
||||
WithBody(bytes.NewBufferString(params.Encode())).
|
||||
SetHeader("Content-Type", "application/x-www-form-urlencoded").
|
||||
Do().
|
||||
UnmarshalInto(&jsonResponse)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -280,6 +281,7 @@ func (p *GoogleProvider) redeemRefreshToken(ctx context.Context, refreshToken st
|
||||
WithMethod("POST").
|
||||
WithBody(bytes.NewBufferString(params.Encode())).
|
||||
SetHeader("Content-Type", "application/x-www-form-urlencoded").
|
||||
Do().
|
||||
UnmarshalInto(&data)
|
||||
if err != nil {
|
||||
return "", "", 0, err
|
||||
|
@ -2,7 +2,6 @@ package providers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
||||
@ -57,23 +56,21 @@ func validateToken(ctx context.Context, p Provider, accessToken string, header h
|
||||
endpoint = endpoint + "?" + params.Encode()
|
||||
}
|
||||
|
||||
resp, err := requests.New(endpoint).
|
||||
result := requests.New(endpoint).
|
||||
WithContext(ctx).
|
||||
WithHeaders(header).
|
||||
Do()
|
||||
if err != nil {
|
||||
if result.Error() != nil {
|
||||
logger.Printf("GET %s", stripToken(endpoint))
|
||||
logger.Printf("token validation request failed: %s", err)
|
||||
logger.Printf("token validation request failed: %s", result.Error())
|
||||
return false
|
||||
}
|
||||
|
||||
body, _ := ioutil.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
logger.Printf("%d GET %s %s", resp.StatusCode, stripToken(endpoint), body)
|
||||
logger.Printf("%d GET %s %s", result.StatusCode(), stripToken(endpoint), result.Body())
|
||||
|
||||
if resp.StatusCode == 200 {
|
||||
if result.StatusCode() == 200 {
|
||||
return true
|
||||
}
|
||||
logger.Printf("token validation request failed: status %d - %s", resp.StatusCode, body)
|
||||
logger.Printf("token validation request failed: status %d - %s", result.StatusCode(), result.Body())
|
||||
return false
|
||||
}
|
||||
|
@ -53,6 +53,7 @@ func (p *KeycloakProvider) GetEmailAddress(ctx context.Context, s *sessions.Sess
|
||||
json, err := requests.New(p.ValidateURL.String()).
|
||||
WithContext(ctx).
|
||||
SetHeader("Authorization", "Bearer "+s.AccessToken).
|
||||
Do().
|
||||
UnmarshalJSON()
|
||||
if err != nil {
|
||||
logger.Printf("failed making request %s", err)
|
||||
|
@ -63,6 +63,7 @@ func (p *LinkedInProvider) GetEmailAddress(ctx context.Context, s *sessions.Sess
|
||||
json, err := requests.New(requestURL).
|
||||
WithContext(ctx).
|
||||
WithHeaders(getLinkedInHeader(s.AccessToken)).
|
||||
Do().
|
||||
UnmarshalJSON()
|
||||
if err != nil {
|
||||
return "", err
|
||||
|
@ -141,6 +141,7 @@ func emailFromUserInfo(ctx context.Context, accessToken string, userInfoEndpoint
|
||||
err := requests.New(userInfoEndpoint).
|
||||
WithContext(ctx).
|
||||
SetHeader("Authorization", "Bearer "+accessToken).
|
||||
Do().
|
||||
UnmarshalInto(&emailData)
|
||||
if err != nil {
|
||||
return "", err
|
||||
@ -196,6 +197,7 @@ func (p *LoginGovProvider) Redeem(ctx context.Context, redirectURL, code string)
|
||||
WithMethod("POST").
|
||||
WithBody(bytes.NewBufferString(params.Encode())).
|
||||
SetHeader("Content-Type", "application/x-www-form-urlencoded").
|
||||
Do().
|
||||
UnmarshalInto(&jsonResponse)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -33,6 +33,7 @@ func (p *NextcloudProvider) GetEmailAddress(ctx context.Context, s *sessions.Ses
|
||||
json, err := requests.New(p.ValidateURL.String()).
|
||||
WithContext(ctx).
|
||||
WithHeaders(getNextcloudHeader(s.AccessToken)).
|
||||
Do().
|
||||
UnmarshalJSON()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error making request: %v", err)
|
||||
|
@ -259,6 +259,7 @@ func (p *OIDCProvider) findClaimsFromIDToken(ctx context.Context, idToken *oidc.
|
||||
respJSON, err := requests.New(profileURL).
|
||||
WithContext(ctx).
|
||||
WithHeaders(getOIDCHeader(accessToken)).
|
||||
Do().
|
||||
UnmarshalJSON()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -3,10 +3,8 @@ package providers
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
@ -39,33 +37,21 @@ func (p *ProviderData) Redeem(ctx context.Context, redirectURL, code string) (s
|
||||
params.Add("resource", p.ProtectedResource.String())
|
||||
}
|
||||
|
||||
resp, err := requests.New(p.RedeemURL.String()).
|
||||
result := requests.New(p.RedeemURL.String()).
|
||||
WithContext(ctx).
|
||||
WithMethod("POST").
|
||||
WithBody(bytes.NewBufferString(params.Encode())).
|
||||
SetHeader("Content-Type", "application/x-www-form-urlencoded").
|
||||
Do()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var body []byte
|
||||
body, err = ioutil.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
err = fmt.Errorf("got %d from %q %s", resp.StatusCode, p.RedeemURL.String(), body)
|
||||
return
|
||||
if result.Error() != nil {
|
||||
return nil, result.Error()
|
||||
}
|
||||
|
||||
// blindly try json and x-www-form-urlencoded
|
||||
var jsonResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
}
|
||||
err = json.Unmarshal(body, &jsonResponse)
|
||||
err = result.UnmarshalInto(&jsonResponse)
|
||||
if err == nil {
|
||||
s = &sessions.SessionState{
|
||||
AccessToken: jsonResponse.AccessToken,
|
||||
@ -74,7 +60,7 @@ func (p *ProviderData) Redeem(ctx context.Context, redirectURL, code string) (s
|
||||
}
|
||||
|
||||
var v url.Values
|
||||
v, err = url.ParseQuery(string(body))
|
||||
v, err = url.ParseQuery(string(result.Body()))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@ -82,7 +68,7 @@ func (p *ProviderData) Redeem(ctx context.Context, redirectURL, code string) (s
|
||||
created := time.Now()
|
||||
s = &sessions.SessionState{AccessToken: a, CreatedAt: &created}
|
||||
} else {
|
||||
err = fmt.Errorf("no access token found %s", body)
|
||||
err = fmt.Errorf("no access token found %s", result.Body())
|
||||
}
|
||||
return
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user