You've already forked oauth2-proxy
mirror of
https://github.com/oauth2-proxy/oauth2-proxy.git
synced 2025-06-15 00:15:00 +02:00
Support context in providers (#519)
Co-authored-by: Henry Jenkins <henry@henryjenkins.name>
This commit is contained in:
committed by
Henry Jenkins
parent
53d8e99f05
commit
e642daef4e
@ -31,14 +31,15 @@ func NewOIDCProvider(p *ProviderData) *OIDCProvider {
|
||||
return &OIDCProvider{ProviderData: p}
|
||||
}
|
||||
|
||||
var _ Provider = (*OIDCProvider)(nil)
|
||||
|
||||
// Redeem exchanges the OAuth2 authentication token for an ID token
|
||||
func (p *OIDCProvider) Redeem(redirectURL, code string) (s *sessions.SessionState, err error) {
|
||||
func (p *OIDCProvider) Redeem(ctx context.Context, redirectURL, code string) (s *sessions.SessionState, err error) {
|
||||
clientSecret, err := p.GetClientSecret()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
c := oauth2.Config{
|
||||
ClientID: p.ClientID,
|
||||
ClientSecret: clientSecret,
|
||||
@ -60,7 +61,7 @@ func (p *OIDCProvider) Redeem(redirectURL, code string) (s *sessions.SessionStat
|
||||
return nil, fmt.Errorf("token response did not contain an id_token")
|
||||
}
|
||||
|
||||
s, err = p.createSessionState(token, idToken)
|
||||
s, err = p.createSessionState(ctx, token, idToken)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to update session: %v", err)
|
||||
}
|
||||
@ -70,12 +71,12 @@ func (p *OIDCProvider) Redeem(redirectURL, code string) (s *sessions.SessionStat
|
||||
|
||||
// RefreshSessionIfNeeded checks if the session has expired and uses the
|
||||
// RefreshToken to fetch a new Access Token (and optional ID token) if required
|
||||
func (p *OIDCProvider) RefreshSessionIfNeeded(s *sessions.SessionState) (bool, error) {
|
||||
func (p *OIDCProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error) {
|
||||
if s == nil || s.ExpiresOn.After(time.Now()) || s.RefreshToken == "" {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
err := p.redeemRefreshToken(s)
|
||||
err := p.redeemRefreshToken(ctx, s)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("unable to redeem refresh token: %v", err)
|
||||
}
|
||||
@ -84,7 +85,7 @@ func (p *OIDCProvider) RefreshSessionIfNeeded(s *sessions.SessionState) (bool, e
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (p *OIDCProvider) redeemRefreshToken(s *sessions.SessionState) (err error) {
|
||||
func (p *OIDCProvider) redeemRefreshToken(ctx context.Context, s *sessions.SessionState) (err error) {
|
||||
clientSecret, err := p.GetClientSecret()
|
||||
if err != nil {
|
||||
return
|
||||
@ -97,7 +98,6 @@ func (p *OIDCProvider) redeemRefreshToken(s *sessions.SessionState) (err error)
|
||||
TokenURL: p.RedeemURL.String(),
|
||||
},
|
||||
}
|
||||
ctx := context.Background()
|
||||
t := &oauth2.Token{
|
||||
RefreshToken: s.RefreshToken,
|
||||
Expiry: time.Now().Add(-time.Hour),
|
||||
@ -113,7 +113,7 @@ func (p *OIDCProvider) redeemRefreshToken(s *sessions.SessionState) (err error)
|
||||
return fmt.Errorf("unable to extract id_token from response: %v", err)
|
||||
}
|
||||
|
||||
newSession, err := p.createSessionState(token, idToken)
|
||||
newSession, err := p.createSessionState(ctx, token, idToken)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable create new session state from response: %v", err)
|
||||
}
|
||||
@ -149,7 +149,7 @@ func (p *OIDCProvider) findVerifiedIDToken(ctx context.Context, token *oauth2.To
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (p *OIDCProvider) createSessionState(token *oauth2.Token, idToken *oidc.IDToken) (*sessions.SessionState, error) {
|
||||
func (p *OIDCProvider) createSessionState(ctx context.Context, token *oauth2.Token, idToken *oidc.IDToken) (*sessions.SessionState, error) {
|
||||
|
||||
var newSession *sessions.SessionState
|
||||
|
||||
@ -157,7 +157,7 @@ func (p *OIDCProvider) createSessionState(token *oauth2.Token, idToken *oidc.IDT
|
||||
newSession = &sessions.SessionState{}
|
||||
} else {
|
||||
var err error
|
||||
newSession, err = p.createSessionStateInternal(token.Extra("id_token").(string), idToken, token)
|
||||
newSession, err = p.createSessionStateInternal(ctx, token.Extra("id_token").(string), idToken, token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -170,8 +170,8 @@ func (p *OIDCProvider) createSessionState(token *oauth2.Token, idToken *oidc.IDT
|
||||
return newSession, nil
|
||||
}
|
||||
|
||||
func (p *OIDCProvider) CreateSessionStateFromBearerToken(rawIDToken string, idToken *oidc.IDToken) (*sessions.SessionState, error) {
|
||||
newSession, err := p.createSessionStateInternal(rawIDToken, idToken, nil)
|
||||
func (p *OIDCProvider) CreateSessionStateFromBearerToken(ctx context.Context, rawIDToken string, idToken *oidc.IDToken) (*sessions.SessionState, error) {
|
||||
newSession, err := p.createSessionStateInternal(ctx, rawIDToken, idToken, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -184,7 +184,7 @@ func (p *OIDCProvider) CreateSessionStateFromBearerToken(rawIDToken string, idTo
|
||||
return newSession, nil
|
||||
}
|
||||
|
||||
func (p *OIDCProvider) createSessionStateInternal(rawIDToken string, idToken *oidc.IDToken, token *oauth2.Token) (*sessions.SessionState, error) {
|
||||
func (p *OIDCProvider) createSessionStateInternal(ctx context.Context, rawIDToken string, idToken *oidc.IDToken, token *oauth2.Token) (*sessions.SessionState, error) {
|
||||
|
||||
newSession := &sessions.SessionState{}
|
||||
|
||||
@ -196,7 +196,7 @@ func (p *OIDCProvider) createSessionStateInternal(rawIDToken string, idToken *oi
|
||||
accessToken = token.AccessToken
|
||||
}
|
||||
|
||||
claims, err := p.findClaimsFromIDToken(idToken, accessToken, p.ProfileURL.String())
|
||||
claims, err := p.findClaimsFromIDToken(ctx, idToken, accessToken, p.ProfileURL.String())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("couldn't extract claims from id_token (%e)", err)
|
||||
}
|
||||
@ -217,8 +217,7 @@ func (p *OIDCProvider) createSessionStateInternal(rawIDToken string, idToken *oi
|
||||
}
|
||||
|
||||
// ValidateSessionState checks that the session's IDToken is still valid
|
||||
func (p *OIDCProvider) ValidateSessionState(s *sessions.SessionState) bool {
|
||||
ctx := context.Background()
|
||||
func (p *OIDCProvider) ValidateSessionState(ctx context.Context, s *sessions.SessionState) bool {
|
||||
_, err := p.Verifier.Verify(ctx, s.IDToken)
|
||||
return err == nil
|
||||
}
|
||||
@ -230,7 +229,7 @@ func getOIDCHeader(accessToken string) http.Header {
|
||||
return header
|
||||
}
|
||||
|
||||
func (p *OIDCProvider) findClaimsFromIDToken(idToken *oidc.IDToken, accessToken string, profileURL string) (*OIDCClaims, error) {
|
||||
func (p *OIDCProvider) findClaimsFromIDToken(ctx context.Context, idToken *oidc.IDToken, accessToken string, profileURL string) (*OIDCClaims, error) {
|
||||
|
||||
claims := &OIDCClaims{}
|
||||
// Extract default claims.
|
||||
@ -257,7 +256,7 @@ func (p *OIDCProvider) findClaimsFromIDToken(idToken *oidc.IDToken, accessToken
|
||||
// contents at the profileURL contains the email.
|
||||
// Make a query to the userinfo endpoint, and attempt to locate the email from there.
|
||||
|
||||
req, err := http.NewRequest("GET", profileURL, nil)
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", profileURL, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
Reference in New Issue
Block a user