1
0
mirror of https://github.com/oauth2-proxy/oauth2-proxy.git synced 2025-06-15 00:15:00 +02:00

Support context in providers (#519)

Co-authored-by: Henry Jenkins <henry@henryjenkins.name>
This commit is contained in:
Mitsuo Heijo
2020-05-06 00:53:33 +09:00
committed by Henry Jenkins
parent 53d8e99f05
commit e642daef4e
33 changed files with 223 additions and 173 deletions

View File

@ -1,6 +1,7 @@
package providers
import (
"context"
"encoding/json"
"fmt"
"io/ioutil"
@ -22,6 +23,8 @@ type GitHubProvider struct {
Team string
}
var _ Provider = (*GitHubProvider)(nil)
// NewGitHubProvider initiates a new GitHubProvider
func NewGitHubProvider(p *ProviderData) *GitHubProvider {
p.ProviderName = "GitHub"
@ -69,7 +72,7 @@ func (p *GitHubProvider) SetOrgTeam(org, team string) {
}
}
func (p *GitHubProvider) hasOrg(accessToken string) (bool, error) {
func (p *GitHubProvider) hasOrg(ctx context.Context, accessToken string) (bool, error) {
// https://developer.github.com/v3/orgs/#list-your-organizations
var orgs []struct {
@ -93,7 +96,7 @@ func (p *GitHubProvider) hasOrg(accessToken string) (bool, error) {
Path: path.Join(p.ValidateURL.Path, "/user/orgs"),
RawQuery: params.Encode(),
}
req, _ := http.NewRequest("GET", endpoint.String(), nil)
req, _ := http.NewRequestWithContext(ctx, "GET", endpoint.String(), nil)
req.Header = getGitHubHeader(accessToken)
resp, err := http.DefaultClient.Do(req)
if err != nil {
@ -135,7 +138,7 @@ func (p *GitHubProvider) hasOrg(accessToken string) (bool, error) {
return false, nil
}
func (p *GitHubProvider) hasOrgAndTeam(accessToken string) (bool, error) {
func (p *GitHubProvider) hasOrgAndTeam(ctx context.Context, accessToken string) (bool, error) {
// https://developer.github.com/v3/orgs/teams/#list-user-teams
var teams []struct {
@ -169,7 +172,7 @@ func (p *GitHubProvider) hasOrgAndTeam(accessToken string) (bool, error) {
RawQuery: params.Encode(),
}
req, _ := http.NewRequest("GET", endpoint.String(), nil)
req, _ := http.NewRequestWithContext(ctx, "GET", endpoint.String(), nil)
req.Header = getGitHubHeader(accessToken)
resp, err := http.DefaultClient.Do(req)
if err != nil {
@ -261,7 +264,7 @@ func (p *GitHubProvider) hasOrgAndTeam(accessToken string) (bool, error) {
}
// GetEmailAddress returns the Account email address
func (p *GitHubProvider) GetEmailAddress(s *sessions.SessionState) (string, error) {
func (p *GitHubProvider) GetEmailAddress(ctx context.Context, s *sessions.SessionState) (string, error) {
var emails []struct {
Email string `json:"email"`
@ -272,11 +275,11 @@ func (p *GitHubProvider) GetEmailAddress(s *sessions.SessionState) (string, erro
// if we require an Org or Team, check that first
if p.Org != "" {
if p.Team != "" {
if ok, err := p.hasOrgAndTeam(s.AccessToken); err != nil || !ok {
if ok, err := p.hasOrgAndTeam(ctx, s.AccessToken); err != nil || !ok {
return "", err
}
} else {
if ok, err := p.hasOrg(s.AccessToken); err != nil || !ok {
if ok, err := p.hasOrg(ctx, s.AccessToken); err != nil || !ok {
return "", err
}
}
@ -287,7 +290,7 @@ func (p *GitHubProvider) GetEmailAddress(s *sessions.SessionState) (string, erro
Host: p.ValidateURL.Host,
Path: path.Join(p.ValidateURL.Path, "/user/emails"),
}
req, _ := http.NewRequest("GET", endpoint.String(), nil)
req, _ := http.NewRequestWithContext(ctx, "GET", endpoint.String(), nil)
req.Header = getGitHubHeader(s.AccessToken)
resp, err := http.DefaultClient.Do(req)
if err != nil {
@ -324,7 +327,7 @@ func (p *GitHubProvider) GetEmailAddress(s *sessions.SessionState) (string, erro
}
// GetUserName returns the Account user name
func (p *GitHubProvider) GetUserName(s *sessions.SessionState) (string, error) {
func (p *GitHubProvider) GetUserName(ctx context.Context, s *sessions.SessionState) (string, error) {
var user struct {
Login string `json:"login"`
Email string `json:"email"`
@ -336,7 +339,7 @@ func (p *GitHubProvider) GetUserName(s *sessions.SessionState) (string, error) {
Path: path.Join(p.ValidateURL.Path, "/user"),
}
req, err := http.NewRequest("GET", endpoint.String(), nil)
req, err := http.NewRequestWithContext(ctx, "GET", endpoint.String(), nil)
if err != nil {
return "", fmt.Errorf("could not create new GET request: %v", err)
}
@ -368,6 +371,6 @@ func (p *GitHubProvider) GetUserName(s *sessions.SessionState) (string, error) {
}
// ValidateSessionState validates the AccessToken
func (p *GitHubProvider) ValidateSessionState(s *sessions.SessionState) bool {
return validateToken(p, s.AccessToken, getGitHubHeader(s.AccessToken))
func (p *GitHubProvider) ValidateSessionState(ctx context.Context, s *sessions.SessionState) bool {
return validateToken(ctx, p, s.AccessToken, getGitHubHeader(s.AccessToken))
}