diff --git a/CHANGELOG.md b/CHANGELOG.md index bf13cd8b..cf78b3fd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -33,6 +33,7 @@ - [#514](https://github.com/oauth2-proxy/oauth2-proxy/pull/514) Add basic string functions to templates - [#524](https://github.com/oauth2-proxy/oauth2-proxy/pull/524) Sign cookies with SHA256 (@NickMeves) - [#515](https://github.com/oauth2-proxy/oauth2-proxy/pull/515) Drop configure script in favour of native Makefile env and checks (@JoelSpeed) +- [#519](https://github.com/oauth2-proxy/oauth2-proxy/pull/519) Support context in providers (@johejo) - [#487](https://github.com/oauth2-proxy/oauth2-proxy/pull/487) Switch flags to PFlag to remove StringArray (@JoelSpeed) - [#484](https://github.com/oauth2-proxy/oauth2-proxy/pull/484) Replace configuration loading with Viper (@JoelSpeed) - [#499](https://github.com/oauth2-proxy/oauth2-proxy/pull/499) Add `-user-id-claim` to support generic claims in addition to email (@holyjak) diff --git a/oauthproxy.go b/oauthproxy.go index a75fb3c2..86d33c59 100644 --- a/oauthproxy.go +++ b/oauthproxy.go @@ -347,29 +347,29 @@ func (p *OAuthProxy) displayCustomLoginForm() bool { return p.HtpasswdFile != nil && p.DisplayHtpasswdForm } -func (p *OAuthProxy) redeemCode(host, code string) (s *sessionsapi.SessionState, err error) { +func (p *OAuthProxy) redeemCode(ctx context.Context, host, code string) (s *sessionsapi.SessionState, err error) { if code == "" { return nil, errors.New("missing code") } redirectURI := p.GetRedirectURI(host) - s, err = p.provider.Redeem(redirectURI, code) + s, err = p.provider.Redeem(ctx, redirectURI, code) if err != nil { return } if s.Email == "" { - s.Email, err = p.provider.GetEmailAddress(s) + s.Email, err = p.provider.GetEmailAddress(ctx, s) } if s.PreferredUsername == "" { - s.PreferredUsername, err = p.provider.GetPreferredUsername(s) + s.PreferredUsername, err = p.provider.GetPreferredUsername(ctx, s) if err != nil && err.Error() == "not implemented" { err = nil } } if s.User == "" { - s.User, err = p.provider.GetUserName(s) + s.User, err = p.provider.GetUserName(ctx, s) if err != nil && err.Error() == "not implemented" { err = nil } @@ -782,7 +782,7 @@ func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) { return } - session, err := p.redeemCode(req.Host, req.Form.Get("code")) + session, err := p.redeemCode(req.Context(), req.Host, req.Form.Get("code")) if err != nil { logger.Printf("Error redeeming code during OAuth2 callback: %s ", err.Error()) p.ErrorPage(rw, 500, "Internal Error", "Internal Error") @@ -907,7 +907,7 @@ func (p *OAuthProxy) getAuthenticatedSession(rw http.ResponseWriter, req *http.R saveSession = true } - if ok, err := p.provider.RefreshSessionIfNeeded(session); err != nil { + if ok, err := p.provider.RefreshSessionIfNeeded(req.Context(), session); err != nil { logger.Printf("%s removing session. error refreshing access token %s %s", remoteAddr, err, session) clearSession = true session = nil @@ -926,7 +926,7 @@ func (p *OAuthProxy) getAuthenticatedSession(rw http.ResponseWriter, req *http.R } if saveSession && !revalidated && session != nil && session.AccessToken != "" { - if !p.provider.ValidateSessionState(session) { + if !p.provider.ValidateSessionState(req.Context(), session) { logger.Printf("Removing session: error validating %s", session) saveSession = false session = nil @@ -1126,16 +1126,15 @@ func (p *OAuthProxy) GetJwtSession(req *http.Request) (*sessionsapi.SessionState return nil, err } - ctx := context.Background() for _, verifier := range p.jwtBearerVerifiers { - bearerToken, err := verifier.Verify(ctx, rawBearerToken) + bearerToken, err := verifier.Verify(req.Context(), rawBearerToken) if err != nil { logger.Printf("failed to verify bearer token: %v", err) continue } - return p.provider.CreateSessionStateFromBearerToken(rawBearerToken, bearerToken) + return p.provider.CreateSessionStateFromBearerToken(req.Context(), rawBearerToken, bearerToken) } return nil, fmt.Errorf("unable to verify jwt token %s", req.Header.Get("Authorization")) } diff --git a/oauthproxy_test.go b/oauthproxy_test.go index a216f2cb..8957e18a 100644 --- a/oauthproxy_test.go +++ b/oauthproxy_test.go @@ -397,6 +397,8 @@ type TestProvider struct { GroupValidator func(string) bool } +var _ providers.Provider = (*TestProvider)(nil) + func NewTestProvider(providerURL *url.URL, emailAddress string) *TestProvider { return &TestProvider{ ProviderData: &providers.ProviderData{ @@ -425,11 +427,11 @@ func NewTestProvider(providerURL *url.URL, emailAddress string) *TestProvider { } } -func (tp *TestProvider) GetEmailAddress(session *sessions.SessionState) (string, error) { +func (tp *TestProvider) GetEmailAddress(ctx context.Context, session *sessions.SessionState) (string, error) { return tp.EmailAddress, nil } -func (tp *TestProvider) ValidateSessionState(session *sessions.SessionState) bool { +func (tp *TestProvider) ValidateSessionState(ctx context.Context, session *sessions.SessionState) bool { return tp.ValidToken } diff --git a/pkg/requests/requests.go b/pkg/requests/requests.go index 36a8bf8c..64cacaa9 100644 --- a/pkg/requests/requests.go +++ b/pkg/requests/requests.go @@ -1,6 +1,7 @@ package requests import ( + "context" "encoding/json" "fmt" "io/ioutil" @@ -62,8 +63,8 @@ func RequestJSON(req *http.Request, v interface{}) error { } // RequestUnparsedResponse performs a GET and returns the raw response object -func RequestUnparsedResponse(url string, header http.Header) (resp *http.Response, err error) { - req, err := http.NewRequest("GET", url, nil) +func RequestUnparsedResponse(ctx context.Context, url string, header http.Header) (resp *http.Response, err error) { + req, err := http.NewRequestWithContext(ctx, "GET", url, nil) if err != nil { return nil, fmt.Errorf("error performing get request: %w", err) } diff --git a/pkg/requests/requests_test.go b/pkg/requests/requests_test.go index acd9b0b8..0c3e4152 100644 --- a/pkg/requests/requests_test.go +++ b/pkg/requests/requests_test.go @@ -1,6 +1,7 @@ package requests import ( + "context" "io/ioutil" "net/http" "net/http/httptest" @@ -87,7 +88,7 @@ func TestRequestUnparsedResponseUsingAccessTokenParameter(t *testing.T) { defer backend.Close() response, err := RequestUnparsedResponse( - backend.URL+"?access_token=my_token", nil) + context.Background(), backend.URL+"?access_token=my_token", nil) assert.Equal(t, nil, err) defer response.Body.Close() @@ -103,7 +104,7 @@ func TestRequestUnparsedResponseUsingAccessTokenParameterFailedResponse(t *testi backend.Close() response, err := RequestUnparsedResponse( - backend.URL+"?access_token=my_token", nil) + context.Background(), backend.URL+"?access_token=my_token", nil) assert.NotEqual(t, nil, err) assert.Equal(t, (*http.Response)(nil), response) } @@ -123,7 +124,7 @@ func TestRequestUnparsedResponseUsingHeaders(t *testing.T) { headers := make(http.Header) headers.Set("Auth", "my_token") - response, err := RequestUnparsedResponse(backend.URL, headers) + response, err := RequestUnparsedResponse(context.Background(), backend.URL, headers) assert.Equal(t, nil, err) defer response.Body.Close() diff --git a/providers/azure.go b/providers/azure.go index 393416e3..961ff908 100644 --- a/providers/azure.go +++ b/providers/azure.go @@ -2,6 +2,7 @@ package providers import ( "bytes" + "context" "encoding/json" "errors" "fmt" @@ -22,6 +23,8 @@ type AzureProvider struct { Tenant string } +var _ Provider = (*AzureProvider)(nil) + // NewAzureProvider initiates a new AzureProvider func NewAzureProvider(p *ProviderData) *AzureProvider { p.ProviderName = "Azure" @@ -68,7 +71,7 @@ func (p *AzureProvider) Configure(tenant string) { } } -func (p *AzureProvider) Redeem(redirectURL, code string) (s *sessions.SessionState, err error) { +func (p *AzureProvider) Redeem(ctx context.Context, redirectURL, code string) (s *sessions.SessionState, err error) { if code == "" { err = errors.New("missing code") return @@ -89,7 +92,7 @@ func (p *AzureProvider) Redeem(redirectURL, code string) (s *sessions.SessionSta } var req *http.Request - req, err = http.NewRequest("POST", p.RedeemURL.String(), bytes.NewBufferString(params.Encode())) + req, err = http.NewRequestWithContext(ctx, "POST", p.RedeemURL.String(), bytes.NewBufferString(params.Encode())) if err != nil { return } @@ -157,14 +160,14 @@ func getEmailFromJSON(json *simplejson.Json) (string, error) { } // GetEmailAddress returns the Account email address -func (p *AzureProvider) GetEmailAddress(s *sessions.SessionState) (string, error) { +func (p *AzureProvider) GetEmailAddress(ctx context.Context, s *sessions.SessionState) (string, error) { var email string var err error if s.AccessToken == "" { return "", errors.New("missing access token") } - req, err := http.NewRequest("GET", p.ProfileURL.String(), nil) + req, err := http.NewRequestWithContext(ctx, "GET", p.ProfileURL.String(), nil) if err != nil { return "", err } diff --git a/providers/azure_test.go b/providers/azure_test.go index 6a38ce10..af364b77 100644 --- a/providers/azure_test.go +++ b/providers/azure_test.go @@ -1,6 +1,7 @@ package providers import ( + "context" "net/http" "net/http/httptest" "net/url" @@ -133,7 +134,7 @@ func TestAzureProviderGetEmailAddress(t *testing.T) { p := testAzureProvider(bURL.Host) session := CreateAuthorizedSession() - email, err := p.GetEmailAddress(session) + email, err := p.GetEmailAddress(context.Background(), session) assert.Equal(t, nil, err) assert.Equal(t, "user@windows.net", email) } @@ -146,7 +147,7 @@ func TestAzureProviderGetEmailAddressMailNull(t *testing.T) { p := testAzureProvider(bURL.Host) session := CreateAuthorizedSession() - email, err := p.GetEmailAddress(session) + email, err := p.GetEmailAddress(context.Background(), session) assert.Equal(t, nil, err) assert.Equal(t, "user@windows.net", email) } @@ -159,7 +160,7 @@ func TestAzureProviderGetEmailAddressGetUserPrincipalName(t *testing.T) { p := testAzureProvider(bURL.Host) session := CreateAuthorizedSession() - email, err := p.GetEmailAddress(session) + email, err := p.GetEmailAddress(context.Background(), session) assert.Equal(t, nil, err) assert.Equal(t, "user@windows.net", email) } @@ -172,7 +173,7 @@ func TestAzureProviderGetEmailAddressFailToGetEmailAddress(t *testing.T) { p := testAzureProvider(bURL.Host) session := CreateAuthorizedSession() - email, err := p.GetEmailAddress(session) + email, err := p.GetEmailAddress(context.Background(), session) assert.Equal(t, "type assertion to string failed", err.Error()) assert.Equal(t, "", email) } @@ -185,7 +186,7 @@ func TestAzureProviderGetEmailAddressEmptyUserPrincipalName(t *testing.T) { p := testAzureProvider(bURL.Host) session := CreateAuthorizedSession() - email, err := p.GetEmailAddress(session) + email, err := p.GetEmailAddress(context.Background(), session) assert.Equal(t, nil, err) assert.Equal(t, "", email) } @@ -198,7 +199,7 @@ func TestAzureProviderGetEmailAddressIncorrectOtherMails(t *testing.T) { p := testAzureProvider(bURL.Host) session := CreateAuthorizedSession() - email, err := p.GetEmailAddress(session) + email, err := p.GetEmailAddress(context.Background(), session) assert.Equal(t, "type assertion to string failed", err.Error()) assert.Equal(t, "", email) } @@ -212,7 +213,7 @@ func TestAzureProviderRedeemReturnsIdToken(t *testing.T) { bURL, _ := url.Parse(b.URL) p := testAzureProvider(bURL.Host) p.Data().RedeemURL.Path = "/common/oauth2/token" - s, err := p.Redeem("https://localhost", "1234") + s, err := p.Redeem(context.Background(), "https://localhost", "1234") assert.Equal(t, nil, err) assert.Equal(t, "testtoken1234", s.IDToken) assert.Equal(t, timestamp, s.ExpiresOn.UTC()) diff --git a/providers/bitbucket.go b/providers/bitbucket.go index f67e48bc..2bb876cb 100644 --- a/providers/bitbucket.go +++ b/providers/bitbucket.go @@ -1,6 +1,7 @@ package providers import ( + "context" "net/http" "net/url" "strings" @@ -17,6 +18,8 @@ type BitbucketProvider struct { Repository string } +var _ Provider = (*BitbucketProvider)(nil) + // NewBitbucketProvider initiates a new BitbucketProvider func NewBitbucketProvider(p *ProviderData) *BitbucketProvider { p.ProviderName = "Bitbucket" @@ -64,7 +67,7 @@ func (p *BitbucketProvider) SetRepository(repository string) { } // GetEmailAddress returns the email of the authenticated user -func (p *BitbucketProvider) GetEmailAddress(s *sessions.SessionState) (string, error) { +func (p *BitbucketProvider) GetEmailAddress(ctx context.Context, s *sessions.SessionState) (string, error) { var emails struct { Values []struct { @@ -82,7 +85,7 @@ func (p *BitbucketProvider) GetEmailAddress(s *sessions.SessionState) (string, e FullName string `json:"full_name"` } } - req, err := http.NewRequest("GET", + req, err := http.NewRequestWithContext(ctx, "GET", p.ValidateURL.String()+"?access_token="+s.AccessToken, nil) if err != nil { logger.Printf("failed building request %s", err) @@ -98,7 +101,7 @@ func (p *BitbucketProvider) GetEmailAddress(s *sessions.SessionState) (string, e teamURL := &url.URL{} *teamURL = *p.ValidateURL teamURL.Path = "/2.0/teams" - req, err = http.NewRequest("GET", + req, err = http.NewRequestWithContext(ctx, "GET", teamURL.String()+"?role=member&access_token="+s.AccessToken, nil) if err != nil { logger.Printf("failed building request %s", err) @@ -126,7 +129,7 @@ func (p *BitbucketProvider) GetEmailAddress(s *sessions.SessionState) (string, e repositoriesURL := &url.URL{} *repositoriesURL = *p.ValidateURL repositoriesURL.Path = "/2.0/repositories/" + strings.Split(p.Repository, "/")[0] - req, err = http.NewRequest("GET", + req, err = http.NewRequestWithContext(ctx, "GET", repositoriesURL.String()+"?role=contributor"+ "&q=full_name="+url.QueryEscape("\""+p.Repository+"\"")+ "&access_token="+s.AccessToken, diff --git a/providers/bitbucket_test.go b/providers/bitbucket_test.go index da3182a0..e788b81e 100644 --- a/providers/bitbucket_test.go +++ b/providers/bitbucket_test.go @@ -1,6 +1,7 @@ package providers import ( + "context" "log" "net/http" "net/http/httptest" @@ -120,7 +121,7 @@ func TestBitbucketProviderGetEmailAddress(t *testing.T) { p := testBitbucketProvider(bURL.Host, "", "") session := CreateAuthorizedSession() - email, err := p.GetEmailAddress(session) + email, err := p.GetEmailAddress(context.Background(), session) assert.Equal(t, nil, err) assert.Equal(t, "michael.bland@gsa.gov", email) } @@ -133,7 +134,7 @@ func TestBitbucketProviderGetEmailAddressAndGroup(t *testing.T) { p := testBitbucketProvider(bURL.Host, "bioinformatics", "") session := CreateAuthorizedSession() - email, err := p.GetEmailAddress(session) + email, err := p.GetEmailAddress(context.Background(), session) assert.Equal(t, nil, err) assert.Equal(t, "michael.bland@gsa.gov", email) } @@ -151,7 +152,7 @@ func TestBitbucketProviderGetEmailAddressFailedRequest(t *testing.T) { // token. Alternatively, we could allow the parsing of the payload as // JSON to fail. session := &sessions.SessionState{AccessToken: "unexpected_access_token"} - email, err := p.GetEmailAddress(session) + email, err := p.GetEmailAddress(context.Background(), session) assert.NotEqual(t, nil, err) assert.Equal(t, "", email) } @@ -164,7 +165,7 @@ func TestBitbucketProviderGetEmailAddressEmailNotPresentInPayload(t *testing.T) p := testBitbucketProvider(bURL.Host, "", "") session := CreateAuthorizedSession() - email, err := p.GetEmailAddress(session) + email, err := p.GetEmailAddress(context.Background(), session) assert.Equal(t, "", email) assert.Equal(t, nil, err) } diff --git a/providers/digitalocean.go b/providers/digitalocean.go index 564d32b0..25d37af9 100644 --- a/providers/digitalocean.go +++ b/providers/digitalocean.go @@ -1,6 +1,7 @@ package providers import ( + "context" "errors" "fmt" "net/http" @@ -15,6 +16,8 @@ type DigitalOceanProvider struct { *ProviderData } +var _ Provider = (*DigitalOceanProvider)(nil) + // NewDigitalOceanProvider initiates a new DigitalOceanProvider func NewDigitalOceanProvider(p *ProviderData) *DigitalOceanProvider { p.ProviderName = "DigitalOcean" @@ -53,11 +56,11 @@ func getDigitalOceanHeader(accessToken string) http.Header { } // GetEmailAddress returns the Account email address -func (p *DigitalOceanProvider) GetEmailAddress(s *sessions.SessionState) (string, error) { +func (p *DigitalOceanProvider) GetEmailAddress(ctx context.Context, s *sessions.SessionState) (string, error) { if s.AccessToken == "" { return "", errors.New("missing access token") } - req, err := http.NewRequest("GET", p.ProfileURL.String(), nil) + req, err := http.NewRequestWithContext(ctx, "GET", p.ProfileURL.String(), nil) if err != nil { return "", err } @@ -76,6 +79,6 @@ func (p *DigitalOceanProvider) GetEmailAddress(s *sessions.SessionState) (string } // ValidateSessionState validates the AccessToken -func (p *DigitalOceanProvider) ValidateSessionState(s *sessions.SessionState) bool { - return validateToken(p, s.AccessToken, getDigitalOceanHeader(s.AccessToken)) +func (p *DigitalOceanProvider) ValidateSessionState(ctx context.Context, s *sessions.SessionState) bool { + return validateToken(ctx, p, s.AccessToken, getDigitalOceanHeader(s.AccessToken)) } diff --git a/providers/digitalocean_test.go b/providers/digitalocean_test.go index 2b3fede3..e7907eba 100644 --- a/providers/digitalocean_test.go +++ b/providers/digitalocean_test.go @@ -1,6 +1,7 @@ package providers import ( + "context" "net/http" "net/http/httptest" "net/url" @@ -99,7 +100,7 @@ func TestDigitalOceanProviderGetEmailAddress(t *testing.T) { p := testDigitalOceanProvider(bURL.Host) session := CreateAuthorizedSession() - email, err := p.GetEmailAddress(session) + email, err := p.GetEmailAddress(context.Background(), session) assert.Equal(t, nil, err) assert.Equal(t, "user@example.com", email) } @@ -115,7 +116,7 @@ func TestDigitalOceanProviderGetEmailAddressFailedRequest(t *testing.T) { // token. Alternatively, we could allow the parsing of the payload as // JSON to fail. session := &sessions.SessionState{AccessToken: "unexpected_access_token"} - email, err := p.GetEmailAddress(session) + email, err := p.GetEmailAddress(context.Background(), session) assert.NotEqual(t, nil, err) assert.Equal(t, "", email) } @@ -128,7 +129,7 @@ func TestDigitalOceanProviderGetEmailAddressEmailNotPresentInPayload(t *testing. p := testDigitalOceanProvider(bURL.Host) session := CreateAuthorizedSession() - email, err := p.GetEmailAddress(session) + email, err := p.GetEmailAddress(context.Background(), session) assert.NotEqual(t, nil, err) assert.Equal(t, "", email) } diff --git a/providers/facebook.go b/providers/facebook.go index 94f3e271..0f9cc624 100644 --- a/providers/facebook.go +++ b/providers/facebook.go @@ -1,6 +1,7 @@ package providers import ( + "context" "errors" "fmt" "net/http" @@ -15,6 +16,8 @@ type FacebookProvider struct { *ProviderData } +var _ Provider = (*FacebookProvider)(nil) + // NewFacebookProvider initiates a new FacebookProvider func NewFacebookProvider(p *ProviderData) *FacebookProvider { p.ProviderName = "Facebook" @@ -55,11 +58,11 @@ func getFacebookHeader(accessToken string) http.Header { } // GetEmailAddress returns the Account email address -func (p *FacebookProvider) GetEmailAddress(s *sessions.SessionState) (string, error) { +func (p *FacebookProvider) GetEmailAddress(ctx context.Context, s *sessions.SessionState) (string, error) { if s.AccessToken == "" { return "", errors.New("missing access token") } - req, err := http.NewRequest("GET", p.ProfileURL.String()+"?fields=name,email", nil) + req, err := http.NewRequestWithContext(ctx, "GET", p.ProfileURL.String()+"?fields=name,email", nil) if err != nil { return "", err } @@ -80,6 +83,6 @@ func (p *FacebookProvider) GetEmailAddress(s *sessions.SessionState) (string, er } // ValidateSessionState validates the AccessToken -func (p *FacebookProvider) ValidateSessionState(s *sessions.SessionState) bool { - return validateToken(p, s.AccessToken, getFacebookHeader(s.AccessToken)) +func (p *FacebookProvider) ValidateSessionState(ctx context.Context, s *sessions.SessionState) bool { + return validateToken(ctx, p, s.AccessToken, getFacebookHeader(s.AccessToken)) } diff --git a/providers/github.go b/providers/github.go index cb522b13..153373cb 100644 --- a/providers/github.go +++ b/providers/github.go @@ -1,6 +1,7 @@ package providers import ( + "context" "encoding/json" "fmt" "io/ioutil" @@ -22,6 +23,8 @@ type GitHubProvider struct { Team string } +var _ Provider = (*GitHubProvider)(nil) + // NewGitHubProvider initiates a new GitHubProvider func NewGitHubProvider(p *ProviderData) *GitHubProvider { p.ProviderName = "GitHub" @@ -69,7 +72,7 @@ func (p *GitHubProvider) SetOrgTeam(org, team string) { } } -func (p *GitHubProvider) hasOrg(accessToken string) (bool, error) { +func (p *GitHubProvider) hasOrg(ctx context.Context, accessToken string) (bool, error) { // https://developer.github.com/v3/orgs/#list-your-organizations var orgs []struct { @@ -93,7 +96,7 @@ func (p *GitHubProvider) hasOrg(accessToken string) (bool, error) { Path: path.Join(p.ValidateURL.Path, "/user/orgs"), RawQuery: params.Encode(), } - req, _ := http.NewRequest("GET", endpoint.String(), nil) + req, _ := http.NewRequestWithContext(ctx, "GET", endpoint.String(), nil) req.Header = getGitHubHeader(accessToken) resp, err := http.DefaultClient.Do(req) if err != nil { @@ -135,7 +138,7 @@ func (p *GitHubProvider) hasOrg(accessToken string) (bool, error) { return false, nil } -func (p *GitHubProvider) hasOrgAndTeam(accessToken string) (bool, error) { +func (p *GitHubProvider) hasOrgAndTeam(ctx context.Context, accessToken string) (bool, error) { // https://developer.github.com/v3/orgs/teams/#list-user-teams var teams []struct { @@ -169,7 +172,7 @@ func (p *GitHubProvider) hasOrgAndTeam(accessToken string) (bool, error) { RawQuery: params.Encode(), } - req, _ := http.NewRequest("GET", endpoint.String(), nil) + req, _ := http.NewRequestWithContext(ctx, "GET", endpoint.String(), nil) req.Header = getGitHubHeader(accessToken) resp, err := http.DefaultClient.Do(req) if err != nil { @@ -261,7 +264,7 @@ func (p *GitHubProvider) hasOrgAndTeam(accessToken string) (bool, error) { } // GetEmailAddress returns the Account email address -func (p *GitHubProvider) GetEmailAddress(s *sessions.SessionState) (string, error) { +func (p *GitHubProvider) GetEmailAddress(ctx context.Context, s *sessions.SessionState) (string, error) { var emails []struct { Email string `json:"email"` @@ -272,11 +275,11 @@ func (p *GitHubProvider) GetEmailAddress(s *sessions.SessionState) (string, erro // if we require an Org or Team, check that first if p.Org != "" { if p.Team != "" { - if ok, err := p.hasOrgAndTeam(s.AccessToken); err != nil || !ok { + if ok, err := p.hasOrgAndTeam(ctx, s.AccessToken); err != nil || !ok { return "", err } } else { - if ok, err := p.hasOrg(s.AccessToken); err != nil || !ok { + if ok, err := p.hasOrg(ctx, s.AccessToken); err != nil || !ok { return "", err } } @@ -287,7 +290,7 @@ func (p *GitHubProvider) GetEmailAddress(s *sessions.SessionState) (string, erro Host: p.ValidateURL.Host, Path: path.Join(p.ValidateURL.Path, "/user/emails"), } - req, _ := http.NewRequest("GET", endpoint.String(), nil) + req, _ := http.NewRequestWithContext(ctx, "GET", endpoint.String(), nil) req.Header = getGitHubHeader(s.AccessToken) resp, err := http.DefaultClient.Do(req) if err != nil { @@ -324,7 +327,7 @@ func (p *GitHubProvider) GetEmailAddress(s *sessions.SessionState) (string, erro } // GetUserName returns the Account user name -func (p *GitHubProvider) GetUserName(s *sessions.SessionState) (string, error) { +func (p *GitHubProvider) GetUserName(ctx context.Context, s *sessions.SessionState) (string, error) { var user struct { Login string `json:"login"` Email string `json:"email"` @@ -336,7 +339,7 @@ func (p *GitHubProvider) GetUserName(s *sessions.SessionState) (string, error) { Path: path.Join(p.ValidateURL.Path, "/user"), } - req, err := http.NewRequest("GET", endpoint.String(), nil) + req, err := http.NewRequestWithContext(ctx, "GET", endpoint.String(), nil) if err != nil { return "", fmt.Errorf("could not create new GET request: %v", err) } @@ -368,6 +371,6 @@ func (p *GitHubProvider) GetUserName(s *sessions.SessionState) (string, error) { } // ValidateSessionState validates the AccessToken -func (p *GitHubProvider) ValidateSessionState(s *sessions.SessionState) bool { - return validateToken(p, s.AccessToken, getGitHubHeader(s.AccessToken)) +func (p *GitHubProvider) ValidateSessionState(ctx context.Context, s *sessions.SessionState) bool { + return validateToken(ctx, p, s.AccessToken, getGitHubHeader(s.AccessToken)) } diff --git a/providers/github_test.go b/providers/github_test.go index a454cb48..7a1c4723 100644 --- a/providers/github_test.go +++ b/providers/github_test.go @@ -1,6 +1,7 @@ package providers import ( + "context" "net/http" "net/http/httptest" "net/url" @@ -105,7 +106,7 @@ func TestGitHubProviderGetEmailAddress(t *testing.T) { p := testGitHubProvider(bURL.Host) session := CreateAuthorizedSession() - email, err := p.GetEmailAddress(session) + email, err := p.GetEmailAddress(context.Background(), session) assert.Equal(t, nil, err) assert.Equal(t, "michael.bland@gsa.gov", email) } @@ -118,7 +119,7 @@ func TestGitHubProviderGetEmailAddressNotVerified(t *testing.T) { p := testGitHubProvider(bURL.Host) session := CreateAuthorizedSession() - email, err := p.GetEmailAddress(session) + email, err := p.GetEmailAddress(context.Background(), session) assert.Equal(t, nil, err) assert.Empty(t, "", email) } @@ -136,7 +137,7 @@ func TestGitHubProviderGetEmailAddressWithOrg(t *testing.T) { p.Org = "testorg1" session := CreateAuthorizedSession() - email, err := p.GetEmailAddress(session) + email, err := p.GetEmailAddress(context.Background(), session) assert.Equal(t, nil, err) assert.Equal(t, "michael.bland@gsa.gov", email) } @@ -154,7 +155,7 @@ func TestGitHubProviderGetEmailAddressFailedRequest(t *testing.T) { // token. Alternatively, we could allow the parsing of the payload as // JSON to fail. session := &sessions.SessionState{AccessToken: "unexpected_access_token"} - email, err := p.GetEmailAddress(session) + email, err := p.GetEmailAddress(context.Background(), session) assert.NotEqual(t, nil, err) assert.Equal(t, "", email) } @@ -167,7 +168,7 @@ func TestGitHubProviderGetEmailAddressEmailNotPresentInPayload(t *testing.T) { p := testGitHubProvider(bURL.Host) session := CreateAuthorizedSession() - email, err := p.GetEmailAddress(session) + email, err := p.GetEmailAddress(context.Background(), session) assert.NotEqual(t, nil, err) assert.Equal(t, "", email) } @@ -180,7 +181,7 @@ func TestGitHubProviderGetUserName(t *testing.T) { p := testGitHubProvider(bURL.Host) session := CreateAuthorizedSession() - email, err := p.GetUserName(session) + email, err := p.GetUserName(context.Background(), session) assert.Equal(t, nil, err) assert.Equal(t, "mbland", email) } diff --git a/providers/gitlab.go b/providers/gitlab.go index 6034115c..beeb6b98 100644 --- a/providers/gitlab.go +++ b/providers/gitlab.go @@ -25,6 +25,8 @@ type GitLabProvider struct { AllowUnverifiedEmail bool } +var _ Provider = (*GitLabProvider)(nil) + // NewGitLabProvider initiates a new GitLabProvider func NewGitLabProvider(p *ProviderData) *GitLabProvider { p.ProviderName = "GitLab" @@ -37,13 +39,12 @@ func NewGitLabProvider(p *ProviderData) *GitLabProvider { } // Redeem exchanges the OAuth2 authentication token for an ID token -func (p *GitLabProvider) Redeem(redirectURL, code string) (s *sessions.SessionState, err error) { +func (p *GitLabProvider) Redeem(ctx context.Context, redirectURL, code string) (s *sessions.SessionState, err error) { clientSecret, err := p.GetClientSecret() if err != nil { return } - ctx := context.Background() c := oauth2.Config{ ClientID: p.ClientID, ClientSecret: clientSecret, @@ -65,14 +66,14 @@ func (p *GitLabProvider) Redeem(redirectURL, code string) (s *sessions.SessionSt // RefreshSessionIfNeeded checks if the session has expired and uses the // RefreshToken to fetch a new ID token if required -func (p *GitLabProvider) RefreshSessionIfNeeded(s *sessions.SessionState) (bool, error) { +func (p *GitLabProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error) { if s == nil || s.ExpiresOn.After(time.Now()) || s.RefreshToken == "" { return false, nil } origExpiration := s.ExpiresOn - err := p.redeemRefreshToken(s) + err := p.redeemRefreshToken(ctx, s) if err != nil { return false, fmt.Errorf("unable to redeem refresh token: %v", err) } @@ -81,7 +82,7 @@ func (p *GitLabProvider) RefreshSessionIfNeeded(s *sessions.SessionState) (bool, return true, nil } -func (p *GitLabProvider) redeemRefreshToken(s *sessions.SessionState) (err error) { +func (p *GitLabProvider) redeemRefreshToken(ctx context.Context, s *sessions.SessionState) (err error) { clientSecret, err := p.GetClientSecret() if err != nil { return @@ -94,7 +95,6 @@ func (p *GitLabProvider) redeemRefreshToken(s *sessions.SessionState) (err error TokenURL: p.RedeemURL.String(), }, } - ctx := context.Background() t := &oauth2.Token{ RefreshToken: s.RefreshToken, Expiry: time.Now().Add(-time.Hour), @@ -123,7 +123,7 @@ type gitlabUserInfo struct { Groups []string `json:"groups"` } -func (p *GitLabProvider) getUserInfo(s *sessions.SessionState) (*gitlabUserInfo, error) { +func (p *GitLabProvider) getUserInfo(ctx context.Context, s *sessions.SessionState) (*gitlabUserInfo, error) { // Retrieve user info JSON // https://docs.gitlab.com/ee/integration/openid_connect_provider.html#shared-information @@ -131,7 +131,7 @@ func (p *GitLabProvider) getUserInfo(s *sessions.SessionState) (*gitlabUserInfo, userInfoURL := *p.LoginURL userInfoURL.Path = "/oauth/userinfo" - req, err := http.NewRequest("GET", userInfoURL.String(), nil) + req, err := http.NewRequestWithContext(ctx, "GET", userInfoURL.String(), nil) if err != nil { return nil, fmt.Errorf("failed to create user info request: %v", err) } @@ -219,16 +219,15 @@ func (p *GitLabProvider) createSessionState(ctx context.Context, token *oauth2.T } // ValidateSessionState checks that the session's IDToken is still valid -func (p *GitLabProvider) ValidateSessionState(s *sessions.SessionState) bool { - ctx := context.Background() +func (p *GitLabProvider) ValidateSessionState(ctx context.Context, s *sessions.SessionState) bool { _, err := p.Verifier.Verify(ctx, s.IDToken) return err == nil } // GetEmailAddress returns the Account email address -func (p *GitLabProvider) GetEmailAddress(s *sessions.SessionState) (string, error) { +func (p *GitLabProvider) GetEmailAddress(ctx context.Context, s *sessions.SessionState) (string, error) { // Retrieve user info - userInfo, err := p.getUserInfo(s) + userInfo, err := p.getUserInfo(ctx, s) if err != nil { return "", fmt.Errorf("failed to retrieve user info: %v", err) } @@ -254,8 +253,8 @@ func (p *GitLabProvider) GetEmailAddress(s *sessions.SessionState) (string, erro } // GetUserName returns the Account user name -func (p *GitLabProvider) GetUserName(s *sessions.SessionState) (string, error) { - userInfo, err := p.getUserInfo(s) +func (p *GitLabProvider) GetUserName(ctx context.Context, s *sessions.SessionState) (string, error) { + userInfo, err := p.getUserInfo(ctx, s) if err != nil { return "", fmt.Errorf("failed to retrieve user info: %v", err) } diff --git a/providers/gitlab_test.go b/providers/gitlab_test.go index 30ce16e5..4a353ce8 100644 --- a/providers/gitlab_test.go +++ b/providers/gitlab_test.go @@ -1,6 +1,7 @@ package providers import ( + "context" "net/http" "net/http/httptest" "net/url" @@ -63,7 +64,7 @@ func TestGitLabProviderBadToken(t *testing.T) { p := testGitLabProvider(bURL.Host) session := &sessions.SessionState{AccessToken: "unexpected_gitlab_access_token"} - _, err := p.GetEmailAddress(session) + _, err := p.GetEmailAddress(context.Background(), session) assert.NotEqual(t, nil, err) } @@ -75,7 +76,7 @@ func TestGitLabProviderUnverifiedEmailDenied(t *testing.T) { p := testGitLabProvider(bURL.Host) session := &sessions.SessionState{AccessToken: "gitlab_access_token"} - _, err := p.GetEmailAddress(session) + _, err := p.GetEmailAddress(context.Background(), session) assert.NotEqual(t, nil, err) } @@ -88,7 +89,7 @@ func TestGitLabProviderUnverifiedEmailAllowed(t *testing.T) { p.AllowUnverifiedEmail = true session := &sessions.SessionState{AccessToken: "gitlab_access_token"} - email, err := p.GetEmailAddress(session) + email, err := p.GetEmailAddress(context.Background(), session) assert.Equal(t, nil, err) assert.Equal(t, "foo@bar.com", email) } @@ -102,7 +103,7 @@ func TestGitLabProviderUsername(t *testing.T) { p.AllowUnverifiedEmail = true session := &sessions.SessionState{AccessToken: "gitlab_access_token"} - username, err := p.GetUserName(session) + username, err := p.GetUserName(context.Background(), session) assert.Equal(t, nil, err) assert.Equal(t, "FooBar", username) } @@ -117,7 +118,7 @@ func TestGitLabProviderGroupMembershipValid(t *testing.T) { p.Group = "foo" session := &sessions.SessionState{AccessToken: "gitlab_access_token"} - email, err := p.GetEmailAddress(session) + email, err := p.GetEmailAddress(context.Background(), session) assert.Equal(t, nil, err) assert.Equal(t, "foo@bar.com", email) } @@ -132,7 +133,7 @@ func TestGitLabProviderGroupMembershipMissing(t *testing.T) { p.Group = "baz" session := &sessions.SessionState{AccessToken: "gitlab_access_token"} - _, err := p.GetEmailAddress(session) + _, err := p.GetEmailAddress(context.Background(), session) assert.NotEqual(t, nil, err) } @@ -146,7 +147,7 @@ func TestGitLabProviderEmailDomainValid(t *testing.T) { p.EmailDomains = []string{"bar.com"} session := &sessions.SessionState{AccessToken: "gitlab_access_token"} - email, err := p.GetEmailAddress(session) + email, err := p.GetEmailAddress(context.Background(), session) assert.Equal(t, nil, err) assert.Equal(t, "foo@bar.com", email) } @@ -161,6 +162,6 @@ func TestGitLabProviderEmailDomainInvalid(t *testing.T) { p.EmailDomains = []string{"baz.com"} session := &sessions.SessionState{AccessToken: "gitlab_access_token"} - _, err := p.GetEmailAddress(session) + _, err := p.GetEmailAddress(context.Background(), session) assert.NotEqual(t, nil, err) } diff --git a/providers/google.go b/providers/google.go index a93b8e08..1406855b 100644 --- a/providers/google.go +++ b/providers/google.go @@ -31,6 +31,8 @@ type GoogleProvider struct { GroupValidator func(string) bool } +var _ Provider = (*GoogleProvider)(nil) + type claims struct { Subject string `json:"sub"` Email string `json:"email"` @@ -98,7 +100,7 @@ func claimsFromIDToken(idToken string) (*claims, error) { } // Redeem exchanges the OAuth2 authentication token for an ID token -func (p *GoogleProvider) Redeem(redirectURL, code string) (s *sessions.SessionState, err error) { +func (p *GoogleProvider) Redeem(ctx context.Context, redirectURL, code string) (s *sessions.SessionState, err error) { if code == "" { err = errors.New("missing code") return @@ -115,7 +117,7 @@ func (p *GoogleProvider) Redeem(redirectURL, code string) (s *sessions.SessionSt params.Add("code", code) params.Add("grant_type", "authorization_code") var req *http.Request - req, err = http.NewRequest("POST", p.RedeemURL.String(), bytes.NewBufferString(params.Encode())) + req, err = http.NewRequestWithContext(ctx, "POST", p.RedeemURL.String(), bytes.NewBufferString(params.Encode())) if err != nil { return } @@ -242,12 +244,12 @@ func (p *GoogleProvider) ValidateGroup(email string) bool { // RefreshSessionIfNeeded checks if the session has expired and uses the // RefreshToken to fetch a new ID token if required -func (p *GoogleProvider) RefreshSessionIfNeeded(s *sessions.SessionState) (bool, error) { +func (p *GoogleProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error) { if s == nil || s.ExpiresOn.After(time.Now()) || s.RefreshToken == "" { return false, nil } - newToken, newIDToken, duration, err := p.redeemRefreshToken(s.RefreshToken) + newToken, newIDToken, duration, err := p.redeemRefreshToken(ctx, s.RefreshToken) if err != nil { return false, err } @@ -265,7 +267,7 @@ func (p *GoogleProvider) RefreshSessionIfNeeded(s *sessions.SessionState) (bool, return true, nil } -func (p *GoogleProvider) redeemRefreshToken(refreshToken string) (token string, idToken string, expires time.Duration, err error) { +func (p *GoogleProvider) redeemRefreshToken(ctx context.Context, refreshToken string) (token string, idToken string, expires time.Duration, err error) { // https://developers.google.com/identity/protocols/OAuth2WebServer#refresh clientSecret, err := p.GetClientSecret() if err != nil { @@ -278,7 +280,7 @@ func (p *GoogleProvider) redeemRefreshToken(refreshToken string) (token string, params.Add("refresh_token", refreshToken) params.Add("grant_type", "refresh_token") var req *http.Request - req, err = http.NewRequest("POST", p.RedeemURL.String(), bytes.NewBufferString(params.Encode())) + req, err = http.NewRequestWithContext(ctx, "POST", p.RedeemURL.String(), bytes.NewBufferString(params.Encode())) if err != nil { return } diff --git a/providers/google_test.go b/providers/google_test.go index 0e1de914..63e5a9a8 100644 --- a/providers/google_test.go +++ b/providers/google_test.go @@ -102,7 +102,7 @@ func TestGoogleProviderGetEmailAddress(t *testing.T) { p.RedeemURL, server = newRedeemServer(body) defer server.Close() - session, err := p.Redeem("http://redirect/", "code1234") + session, err := p.Redeem(context.Background(), "http://redirect/", "code1234") assert.Equal(t, nil, err) assert.NotEqual(t, session, nil) assert.Equal(t, "michael.bland@gsa.gov", session.Email) @@ -139,7 +139,7 @@ func TestGoogleProviderGetEmailAddressInvalidEncoding(t *testing.T) { p.RedeemURL, server = newRedeemServer(body) defer server.Close() - session, err := p.Redeem("http://redirect/", "code1234") + session, err := p.Redeem(context.Background(), "http://redirect/", "code1234") assert.NotEqual(t, nil, err) if session != nil { t.Errorf("expect nill session %#v", session) @@ -150,7 +150,7 @@ func TestGoogleProviderRedeemFailsNoCLientSecret(t *testing.T) { p := newGoogleProvider() p.ProviderData.ClientSecretFile = "srvnoerre" - session, err := p.Redeem("http://redirect/", "code1234") + session, err := p.Redeem(context.Background(), "http://redirect/", "code1234") assert.NotEqual(t, nil, err) if session != nil { t.Errorf("expect nill session %#v", session) @@ -170,7 +170,7 @@ func TestGoogleProviderGetEmailAddressInvalidJson(t *testing.T) { p.RedeemURL, server = newRedeemServer(body) defer server.Close() - session, err := p.Redeem("http://redirect/", "code1234") + session, err := p.Redeem(context.Background(), "http://redirect/", "code1234") assert.NotEqual(t, nil, err) if session != nil { t.Errorf("expect nill session %#v", session) @@ -189,7 +189,7 @@ func TestGoogleProviderGetEmailAddressEmailMissing(t *testing.T) { p.RedeemURL, server = newRedeemServer(body) defer server.Close() - session, err := p.Redeem("http://redirect/", "code1234") + session, err := p.Redeem(context.Background(), "http://redirect/", "code1234") assert.NotEqual(t, nil, err) if session != nil { t.Errorf("expect nill session %#v", session) diff --git a/providers/internal_util.go b/providers/internal_util.go index 4cc502ed..f9bdc304 100644 --- a/providers/internal_util.go +++ b/providers/internal_util.go @@ -1,6 +1,7 @@ package providers import ( + "context" "io/ioutil" "net/http" "net/url" @@ -46,7 +47,7 @@ func stripParam(param, endpoint string) string { } // validateToken returns true if token is valid -func validateToken(p Provider, accessToken string, header http.Header) bool { +func validateToken(ctx context.Context, p Provider, accessToken string, header http.Header) bool { if accessToken == "" || p.Data().ValidateURL == nil || p.Data().ValidateURL.String() == "" { return false } @@ -55,7 +56,7 @@ func validateToken(p Provider, accessToken string, header http.Header) bool { params := url.Values{"access_token": {accessToken}} endpoint = endpoint + "?" + params.Encode() } - resp, err := requests.RequestUnparsedResponse(endpoint, header) + resp, err := requests.RequestUnparsedResponse(ctx, endpoint, header) if err != nil { logger.Printf("GET %s", stripToken(endpoint)) logger.Printf("token validation request failed: %s", err) diff --git a/providers/internal_util_test.go b/providers/internal_util_test.go index 591c7f30..0f6aa437 100644 --- a/providers/internal_util_test.go +++ b/providers/internal_util_test.go @@ -1,6 +1,7 @@ package providers import ( + "context" "errors" "net/http" "net/http/httptest" @@ -20,13 +21,15 @@ type ValidateSessionStateTestProvider struct { *ProviderData } -func (tp *ValidateSessionStateTestProvider) GetEmailAddress(s *sessions.SessionState) (string, error) { +var _ Provider = (*ValidateSessionStateTestProvider)(nil) + +func (tp *ValidateSessionStateTestProvider) GetEmailAddress(ctx context.Context, s *sessions.SessionState) (string, error) { return "", errors.New("not implemented") } // Note that we're testing the internal validateToken() used to implement // several Provider's ValidateSessionState() implementations -func (tp *ValidateSessionStateTestProvider) ValidateSessionState(s *sessions.SessionState) bool { +func (tp *ValidateSessionStateTestProvider) ValidateSessionState(ctx context.Context, s *sessions.SessionState) bool { return false } @@ -87,7 +90,7 @@ func (vtTest *ValidateSessionStateTest) Close() { func TestValidateSessionStateValidToken(t *testing.T) { vtTest := NewValidateSessionStateTest() defer vtTest.Close() - assert.Equal(t, true, validateToken(vtTest.provider, "foobar", nil)) + assert.Equal(t, true, validateToken(context.Background(), vtTest.provider, "foobar", nil)) } func TestValidateSessionStateValidTokenWithHeaders(t *testing.T) { @@ -96,34 +99,34 @@ func TestValidateSessionStateValidTokenWithHeaders(t *testing.T) { vtTest.header = make(http.Header) vtTest.header.Set("Authorization", "Bearer foobar") assert.Equal(t, true, - validateToken(vtTest.provider, "foobar", vtTest.header)) + validateToken(context.Background(), vtTest.provider, "foobar", vtTest.header)) } func TestValidateSessionStateEmptyToken(t *testing.T) { vtTest := NewValidateSessionStateTest() defer vtTest.Close() - assert.Equal(t, false, validateToken(vtTest.provider, "", nil)) + assert.Equal(t, false, validateToken(context.Background(), vtTest.provider, "", nil)) } func TestValidateSessionStateEmptyValidateURL(t *testing.T) { vtTest := NewValidateSessionStateTest() defer vtTest.Close() vtTest.provider.Data().ValidateURL = nil - assert.Equal(t, false, validateToken(vtTest.provider, "foobar", nil)) + assert.Equal(t, false, validateToken(context.Background(), vtTest.provider, "foobar", nil)) } func TestValidateSessionStateRequestNetworkFailure(t *testing.T) { vtTest := NewValidateSessionStateTest() // Close immediately to simulate a network failure vtTest.Close() - assert.Equal(t, false, validateToken(vtTest.provider, "foobar", nil)) + assert.Equal(t, false, validateToken(context.Background(), vtTest.provider, "foobar", nil)) } func TestValidateSessionStateExpiredToken(t *testing.T) { vtTest := NewValidateSessionStateTest() defer vtTest.Close() vtTest.responseCode = 401 - assert.Equal(t, false, validateToken(vtTest.provider, "foobar", nil)) + assert.Equal(t, false, validateToken(context.Background(), vtTest.provider, "foobar", nil)) } func TestStripTokenNotPresent(t *testing.T) { diff --git a/providers/keycloak.go b/providers/keycloak.go index d8d41801..414c4973 100644 --- a/providers/keycloak.go +++ b/providers/keycloak.go @@ -1,6 +1,7 @@ package providers import ( + "context" "net/http" "net/url" @@ -14,6 +15,8 @@ type KeycloakProvider struct { Group string } +var _ Provider = (*KeycloakProvider)(nil) + func NewKeycloakProvider(p *ProviderData) *KeycloakProvider { p.ProviderName = "Keycloak" if p.LoginURL == nil || p.LoginURL.String() == "" { @@ -47,9 +50,9 @@ func (p *KeycloakProvider) SetGroup(group string) { p.Group = group } -func (p *KeycloakProvider) GetEmailAddress(s *sessions.SessionState) (string, error) { +func (p *KeycloakProvider) GetEmailAddress(ctx context.Context, s *sessions.SessionState) (string, error) { - req, err := http.NewRequest("GET", p.ValidateURL.String(), nil) + req, err := http.NewRequestWithContext(ctx, "GET", p.ValidateURL.String(), nil) req.Header.Set("Authorization", "Bearer "+s.AccessToken) if err != nil { logger.Printf("failed building request %s", err) diff --git a/providers/keycloak_test.go b/providers/keycloak_test.go index 2ac9d67f..239d727f 100644 --- a/providers/keycloak_test.go +++ b/providers/keycloak_test.go @@ -1,6 +1,7 @@ package providers import ( + "context" "net/http" "net/http/httptest" "net/url" @@ -99,7 +100,7 @@ func TestKeycloakProviderGetEmailAddress(t *testing.T) { p := testKeycloakProvider(bURL.Host, "") session := CreateAuthorizedSession() - email, err := p.GetEmailAddress(session) + email, err := p.GetEmailAddress(context.Background(), session) assert.Equal(t, nil, err) assert.Equal(t, "michael.bland@gsa.gov", email) } @@ -112,7 +113,7 @@ func TestKeycloakProviderGetEmailAddressAndGroup(t *testing.T) { p := testKeycloakProvider(bURL.Host, "test-grp1") session := CreateAuthorizedSession() - email, err := p.GetEmailAddress(session) + email, err := p.GetEmailAddress(context.Background(), session) assert.Equal(t, nil, err) assert.Equal(t, "michael.bland@gsa.gov", email) } @@ -130,7 +131,7 @@ func TestKeycloakProviderGetEmailAddressFailedRequest(t *testing.T) { // token. Alternatively, we could allow the parsing of the payload as // JSON to fail. session := &sessions.SessionState{AccessToken: "unexpected_access_token"} - email, err := p.GetEmailAddress(session) + email, err := p.GetEmailAddress(context.Background(), session) assert.NotEqual(t, nil, err) assert.Equal(t, "", email) } @@ -143,7 +144,7 @@ func TestKeycloakProviderGetEmailAddressEmailNotPresentInPayload(t *testing.T) { p := testKeycloakProvider(bURL.Host, "") session := CreateAuthorizedSession() - email, err := p.GetEmailAddress(session) + email, err := p.GetEmailAddress(context.Background(), session) assert.NotEqual(t, nil, err) assert.Equal(t, "", email) } diff --git a/providers/linkedin.go b/providers/linkedin.go index b69ae933..6cc24239 100644 --- a/providers/linkedin.go +++ b/providers/linkedin.go @@ -1,6 +1,7 @@ package providers import ( + "context" "errors" "fmt" "net/http" @@ -15,6 +16,8 @@ type LinkedInProvider struct { *ProviderData } +var _ Provider = (*LinkedInProvider)(nil) + // NewLinkedInProvider initiates a new LinkedInProvider func NewLinkedInProvider(p *ProviderData) *LinkedInProvider { p.ProviderName = "LinkedIn" @@ -51,11 +54,11 @@ func getLinkedInHeader(accessToken string) http.Header { } // GetEmailAddress returns the Account email address -func (p *LinkedInProvider) GetEmailAddress(s *sessions.SessionState) (string, error) { +func (p *LinkedInProvider) GetEmailAddress(ctx context.Context, s *sessions.SessionState) (string, error) { if s.AccessToken == "" { return "", errors.New("missing access token") } - req, err := http.NewRequest("GET", p.ProfileURL.String()+"?format=json", nil) + req, err := http.NewRequestWithContext(ctx, "GET", p.ProfileURL.String()+"?format=json", nil) if err != nil { return "", err } @@ -74,6 +77,6 @@ func (p *LinkedInProvider) GetEmailAddress(s *sessions.SessionState) (string, er } // ValidateSessionState validates the AccessToken -func (p *LinkedInProvider) ValidateSessionState(s *sessions.SessionState) bool { - return validateToken(p, s.AccessToken, getLinkedInHeader(s.AccessToken)) +func (p *LinkedInProvider) ValidateSessionState(ctx context.Context, s *sessions.SessionState) bool { + return validateToken(ctx, p, s.AccessToken, getLinkedInHeader(s.AccessToken)) } diff --git a/providers/linkedin_test.go b/providers/linkedin_test.go index 9f325eae..6d70d57c 100644 --- a/providers/linkedin_test.go +++ b/providers/linkedin_test.go @@ -1,6 +1,7 @@ package providers import ( + "context" "net/http" "net/http/httptest" "net/url" @@ -99,7 +100,7 @@ func TestLinkedInProviderGetEmailAddress(t *testing.T) { p := testLinkedInProvider(bURL.Host) session := CreateAuthorizedSession() - email, err := p.GetEmailAddress(session) + email, err := p.GetEmailAddress(context.Background(), session) assert.Equal(t, nil, err) assert.Equal(t, "user@linkedin.com", email) } @@ -115,7 +116,7 @@ func TestLinkedInProviderGetEmailAddressFailedRequest(t *testing.T) { // token. Alternatively, we could allow the parsing of the payload as // JSON to fail. session := &sessions.SessionState{AccessToken: "unexpected_access_token"} - email, err := p.GetEmailAddress(session) + email, err := p.GetEmailAddress(context.Background(), session) assert.NotEqual(t, nil, err) assert.Equal(t, "", email) } @@ -128,7 +129,7 @@ func TestLinkedInProviderGetEmailAddressEmailNotPresentInPayload(t *testing.T) { p := testLinkedInProvider(bURL.Host) session := CreateAuthorizedSession() - email, err := p.GetEmailAddress(session) + email, err := p.GetEmailAddress(context.Background(), session) assert.NotEqual(t, nil, err) assert.Equal(t, "", email) } diff --git a/providers/logingov.go b/providers/logingov.go index d5a34a41..6f98d0cc 100644 --- a/providers/logingov.go +++ b/providers/logingov.go @@ -2,6 +2,7 @@ package providers import ( "bytes" + "context" "crypto/rsa" "encoding/json" "errors" @@ -28,6 +29,8 @@ type LoginGovProvider struct { PubJWKURL *url.URL } +var _ Provider = (*LoginGovProvider)(nil) + // For generating a nonce var letters = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ") @@ -125,10 +128,10 @@ func checkNonce(idToken string, p *LoginGovProvider) (err error) { return } -func emailFromUserInfo(accessToken string, userInfoEndpoint string) (email string, err error) { +func emailFromUserInfo(ctx context.Context, accessToken string, userInfoEndpoint string) (email string, err error) { // query the user info endpoint for user attributes var req *http.Request - req, err = http.NewRequest("GET", userInfoEndpoint, nil) + req, err = http.NewRequestWithContext(ctx, "GET", userInfoEndpoint, nil) if err != nil { return } @@ -173,7 +176,7 @@ func emailFromUserInfo(accessToken string, userInfoEndpoint string) (email strin } // Redeem exchanges the OAuth2 authentication token for an ID token -func (p *LoginGovProvider) Redeem(redirectURL, code string) (s *sessions.SessionState, err error) { +func (p *LoginGovProvider) Redeem(ctx context.Context, redirectURL, code string) (s *sessions.SessionState, err error) { if code == "" { err = errors.New("missing code") return @@ -199,7 +202,7 @@ func (p *LoginGovProvider) Redeem(redirectURL, code string) (s *sessions.Session params.Add("grant_type", "authorization_code") var req *http.Request - req, err = http.NewRequest("POST", p.RedeemURL.String(), bytes.NewBufferString(params.Encode())) + req, err = http.NewRequestWithContext(ctx, "POST", p.RedeemURL.String(), bytes.NewBufferString(params.Encode())) if err != nil { return } @@ -242,7 +245,7 @@ func (p *LoginGovProvider) Redeem(redirectURL, code string) (s *sessions.Session // Get the email address var email string - email, err = emailFromUserInfo(jsonResponse.AccessToken, p.ProfileURL.String()) + email, err = emailFromUserInfo(ctx, jsonResponse.AccessToken, p.ProfileURL.String()) if err != nil { return } diff --git a/providers/logingov_test.go b/providers/logingov_test.go index 29808d02..96934c68 100644 --- a/providers/logingov_test.go +++ b/providers/logingov_test.go @@ -1,6 +1,7 @@ package providers import ( + "context" "crypto" "crypto/rand" "crypto/rsa" @@ -189,7 +190,7 @@ func TestLoginGovProviderSessionData(t *testing.T) { p.PubJWKURL, pubjwkserver = newLoginGovServer(pubjwkbody) defer pubjwkserver.Close() - session, err := p.Redeem("http://redirect/", "code1234") + session, err := p.Redeem(context.Background(), "http://redirect/", "code1234") assert.NoError(t, err) assert.NotEqual(t, session, nil) assert.Equal(t, "timothy.spencer@gsa.gov", session.Email) @@ -283,7 +284,7 @@ func TestLoginGovProviderBadNonce(t *testing.T) { p.PubJWKURL, pubjwkserver = newLoginGovServer(pubjwkbody) defer pubjwkserver.Close() - _, err = p.Redeem("http://redirect/", "code1234") + _, err = p.Redeem(context.Background(), "http://redirect/", "code1234") // The "badfakenonce" in the idtoken above should cause this to error out assert.Error(t, err) diff --git a/providers/nextcloud.go b/providers/nextcloud.go index 6b2932e3..d51b7183 100644 --- a/providers/nextcloud.go +++ b/providers/nextcloud.go @@ -1,6 +1,7 @@ package providers import ( + "context" "fmt" "net/http" @@ -14,6 +15,8 @@ type NextcloudProvider struct { *ProviderData } +var _ Provider = (*NextcloudProvider)(nil) + // NewNextcloudProvider initiates a new NextcloudProvider func NewNextcloudProvider(p *ProviderData) *NextcloudProvider { p.ProviderName = "Nextcloud" @@ -27,8 +30,8 @@ func getNextcloudHeader(accessToken string) http.Header { } // GetEmailAddress returns the Account email address -func (p *NextcloudProvider) GetEmailAddress(s *sessions.SessionState) (string, error) { - req, err := http.NewRequest("GET", +func (p *NextcloudProvider) GetEmailAddress(ctx context.Context, s *sessions.SessionState) (string, error) { + req, err := http.NewRequestWithContext(ctx, "GET", p.ValidateURL.String(), nil) if err != nil { logger.Printf("failed building request %s", err) diff --git a/providers/nextcloud_test.go b/providers/nextcloud_test.go index 0f3a8293..ac93d877 100644 --- a/providers/nextcloud_test.go +++ b/providers/nextcloud_test.go @@ -1,6 +1,7 @@ package providers import ( + "context" "net/http" "net/http/httptest" "net/url" @@ -97,7 +98,7 @@ func TestNextcloudProviderGetEmailAddress(t *testing.T) { p.ValidateURL.RawQuery = formatJSON session := CreateAuthorizedSession() - email, err := p.GetEmailAddress(session) + email, err := p.GetEmailAddress(context.Background(), session) assert.Equal(t, nil, err) assert.Equal(t, "michael.bland@gsa.gov", email) } @@ -117,7 +118,7 @@ func TestNextcloudProviderGetEmailAddressFailedRequest(t *testing.T) { // token. Alternatively, we could allow the parsing of the payload as // JSON to fail. session := &sessions.SessionState{AccessToken: "unexpected_access_token"} - email, err := p.GetEmailAddress(session) + email, err := p.GetEmailAddress(context.Background(), session) assert.NotEqual(t, nil, err) assert.Equal(t, "", email) } @@ -132,7 +133,7 @@ func TestNextcloudProviderGetEmailAddressEmailNotPresentInPayload(t *testing.T) p.ValidateURL.RawQuery = formatJSON session := CreateAuthorizedSession() - email, err := p.GetEmailAddress(session) + email, err := p.GetEmailAddress(context.Background(), session) assert.NotEqual(t, nil, err) assert.Equal(t, "", email) } diff --git a/providers/oidc.go b/providers/oidc.go index ac27c8aa..1b6758b9 100644 --- a/providers/oidc.go +++ b/providers/oidc.go @@ -31,14 +31,15 @@ func NewOIDCProvider(p *ProviderData) *OIDCProvider { return &OIDCProvider{ProviderData: p} } +var _ Provider = (*OIDCProvider)(nil) + // Redeem exchanges the OAuth2 authentication token for an ID token -func (p *OIDCProvider) Redeem(redirectURL, code string) (s *sessions.SessionState, err error) { +func (p *OIDCProvider) Redeem(ctx context.Context, redirectURL, code string) (s *sessions.SessionState, err error) { clientSecret, err := p.GetClientSecret() if err != nil { return } - ctx := context.Background() c := oauth2.Config{ ClientID: p.ClientID, ClientSecret: clientSecret, @@ -60,7 +61,7 @@ func (p *OIDCProvider) Redeem(redirectURL, code string) (s *sessions.SessionStat return nil, fmt.Errorf("token response did not contain an id_token") } - s, err = p.createSessionState(token, idToken) + s, err = p.createSessionState(ctx, token, idToken) if err != nil { return nil, fmt.Errorf("unable to update session: %v", err) } @@ -70,12 +71,12 @@ func (p *OIDCProvider) Redeem(redirectURL, code string) (s *sessions.SessionStat // RefreshSessionIfNeeded checks if the session has expired and uses the // RefreshToken to fetch a new Access Token (and optional ID token) if required -func (p *OIDCProvider) RefreshSessionIfNeeded(s *sessions.SessionState) (bool, error) { +func (p *OIDCProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error) { if s == nil || s.ExpiresOn.After(time.Now()) || s.RefreshToken == "" { return false, nil } - err := p.redeemRefreshToken(s) + err := p.redeemRefreshToken(ctx, s) if err != nil { return false, fmt.Errorf("unable to redeem refresh token: %v", err) } @@ -84,7 +85,7 @@ func (p *OIDCProvider) RefreshSessionIfNeeded(s *sessions.SessionState) (bool, e return true, nil } -func (p *OIDCProvider) redeemRefreshToken(s *sessions.SessionState) (err error) { +func (p *OIDCProvider) redeemRefreshToken(ctx context.Context, s *sessions.SessionState) (err error) { clientSecret, err := p.GetClientSecret() if err != nil { return @@ -97,7 +98,6 @@ func (p *OIDCProvider) redeemRefreshToken(s *sessions.SessionState) (err error) TokenURL: p.RedeemURL.String(), }, } - ctx := context.Background() t := &oauth2.Token{ RefreshToken: s.RefreshToken, Expiry: time.Now().Add(-time.Hour), @@ -113,7 +113,7 @@ func (p *OIDCProvider) redeemRefreshToken(s *sessions.SessionState) (err error) return fmt.Errorf("unable to extract id_token from response: %v", err) } - newSession, err := p.createSessionState(token, idToken) + newSession, err := p.createSessionState(ctx, token, idToken) if err != nil { return fmt.Errorf("unable create new session state from response: %v", err) } @@ -149,7 +149,7 @@ func (p *OIDCProvider) findVerifiedIDToken(ctx context.Context, token *oauth2.To return nil, nil } -func (p *OIDCProvider) createSessionState(token *oauth2.Token, idToken *oidc.IDToken) (*sessions.SessionState, error) { +func (p *OIDCProvider) createSessionState(ctx context.Context, token *oauth2.Token, idToken *oidc.IDToken) (*sessions.SessionState, error) { var newSession *sessions.SessionState @@ -157,7 +157,7 @@ func (p *OIDCProvider) createSessionState(token *oauth2.Token, idToken *oidc.IDT newSession = &sessions.SessionState{} } else { var err error - newSession, err = p.createSessionStateInternal(token.Extra("id_token").(string), idToken, token) + newSession, err = p.createSessionStateInternal(ctx, token.Extra("id_token").(string), idToken, token) if err != nil { return nil, err } @@ -170,8 +170,8 @@ func (p *OIDCProvider) createSessionState(token *oauth2.Token, idToken *oidc.IDT return newSession, nil } -func (p *OIDCProvider) CreateSessionStateFromBearerToken(rawIDToken string, idToken *oidc.IDToken) (*sessions.SessionState, error) { - newSession, err := p.createSessionStateInternal(rawIDToken, idToken, nil) +func (p *OIDCProvider) CreateSessionStateFromBearerToken(ctx context.Context, rawIDToken string, idToken *oidc.IDToken) (*sessions.SessionState, error) { + newSession, err := p.createSessionStateInternal(ctx, rawIDToken, idToken, nil) if err != nil { return nil, err } @@ -184,7 +184,7 @@ func (p *OIDCProvider) CreateSessionStateFromBearerToken(rawIDToken string, idTo return newSession, nil } -func (p *OIDCProvider) createSessionStateInternal(rawIDToken string, idToken *oidc.IDToken, token *oauth2.Token) (*sessions.SessionState, error) { +func (p *OIDCProvider) createSessionStateInternal(ctx context.Context, rawIDToken string, idToken *oidc.IDToken, token *oauth2.Token) (*sessions.SessionState, error) { newSession := &sessions.SessionState{} @@ -196,7 +196,7 @@ func (p *OIDCProvider) createSessionStateInternal(rawIDToken string, idToken *oi accessToken = token.AccessToken } - claims, err := p.findClaimsFromIDToken(idToken, accessToken, p.ProfileURL.String()) + claims, err := p.findClaimsFromIDToken(ctx, idToken, accessToken, p.ProfileURL.String()) if err != nil { return nil, fmt.Errorf("couldn't extract claims from id_token (%e)", err) } @@ -217,8 +217,7 @@ func (p *OIDCProvider) createSessionStateInternal(rawIDToken string, idToken *oi } // ValidateSessionState checks that the session's IDToken is still valid -func (p *OIDCProvider) ValidateSessionState(s *sessions.SessionState) bool { - ctx := context.Background() +func (p *OIDCProvider) ValidateSessionState(ctx context.Context, s *sessions.SessionState) bool { _, err := p.Verifier.Verify(ctx, s.IDToken) return err == nil } @@ -230,7 +229,7 @@ func getOIDCHeader(accessToken string) http.Header { return header } -func (p *OIDCProvider) findClaimsFromIDToken(idToken *oidc.IDToken, accessToken string, profileURL string) (*OIDCClaims, error) { +func (p *OIDCProvider) findClaimsFromIDToken(ctx context.Context, idToken *oidc.IDToken, accessToken string, profileURL string) (*OIDCClaims, error) { claims := &OIDCClaims{} // Extract default claims. @@ -257,7 +256,7 @@ func (p *OIDCProvider) findClaimsFromIDToken(idToken *oidc.IDToken, accessToken // contents at the profileURL contains the email. // Make a query to the userinfo endpoint, and attempt to locate the email from there. - req, err := http.NewRequest("GET", profileURL, nil) + req, err := http.NewRequestWithContext(ctx, "GET", profileURL, nil) if err != nil { return nil, err } diff --git a/providers/oidc_test.go b/providers/oidc_test.go index b0596e36..823af30c 100644 --- a/providers/oidc_test.go +++ b/providers/oidc_test.go @@ -159,7 +159,7 @@ func TestOIDCProviderRedeem(t *testing.T) { server, provider := newTestSetup(body) defer server.Close() - session, err := provider.Redeem(provider.RedeemURL.String(), "code1234") + session, err := provider.Redeem(context.Background(), provider.RedeemURL.String(), "code1234") assert.Equal(t, nil, err) assert.Equal(t, defaultIDToken.Email, session.Email) assert.Equal(t, accessToken, session.AccessToken) @@ -183,7 +183,7 @@ func TestOIDCProviderRedeem_custom_userid(t *testing.T) { provider.UserIDClaim = "phone_number" defer server.Close() - session, err := provider.Redeem(provider.RedeemURL.String(), "code1234") + session, err := provider.Redeem(context.Background(), provider.RedeemURL.String(), "code1234") assert.Equal(t, nil, err) assert.Equal(t, defaultIDToken.Phone, session.Email) } @@ -211,7 +211,7 @@ func TestOIDCProviderRefreshSessionIfNeededWithoutIdToken(t *testing.T) { User: "11223344", } - refreshed, err := provider.RefreshSessionIfNeeded(existingSession) + refreshed, err := provider.RefreshSessionIfNeeded(context.Background(), existingSession) assert.Equal(t, nil, err) assert.Equal(t, refreshed, true) assert.Equal(t, "janedoe@example.com", existingSession.Email) @@ -244,7 +244,7 @@ func TestOIDCProviderRefreshSessionIfNeededWithIdToken(t *testing.T) { Email: "changeit", User: "changeit", } - refreshed, err := provider.RefreshSessionIfNeeded(existingSession) + refreshed, err := provider.RefreshSessionIfNeeded(context.Background(), existingSession) assert.Equal(t, nil, err) assert.Equal(t, refreshed, true) assert.Equal(t, defaultIDToken.Email, existingSession.Email) diff --git a/providers/provider_default.go b/providers/provider_default.go index 720dd580..d90ad164 100644 --- a/providers/provider_default.go +++ b/providers/provider_default.go @@ -2,6 +2,7 @@ package providers import ( "bytes" + "context" "encoding/json" "errors" "fmt" @@ -16,8 +17,10 @@ import ( "github.com/oauth2-proxy/oauth2-proxy/pkg/encryption" ) +var _ Provider = (*ProviderData)(nil) + // Redeem provides a default implementation of the OAuth2 token redemption process -func (p *ProviderData) Redeem(redirectURL, code string) (s *sessions.SessionState, err error) { +func (p *ProviderData) Redeem(ctx context.Context, redirectURL, code string) (s *sessions.SessionState, err error) { if code == "" { err = errors.New("missing code") return @@ -38,7 +41,7 @@ func (p *ProviderData) Redeem(redirectURL, code string) (s *sessions.SessionStat } var req *http.Request - req, err = http.NewRequest("POST", p.RedeemURL.String(), bytes.NewBufferString(params.Encode())) + req, err = http.NewRequestWithContext(ctx, "POST", p.RedeemURL.String(), bytes.NewBufferString(params.Encode())) if err != nil { return } @@ -116,17 +119,17 @@ func (p *ProviderData) SessionFromCookie(v string, c *encryption.Cipher) (s *ses } // GetEmailAddress returns the Account email address -func (p *ProviderData) GetEmailAddress(s *sessions.SessionState) (string, error) { +func (p *ProviderData) GetEmailAddress(ctx context.Context, s *sessions.SessionState) (string, error) { return "", errors.New("not implemented") } // GetUserName returns the Account username -func (p *ProviderData) GetUserName(s *sessions.SessionState) (string, error) { +func (p *ProviderData) GetUserName(ctx context.Context, s *sessions.SessionState) (string, error) { return "", errors.New("not implemented") } // GetPreferredUsername returns the Account preferred username -func (p *ProviderData) GetPreferredUsername(s *sessions.SessionState) (string, error) { +func (p *ProviderData) GetPreferredUsername(ctx context.Context, s *sessions.SessionState) (string, error) { return "", errors.New("not implemented") } @@ -137,17 +140,17 @@ func (p *ProviderData) ValidateGroup(email string) bool { } // ValidateSessionState validates the AccessToken -func (p *ProviderData) ValidateSessionState(s *sessions.SessionState) bool { - return validateToken(p, s.AccessToken, nil) +func (p *ProviderData) ValidateSessionState(ctx context.Context, s *sessions.SessionState) bool { + return validateToken(ctx, p, s.AccessToken, nil) } // RefreshSessionIfNeeded should refresh the user's session if required and // do nothing if a refresh is not required -func (p *ProviderData) RefreshSessionIfNeeded(s *sessions.SessionState) (bool, error) { +func (p *ProviderData) RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error) { return false, nil } -func (p *ProviderData) CreateSessionStateFromBearerToken(rawIDToken string, idToken *oidc.IDToken) (*sessions.SessionState, error) { +func (p *ProviderData) CreateSessionStateFromBearerToken(ctx context.Context, rawIDToken string, idToken *oidc.IDToken) (*sessions.SessionState, error) { var claims struct { Subject string `json:"sub"` Email string `json:"email"` diff --git a/providers/provider_default_test.go b/providers/provider_default_test.go index e8a51f51..4d8a8306 100644 --- a/providers/provider_default_test.go +++ b/providers/provider_default_test.go @@ -1,6 +1,7 @@ package providers import ( + "context" "testing" "time" @@ -10,7 +11,7 @@ import ( func TestRefresh(t *testing.T) { p := &ProviderData{} - refreshed, err := p.RefreshSessionIfNeeded(&sessions.SessionState{ + refreshed, err := p.RefreshSessionIfNeeded(context.Background(), &sessions.SessionState{ ExpiresOn: time.Now().Add(time.Duration(-11) * time.Minute), }) assert.Equal(t, false, refreshed) diff --git a/providers/providers.go b/providers/providers.go index 20c42489..87ba9103 100644 --- a/providers/providers.go +++ b/providers/providers.go @@ -1,6 +1,8 @@ package providers import ( + "context" + "github.com/coreos/go-oidc" "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions" "github.com/oauth2-proxy/oauth2-proxy/pkg/encryption" @@ -9,17 +11,17 @@ import ( // Provider represents an upstream identity provider implementation type Provider interface { Data() *ProviderData - GetEmailAddress(*sessions.SessionState) (string, error) - GetUserName(*sessions.SessionState) (string, error) - GetPreferredUsername(*sessions.SessionState) (string, error) - Redeem(string, string) (*sessions.SessionState, error) + GetEmailAddress(ctx context.Context, s *sessions.SessionState) (string, error) + GetUserName(ctx context.Context, s *sessions.SessionState) (string, error) + GetPreferredUsername(ctx context.Context, s *sessions.SessionState) (string, error) + Redeem(ctx context.Context, redirectURI, code string) (*sessions.SessionState, error) ValidateGroup(string) bool - ValidateSessionState(*sessions.SessionState) bool + ValidateSessionState(ctx context.Context, s *sessions.SessionState) bool GetLoginURL(redirectURI, finalRedirect string) string - RefreshSessionIfNeeded(*sessions.SessionState) (bool, error) + RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error) SessionFromCookie(string, *encryption.Cipher) (*sessions.SessionState, error) CookieForSession(*sessions.SessionState, *encryption.Cipher) (string, error) - CreateSessionStateFromBearerToken(rawIDToken string, idToken *oidc.IDToken) (*sessions.SessionState, error) + CreateSessionStateFromBearerToken(ctx context.Context, rawIDToken string, idToken *oidc.IDToken) (*sessions.SessionState, error) } // New provides a new Provider based on the configured provider string