1
0
mirror of https://github.com/oauth2-proxy/oauth2-proxy.git synced 2025-06-15 00:15:00 +02:00

Support context in providers (#519)

Co-authored-by: Henry Jenkins <henry@henryjenkins.name>
This commit is contained in:
Mitsuo Heijo
2020-05-06 00:53:33 +09:00
committed by Henry Jenkins
parent 53d8e99f05
commit e642daef4e
33 changed files with 223 additions and 173 deletions

View File

@ -347,29 +347,29 @@ func (p *OAuthProxy) displayCustomLoginForm() bool {
return p.HtpasswdFile != nil && p.DisplayHtpasswdForm
}
func (p *OAuthProxy) redeemCode(host, code string) (s *sessionsapi.SessionState, err error) {
func (p *OAuthProxy) redeemCode(ctx context.Context, host, code string) (s *sessionsapi.SessionState, err error) {
if code == "" {
return nil, errors.New("missing code")
}
redirectURI := p.GetRedirectURI(host)
s, err = p.provider.Redeem(redirectURI, code)
s, err = p.provider.Redeem(ctx, redirectURI, code)
if err != nil {
return
}
if s.Email == "" {
s.Email, err = p.provider.GetEmailAddress(s)
s.Email, err = p.provider.GetEmailAddress(ctx, s)
}
if s.PreferredUsername == "" {
s.PreferredUsername, err = p.provider.GetPreferredUsername(s)
s.PreferredUsername, err = p.provider.GetPreferredUsername(ctx, s)
if err != nil && err.Error() == "not implemented" {
err = nil
}
}
if s.User == "" {
s.User, err = p.provider.GetUserName(s)
s.User, err = p.provider.GetUserName(ctx, s)
if err != nil && err.Error() == "not implemented" {
err = nil
}
@ -782,7 +782,7 @@ func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) {
return
}
session, err := p.redeemCode(req.Host, req.Form.Get("code"))
session, err := p.redeemCode(req.Context(), req.Host, req.Form.Get("code"))
if err != nil {
logger.Printf("Error redeeming code during OAuth2 callback: %s ", err.Error())
p.ErrorPage(rw, 500, "Internal Error", "Internal Error")
@ -907,7 +907,7 @@ func (p *OAuthProxy) getAuthenticatedSession(rw http.ResponseWriter, req *http.R
saveSession = true
}
if ok, err := p.provider.RefreshSessionIfNeeded(session); err != nil {
if ok, err := p.provider.RefreshSessionIfNeeded(req.Context(), session); err != nil {
logger.Printf("%s removing session. error refreshing access token %s %s", remoteAddr, err, session)
clearSession = true
session = nil
@ -926,7 +926,7 @@ func (p *OAuthProxy) getAuthenticatedSession(rw http.ResponseWriter, req *http.R
}
if saveSession && !revalidated && session != nil && session.AccessToken != "" {
if !p.provider.ValidateSessionState(session) {
if !p.provider.ValidateSessionState(req.Context(), session) {
logger.Printf("Removing session: error validating %s", session)
saveSession = false
session = nil
@ -1126,16 +1126,15 @@ func (p *OAuthProxy) GetJwtSession(req *http.Request) (*sessionsapi.SessionState
return nil, err
}
ctx := context.Background()
for _, verifier := range p.jwtBearerVerifiers {
bearerToken, err := verifier.Verify(ctx, rawBearerToken)
bearerToken, err := verifier.Verify(req.Context(), rawBearerToken)
if err != nil {
logger.Printf("failed to verify bearer token: %v", err)
continue
}
return p.provider.CreateSessionStateFromBearerToken(rawBearerToken, bearerToken)
return p.provider.CreateSessionStateFromBearerToken(req.Context(), rawBearerToken, bearerToken)
}
return nil, fmt.Errorf("unable to verify jwt token %s", req.Header.Get("Authorization"))
}