1
0
mirror of https://github.com/oauth2-proxy/oauth2-proxy.git synced 2025-06-06 23:46:28 +02:00

Support context in providers (#519)

Co-authored-by: Henry Jenkins <henry@henryjenkins.name>
This commit is contained in:
Mitsuo Heijo 2020-05-06 00:53:33 +09:00 committed by Henry Jenkins
parent 53d8e99f05
commit e642daef4e
33 changed files with 223 additions and 173 deletions

View File

@ -33,6 +33,7 @@
- [#514](https://github.com/oauth2-proxy/oauth2-proxy/pull/514) Add basic string functions to templates - [#514](https://github.com/oauth2-proxy/oauth2-proxy/pull/514) Add basic string functions to templates
- [#524](https://github.com/oauth2-proxy/oauth2-proxy/pull/524) Sign cookies with SHA256 (@NickMeves) - [#524](https://github.com/oauth2-proxy/oauth2-proxy/pull/524) Sign cookies with SHA256 (@NickMeves)
- [#515](https://github.com/oauth2-proxy/oauth2-proxy/pull/515) Drop configure script in favour of native Makefile env and checks (@JoelSpeed) - [#515](https://github.com/oauth2-proxy/oauth2-proxy/pull/515) Drop configure script in favour of native Makefile env and checks (@JoelSpeed)
- [#519](https://github.com/oauth2-proxy/oauth2-proxy/pull/519) Support context in providers (@johejo)
- [#487](https://github.com/oauth2-proxy/oauth2-proxy/pull/487) Switch flags to PFlag to remove StringArray (@JoelSpeed) - [#487](https://github.com/oauth2-proxy/oauth2-proxy/pull/487) Switch flags to PFlag to remove StringArray (@JoelSpeed)
- [#484](https://github.com/oauth2-proxy/oauth2-proxy/pull/484) Replace configuration loading with Viper (@JoelSpeed) - [#484](https://github.com/oauth2-proxy/oauth2-proxy/pull/484) Replace configuration loading with Viper (@JoelSpeed)
- [#499](https://github.com/oauth2-proxy/oauth2-proxy/pull/499) Add `-user-id-claim` to support generic claims in addition to email (@holyjak) - [#499](https://github.com/oauth2-proxy/oauth2-proxy/pull/499) Add `-user-id-claim` to support generic claims in addition to email (@holyjak)

View File

@ -347,29 +347,29 @@ func (p *OAuthProxy) displayCustomLoginForm() bool {
return p.HtpasswdFile != nil && p.DisplayHtpasswdForm return p.HtpasswdFile != nil && p.DisplayHtpasswdForm
} }
func (p *OAuthProxy) redeemCode(host, code string) (s *sessionsapi.SessionState, err error) { func (p *OAuthProxy) redeemCode(ctx context.Context, host, code string) (s *sessionsapi.SessionState, err error) {
if code == "" { if code == "" {
return nil, errors.New("missing code") return nil, errors.New("missing code")
} }
redirectURI := p.GetRedirectURI(host) redirectURI := p.GetRedirectURI(host)
s, err = p.provider.Redeem(redirectURI, code) s, err = p.provider.Redeem(ctx, redirectURI, code)
if err != nil { if err != nil {
return return
} }
if s.Email == "" { if s.Email == "" {
s.Email, err = p.provider.GetEmailAddress(s) s.Email, err = p.provider.GetEmailAddress(ctx, s)
} }
if s.PreferredUsername == "" { if s.PreferredUsername == "" {
s.PreferredUsername, err = p.provider.GetPreferredUsername(s) s.PreferredUsername, err = p.provider.GetPreferredUsername(ctx, s)
if err != nil && err.Error() == "not implemented" { if err != nil && err.Error() == "not implemented" {
err = nil err = nil
} }
} }
if s.User == "" { if s.User == "" {
s.User, err = p.provider.GetUserName(s) s.User, err = p.provider.GetUserName(ctx, s)
if err != nil && err.Error() == "not implemented" { if err != nil && err.Error() == "not implemented" {
err = nil err = nil
} }
@ -782,7 +782,7 @@ func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) {
return return
} }
session, err := p.redeemCode(req.Host, req.Form.Get("code")) session, err := p.redeemCode(req.Context(), req.Host, req.Form.Get("code"))
if err != nil { if err != nil {
logger.Printf("Error redeeming code during OAuth2 callback: %s ", err.Error()) logger.Printf("Error redeeming code during OAuth2 callback: %s ", err.Error())
p.ErrorPage(rw, 500, "Internal Error", "Internal Error") p.ErrorPage(rw, 500, "Internal Error", "Internal Error")
@ -907,7 +907,7 @@ func (p *OAuthProxy) getAuthenticatedSession(rw http.ResponseWriter, req *http.R
saveSession = true saveSession = true
} }
if ok, err := p.provider.RefreshSessionIfNeeded(session); err != nil { if ok, err := p.provider.RefreshSessionIfNeeded(req.Context(), session); err != nil {
logger.Printf("%s removing session. error refreshing access token %s %s", remoteAddr, err, session) logger.Printf("%s removing session. error refreshing access token %s %s", remoteAddr, err, session)
clearSession = true clearSession = true
session = nil session = nil
@ -926,7 +926,7 @@ func (p *OAuthProxy) getAuthenticatedSession(rw http.ResponseWriter, req *http.R
} }
if saveSession && !revalidated && session != nil && session.AccessToken != "" { if saveSession && !revalidated && session != nil && session.AccessToken != "" {
if !p.provider.ValidateSessionState(session) { if !p.provider.ValidateSessionState(req.Context(), session) {
logger.Printf("Removing session: error validating %s", session) logger.Printf("Removing session: error validating %s", session)
saveSession = false saveSession = false
session = nil session = nil
@ -1126,16 +1126,15 @@ func (p *OAuthProxy) GetJwtSession(req *http.Request) (*sessionsapi.SessionState
return nil, err return nil, err
} }
ctx := context.Background()
for _, verifier := range p.jwtBearerVerifiers { for _, verifier := range p.jwtBearerVerifiers {
bearerToken, err := verifier.Verify(ctx, rawBearerToken) bearerToken, err := verifier.Verify(req.Context(), rawBearerToken)
if err != nil { if err != nil {
logger.Printf("failed to verify bearer token: %v", err) logger.Printf("failed to verify bearer token: %v", err)
continue continue
} }
return p.provider.CreateSessionStateFromBearerToken(rawBearerToken, bearerToken) return p.provider.CreateSessionStateFromBearerToken(req.Context(), rawBearerToken, bearerToken)
} }
return nil, fmt.Errorf("unable to verify jwt token %s", req.Header.Get("Authorization")) return nil, fmt.Errorf("unable to verify jwt token %s", req.Header.Get("Authorization"))
} }

View File

@ -397,6 +397,8 @@ type TestProvider struct {
GroupValidator func(string) bool GroupValidator func(string) bool
} }
var _ providers.Provider = (*TestProvider)(nil)
func NewTestProvider(providerURL *url.URL, emailAddress string) *TestProvider { func NewTestProvider(providerURL *url.URL, emailAddress string) *TestProvider {
return &TestProvider{ return &TestProvider{
ProviderData: &providers.ProviderData{ ProviderData: &providers.ProviderData{
@ -425,11 +427,11 @@ func NewTestProvider(providerURL *url.URL, emailAddress string) *TestProvider {
} }
} }
func (tp *TestProvider) GetEmailAddress(session *sessions.SessionState) (string, error) { func (tp *TestProvider) GetEmailAddress(ctx context.Context, session *sessions.SessionState) (string, error) {
return tp.EmailAddress, nil return tp.EmailAddress, nil
} }
func (tp *TestProvider) ValidateSessionState(session *sessions.SessionState) bool { func (tp *TestProvider) ValidateSessionState(ctx context.Context, session *sessions.SessionState) bool {
return tp.ValidToken return tp.ValidToken
} }

View File

@ -1,6 +1,7 @@
package requests package requests
import ( import (
"context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
@ -62,8 +63,8 @@ func RequestJSON(req *http.Request, v interface{}) error {
} }
// RequestUnparsedResponse performs a GET and returns the raw response object // RequestUnparsedResponse performs a GET and returns the raw response object
func RequestUnparsedResponse(url string, header http.Header) (resp *http.Response, err error) { func RequestUnparsedResponse(ctx context.Context, url string, header http.Header) (resp *http.Response, err error) {
req, err := http.NewRequest("GET", url, nil) req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
if err != nil { if err != nil {
return nil, fmt.Errorf("error performing get request: %w", err) return nil, fmt.Errorf("error performing get request: %w", err)
} }

View File

@ -1,6 +1,7 @@
package requests package requests
import ( import (
"context"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
@ -87,7 +88,7 @@ func TestRequestUnparsedResponseUsingAccessTokenParameter(t *testing.T) {
defer backend.Close() defer backend.Close()
response, err := RequestUnparsedResponse( response, err := RequestUnparsedResponse(
backend.URL+"?access_token=my_token", nil) context.Background(), backend.URL+"?access_token=my_token", nil)
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
defer response.Body.Close() defer response.Body.Close()
@ -103,7 +104,7 @@ func TestRequestUnparsedResponseUsingAccessTokenParameterFailedResponse(t *testi
backend.Close() backend.Close()
response, err := RequestUnparsedResponse( response, err := RequestUnparsedResponse(
backend.URL+"?access_token=my_token", nil) context.Background(), backend.URL+"?access_token=my_token", nil)
assert.NotEqual(t, nil, err) assert.NotEqual(t, nil, err)
assert.Equal(t, (*http.Response)(nil), response) assert.Equal(t, (*http.Response)(nil), response)
} }
@ -123,7 +124,7 @@ func TestRequestUnparsedResponseUsingHeaders(t *testing.T) {
headers := make(http.Header) headers := make(http.Header)
headers.Set("Auth", "my_token") headers.Set("Auth", "my_token")
response, err := RequestUnparsedResponse(backend.URL, headers) response, err := RequestUnparsedResponse(context.Background(), backend.URL, headers)
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
defer response.Body.Close() defer response.Body.Close()

View File

@ -2,6 +2,7 @@ package providers
import ( import (
"bytes" "bytes"
"context"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
@ -22,6 +23,8 @@ type AzureProvider struct {
Tenant string Tenant string
} }
var _ Provider = (*AzureProvider)(nil)
// NewAzureProvider initiates a new AzureProvider // NewAzureProvider initiates a new AzureProvider
func NewAzureProvider(p *ProviderData) *AzureProvider { func NewAzureProvider(p *ProviderData) *AzureProvider {
p.ProviderName = "Azure" p.ProviderName = "Azure"
@ -68,7 +71,7 @@ func (p *AzureProvider) Configure(tenant string) {
} }
} }
func (p *AzureProvider) Redeem(redirectURL, code string) (s *sessions.SessionState, err error) { func (p *AzureProvider) Redeem(ctx context.Context, redirectURL, code string) (s *sessions.SessionState, err error) {
if code == "" { if code == "" {
err = errors.New("missing code") err = errors.New("missing code")
return return
@ -89,7 +92,7 @@ func (p *AzureProvider) Redeem(redirectURL, code string) (s *sessions.SessionSta
} }
var req *http.Request var req *http.Request
req, err = http.NewRequest("POST", p.RedeemURL.String(), bytes.NewBufferString(params.Encode())) req, err = http.NewRequestWithContext(ctx, "POST", p.RedeemURL.String(), bytes.NewBufferString(params.Encode()))
if err != nil { if err != nil {
return return
} }
@ -157,14 +160,14 @@ func getEmailFromJSON(json *simplejson.Json) (string, error) {
} }
// GetEmailAddress returns the Account email address // GetEmailAddress returns the Account email address
func (p *AzureProvider) GetEmailAddress(s *sessions.SessionState) (string, error) { func (p *AzureProvider) GetEmailAddress(ctx context.Context, s *sessions.SessionState) (string, error) {
var email string var email string
var err error var err error
if s.AccessToken == "" { if s.AccessToken == "" {
return "", errors.New("missing access token") return "", errors.New("missing access token")
} }
req, err := http.NewRequest("GET", p.ProfileURL.String(), nil) req, err := http.NewRequestWithContext(ctx, "GET", p.ProfileURL.String(), nil)
if err != nil { if err != nil {
return "", err return "", err
} }

View File

@ -1,6 +1,7 @@
package providers package providers
import ( import (
"context"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/url" "net/url"
@ -133,7 +134,7 @@ func TestAzureProviderGetEmailAddress(t *testing.T) {
p := testAzureProvider(bURL.Host) p := testAzureProvider(bURL.Host)
session := CreateAuthorizedSession() session := CreateAuthorizedSession()
email, err := p.GetEmailAddress(session) email, err := p.GetEmailAddress(context.Background(), session)
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
assert.Equal(t, "user@windows.net", email) assert.Equal(t, "user@windows.net", email)
} }
@ -146,7 +147,7 @@ func TestAzureProviderGetEmailAddressMailNull(t *testing.T) {
p := testAzureProvider(bURL.Host) p := testAzureProvider(bURL.Host)
session := CreateAuthorizedSession() session := CreateAuthorizedSession()
email, err := p.GetEmailAddress(session) email, err := p.GetEmailAddress(context.Background(), session)
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
assert.Equal(t, "user@windows.net", email) assert.Equal(t, "user@windows.net", email)
} }
@ -159,7 +160,7 @@ func TestAzureProviderGetEmailAddressGetUserPrincipalName(t *testing.T) {
p := testAzureProvider(bURL.Host) p := testAzureProvider(bURL.Host)
session := CreateAuthorizedSession() session := CreateAuthorizedSession()
email, err := p.GetEmailAddress(session) email, err := p.GetEmailAddress(context.Background(), session)
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
assert.Equal(t, "user@windows.net", email) assert.Equal(t, "user@windows.net", email)
} }
@ -172,7 +173,7 @@ func TestAzureProviderGetEmailAddressFailToGetEmailAddress(t *testing.T) {
p := testAzureProvider(bURL.Host) p := testAzureProvider(bURL.Host)
session := CreateAuthorizedSession() session := CreateAuthorizedSession()
email, err := p.GetEmailAddress(session) email, err := p.GetEmailAddress(context.Background(), session)
assert.Equal(t, "type assertion to string failed", err.Error()) assert.Equal(t, "type assertion to string failed", err.Error())
assert.Equal(t, "", email) assert.Equal(t, "", email)
} }
@ -185,7 +186,7 @@ func TestAzureProviderGetEmailAddressEmptyUserPrincipalName(t *testing.T) {
p := testAzureProvider(bURL.Host) p := testAzureProvider(bURL.Host)
session := CreateAuthorizedSession() session := CreateAuthorizedSession()
email, err := p.GetEmailAddress(session) email, err := p.GetEmailAddress(context.Background(), session)
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
assert.Equal(t, "", email) assert.Equal(t, "", email)
} }
@ -198,7 +199,7 @@ func TestAzureProviderGetEmailAddressIncorrectOtherMails(t *testing.T) {
p := testAzureProvider(bURL.Host) p := testAzureProvider(bURL.Host)
session := CreateAuthorizedSession() session := CreateAuthorizedSession()
email, err := p.GetEmailAddress(session) email, err := p.GetEmailAddress(context.Background(), session)
assert.Equal(t, "type assertion to string failed", err.Error()) assert.Equal(t, "type assertion to string failed", err.Error())
assert.Equal(t, "", email) assert.Equal(t, "", email)
} }
@ -212,7 +213,7 @@ func TestAzureProviderRedeemReturnsIdToken(t *testing.T) {
bURL, _ := url.Parse(b.URL) bURL, _ := url.Parse(b.URL)
p := testAzureProvider(bURL.Host) p := testAzureProvider(bURL.Host)
p.Data().RedeemURL.Path = "/common/oauth2/token" p.Data().RedeemURL.Path = "/common/oauth2/token"
s, err := p.Redeem("https://localhost", "1234") s, err := p.Redeem(context.Background(), "https://localhost", "1234")
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
assert.Equal(t, "testtoken1234", s.IDToken) assert.Equal(t, "testtoken1234", s.IDToken)
assert.Equal(t, timestamp, s.ExpiresOn.UTC()) assert.Equal(t, timestamp, s.ExpiresOn.UTC())

View File

@ -1,6 +1,7 @@
package providers package providers
import ( import (
"context"
"net/http" "net/http"
"net/url" "net/url"
"strings" "strings"
@ -17,6 +18,8 @@ type BitbucketProvider struct {
Repository string Repository string
} }
var _ Provider = (*BitbucketProvider)(nil)
// NewBitbucketProvider initiates a new BitbucketProvider // NewBitbucketProvider initiates a new BitbucketProvider
func NewBitbucketProvider(p *ProviderData) *BitbucketProvider { func NewBitbucketProvider(p *ProviderData) *BitbucketProvider {
p.ProviderName = "Bitbucket" p.ProviderName = "Bitbucket"
@ -64,7 +67,7 @@ func (p *BitbucketProvider) SetRepository(repository string) {
} }
// GetEmailAddress returns the email of the authenticated user // GetEmailAddress returns the email of the authenticated user
func (p *BitbucketProvider) GetEmailAddress(s *sessions.SessionState) (string, error) { func (p *BitbucketProvider) GetEmailAddress(ctx context.Context, s *sessions.SessionState) (string, error) {
var emails struct { var emails struct {
Values []struct { Values []struct {
@ -82,7 +85,7 @@ func (p *BitbucketProvider) GetEmailAddress(s *sessions.SessionState) (string, e
FullName string `json:"full_name"` FullName string `json:"full_name"`
} }
} }
req, err := http.NewRequest("GET", req, err := http.NewRequestWithContext(ctx, "GET",
p.ValidateURL.String()+"?access_token="+s.AccessToken, nil) p.ValidateURL.String()+"?access_token="+s.AccessToken, nil)
if err != nil { if err != nil {
logger.Printf("failed building request %s", err) logger.Printf("failed building request %s", err)
@ -98,7 +101,7 @@ func (p *BitbucketProvider) GetEmailAddress(s *sessions.SessionState) (string, e
teamURL := &url.URL{} teamURL := &url.URL{}
*teamURL = *p.ValidateURL *teamURL = *p.ValidateURL
teamURL.Path = "/2.0/teams" teamURL.Path = "/2.0/teams"
req, err = http.NewRequest("GET", req, err = http.NewRequestWithContext(ctx, "GET",
teamURL.String()+"?role=member&access_token="+s.AccessToken, nil) teamURL.String()+"?role=member&access_token="+s.AccessToken, nil)
if err != nil { if err != nil {
logger.Printf("failed building request %s", err) logger.Printf("failed building request %s", err)
@ -126,7 +129,7 @@ func (p *BitbucketProvider) GetEmailAddress(s *sessions.SessionState) (string, e
repositoriesURL := &url.URL{} repositoriesURL := &url.URL{}
*repositoriesURL = *p.ValidateURL *repositoriesURL = *p.ValidateURL
repositoriesURL.Path = "/2.0/repositories/" + strings.Split(p.Repository, "/")[0] repositoriesURL.Path = "/2.0/repositories/" + strings.Split(p.Repository, "/")[0]
req, err = http.NewRequest("GET", req, err = http.NewRequestWithContext(ctx, "GET",
repositoriesURL.String()+"?role=contributor"+ repositoriesURL.String()+"?role=contributor"+
"&q=full_name="+url.QueryEscape("\""+p.Repository+"\"")+ "&q=full_name="+url.QueryEscape("\""+p.Repository+"\"")+
"&access_token="+s.AccessToken, "&access_token="+s.AccessToken,

View File

@ -1,6 +1,7 @@
package providers package providers
import ( import (
"context"
"log" "log"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
@ -120,7 +121,7 @@ func TestBitbucketProviderGetEmailAddress(t *testing.T) {
p := testBitbucketProvider(bURL.Host, "", "") p := testBitbucketProvider(bURL.Host, "", "")
session := CreateAuthorizedSession() session := CreateAuthorizedSession()
email, err := p.GetEmailAddress(session) email, err := p.GetEmailAddress(context.Background(), session)
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
assert.Equal(t, "michael.bland@gsa.gov", email) assert.Equal(t, "michael.bland@gsa.gov", email)
} }
@ -133,7 +134,7 @@ func TestBitbucketProviderGetEmailAddressAndGroup(t *testing.T) {
p := testBitbucketProvider(bURL.Host, "bioinformatics", "") p := testBitbucketProvider(bURL.Host, "bioinformatics", "")
session := CreateAuthorizedSession() session := CreateAuthorizedSession()
email, err := p.GetEmailAddress(session) email, err := p.GetEmailAddress(context.Background(), session)
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
assert.Equal(t, "michael.bland@gsa.gov", email) assert.Equal(t, "michael.bland@gsa.gov", email)
} }
@ -151,7 +152,7 @@ func TestBitbucketProviderGetEmailAddressFailedRequest(t *testing.T) {
// token. Alternatively, we could allow the parsing of the payload as // token. Alternatively, we could allow the parsing of the payload as
// JSON to fail. // JSON to fail.
session := &sessions.SessionState{AccessToken: "unexpected_access_token"} session := &sessions.SessionState{AccessToken: "unexpected_access_token"}
email, err := p.GetEmailAddress(session) email, err := p.GetEmailAddress(context.Background(), session)
assert.NotEqual(t, nil, err) assert.NotEqual(t, nil, err)
assert.Equal(t, "", email) assert.Equal(t, "", email)
} }
@ -164,7 +165,7 @@ func TestBitbucketProviderGetEmailAddressEmailNotPresentInPayload(t *testing.T)
p := testBitbucketProvider(bURL.Host, "", "") p := testBitbucketProvider(bURL.Host, "", "")
session := CreateAuthorizedSession() session := CreateAuthorizedSession()
email, err := p.GetEmailAddress(session) email, err := p.GetEmailAddress(context.Background(), session)
assert.Equal(t, "", email) assert.Equal(t, "", email)
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
} }

View File

@ -1,6 +1,7 @@
package providers package providers
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"net/http" "net/http"
@ -15,6 +16,8 @@ type DigitalOceanProvider struct {
*ProviderData *ProviderData
} }
var _ Provider = (*DigitalOceanProvider)(nil)
// NewDigitalOceanProvider initiates a new DigitalOceanProvider // NewDigitalOceanProvider initiates a new DigitalOceanProvider
func NewDigitalOceanProvider(p *ProviderData) *DigitalOceanProvider { func NewDigitalOceanProvider(p *ProviderData) *DigitalOceanProvider {
p.ProviderName = "DigitalOcean" p.ProviderName = "DigitalOcean"
@ -53,11 +56,11 @@ func getDigitalOceanHeader(accessToken string) http.Header {
} }
// GetEmailAddress returns the Account email address // GetEmailAddress returns the Account email address
func (p *DigitalOceanProvider) GetEmailAddress(s *sessions.SessionState) (string, error) { func (p *DigitalOceanProvider) GetEmailAddress(ctx context.Context, s *sessions.SessionState) (string, error) {
if s.AccessToken == "" { if s.AccessToken == "" {
return "", errors.New("missing access token") return "", errors.New("missing access token")
} }
req, err := http.NewRequest("GET", p.ProfileURL.String(), nil) req, err := http.NewRequestWithContext(ctx, "GET", p.ProfileURL.String(), nil)
if err != nil { if err != nil {
return "", err return "", err
} }
@ -76,6 +79,6 @@ func (p *DigitalOceanProvider) GetEmailAddress(s *sessions.SessionState) (string
} }
// ValidateSessionState validates the AccessToken // ValidateSessionState validates the AccessToken
func (p *DigitalOceanProvider) ValidateSessionState(s *sessions.SessionState) bool { func (p *DigitalOceanProvider) ValidateSessionState(ctx context.Context, s *sessions.SessionState) bool {
return validateToken(p, s.AccessToken, getDigitalOceanHeader(s.AccessToken)) return validateToken(ctx, p, s.AccessToken, getDigitalOceanHeader(s.AccessToken))
} }

View File

@ -1,6 +1,7 @@
package providers package providers
import ( import (
"context"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/url" "net/url"
@ -99,7 +100,7 @@ func TestDigitalOceanProviderGetEmailAddress(t *testing.T) {
p := testDigitalOceanProvider(bURL.Host) p := testDigitalOceanProvider(bURL.Host)
session := CreateAuthorizedSession() session := CreateAuthorizedSession()
email, err := p.GetEmailAddress(session) email, err := p.GetEmailAddress(context.Background(), session)
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
assert.Equal(t, "user@example.com", email) assert.Equal(t, "user@example.com", email)
} }
@ -115,7 +116,7 @@ func TestDigitalOceanProviderGetEmailAddressFailedRequest(t *testing.T) {
// token. Alternatively, we could allow the parsing of the payload as // token. Alternatively, we could allow the parsing of the payload as
// JSON to fail. // JSON to fail.
session := &sessions.SessionState{AccessToken: "unexpected_access_token"} session := &sessions.SessionState{AccessToken: "unexpected_access_token"}
email, err := p.GetEmailAddress(session) email, err := p.GetEmailAddress(context.Background(), session)
assert.NotEqual(t, nil, err) assert.NotEqual(t, nil, err)
assert.Equal(t, "", email) assert.Equal(t, "", email)
} }
@ -128,7 +129,7 @@ func TestDigitalOceanProviderGetEmailAddressEmailNotPresentInPayload(t *testing.
p := testDigitalOceanProvider(bURL.Host) p := testDigitalOceanProvider(bURL.Host)
session := CreateAuthorizedSession() session := CreateAuthorizedSession()
email, err := p.GetEmailAddress(session) email, err := p.GetEmailAddress(context.Background(), session)
assert.NotEqual(t, nil, err) assert.NotEqual(t, nil, err)
assert.Equal(t, "", email) assert.Equal(t, "", email)
} }

View File

@ -1,6 +1,7 @@
package providers package providers
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"net/http" "net/http"
@ -15,6 +16,8 @@ type FacebookProvider struct {
*ProviderData *ProviderData
} }
var _ Provider = (*FacebookProvider)(nil)
// NewFacebookProvider initiates a new FacebookProvider // NewFacebookProvider initiates a new FacebookProvider
func NewFacebookProvider(p *ProviderData) *FacebookProvider { func NewFacebookProvider(p *ProviderData) *FacebookProvider {
p.ProviderName = "Facebook" p.ProviderName = "Facebook"
@ -55,11 +58,11 @@ func getFacebookHeader(accessToken string) http.Header {
} }
// GetEmailAddress returns the Account email address // GetEmailAddress returns the Account email address
func (p *FacebookProvider) GetEmailAddress(s *sessions.SessionState) (string, error) { func (p *FacebookProvider) GetEmailAddress(ctx context.Context, s *sessions.SessionState) (string, error) {
if s.AccessToken == "" { if s.AccessToken == "" {
return "", errors.New("missing access token") return "", errors.New("missing access token")
} }
req, err := http.NewRequest("GET", p.ProfileURL.String()+"?fields=name,email", nil) req, err := http.NewRequestWithContext(ctx, "GET", p.ProfileURL.String()+"?fields=name,email", nil)
if err != nil { if err != nil {
return "", err return "", err
} }
@ -80,6 +83,6 @@ func (p *FacebookProvider) GetEmailAddress(s *sessions.SessionState) (string, er
} }
// ValidateSessionState validates the AccessToken // ValidateSessionState validates the AccessToken
func (p *FacebookProvider) ValidateSessionState(s *sessions.SessionState) bool { func (p *FacebookProvider) ValidateSessionState(ctx context.Context, s *sessions.SessionState) bool {
return validateToken(p, s.AccessToken, getFacebookHeader(s.AccessToken)) return validateToken(ctx, p, s.AccessToken, getFacebookHeader(s.AccessToken))
} }

View File

@ -1,6 +1,7 @@
package providers package providers
import ( import (
"context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
@ -22,6 +23,8 @@ type GitHubProvider struct {
Team string Team string
} }
var _ Provider = (*GitHubProvider)(nil)
// NewGitHubProvider initiates a new GitHubProvider // NewGitHubProvider initiates a new GitHubProvider
func NewGitHubProvider(p *ProviderData) *GitHubProvider { func NewGitHubProvider(p *ProviderData) *GitHubProvider {
p.ProviderName = "GitHub" p.ProviderName = "GitHub"
@ -69,7 +72,7 @@ func (p *GitHubProvider) SetOrgTeam(org, team string) {
} }
} }
func (p *GitHubProvider) hasOrg(accessToken string) (bool, error) { func (p *GitHubProvider) hasOrg(ctx context.Context, accessToken string) (bool, error) {
// https://developer.github.com/v3/orgs/#list-your-organizations // https://developer.github.com/v3/orgs/#list-your-organizations
var orgs []struct { var orgs []struct {
@ -93,7 +96,7 @@ func (p *GitHubProvider) hasOrg(accessToken string) (bool, error) {
Path: path.Join(p.ValidateURL.Path, "/user/orgs"), Path: path.Join(p.ValidateURL.Path, "/user/orgs"),
RawQuery: params.Encode(), RawQuery: params.Encode(),
} }
req, _ := http.NewRequest("GET", endpoint.String(), nil) req, _ := http.NewRequestWithContext(ctx, "GET", endpoint.String(), nil)
req.Header = getGitHubHeader(accessToken) req.Header = getGitHubHeader(accessToken)
resp, err := http.DefaultClient.Do(req) resp, err := http.DefaultClient.Do(req)
if err != nil { if err != nil {
@ -135,7 +138,7 @@ func (p *GitHubProvider) hasOrg(accessToken string) (bool, error) {
return false, nil return false, nil
} }
func (p *GitHubProvider) hasOrgAndTeam(accessToken string) (bool, error) { func (p *GitHubProvider) hasOrgAndTeam(ctx context.Context, accessToken string) (bool, error) {
// https://developer.github.com/v3/orgs/teams/#list-user-teams // https://developer.github.com/v3/orgs/teams/#list-user-teams
var teams []struct { var teams []struct {
@ -169,7 +172,7 @@ func (p *GitHubProvider) hasOrgAndTeam(accessToken string) (bool, error) {
RawQuery: params.Encode(), RawQuery: params.Encode(),
} }
req, _ := http.NewRequest("GET", endpoint.String(), nil) req, _ := http.NewRequestWithContext(ctx, "GET", endpoint.String(), nil)
req.Header = getGitHubHeader(accessToken) req.Header = getGitHubHeader(accessToken)
resp, err := http.DefaultClient.Do(req) resp, err := http.DefaultClient.Do(req)
if err != nil { if err != nil {
@ -261,7 +264,7 @@ func (p *GitHubProvider) hasOrgAndTeam(accessToken string) (bool, error) {
} }
// GetEmailAddress returns the Account email address // GetEmailAddress returns the Account email address
func (p *GitHubProvider) GetEmailAddress(s *sessions.SessionState) (string, error) { func (p *GitHubProvider) GetEmailAddress(ctx context.Context, s *sessions.SessionState) (string, error) {
var emails []struct { var emails []struct {
Email string `json:"email"` Email string `json:"email"`
@ -272,11 +275,11 @@ func (p *GitHubProvider) GetEmailAddress(s *sessions.SessionState) (string, erro
// if we require an Org or Team, check that first // if we require an Org or Team, check that first
if p.Org != "" { if p.Org != "" {
if p.Team != "" { if p.Team != "" {
if ok, err := p.hasOrgAndTeam(s.AccessToken); err != nil || !ok { if ok, err := p.hasOrgAndTeam(ctx, s.AccessToken); err != nil || !ok {
return "", err return "", err
} }
} else { } else {
if ok, err := p.hasOrg(s.AccessToken); err != nil || !ok { if ok, err := p.hasOrg(ctx, s.AccessToken); err != nil || !ok {
return "", err return "", err
} }
} }
@ -287,7 +290,7 @@ func (p *GitHubProvider) GetEmailAddress(s *sessions.SessionState) (string, erro
Host: p.ValidateURL.Host, Host: p.ValidateURL.Host,
Path: path.Join(p.ValidateURL.Path, "/user/emails"), Path: path.Join(p.ValidateURL.Path, "/user/emails"),
} }
req, _ := http.NewRequest("GET", endpoint.String(), nil) req, _ := http.NewRequestWithContext(ctx, "GET", endpoint.String(), nil)
req.Header = getGitHubHeader(s.AccessToken) req.Header = getGitHubHeader(s.AccessToken)
resp, err := http.DefaultClient.Do(req) resp, err := http.DefaultClient.Do(req)
if err != nil { if err != nil {
@ -324,7 +327,7 @@ func (p *GitHubProvider) GetEmailAddress(s *sessions.SessionState) (string, erro
} }
// GetUserName returns the Account user name // GetUserName returns the Account user name
func (p *GitHubProvider) GetUserName(s *sessions.SessionState) (string, error) { func (p *GitHubProvider) GetUserName(ctx context.Context, s *sessions.SessionState) (string, error) {
var user struct { var user struct {
Login string `json:"login"` Login string `json:"login"`
Email string `json:"email"` Email string `json:"email"`
@ -336,7 +339,7 @@ func (p *GitHubProvider) GetUserName(s *sessions.SessionState) (string, error) {
Path: path.Join(p.ValidateURL.Path, "/user"), Path: path.Join(p.ValidateURL.Path, "/user"),
} }
req, err := http.NewRequest("GET", endpoint.String(), nil) req, err := http.NewRequestWithContext(ctx, "GET", endpoint.String(), nil)
if err != nil { if err != nil {
return "", fmt.Errorf("could not create new GET request: %v", err) return "", fmt.Errorf("could not create new GET request: %v", err)
} }
@ -368,6 +371,6 @@ func (p *GitHubProvider) GetUserName(s *sessions.SessionState) (string, error) {
} }
// ValidateSessionState validates the AccessToken // ValidateSessionState validates the AccessToken
func (p *GitHubProvider) ValidateSessionState(s *sessions.SessionState) bool { func (p *GitHubProvider) ValidateSessionState(ctx context.Context, s *sessions.SessionState) bool {
return validateToken(p, s.AccessToken, getGitHubHeader(s.AccessToken)) return validateToken(ctx, p, s.AccessToken, getGitHubHeader(s.AccessToken))
} }

View File

@ -1,6 +1,7 @@
package providers package providers
import ( import (
"context"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/url" "net/url"
@ -105,7 +106,7 @@ func TestGitHubProviderGetEmailAddress(t *testing.T) {
p := testGitHubProvider(bURL.Host) p := testGitHubProvider(bURL.Host)
session := CreateAuthorizedSession() session := CreateAuthorizedSession()
email, err := p.GetEmailAddress(session) email, err := p.GetEmailAddress(context.Background(), session)
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
assert.Equal(t, "michael.bland@gsa.gov", email) assert.Equal(t, "michael.bland@gsa.gov", email)
} }
@ -118,7 +119,7 @@ func TestGitHubProviderGetEmailAddressNotVerified(t *testing.T) {
p := testGitHubProvider(bURL.Host) p := testGitHubProvider(bURL.Host)
session := CreateAuthorizedSession() session := CreateAuthorizedSession()
email, err := p.GetEmailAddress(session) email, err := p.GetEmailAddress(context.Background(), session)
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
assert.Empty(t, "", email) assert.Empty(t, "", email)
} }
@ -136,7 +137,7 @@ func TestGitHubProviderGetEmailAddressWithOrg(t *testing.T) {
p.Org = "testorg1" p.Org = "testorg1"
session := CreateAuthorizedSession() session := CreateAuthorizedSession()
email, err := p.GetEmailAddress(session) email, err := p.GetEmailAddress(context.Background(), session)
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
assert.Equal(t, "michael.bland@gsa.gov", email) assert.Equal(t, "michael.bland@gsa.gov", email)
} }
@ -154,7 +155,7 @@ func TestGitHubProviderGetEmailAddressFailedRequest(t *testing.T) {
// token. Alternatively, we could allow the parsing of the payload as // token. Alternatively, we could allow the parsing of the payload as
// JSON to fail. // JSON to fail.
session := &sessions.SessionState{AccessToken: "unexpected_access_token"} session := &sessions.SessionState{AccessToken: "unexpected_access_token"}
email, err := p.GetEmailAddress(session) email, err := p.GetEmailAddress(context.Background(), session)
assert.NotEqual(t, nil, err) assert.NotEqual(t, nil, err)
assert.Equal(t, "", email) assert.Equal(t, "", email)
} }
@ -167,7 +168,7 @@ func TestGitHubProviderGetEmailAddressEmailNotPresentInPayload(t *testing.T) {
p := testGitHubProvider(bURL.Host) p := testGitHubProvider(bURL.Host)
session := CreateAuthorizedSession() session := CreateAuthorizedSession()
email, err := p.GetEmailAddress(session) email, err := p.GetEmailAddress(context.Background(), session)
assert.NotEqual(t, nil, err) assert.NotEqual(t, nil, err)
assert.Equal(t, "", email) assert.Equal(t, "", email)
} }
@ -180,7 +181,7 @@ func TestGitHubProviderGetUserName(t *testing.T) {
p := testGitHubProvider(bURL.Host) p := testGitHubProvider(bURL.Host)
session := CreateAuthorizedSession() session := CreateAuthorizedSession()
email, err := p.GetUserName(session) email, err := p.GetUserName(context.Background(), session)
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
assert.Equal(t, "mbland", email) assert.Equal(t, "mbland", email)
} }

View File

@ -25,6 +25,8 @@ type GitLabProvider struct {
AllowUnverifiedEmail bool AllowUnverifiedEmail bool
} }
var _ Provider = (*GitLabProvider)(nil)
// NewGitLabProvider initiates a new GitLabProvider // NewGitLabProvider initiates a new GitLabProvider
func NewGitLabProvider(p *ProviderData) *GitLabProvider { func NewGitLabProvider(p *ProviderData) *GitLabProvider {
p.ProviderName = "GitLab" p.ProviderName = "GitLab"
@ -37,13 +39,12 @@ func NewGitLabProvider(p *ProviderData) *GitLabProvider {
} }
// Redeem exchanges the OAuth2 authentication token for an ID token // Redeem exchanges the OAuth2 authentication token for an ID token
func (p *GitLabProvider) Redeem(redirectURL, code string) (s *sessions.SessionState, err error) { func (p *GitLabProvider) Redeem(ctx context.Context, redirectURL, code string) (s *sessions.SessionState, err error) {
clientSecret, err := p.GetClientSecret() clientSecret, err := p.GetClientSecret()
if err != nil { if err != nil {
return return
} }
ctx := context.Background()
c := oauth2.Config{ c := oauth2.Config{
ClientID: p.ClientID, ClientID: p.ClientID,
ClientSecret: clientSecret, ClientSecret: clientSecret,
@ -65,14 +66,14 @@ func (p *GitLabProvider) Redeem(redirectURL, code string) (s *sessions.SessionSt
// RefreshSessionIfNeeded checks if the session has expired and uses the // RefreshSessionIfNeeded checks if the session has expired and uses the
// RefreshToken to fetch a new ID token if required // RefreshToken to fetch a new ID token if required
func (p *GitLabProvider) RefreshSessionIfNeeded(s *sessions.SessionState) (bool, error) { func (p *GitLabProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error) {
if s == nil || s.ExpiresOn.After(time.Now()) || s.RefreshToken == "" { if s == nil || s.ExpiresOn.After(time.Now()) || s.RefreshToken == "" {
return false, nil return false, nil
} }
origExpiration := s.ExpiresOn origExpiration := s.ExpiresOn
err := p.redeemRefreshToken(s) err := p.redeemRefreshToken(ctx, s)
if err != nil { if err != nil {
return false, fmt.Errorf("unable to redeem refresh token: %v", err) return false, fmt.Errorf("unable to redeem refresh token: %v", err)
} }
@ -81,7 +82,7 @@ func (p *GitLabProvider) RefreshSessionIfNeeded(s *sessions.SessionState) (bool,
return true, nil return true, nil
} }
func (p *GitLabProvider) redeemRefreshToken(s *sessions.SessionState) (err error) { func (p *GitLabProvider) redeemRefreshToken(ctx context.Context, s *sessions.SessionState) (err error) {
clientSecret, err := p.GetClientSecret() clientSecret, err := p.GetClientSecret()
if err != nil { if err != nil {
return return
@ -94,7 +95,6 @@ func (p *GitLabProvider) redeemRefreshToken(s *sessions.SessionState) (err error
TokenURL: p.RedeemURL.String(), TokenURL: p.RedeemURL.String(),
}, },
} }
ctx := context.Background()
t := &oauth2.Token{ t := &oauth2.Token{
RefreshToken: s.RefreshToken, RefreshToken: s.RefreshToken,
Expiry: time.Now().Add(-time.Hour), Expiry: time.Now().Add(-time.Hour),
@ -123,7 +123,7 @@ type gitlabUserInfo struct {
Groups []string `json:"groups"` Groups []string `json:"groups"`
} }
func (p *GitLabProvider) getUserInfo(s *sessions.SessionState) (*gitlabUserInfo, error) { func (p *GitLabProvider) getUserInfo(ctx context.Context, s *sessions.SessionState) (*gitlabUserInfo, error) {
// Retrieve user info JSON // Retrieve user info JSON
// https://docs.gitlab.com/ee/integration/openid_connect_provider.html#shared-information // https://docs.gitlab.com/ee/integration/openid_connect_provider.html#shared-information
@ -131,7 +131,7 @@ func (p *GitLabProvider) getUserInfo(s *sessions.SessionState) (*gitlabUserInfo,
userInfoURL := *p.LoginURL userInfoURL := *p.LoginURL
userInfoURL.Path = "/oauth/userinfo" userInfoURL.Path = "/oauth/userinfo"
req, err := http.NewRequest("GET", userInfoURL.String(), nil) req, err := http.NewRequestWithContext(ctx, "GET", userInfoURL.String(), nil)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to create user info request: %v", err) return nil, fmt.Errorf("failed to create user info request: %v", err)
} }
@ -219,16 +219,15 @@ func (p *GitLabProvider) createSessionState(ctx context.Context, token *oauth2.T
} }
// ValidateSessionState checks that the session's IDToken is still valid // ValidateSessionState checks that the session's IDToken is still valid
func (p *GitLabProvider) ValidateSessionState(s *sessions.SessionState) bool { func (p *GitLabProvider) ValidateSessionState(ctx context.Context, s *sessions.SessionState) bool {
ctx := context.Background()
_, err := p.Verifier.Verify(ctx, s.IDToken) _, err := p.Verifier.Verify(ctx, s.IDToken)
return err == nil return err == nil
} }
// GetEmailAddress returns the Account email address // GetEmailAddress returns the Account email address
func (p *GitLabProvider) GetEmailAddress(s *sessions.SessionState) (string, error) { func (p *GitLabProvider) GetEmailAddress(ctx context.Context, s *sessions.SessionState) (string, error) {
// Retrieve user info // Retrieve user info
userInfo, err := p.getUserInfo(s) userInfo, err := p.getUserInfo(ctx, s)
if err != nil { if err != nil {
return "", fmt.Errorf("failed to retrieve user info: %v", err) return "", fmt.Errorf("failed to retrieve user info: %v", err)
} }
@ -254,8 +253,8 @@ func (p *GitLabProvider) GetEmailAddress(s *sessions.SessionState) (string, erro
} }
// GetUserName returns the Account user name // GetUserName returns the Account user name
func (p *GitLabProvider) GetUserName(s *sessions.SessionState) (string, error) { func (p *GitLabProvider) GetUserName(ctx context.Context, s *sessions.SessionState) (string, error) {
userInfo, err := p.getUserInfo(s) userInfo, err := p.getUserInfo(ctx, s)
if err != nil { if err != nil {
return "", fmt.Errorf("failed to retrieve user info: %v", err) return "", fmt.Errorf("failed to retrieve user info: %v", err)
} }

View File

@ -1,6 +1,7 @@
package providers package providers
import ( import (
"context"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/url" "net/url"
@ -63,7 +64,7 @@ func TestGitLabProviderBadToken(t *testing.T) {
p := testGitLabProvider(bURL.Host) p := testGitLabProvider(bURL.Host)
session := &sessions.SessionState{AccessToken: "unexpected_gitlab_access_token"} session := &sessions.SessionState{AccessToken: "unexpected_gitlab_access_token"}
_, err := p.GetEmailAddress(session) _, err := p.GetEmailAddress(context.Background(), session)
assert.NotEqual(t, nil, err) assert.NotEqual(t, nil, err)
} }
@ -75,7 +76,7 @@ func TestGitLabProviderUnverifiedEmailDenied(t *testing.T) {
p := testGitLabProvider(bURL.Host) p := testGitLabProvider(bURL.Host)
session := &sessions.SessionState{AccessToken: "gitlab_access_token"} session := &sessions.SessionState{AccessToken: "gitlab_access_token"}
_, err := p.GetEmailAddress(session) _, err := p.GetEmailAddress(context.Background(), session)
assert.NotEqual(t, nil, err) assert.NotEqual(t, nil, err)
} }
@ -88,7 +89,7 @@ func TestGitLabProviderUnverifiedEmailAllowed(t *testing.T) {
p.AllowUnverifiedEmail = true p.AllowUnverifiedEmail = true
session := &sessions.SessionState{AccessToken: "gitlab_access_token"} session := &sessions.SessionState{AccessToken: "gitlab_access_token"}
email, err := p.GetEmailAddress(session) email, err := p.GetEmailAddress(context.Background(), session)
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
assert.Equal(t, "foo@bar.com", email) assert.Equal(t, "foo@bar.com", email)
} }
@ -102,7 +103,7 @@ func TestGitLabProviderUsername(t *testing.T) {
p.AllowUnverifiedEmail = true p.AllowUnverifiedEmail = true
session := &sessions.SessionState{AccessToken: "gitlab_access_token"} session := &sessions.SessionState{AccessToken: "gitlab_access_token"}
username, err := p.GetUserName(session) username, err := p.GetUserName(context.Background(), session)
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
assert.Equal(t, "FooBar", username) assert.Equal(t, "FooBar", username)
} }
@ -117,7 +118,7 @@ func TestGitLabProviderGroupMembershipValid(t *testing.T) {
p.Group = "foo" p.Group = "foo"
session := &sessions.SessionState{AccessToken: "gitlab_access_token"} session := &sessions.SessionState{AccessToken: "gitlab_access_token"}
email, err := p.GetEmailAddress(session) email, err := p.GetEmailAddress(context.Background(), session)
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
assert.Equal(t, "foo@bar.com", email) assert.Equal(t, "foo@bar.com", email)
} }
@ -132,7 +133,7 @@ func TestGitLabProviderGroupMembershipMissing(t *testing.T) {
p.Group = "baz" p.Group = "baz"
session := &sessions.SessionState{AccessToken: "gitlab_access_token"} session := &sessions.SessionState{AccessToken: "gitlab_access_token"}
_, err := p.GetEmailAddress(session) _, err := p.GetEmailAddress(context.Background(), session)
assert.NotEqual(t, nil, err) assert.NotEqual(t, nil, err)
} }
@ -146,7 +147,7 @@ func TestGitLabProviderEmailDomainValid(t *testing.T) {
p.EmailDomains = []string{"bar.com"} p.EmailDomains = []string{"bar.com"}
session := &sessions.SessionState{AccessToken: "gitlab_access_token"} session := &sessions.SessionState{AccessToken: "gitlab_access_token"}
email, err := p.GetEmailAddress(session) email, err := p.GetEmailAddress(context.Background(), session)
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
assert.Equal(t, "foo@bar.com", email) assert.Equal(t, "foo@bar.com", email)
} }
@ -161,6 +162,6 @@ func TestGitLabProviderEmailDomainInvalid(t *testing.T) {
p.EmailDomains = []string{"baz.com"} p.EmailDomains = []string{"baz.com"}
session := &sessions.SessionState{AccessToken: "gitlab_access_token"} session := &sessions.SessionState{AccessToken: "gitlab_access_token"}
_, err := p.GetEmailAddress(session) _, err := p.GetEmailAddress(context.Background(), session)
assert.NotEqual(t, nil, err) assert.NotEqual(t, nil, err)
} }

View File

@ -31,6 +31,8 @@ type GoogleProvider struct {
GroupValidator func(string) bool GroupValidator func(string) bool
} }
var _ Provider = (*GoogleProvider)(nil)
type claims struct { type claims struct {
Subject string `json:"sub"` Subject string `json:"sub"`
Email string `json:"email"` Email string `json:"email"`
@ -98,7 +100,7 @@ func claimsFromIDToken(idToken string) (*claims, error) {
} }
// Redeem exchanges the OAuth2 authentication token for an ID token // Redeem exchanges the OAuth2 authentication token for an ID token
func (p *GoogleProvider) Redeem(redirectURL, code string) (s *sessions.SessionState, err error) { func (p *GoogleProvider) Redeem(ctx context.Context, redirectURL, code string) (s *sessions.SessionState, err error) {
if code == "" { if code == "" {
err = errors.New("missing code") err = errors.New("missing code")
return return
@ -115,7 +117,7 @@ func (p *GoogleProvider) Redeem(redirectURL, code string) (s *sessions.SessionSt
params.Add("code", code) params.Add("code", code)
params.Add("grant_type", "authorization_code") params.Add("grant_type", "authorization_code")
var req *http.Request var req *http.Request
req, err = http.NewRequest("POST", p.RedeemURL.String(), bytes.NewBufferString(params.Encode())) req, err = http.NewRequestWithContext(ctx, "POST", p.RedeemURL.String(), bytes.NewBufferString(params.Encode()))
if err != nil { if err != nil {
return return
} }
@ -242,12 +244,12 @@ func (p *GoogleProvider) ValidateGroup(email string) bool {
// RefreshSessionIfNeeded checks if the session has expired and uses the // RefreshSessionIfNeeded checks if the session has expired and uses the
// RefreshToken to fetch a new ID token if required // RefreshToken to fetch a new ID token if required
func (p *GoogleProvider) RefreshSessionIfNeeded(s *sessions.SessionState) (bool, error) { func (p *GoogleProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error) {
if s == nil || s.ExpiresOn.After(time.Now()) || s.RefreshToken == "" { if s == nil || s.ExpiresOn.After(time.Now()) || s.RefreshToken == "" {
return false, nil return false, nil
} }
newToken, newIDToken, duration, err := p.redeemRefreshToken(s.RefreshToken) newToken, newIDToken, duration, err := p.redeemRefreshToken(ctx, s.RefreshToken)
if err != nil { if err != nil {
return false, err return false, err
} }
@ -265,7 +267,7 @@ func (p *GoogleProvider) RefreshSessionIfNeeded(s *sessions.SessionState) (bool,
return true, nil return true, nil
} }
func (p *GoogleProvider) redeemRefreshToken(refreshToken string) (token string, idToken string, expires time.Duration, err error) { func (p *GoogleProvider) redeemRefreshToken(ctx context.Context, refreshToken string) (token string, idToken string, expires time.Duration, err error) {
// https://developers.google.com/identity/protocols/OAuth2WebServer#refresh // https://developers.google.com/identity/protocols/OAuth2WebServer#refresh
clientSecret, err := p.GetClientSecret() clientSecret, err := p.GetClientSecret()
if err != nil { if err != nil {
@ -278,7 +280,7 @@ func (p *GoogleProvider) redeemRefreshToken(refreshToken string) (token string,
params.Add("refresh_token", refreshToken) params.Add("refresh_token", refreshToken)
params.Add("grant_type", "refresh_token") params.Add("grant_type", "refresh_token")
var req *http.Request var req *http.Request
req, err = http.NewRequest("POST", p.RedeemURL.String(), bytes.NewBufferString(params.Encode())) req, err = http.NewRequestWithContext(ctx, "POST", p.RedeemURL.String(), bytes.NewBufferString(params.Encode()))
if err != nil { if err != nil {
return return
} }

View File

@ -102,7 +102,7 @@ func TestGoogleProviderGetEmailAddress(t *testing.T) {
p.RedeemURL, server = newRedeemServer(body) p.RedeemURL, server = newRedeemServer(body)
defer server.Close() defer server.Close()
session, err := p.Redeem("http://redirect/", "code1234") session, err := p.Redeem(context.Background(), "http://redirect/", "code1234")
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
assert.NotEqual(t, session, nil) assert.NotEqual(t, session, nil)
assert.Equal(t, "michael.bland@gsa.gov", session.Email) assert.Equal(t, "michael.bland@gsa.gov", session.Email)
@ -139,7 +139,7 @@ func TestGoogleProviderGetEmailAddressInvalidEncoding(t *testing.T) {
p.RedeemURL, server = newRedeemServer(body) p.RedeemURL, server = newRedeemServer(body)
defer server.Close() defer server.Close()
session, err := p.Redeem("http://redirect/", "code1234") session, err := p.Redeem(context.Background(), "http://redirect/", "code1234")
assert.NotEqual(t, nil, err) assert.NotEqual(t, nil, err)
if session != nil { if session != nil {
t.Errorf("expect nill session %#v", session) t.Errorf("expect nill session %#v", session)
@ -150,7 +150,7 @@ func TestGoogleProviderRedeemFailsNoCLientSecret(t *testing.T) {
p := newGoogleProvider() p := newGoogleProvider()
p.ProviderData.ClientSecretFile = "srvnoerre" p.ProviderData.ClientSecretFile = "srvnoerre"
session, err := p.Redeem("http://redirect/", "code1234") session, err := p.Redeem(context.Background(), "http://redirect/", "code1234")
assert.NotEqual(t, nil, err) assert.NotEqual(t, nil, err)
if session != nil { if session != nil {
t.Errorf("expect nill session %#v", session) t.Errorf("expect nill session %#v", session)
@ -170,7 +170,7 @@ func TestGoogleProviderGetEmailAddressInvalidJson(t *testing.T) {
p.RedeemURL, server = newRedeemServer(body) p.RedeemURL, server = newRedeemServer(body)
defer server.Close() defer server.Close()
session, err := p.Redeem("http://redirect/", "code1234") session, err := p.Redeem(context.Background(), "http://redirect/", "code1234")
assert.NotEqual(t, nil, err) assert.NotEqual(t, nil, err)
if session != nil { if session != nil {
t.Errorf("expect nill session %#v", session) t.Errorf("expect nill session %#v", session)
@ -189,7 +189,7 @@ func TestGoogleProviderGetEmailAddressEmailMissing(t *testing.T) {
p.RedeemURL, server = newRedeemServer(body) p.RedeemURL, server = newRedeemServer(body)
defer server.Close() defer server.Close()
session, err := p.Redeem("http://redirect/", "code1234") session, err := p.Redeem(context.Background(), "http://redirect/", "code1234")
assert.NotEqual(t, nil, err) assert.NotEqual(t, nil, err)
if session != nil { if session != nil {
t.Errorf("expect nill session %#v", session) t.Errorf("expect nill session %#v", session)

View File

@ -1,6 +1,7 @@
package providers package providers
import ( import (
"context"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"net/url" "net/url"
@ -46,7 +47,7 @@ func stripParam(param, endpoint string) string {
} }
// validateToken returns true if token is valid // validateToken returns true if token is valid
func validateToken(p Provider, accessToken string, header http.Header) bool { func validateToken(ctx context.Context, p Provider, accessToken string, header http.Header) bool {
if accessToken == "" || p.Data().ValidateURL == nil || p.Data().ValidateURL.String() == "" { if accessToken == "" || p.Data().ValidateURL == nil || p.Data().ValidateURL.String() == "" {
return false return false
} }
@ -55,7 +56,7 @@ func validateToken(p Provider, accessToken string, header http.Header) bool {
params := url.Values{"access_token": {accessToken}} params := url.Values{"access_token": {accessToken}}
endpoint = endpoint + "?" + params.Encode() endpoint = endpoint + "?" + params.Encode()
} }
resp, err := requests.RequestUnparsedResponse(endpoint, header) resp, err := requests.RequestUnparsedResponse(ctx, endpoint, header)
if err != nil { if err != nil {
logger.Printf("GET %s", stripToken(endpoint)) logger.Printf("GET %s", stripToken(endpoint))
logger.Printf("token validation request failed: %s", err) logger.Printf("token validation request failed: %s", err)

View File

@ -1,6 +1,7 @@
package providers package providers
import ( import (
"context"
"errors" "errors"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
@ -20,13 +21,15 @@ type ValidateSessionStateTestProvider struct {
*ProviderData *ProviderData
} }
func (tp *ValidateSessionStateTestProvider) GetEmailAddress(s *sessions.SessionState) (string, error) { var _ Provider = (*ValidateSessionStateTestProvider)(nil)
func (tp *ValidateSessionStateTestProvider) GetEmailAddress(ctx context.Context, s *sessions.SessionState) (string, error) {
return "", errors.New("not implemented") return "", errors.New("not implemented")
} }
// Note that we're testing the internal validateToken() used to implement // Note that we're testing the internal validateToken() used to implement
// several Provider's ValidateSessionState() implementations // several Provider's ValidateSessionState() implementations
func (tp *ValidateSessionStateTestProvider) ValidateSessionState(s *sessions.SessionState) bool { func (tp *ValidateSessionStateTestProvider) ValidateSessionState(ctx context.Context, s *sessions.SessionState) bool {
return false return false
} }
@ -87,7 +90,7 @@ func (vtTest *ValidateSessionStateTest) Close() {
func TestValidateSessionStateValidToken(t *testing.T) { func TestValidateSessionStateValidToken(t *testing.T) {
vtTest := NewValidateSessionStateTest() vtTest := NewValidateSessionStateTest()
defer vtTest.Close() defer vtTest.Close()
assert.Equal(t, true, validateToken(vtTest.provider, "foobar", nil)) assert.Equal(t, true, validateToken(context.Background(), vtTest.provider, "foobar", nil))
} }
func TestValidateSessionStateValidTokenWithHeaders(t *testing.T) { func TestValidateSessionStateValidTokenWithHeaders(t *testing.T) {
@ -96,34 +99,34 @@ func TestValidateSessionStateValidTokenWithHeaders(t *testing.T) {
vtTest.header = make(http.Header) vtTest.header = make(http.Header)
vtTest.header.Set("Authorization", "Bearer foobar") vtTest.header.Set("Authorization", "Bearer foobar")
assert.Equal(t, true, assert.Equal(t, true,
validateToken(vtTest.provider, "foobar", vtTest.header)) validateToken(context.Background(), vtTest.provider, "foobar", vtTest.header))
} }
func TestValidateSessionStateEmptyToken(t *testing.T) { func TestValidateSessionStateEmptyToken(t *testing.T) {
vtTest := NewValidateSessionStateTest() vtTest := NewValidateSessionStateTest()
defer vtTest.Close() defer vtTest.Close()
assert.Equal(t, false, validateToken(vtTest.provider, "", nil)) assert.Equal(t, false, validateToken(context.Background(), vtTest.provider, "", nil))
} }
func TestValidateSessionStateEmptyValidateURL(t *testing.T) { func TestValidateSessionStateEmptyValidateURL(t *testing.T) {
vtTest := NewValidateSessionStateTest() vtTest := NewValidateSessionStateTest()
defer vtTest.Close() defer vtTest.Close()
vtTest.provider.Data().ValidateURL = nil vtTest.provider.Data().ValidateURL = nil
assert.Equal(t, false, validateToken(vtTest.provider, "foobar", nil)) assert.Equal(t, false, validateToken(context.Background(), vtTest.provider, "foobar", nil))
} }
func TestValidateSessionStateRequestNetworkFailure(t *testing.T) { func TestValidateSessionStateRequestNetworkFailure(t *testing.T) {
vtTest := NewValidateSessionStateTest() vtTest := NewValidateSessionStateTest()
// Close immediately to simulate a network failure // Close immediately to simulate a network failure
vtTest.Close() vtTest.Close()
assert.Equal(t, false, validateToken(vtTest.provider, "foobar", nil)) assert.Equal(t, false, validateToken(context.Background(), vtTest.provider, "foobar", nil))
} }
func TestValidateSessionStateExpiredToken(t *testing.T) { func TestValidateSessionStateExpiredToken(t *testing.T) {
vtTest := NewValidateSessionStateTest() vtTest := NewValidateSessionStateTest()
defer vtTest.Close() defer vtTest.Close()
vtTest.responseCode = 401 vtTest.responseCode = 401
assert.Equal(t, false, validateToken(vtTest.provider, "foobar", nil)) assert.Equal(t, false, validateToken(context.Background(), vtTest.provider, "foobar", nil))
} }
func TestStripTokenNotPresent(t *testing.T) { func TestStripTokenNotPresent(t *testing.T) {

View File

@ -1,6 +1,7 @@
package providers package providers
import ( import (
"context"
"net/http" "net/http"
"net/url" "net/url"
@ -14,6 +15,8 @@ type KeycloakProvider struct {
Group string Group string
} }
var _ Provider = (*KeycloakProvider)(nil)
func NewKeycloakProvider(p *ProviderData) *KeycloakProvider { func NewKeycloakProvider(p *ProviderData) *KeycloakProvider {
p.ProviderName = "Keycloak" p.ProviderName = "Keycloak"
if p.LoginURL == nil || p.LoginURL.String() == "" { if p.LoginURL == nil || p.LoginURL.String() == "" {
@ -47,9 +50,9 @@ func (p *KeycloakProvider) SetGroup(group string) {
p.Group = group p.Group = group
} }
func (p *KeycloakProvider) GetEmailAddress(s *sessions.SessionState) (string, error) { func (p *KeycloakProvider) GetEmailAddress(ctx context.Context, s *sessions.SessionState) (string, error) {
req, err := http.NewRequest("GET", p.ValidateURL.String(), nil) req, err := http.NewRequestWithContext(ctx, "GET", p.ValidateURL.String(), nil)
req.Header.Set("Authorization", "Bearer "+s.AccessToken) req.Header.Set("Authorization", "Bearer "+s.AccessToken)
if err != nil { if err != nil {
logger.Printf("failed building request %s", err) logger.Printf("failed building request %s", err)

View File

@ -1,6 +1,7 @@
package providers package providers
import ( import (
"context"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/url" "net/url"
@ -99,7 +100,7 @@ func TestKeycloakProviderGetEmailAddress(t *testing.T) {
p := testKeycloakProvider(bURL.Host, "") p := testKeycloakProvider(bURL.Host, "")
session := CreateAuthorizedSession() session := CreateAuthorizedSession()
email, err := p.GetEmailAddress(session) email, err := p.GetEmailAddress(context.Background(), session)
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
assert.Equal(t, "michael.bland@gsa.gov", email) assert.Equal(t, "michael.bland@gsa.gov", email)
} }
@ -112,7 +113,7 @@ func TestKeycloakProviderGetEmailAddressAndGroup(t *testing.T) {
p := testKeycloakProvider(bURL.Host, "test-grp1") p := testKeycloakProvider(bURL.Host, "test-grp1")
session := CreateAuthorizedSession() session := CreateAuthorizedSession()
email, err := p.GetEmailAddress(session) email, err := p.GetEmailAddress(context.Background(), session)
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
assert.Equal(t, "michael.bland@gsa.gov", email) assert.Equal(t, "michael.bland@gsa.gov", email)
} }
@ -130,7 +131,7 @@ func TestKeycloakProviderGetEmailAddressFailedRequest(t *testing.T) {
// token. Alternatively, we could allow the parsing of the payload as // token. Alternatively, we could allow the parsing of the payload as
// JSON to fail. // JSON to fail.
session := &sessions.SessionState{AccessToken: "unexpected_access_token"} session := &sessions.SessionState{AccessToken: "unexpected_access_token"}
email, err := p.GetEmailAddress(session) email, err := p.GetEmailAddress(context.Background(), session)
assert.NotEqual(t, nil, err) assert.NotEqual(t, nil, err)
assert.Equal(t, "", email) assert.Equal(t, "", email)
} }
@ -143,7 +144,7 @@ func TestKeycloakProviderGetEmailAddressEmailNotPresentInPayload(t *testing.T) {
p := testKeycloakProvider(bURL.Host, "") p := testKeycloakProvider(bURL.Host, "")
session := CreateAuthorizedSession() session := CreateAuthorizedSession()
email, err := p.GetEmailAddress(session) email, err := p.GetEmailAddress(context.Background(), session)
assert.NotEqual(t, nil, err) assert.NotEqual(t, nil, err)
assert.Equal(t, "", email) assert.Equal(t, "", email)
} }

View File

@ -1,6 +1,7 @@
package providers package providers
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"net/http" "net/http"
@ -15,6 +16,8 @@ type LinkedInProvider struct {
*ProviderData *ProviderData
} }
var _ Provider = (*LinkedInProvider)(nil)
// NewLinkedInProvider initiates a new LinkedInProvider // NewLinkedInProvider initiates a new LinkedInProvider
func NewLinkedInProvider(p *ProviderData) *LinkedInProvider { func NewLinkedInProvider(p *ProviderData) *LinkedInProvider {
p.ProviderName = "LinkedIn" p.ProviderName = "LinkedIn"
@ -51,11 +54,11 @@ func getLinkedInHeader(accessToken string) http.Header {
} }
// GetEmailAddress returns the Account email address // GetEmailAddress returns the Account email address
func (p *LinkedInProvider) GetEmailAddress(s *sessions.SessionState) (string, error) { func (p *LinkedInProvider) GetEmailAddress(ctx context.Context, s *sessions.SessionState) (string, error) {
if s.AccessToken == "" { if s.AccessToken == "" {
return "", errors.New("missing access token") return "", errors.New("missing access token")
} }
req, err := http.NewRequest("GET", p.ProfileURL.String()+"?format=json", nil) req, err := http.NewRequestWithContext(ctx, "GET", p.ProfileURL.String()+"?format=json", nil)
if err != nil { if err != nil {
return "", err return "", err
} }
@ -74,6 +77,6 @@ func (p *LinkedInProvider) GetEmailAddress(s *sessions.SessionState) (string, er
} }
// ValidateSessionState validates the AccessToken // ValidateSessionState validates the AccessToken
func (p *LinkedInProvider) ValidateSessionState(s *sessions.SessionState) bool { func (p *LinkedInProvider) ValidateSessionState(ctx context.Context, s *sessions.SessionState) bool {
return validateToken(p, s.AccessToken, getLinkedInHeader(s.AccessToken)) return validateToken(ctx, p, s.AccessToken, getLinkedInHeader(s.AccessToken))
} }

View File

@ -1,6 +1,7 @@
package providers package providers
import ( import (
"context"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/url" "net/url"
@ -99,7 +100,7 @@ func TestLinkedInProviderGetEmailAddress(t *testing.T) {
p := testLinkedInProvider(bURL.Host) p := testLinkedInProvider(bURL.Host)
session := CreateAuthorizedSession() session := CreateAuthorizedSession()
email, err := p.GetEmailAddress(session) email, err := p.GetEmailAddress(context.Background(), session)
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
assert.Equal(t, "user@linkedin.com", email) assert.Equal(t, "user@linkedin.com", email)
} }
@ -115,7 +116,7 @@ func TestLinkedInProviderGetEmailAddressFailedRequest(t *testing.T) {
// token. Alternatively, we could allow the parsing of the payload as // token. Alternatively, we could allow the parsing of the payload as
// JSON to fail. // JSON to fail.
session := &sessions.SessionState{AccessToken: "unexpected_access_token"} session := &sessions.SessionState{AccessToken: "unexpected_access_token"}
email, err := p.GetEmailAddress(session) email, err := p.GetEmailAddress(context.Background(), session)
assert.NotEqual(t, nil, err) assert.NotEqual(t, nil, err)
assert.Equal(t, "", email) assert.Equal(t, "", email)
} }
@ -128,7 +129,7 @@ func TestLinkedInProviderGetEmailAddressEmailNotPresentInPayload(t *testing.T) {
p := testLinkedInProvider(bURL.Host) p := testLinkedInProvider(bURL.Host)
session := CreateAuthorizedSession() session := CreateAuthorizedSession()
email, err := p.GetEmailAddress(session) email, err := p.GetEmailAddress(context.Background(), session)
assert.NotEqual(t, nil, err) assert.NotEqual(t, nil, err)
assert.Equal(t, "", email) assert.Equal(t, "", email)
} }

View File

@ -2,6 +2,7 @@ package providers
import ( import (
"bytes" "bytes"
"context"
"crypto/rsa" "crypto/rsa"
"encoding/json" "encoding/json"
"errors" "errors"
@ -28,6 +29,8 @@ type LoginGovProvider struct {
PubJWKURL *url.URL PubJWKURL *url.URL
} }
var _ Provider = (*LoginGovProvider)(nil)
// For generating a nonce // For generating a nonce
var letters = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ") var letters = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ")
@ -125,10 +128,10 @@ func checkNonce(idToken string, p *LoginGovProvider) (err error) {
return return
} }
func emailFromUserInfo(accessToken string, userInfoEndpoint string) (email string, err error) { func emailFromUserInfo(ctx context.Context, accessToken string, userInfoEndpoint string) (email string, err error) {
// query the user info endpoint for user attributes // query the user info endpoint for user attributes
var req *http.Request var req *http.Request
req, err = http.NewRequest("GET", userInfoEndpoint, nil) req, err = http.NewRequestWithContext(ctx, "GET", userInfoEndpoint, nil)
if err != nil { if err != nil {
return return
} }
@ -173,7 +176,7 @@ func emailFromUserInfo(accessToken string, userInfoEndpoint string) (email strin
} }
// Redeem exchanges the OAuth2 authentication token for an ID token // Redeem exchanges the OAuth2 authentication token for an ID token
func (p *LoginGovProvider) Redeem(redirectURL, code string) (s *sessions.SessionState, err error) { func (p *LoginGovProvider) Redeem(ctx context.Context, redirectURL, code string) (s *sessions.SessionState, err error) {
if code == "" { if code == "" {
err = errors.New("missing code") err = errors.New("missing code")
return return
@ -199,7 +202,7 @@ func (p *LoginGovProvider) Redeem(redirectURL, code string) (s *sessions.Session
params.Add("grant_type", "authorization_code") params.Add("grant_type", "authorization_code")
var req *http.Request var req *http.Request
req, err = http.NewRequest("POST", p.RedeemURL.String(), bytes.NewBufferString(params.Encode())) req, err = http.NewRequestWithContext(ctx, "POST", p.RedeemURL.String(), bytes.NewBufferString(params.Encode()))
if err != nil { if err != nil {
return return
} }
@ -242,7 +245,7 @@ func (p *LoginGovProvider) Redeem(redirectURL, code string) (s *sessions.Session
// Get the email address // Get the email address
var email string var email string
email, err = emailFromUserInfo(jsonResponse.AccessToken, p.ProfileURL.String()) email, err = emailFromUserInfo(ctx, jsonResponse.AccessToken, p.ProfileURL.String())
if err != nil { if err != nil {
return return
} }

View File

@ -1,6 +1,7 @@
package providers package providers
import ( import (
"context"
"crypto" "crypto"
"crypto/rand" "crypto/rand"
"crypto/rsa" "crypto/rsa"
@ -189,7 +190,7 @@ func TestLoginGovProviderSessionData(t *testing.T) {
p.PubJWKURL, pubjwkserver = newLoginGovServer(pubjwkbody) p.PubJWKURL, pubjwkserver = newLoginGovServer(pubjwkbody)
defer pubjwkserver.Close() defer pubjwkserver.Close()
session, err := p.Redeem("http://redirect/", "code1234") session, err := p.Redeem(context.Background(), "http://redirect/", "code1234")
assert.NoError(t, err) assert.NoError(t, err)
assert.NotEqual(t, session, nil) assert.NotEqual(t, session, nil)
assert.Equal(t, "timothy.spencer@gsa.gov", session.Email) assert.Equal(t, "timothy.spencer@gsa.gov", session.Email)
@ -283,7 +284,7 @@ func TestLoginGovProviderBadNonce(t *testing.T) {
p.PubJWKURL, pubjwkserver = newLoginGovServer(pubjwkbody) p.PubJWKURL, pubjwkserver = newLoginGovServer(pubjwkbody)
defer pubjwkserver.Close() defer pubjwkserver.Close()
_, err = p.Redeem("http://redirect/", "code1234") _, err = p.Redeem(context.Background(), "http://redirect/", "code1234")
// The "badfakenonce" in the idtoken above should cause this to error out // The "badfakenonce" in the idtoken above should cause this to error out
assert.Error(t, err) assert.Error(t, err)

View File

@ -1,6 +1,7 @@
package providers package providers
import ( import (
"context"
"fmt" "fmt"
"net/http" "net/http"
@ -14,6 +15,8 @@ type NextcloudProvider struct {
*ProviderData *ProviderData
} }
var _ Provider = (*NextcloudProvider)(nil)
// NewNextcloudProvider initiates a new NextcloudProvider // NewNextcloudProvider initiates a new NextcloudProvider
func NewNextcloudProvider(p *ProviderData) *NextcloudProvider { func NewNextcloudProvider(p *ProviderData) *NextcloudProvider {
p.ProviderName = "Nextcloud" p.ProviderName = "Nextcloud"
@ -27,8 +30,8 @@ func getNextcloudHeader(accessToken string) http.Header {
} }
// GetEmailAddress returns the Account email address // GetEmailAddress returns the Account email address
func (p *NextcloudProvider) GetEmailAddress(s *sessions.SessionState) (string, error) { func (p *NextcloudProvider) GetEmailAddress(ctx context.Context, s *sessions.SessionState) (string, error) {
req, err := http.NewRequest("GET", req, err := http.NewRequestWithContext(ctx, "GET",
p.ValidateURL.String(), nil) p.ValidateURL.String(), nil)
if err != nil { if err != nil {
logger.Printf("failed building request %s", err) logger.Printf("failed building request %s", err)

View File

@ -1,6 +1,7 @@
package providers package providers
import ( import (
"context"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/url" "net/url"
@ -97,7 +98,7 @@ func TestNextcloudProviderGetEmailAddress(t *testing.T) {
p.ValidateURL.RawQuery = formatJSON p.ValidateURL.RawQuery = formatJSON
session := CreateAuthorizedSession() session := CreateAuthorizedSession()
email, err := p.GetEmailAddress(session) email, err := p.GetEmailAddress(context.Background(), session)
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
assert.Equal(t, "michael.bland@gsa.gov", email) assert.Equal(t, "michael.bland@gsa.gov", email)
} }
@ -117,7 +118,7 @@ func TestNextcloudProviderGetEmailAddressFailedRequest(t *testing.T) {
// token. Alternatively, we could allow the parsing of the payload as // token. Alternatively, we could allow the parsing of the payload as
// JSON to fail. // JSON to fail.
session := &sessions.SessionState{AccessToken: "unexpected_access_token"} session := &sessions.SessionState{AccessToken: "unexpected_access_token"}
email, err := p.GetEmailAddress(session) email, err := p.GetEmailAddress(context.Background(), session)
assert.NotEqual(t, nil, err) assert.NotEqual(t, nil, err)
assert.Equal(t, "", email) assert.Equal(t, "", email)
} }
@ -132,7 +133,7 @@ func TestNextcloudProviderGetEmailAddressEmailNotPresentInPayload(t *testing.T)
p.ValidateURL.RawQuery = formatJSON p.ValidateURL.RawQuery = formatJSON
session := CreateAuthorizedSession() session := CreateAuthorizedSession()
email, err := p.GetEmailAddress(session) email, err := p.GetEmailAddress(context.Background(), session)
assert.NotEqual(t, nil, err) assert.NotEqual(t, nil, err)
assert.Equal(t, "", email) assert.Equal(t, "", email)
} }

View File

@ -31,14 +31,15 @@ func NewOIDCProvider(p *ProviderData) *OIDCProvider {
return &OIDCProvider{ProviderData: p} return &OIDCProvider{ProviderData: p}
} }
var _ Provider = (*OIDCProvider)(nil)
// Redeem exchanges the OAuth2 authentication token for an ID token // Redeem exchanges the OAuth2 authentication token for an ID token
func (p *OIDCProvider) Redeem(redirectURL, code string) (s *sessions.SessionState, err error) { func (p *OIDCProvider) Redeem(ctx context.Context, redirectURL, code string) (s *sessions.SessionState, err error) {
clientSecret, err := p.GetClientSecret() clientSecret, err := p.GetClientSecret()
if err != nil { if err != nil {
return return
} }
ctx := context.Background()
c := oauth2.Config{ c := oauth2.Config{
ClientID: p.ClientID, ClientID: p.ClientID,
ClientSecret: clientSecret, ClientSecret: clientSecret,
@ -60,7 +61,7 @@ func (p *OIDCProvider) Redeem(redirectURL, code string) (s *sessions.SessionStat
return nil, fmt.Errorf("token response did not contain an id_token") return nil, fmt.Errorf("token response did not contain an id_token")
} }
s, err = p.createSessionState(token, idToken) s, err = p.createSessionState(ctx, token, idToken)
if err != nil { if err != nil {
return nil, fmt.Errorf("unable to update session: %v", err) return nil, fmt.Errorf("unable to update session: %v", err)
} }
@ -70,12 +71,12 @@ func (p *OIDCProvider) Redeem(redirectURL, code string) (s *sessions.SessionStat
// RefreshSessionIfNeeded checks if the session has expired and uses the // RefreshSessionIfNeeded checks if the session has expired and uses the
// RefreshToken to fetch a new Access Token (and optional ID token) if required // RefreshToken to fetch a new Access Token (and optional ID token) if required
func (p *OIDCProvider) RefreshSessionIfNeeded(s *sessions.SessionState) (bool, error) { func (p *OIDCProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error) {
if s == nil || s.ExpiresOn.After(time.Now()) || s.RefreshToken == "" { if s == nil || s.ExpiresOn.After(time.Now()) || s.RefreshToken == "" {
return false, nil return false, nil
} }
err := p.redeemRefreshToken(s) err := p.redeemRefreshToken(ctx, s)
if err != nil { if err != nil {
return false, fmt.Errorf("unable to redeem refresh token: %v", err) return false, fmt.Errorf("unable to redeem refresh token: %v", err)
} }
@ -84,7 +85,7 @@ func (p *OIDCProvider) RefreshSessionIfNeeded(s *sessions.SessionState) (bool, e
return true, nil return true, nil
} }
func (p *OIDCProvider) redeemRefreshToken(s *sessions.SessionState) (err error) { func (p *OIDCProvider) redeemRefreshToken(ctx context.Context, s *sessions.SessionState) (err error) {
clientSecret, err := p.GetClientSecret() clientSecret, err := p.GetClientSecret()
if err != nil { if err != nil {
return return
@ -97,7 +98,6 @@ func (p *OIDCProvider) redeemRefreshToken(s *sessions.SessionState) (err error)
TokenURL: p.RedeemURL.String(), TokenURL: p.RedeemURL.String(),
}, },
} }
ctx := context.Background()
t := &oauth2.Token{ t := &oauth2.Token{
RefreshToken: s.RefreshToken, RefreshToken: s.RefreshToken,
Expiry: time.Now().Add(-time.Hour), Expiry: time.Now().Add(-time.Hour),
@ -113,7 +113,7 @@ func (p *OIDCProvider) redeemRefreshToken(s *sessions.SessionState) (err error)
return fmt.Errorf("unable to extract id_token from response: %v", err) return fmt.Errorf("unable to extract id_token from response: %v", err)
} }
newSession, err := p.createSessionState(token, idToken) newSession, err := p.createSessionState(ctx, token, idToken)
if err != nil { if err != nil {
return fmt.Errorf("unable create new session state from response: %v", err) return fmt.Errorf("unable create new session state from response: %v", err)
} }
@ -149,7 +149,7 @@ func (p *OIDCProvider) findVerifiedIDToken(ctx context.Context, token *oauth2.To
return nil, nil return nil, nil
} }
func (p *OIDCProvider) createSessionState(token *oauth2.Token, idToken *oidc.IDToken) (*sessions.SessionState, error) { func (p *OIDCProvider) createSessionState(ctx context.Context, token *oauth2.Token, idToken *oidc.IDToken) (*sessions.SessionState, error) {
var newSession *sessions.SessionState var newSession *sessions.SessionState
@ -157,7 +157,7 @@ func (p *OIDCProvider) createSessionState(token *oauth2.Token, idToken *oidc.IDT
newSession = &sessions.SessionState{} newSession = &sessions.SessionState{}
} else { } else {
var err error var err error
newSession, err = p.createSessionStateInternal(token.Extra("id_token").(string), idToken, token) newSession, err = p.createSessionStateInternal(ctx, token.Extra("id_token").(string), idToken, token)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -170,8 +170,8 @@ func (p *OIDCProvider) createSessionState(token *oauth2.Token, idToken *oidc.IDT
return newSession, nil return newSession, nil
} }
func (p *OIDCProvider) CreateSessionStateFromBearerToken(rawIDToken string, idToken *oidc.IDToken) (*sessions.SessionState, error) { func (p *OIDCProvider) CreateSessionStateFromBearerToken(ctx context.Context, rawIDToken string, idToken *oidc.IDToken) (*sessions.SessionState, error) {
newSession, err := p.createSessionStateInternal(rawIDToken, idToken, nil) newSession, err := p.createSessionStateInternal(ctx, rawIDToken, idToken, nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -184,7 +184,7 @@ func (p *OIDCProvider) CreateSessionStateFromBearerToken(rawIDToken string, idTo
return newSession, nil return newSession, nil
} }
func (p *OIDCProvider) createSessionStateInternal(rawIDToken string, idToken *oidc.IDToken, token *oauth2.Token) (*sessions.SessionState, error) { func (p *OIDCProvider) createSessionStateInternal(ctx context.Context, rawIDToken string, idToken *oidc.IDToken, token *oauth2.Token) (*sessions.SessionState, error) {
newSession := &sessions.SessionState{} newSession := &sessions.SessionState{}
@ -196,7 +196,7 @@ func (p *OIDCProvider) createSessionStateInternal(rawIDToken string, idToken *oi
accessToken = token.AccessToken accessToken = token.AccessToken
} }
claims, err := p.findClaimsFromIDToken(idToken, accessToken, p.ProfileURL.String()) claims, err := p.findClaimsFromIDToken(ctx, idToken, accessToken, p.ProfileURL.String())
if err != nil { if err != nil {
return nil, fmt.Errorf("couldn't extract claims from id_token (%e)", err) return nil, fmt.Errorf("couldn't extract claims from id_token (%e)", err)
} }
@ -217,8 +217,7 @@ func (p *OIDCProvider) createSessionStateInternal(rawIDToken string, idToken *oi
} }
// ValidateSessionState checks that the session's IDToken is still valid // ValidateSessionState checks that the session's IDToken is still valid
func (p *OIDCProvider) ValidateSessionState(s *sessions.SessionState) bool { func (p *OIDCProvider) ValidateSessionState(ctx context.Context, s *sessions.SessionState) bool {
ctx := context.Background()
_, err := p.Verifier.Verify(ctx, s.IDToken) _, err := p.Verifier.Verify(ctx, s.IDToken)
return err == nil return err == nil
} }
@ -230,7 +229,7 @@ func getOIDCHeader(accessToken string) http.Header {
return header return header
} }
func (p *OIDCProvider) findClaimsFromIDToken(idToken *oidc.IDToken, accessToken string, profileURL string) (*OIDCClaims, error) { func (p *OIDCProvider) findClaimsFromIDToken(ctx context.Context, idToken *oidc.IDToken, accessToken string, profileURL string) (*OIDCClaims, error) {
claims := &OIDCClaims{} claims := &OIDCClaims{}
// Extract default claims. // Extract default claims.
@ -257,7 +256,7 @@ func (p *OIDCProvider) findClaimsFromIDToken(idToken *oidc.IDToken, accessToken
// contents at the profileURL contains the email. // contents at the profileURL contains the email.
// Make a query to the userinfo endpoint, and attempt to locate the email from there. // Make a query to the userinfo endpoint, and attempt to locate the email from there.
req, err := http.NewRequest("GET", profileURL, nil) req, err := http.NewRequestWithContext(ctx, "GET", profileURL, nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -159,7 +159,7 @@ func TestOIDCProviderRedeem(t *testing.T) {
server, provider := newTestSetup(body) server, provider := newTestSetup(body)
defer server.Close() defer server.Close()
session, err := provider.Redeem(provider.RedeemURL.String(), "code1234") session, err := provider.Redeem(context.Background(), provider.RedeemURL.String(), "code1234")
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
assert.Equal(t, defaultIDToken.Email, session.Email) assert.Equal(t, defaultIDToken.Email, session.Email)
assert.Equal(t, accessToken, session.AccessToken) assert.Equal(t, accessToken, session.AccessToken)
@ -183,7 +183,7 @@ func TestOIDCProviderRedeem_custom_userid(t *testing.T) {
provider.UserIDClaim = "phone_number" provider.UserIDClaim = "phone_number"
defer server.Close() defer server.Close()
session, err := provider.Redeem(provider.RedeemURL.String(), "code1234") session, err := provider.Redeem(context.Background(), provider.RedeemURL.String(), "code1234")
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
assert.Equal(t, defaultIDToken.Phone, session.Email) assert.Equal(t, defaultIDToken.Phone, session.Email)
} }
@ -211,7 +211,7 @@ func TestOIDCProviderRefreshSessionIfNeededWithoutIdToken(t *testing.T) {
User: "11223344", User: "11223344",
} }
refreshed, err := provider.RefreshSessionIfNeeded(existingSession) refreshed, err := provider.RefreshSessionIfNeeded(context.Background(), existingSession)
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
assert.Equal(t, refreshed, true) assert.Equal(t, refreshed, true)
assert.Equal(t, "janedoe@example.com", existingSession.Email) assert.Equal(t, "janedoe@example.com", existingSession.Email)
@ -244,7 +244,7 @@ func TestOIDCProviderRefreshSessionIfNeededWithIdToken(t *testing.T) {
Email: "changeit", Email: "changeit",
User: "changeit", User: "changeit",
} }
refreshed, err := provider.RefreshSessionIfNeeded(existingSession) refreshed, err := provider.RefreshSessionIfNeeded(context.Background(), existingSession)
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
assert.Equal(t, refreshed, true) assert.Equal(t, refreshed, true)
assert.Equal(t, defaultIDToken.Email, existingSession.Email) assert.Equal(t, defaultIDToken.Email, existingSession.Email)

View File

@ -2,6 +2,7 @@ package providers
import ( import (
"bytes" "bytes"
"context"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
@ -16,8 +17,10 @@ import (
"github.com/oauth2-proxy/oauth2-proxy/pkg/encryption" "github.com/oauth2-proxy/oauth2-proxy/pkg/encryption"
) )
var _ Provider = (*ProviderData)(nil)
// Redeem provides a default implementation of the OAuth2 token redemption process // Redeem provides a default implementation of the OAuth2 token redemption process
func (p *ProviderData) Redeem(redirectURL, code string) (s *sessions.SessionState, err error) { func (p *ProviderData) Redeem(ctx context.Context, redirectURL, code string) (s *sessions.SessionState, err error) {
if code == "" { if code == "" {
err = errors.New("missing code") err = errors.New("missing code")
return return
@ -38,7 +41,7 @@ func (p *ProviderData) Redeem(redirectURL, code string) (s *sessions.SessionStat
} }
var req *http.Request var req *http.Request
req, err = http.NewRequest("POST", p.RedeemURL.String(), bytes.NewBufferString(params.Encode())) req, err = http.NewRequestWithContext(ctx, "POST", p.RedeemURL.String(), bytes.NewBufferString(params.Encode()))
if err != nil { if err != nil {
return return
} }
@ -116,17 +119,17 @@ func (p *ProviderData) SessionFromCookie(v string, c *encryption.Cipher) (s *ses
} }
// GetEmailAddress returns the Account email address // GetEmailAddress returns the Account email address
func (p *ProviderData) GetEmailAddress(s *sessions.SessionState) (string, error) { func (p *ProviderData) GetEmailAddress(ctx context.Context, s *sessions.SessionState) (string, error) {
return "", errors.New("not implemented") return "", errors.New("not implemented")
} }
// GetUserName returns the Account username // GetUserName returns the Account username
func (p *ProviderData) GetUserName(s *sessions.SessionState) (string, error) { func (p *ProviderData) GetUserName(ctx context.Context, s *sessions.SessionState) (string, error) {
return "", errors.New("not implemented") return "", errors.New("not implemented")
} }
// GetPreferredUsername returns the Account preferred username // GetPreferredUsername returns the Account preferred username
func (p *ProviderData) GetPreferredUsername(s *sessions.SessionState) (string, error) { func (p *ProviderData) GetPreferredUsername(ctx context.Context, s *sessions.SessionState) (string, error) {
return "", errors.New("not implemented") return "", errors.New("not implemented")
} }
@ -137,17 +140,17 @@ func (p *ProviderData) ValidateGroup(email string) bool {
} }
// ValidateSessionState validates the AccessToken // ValidateSessionState validates the AccessToken
func (p *ProviderData) ValidateSessionState(s *sessions.SessionState) bool { func (p *ProviderData) ValidateSessionState(ctx context.Context, s *sessions.SessionState) bool {
return validateToken(p, s.AccessToken, nil) return validateToken(ctx, p, s.AccessToken, nil)
} }
// RefreshSessionIfNeeded should refresh the user's session if required and // RefreshSessionIfNeeded should refresh the user's session if required and
// do nothing if a refresh is not required // do nothing if a refresh is not required
func (p *ProviderData) RefreshSessionIfNeeded(s *sessions.SessionState) (bool, error) { func (p *ProviderData) RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error) {
return false, nil return false, nil
} }
func (p *ProviderData) CreateSessionStateFromBearerToken(rawIDToken string, idToken *oidc.IDToken) (*sessions.SessionState, error) { func (p *ProviderData) CreateSessionStateFromBearerToken(ctx context.Context, rawIDToken string, idToken *oidc.IDToken) (*sessions.SessionState, error) {
var claims struct { var claims struct {
Subject string `json:"sub"` Subject string `json:"sub"`
Email string `json:"email"` Email string `json:"email"`

View File

@ -1,6 +1,7 @@
package providers package providers
import ( import (
"context"
"testing" "testing"
"time" "time"
@ -10,7 +11,7 @@ import (
func TestRefresh(t *testing.T) { func TestRefresh(t *testing.T) {
p := &ProviderData{} p := &ProviderData{}
refreshed, err := p.RefreshSessionIfNeeded(&sessions.SessionState{ refreshed, err := p.RefreshSessionIfNeeded(context.Background(), &sessions.SessionState{
ExpiresOn: time.Now().Add(time.Duration(-11) * time.Minute), ExpiresOn: time.Now().Add(time.Duration(-11) * time.Minute),
}) })
assert.Equal(t, false, refreshed) assert.Equal(t, false, refreshed)

View File

@ -1,6 +1,8 @@
package providers package providers
import ( import (
"context"
"github.com/coreos/go-oidc" "github.com/coreos/go-oidc"
"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions" "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions"
"github.com/oauth2-proxy/oauth2-proxy/pkg/encryption" "github.com/oauth2-proxy/oauth2-proxy/pkg/encryption"
@ -9,17 +11,17 @@ import (
// Provider represents an upstream identity provider implementation // Provider represents an upstream identity provider implementation
type Provider interface { type Provider interface {
Data() *ProviderData Data() *ProviderData
GetEmailAddress(*sessions.SessionState) (string, error) GetEmailAddress(ctx context.Context, s *sessions.SessionState) (string, error)
GetUserName(*sessions.SessionState) (string, error) GetUserName(ctx context.Context, s *sessions.SessionState) (string, error)
GetPreferredUsername(*sessions.SessionState) (string, error) GetPreferredUsername(ctx context.Context, s *sessions.SessionState) (string, error)
Redeem(string, string) (*sessions.SessionState, error) Redeem(ctx context.Context, redirectURI, code string) (*sessions.SessionState, error)
ValidateGroup(string) bool ValidateGroup(string) bool
ValidateSessionState(*sessions.SessionState) bool ValidateSessionState(ctx context.Context, s *sessions.SessionState) bool
GetLoginURL(redirectURI, finalRedirect string) string GetLoginURL(redirectURI, finalRedirect string) string
RefreshSessionIfNeeded(*sessions.SessionState) (bool, error) RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error)
SessionFromCookie(string, *encryption.Cipher) (*sessions.SessionState, error) SessionFromCookie(string, *encryption.Cipher) (*sessions.SessionState, error)
CookieForSession(*sessions.SessionState, *encryption.Cipher) (string, error) CookieForSession(*sessions.SessionState, *encryption.Cipher) (string, error)
CreateSessionStateFromBearerToken(rawIDToken string, idToken *oidc.IDToken) (*sessions.SessionState, error) CreateSessionStateFromBearerToken(ctx context.Context, rawIDToken string, idToken *oidc.IDToken) (*sessions.SessionState, error)
} }
// New provides a new Provider based on the configured provider string // New provides a new Provider based on the configured provider string