You've already forked oauth2-proxy
mirror of
https://github.com/oauth2-proxy/oauth2-proxy.git
synced 2026-05-22 10:15:21 +02:00
Standarize provider refresh implemention & logging
This commit is contained in:
+7
-5
@@ -242,21 +242,23 @@ func (p *AzureProvider) RefreshSession(ctx context.Context, s *sessions.SessionS
|
||||
return false, nil
|
||||
}
|
||||
|
||||
origExpiration := s.ExpiresOn
|
||||
|
||||
err := p.redeemRefreshToken(ctx, s)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("unable to redeem refresh token: %v", err)
|
||||
}
|
||||
|
||||
logger.Printf("refreshed id token %s (expired on %s)\n", s, origExpiration)
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (p *AzureProvider) redeemRefreshToken(ctx context.Context, s *sessions.SessionState) error {
|
||||
clientSecret, err := p.GetClientSecret()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
params := url.Values{}
|
||||
params.Add("client_id", p.ClientID)
|
||||
params.Add("client_secret", p.ClientSecret)
|
||||
params.Add("client_secret", clientSecret)
|
||||
params.Add("refresh_token", s.RefreshToken)
|
||||
params.Add("grant_type", "refresh_token")
|
||||
|
||||
@@ -267,7 +269,7 @@ func (p *AzureProvider) redeemRefreshToken(ctx context.Context, s *sessions.Sess
|
||||
IDToken string `json:"id_token"`
|
||||
}
|
||||
|
||||
err := requests.New(p.RedeemURL.String()).
|
||||
err = requests.New(p.RedeemURL.String()).
|
||||
WithContext(ctx).
|
||||
WithMethod("POST").
|
||||
WithBody(bytes.NewBufferString(params.Encode())).
|
||||
|
||||
+1
-11
@@ -340,17 +340,7 @@ func TestAzureProviderProtectedResourceConfigured(t *testing.T) {
|
||||
assert.Contains(t, result, "resource="+url.QueryEscape("http://my.resource.test"))
|
||||
}
|
||||
|
||||
func TestAzureProviderNotRefreshWhenNotExpired(t *testing.T) {
|
||||
p := testAzureProvider("")
|
||||
|
||||
expires := time.Now().Add(time.Duration(1) * time.Hour)
|
||||
session := &sessions.SessionState{AccessToken: "some_access_token", RefreshToken: "some_refresh_token", IDToken: "some_id_token", ExpiresOn: &expires}
|
||||
refreshNeeded, err := p.RefreshSession(context.Background(), session)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.False(t, refreshNeeded)
|
||||
}
|
||||
|
||||
func TestAzureProviderRefreshWhenExpired(t *testing.T) {
|
||||
func TestAzureProviderRefresh(t *testing.T) {
|
||||
email := "foo@example.com"
|
||||
idToken := idTokenClaims{Email: email}
|
||||
idTokenString, err := newSignedTestIDToken(idToken)
|
||||
|
||||
+12
-15
@@ -271,7 +271,7 @@ func (p *GoogleProvider) RefreshSession(ctx context.Context, s *sessions.Session
|
||||
return false, nil
|
||||
}
|
||||
|
||||
newToken, newIDToken, ttl, err := p.redeemRefreshToken(ctx, s.RefreshToken)
|
||||
err := p.redeemRefreshToken(ctx, s)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
@@ -284,26 +284,20 @@ func (p *GoogleProvider) RefreshSession(ctx context.Context, s *sessions.Session
|
||||
return false, fmt.Errorf("%s is no longer in the group(s)", s.Email)
|
||||
}
|
||||
|
||||
s.AccessToken = newToken
|
||||
s.IDToken = newIDToken
|
||||
|
||||
s.CreatedAtNow()
|
||||
s.ExpiresIn(ttl)
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (p *GoogleProvider) redeemRefreshToken(ctx context.Context, refreshToken string) (token string, idToken string, expires time.Duration, err error) {
|
||||
func (p *GoogleProvider) redeemRefreshToken(ctx context.Context, s *sessions.SessionState) error {
|
||||
// https://developers.google.com/identity/protocols/OAuth2WebServer#refresh
|
||||
clientSecret, err := p.GetClientSecret()
|
||||
if err != nil {
|
||||
return
|
||||
return err
|
||||
}
|
||||
|
||||
params := url.Values{}
|
||||
params.Add("client_id", p.ClientID)
|
||||
params.Add("client_secret", clientSecret)
|
||||
params.Add("refresh_token", refreshToken)
|
||||
params.Add("refresh_token", s.RefreshToken)
|
||||
params.Add("grant_type", "refresh_token")
|
||||
|
||||
var data struct {
|
||||
@@ -320,11 +314,14 @@ func (p *GoogleProvider) redeemRefreshToken(ctx context.Context, refreshToken st
|
||||
Do().
|
||||
UnmarshalInto(&data)
|
||||
if err != nil {
|
||||
return "", "", 0, err
|
||||
return err
|
||||
}
|
||||
|
||||
token = data.AccessToken
|
||||
idToken = data.IDToken
|
||||
expires = time.Duration(data.ExpiresIn) * time.Second
|
||||
return
|
||||
s.AccessToken = data.AccessToken
|
||||
s.IDToken = data.IDToken
|
||||
|
||||
s.CreatedAtNow()
|
||||
s.ExpiresIn(time.Duration(data.ExpiresIn) * time.Second)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -154,7 +154,6 @@ func (p *OIDCProvider) RefreshSession(ctx context.Context, s *sessions.SessionSt
|
||||
return false, fmt.Errorf("unable to redeem refresh token: %v", err)
|
||||
}
|
||||
|
||||
logger.Printf("refreshed session: %s", s)
|
||||
return true, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -135,8 +135,13 @@ func (p *ProviderData) RefreshSession(_ context.Context, s *sessions.SessionStat
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// Pretend `RefreshSession` occured so `ValidateSession` isn't called
|
||||
// HACK:
|
||||
// Pretend `RefreshSession` occurred so `ValidateSession` isn't called
|
||||
// on every request after any potential set refresh period elapses.
|
||||
// See `middleware.refreshSession` for detailed logic & explanation.
|
||||
//
|
||||
// Intentionally doesn't use `ErrNotImplemented` since all providers will
|
||||
// call this and we don't want to force them to implement this dummy logic.
|
||||
return true, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -14,12 +14,20 @@ import (
|
||||
func TestRefresh(t *testing.T) {
|
||||
p := &ProviderData{}
|
||||
|
||||
expires := time.Now().Add(time.Duration(-11) * time.Minute)
|
||||
refreshed, err := p.RefreshSession(context.Background(), &sessions.SessionState{
|
||||
ExpiresOn: &expires,
|
||||
})
|
||||
assert.Equal(t, false, refreshed)
|
||||
assert.Equal(t, nil, err)
|
||||
now := time.Unix(1234567890, 10)
|
||||
expires := time.Unix(1234567890, 0)
|
||||
|
||||
ss := &sessions.SessionState{}
|
||||
ss.Clock.Set(now)
|
||||
ss.SetExpiresOn(expires)
|
||||
|
||||
refreshed, err := p.RefreshSession(context.Background(), ss)
|
||||
assert.True(t, refreshed)
|
||||
assert.NoError(t, err)
|
||||
|
||||
refreshed, err = p.RefreshSession(context.Background(), nil)
|
||||
assert.False(t, refreshed)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestAcrValuesNotConfigured(t *testing.T) {
|
||||
|
||||
Reference in New Issue
Block a user