1
0
mirror of https://github.com/oauth2-proxy/oauth2-proxy.git synced 2025-06-15 00:15:00 +02:00

Support context in providers (#519)

Co-authored-by: Henry Jenkins <henry@henryjenkins.name>
This commit is contained in:
Mitsuo Heijo
2020-05-06 00:53:33 +09:00
committed by Henry Jenkins
parent 53d8e99f05
commit e642daef4e
33 changed files with 223 additions and 173 deletions

View File

@ -31,14 +31,15 @@ func NewOIDCProvider(p *ProviderData) *OIDCProvider {
return &OIDCProvider{ProviderData: p}
}
var _ Provider = (*OIDCProvider)(nil)
// Redeem exchanges the OAuth2 authentication token for an ID token
func (p *OIDCProvider) Redeem(redirectURL, code string) (s *sessions.SessionState, err error) {
func (p *OIDCProvider) Redeem(ctx context.Context, redirectURL, code string) (s *sessions.SessionState, err error) {
clientSecret, err := p.GetClientSecret()
if err != nil {
return
}
ctx := context.Background()
c := oauth2.Config{
ClientID: p.ClientID,
ClientSecret: clientSecret,
@ -60,7 +61,7 @@ func (p *OIDCProvider) Redeem(redirectURL, code string) (s *sessions.SessionStat
return nil, fmt.Errorf("token response did not contain an id_token")
}
s, err = p.createSessionState(token, idToken)
s, err = p.createSessionState(ctx, token, idToken)
if err != nil {
return nil, fmt.Errorf("unable to update session: %v", err)
}
@ -70,12 +71,12 @@ func (p *OIDCProvider) Redeem(redirectURL, code string) (s *sessions.SessionStat
// RefreshSessionIfNeeded checks if the session has expired and uses the
// RefreshToken to fetch a new Access Token (and optional ID token) if required
func (p *OIDCProvider) RefreshSessionIfNeeded(s *sessions.SessionState) (bool, error) {
func (p *OIDCProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error) {
if s == nil || s.ExpiresOn.After(time.Now()) || s.RefreshToken == "" {
return false, nil
}
err := p.redeemRefreshToken(s)
err := p.redeemRefreshToken(ctx, s)
if err != nil {
return false, fmt.Errorf("unable to redeem refresh token: %v", err)
}
@ -84,7 +85,7 @@ func (p *OIDCProvider) RefreshSessionIfNeeded(s *sessions.SessionState) (bool, e
return true, nil
}
func (p *OIDCProvider) redeemRefreshToken(s *sessions.SessionState) (err error) {
func (p *OIDCProvider) redeemRefreshToken(ctx context.Context, s *sessions.SessionState) (err error) {
clientSecret, err := p.GetClientSecret()
if err != nil {
return
@ -97,7 +98,6 @@ func (p *OIDCProvider) redeemRefreshToken(s *sessions.SessionState) (err error)
TokenURL: p.RedeemURL.String(),
},
}
ctx := context.Background()
t := &oauth2.Token{
RefreshToken: s.RefreshToken,
Expiry: time.Now().Add(-time.Hour),
@ -113,7 +113,7 @@ func (p *OIDCProvider) redeemRefreshToken(s *sessions.SessionState) (err error)
return fmt.Errorf("unable to extract id_token from response: %v", err)
}
newSession, err := p.createSessionState(token, idToken)
newSession, err := p.createSessionState(ctx, token, idToken)
if err != nil {
return fmt.Errorf("unable create new session state from response: %v", err)
}
@ -149,7 +149,7 @@ func (p *OIDCProvider) findVerifiedIDToken(ctx context.Context, token *oauth2.To
return nil, nil
}
func (p *OIDCProvider) createSessionState(token *oauth2.Token, idToken *oidc.IDToken) (*sessions.SessionState, error) {
func (p *OIDCProvider) createSessionState(ctx context.Context, token *oauth2.Token, idToken *oidc.IDToken) (*sessions.SessionState, error) {
var newSession *sessions.SessionState
@ -157,7 +157,7 @@ func (p *OIDCProvider) createSessionState(token *oauth2.Token, idToken *oidc.IDT
newSession = &sessions.SessionState{}
} else {
var err error
newSession, err = p.createSessionStateInternal(token.Extra("id_token").(string), idToken, token)
newSession, err = p.createSessionStateInternal(ctx, token.Extra("id_token").(string), idToken, token)
if err != nil {
return nil, err
}
@ -170,8 +170,8 @@ func (p *OIDCProvider) createSessionState(token *oauth2.Token, idToken *oidc.IDT
return newSession, nil
}
func (p *OIDCProvider) CreateSessionStateFromBearerToken(rawIDToken string, idToken *oidc.IDToken) (*sessions.SessionState, error) {
newSession, err := p.createSessionStateInternal(rawIDToken, idToken, nil)
func (p *OIDCProvider) CreateSessionStateFromBearerToken(ctx context.Context, rawIDToken string, idToken *oidc.IDToken) (*sessions.SessionState, error) {
newSession, err := p.createSessionStateInternal(ctx, rawIDToken, idToken, nil)
if err != nil {
return nil, err
}
@ -184,7 +184,7 @@ func (p *OIDCProvider) CreateSessionStateFromBearerToken(rawIDToken string, idTo
return newSession, nil
}
func (p *OIDCProvider) createSessionStateInternal(rawIDToken string, idToken *oidc.IDToken, token *oauth2.Token) (*sessions.SessionState, error) {
func (p *OIDCProvider) createSessionStateInternal(ctx context.Context, rawIDToken string, idToken *oidc.IDToken, token *oauth2.Token) (*sessions.SessionState, error) {
newSession := &sessions.SessionState{}
@ -196,7 +196,7 @@ func (p *OIDCProvider) createSessionStateInternal(rawIDToken string, idToken *oi
accessToken = token.AccessToken
}
claims, err := p.findClaimsFromIDToken(idToken, accessToken, p.ProfileURL.String())
claims, err := p.findClaimsFromIDToken(ctx, idToken, accessToken, p.ProfileURL.String())
if err != nil {
return nil, fmt.Errorf("couldn't extract claims from id_token (%e)", err)
}
@ -217,8 +217,7 @@ func (p *OIDCProvider) createSessionStateInternal(rawIDToken string, idToken *oi
}
// ValidateSessionState checks that the session's IDToken is still valid
func (p *OIDCProvider) ValidateSessionState(s *sessions.SessionState) bool {
ctx := context.Background()
func (p *OIDCProvider) ValidateSessionState(ctx context.Context, s *sessions.SessionState) bool {
_, err := p.Verifier.Verify(ctx, s.IDToken)
return err == nil
}
@ -230,7 +229,7 @@ func getOIDCHeader(accessToken string) http.Header {
return header
}
func (p *OIDCProvider) findClaimsFromIDToken(idToken *oidc.IDToken, accessToken string, profileURL string) (*OIDCClaims, error) {
func (p *OIDCProvider) findClaimsFromIDToken(ctx context.Context, idToken *oidc.IDToken, accessToken string, profileURL string) (*OIDCClaims, error) {
claims := &OIDCClaims{}
// Extract default claims.
@ -257,7 +256,7 @@ func (p *OIDCProvider) findClaimsFromIDToken(idToken *oidc.IDToken, accessToken
// contents at the profileURL contains the email.
// Make a query to the userinfo endpoint, and attempt to locate the email from there.
req, err := http.NewRequest("GET", profileURL, nil)
req, err := http.NewRequestWithContext(ctx, "GET", profileURL, nil)
if err != nil {
return nil, err
}