1
0
mirror of https://github.com/oauth2-proxy/oauth2-proxy.git synced 2026-05-22 10:15:21 +02:00

Standarize provider refresh implemention & logging

This commit is contained in:
Nick Meves
2021-03-06 15:48:31 -08:00
parent 7fa6d2d024
commit 593125152d
10 changed files with 123 additions and 70 deletions
+7 -5
View File
@@ -242,21 +242,23 @@ func (p *AzureProvider) RefreshSession(ctx context.Context, s *sessions.SessionS
return false, nil
}
origExpiration := s.ExpiresOn
err := p.redeemRefreshToken(ctx, s)
if err != nil {
return false, fmt.Errorf("unable to redeem refresh token: %v", err)
}
logger.Printf("refreshed id token %s (expired on %s)\n", s, origExpiration)
return true, nil
}
func (p *AzureProvider) redeemRefreshToken(ctx context.Context, s *sessions.SessionState) error {
clientSecret, err := p.GetClientSecret()
if err != nil {
return err
}
params := url.Values{}
params.Add("client_id", p.ClientID)
params.Add("client_secret", p.ClientSecret)
params.Add("client_secret", clientSecret)
params.Add("refresh_token", s.RefreshToken)
params.Add("grant_type", "refresh_token")
@@ -267,7 +269,7 @@ func (p *AzureProvider) redeemRefreshToken(ctx context.Context, s *sessions.Sess
IDToken string `json:"id_token"`
}
err := requests.New(p.RedeemURL.String()).
err = requests.New(p.RedeemURL.String()).
WithContext(ctx).
WithMethod("POST").
WithBody(bytes.NewBufferString(params.Encode())).
+1 -11
View File
@@ -340,17 +340,7 @@ func TestAzureProviderProtectedResourceConfigured(t *testing.T) {
assert.Contains(t, result, "resource="+url.QueryEscape("http://my.resource.test"))
}
func TestAzureProviderNotRefreshWhenNotExpired(t *testing.T) {
p := testAzureProvider("")
expires := time.Now().Add(time.Duration(1) * time.Hour)
session := &sessions.SessionState{AccessToken: "some_access_token", RefreshToken: "some_refresh_token", IDToken: "some_id_token", ExpiresOn: &expires}
refreshNeeded, err := p.RefreshSession(context.Background(), session)
assert.Equal(t, nil, err)
assert.False(t, refreshNeeded)
}
func TestAzureProviderRefreshWhenExpired(t *testing.T) {
func TestAzureProviderRefresh(t *testing.T) {
email := "foo@example.com"
idToken := idTokenClaims{Email: email}
idTokenString, err := newSignedTestIDToken(idToken)
+12 -15
View File
@@ -271,7 +271,7 @@ func (p *GoogleProvider) RefreshSession(ctx context.Context, s *sessions.Session
return false, nil
}
newToken, newIDToken, ttl, err := p.redeemRefreshToken(ctx, s.RefreshToken)
err := p.redeemRefreshToken(ctx, s)
if err != nil {
return false, err
}
@@ -284,26 +284,20 @@ func (p *GoogleProvider) RefreshSession(ctx context.Context, s *sessions.Session
return false, fmt.Errorf("%s is no longer in the group(s)", s.Email)
}
s.AccessToken = newToken
s.IDToken = newIDToken
s.CreatedAtNow()
s.ExpiresIn(ttl)
return true, nil
}
func (p *GoogleProvider) redeemRefreshToken(ctx context.Context, refreshToken string) (token string, idToken string, expires time.Duration, err error) {
func (p *GoogleProvider) redeemRefreshToken(ctx context.Context, s *sessions.SessionState) error {
// https://developers.google.com/identity/protocols/OAuth2WebServer#refresh
clientSecret, err := p.GetClientSecret()
if err != nil {
return
return err
}
params := url.Values{}
params.Add("client_id", p.ClientID)
params.Add("client_secret", clientSecret)
params.Add("refresh_token", refreshToken)
params.Add("refresh_token", s.RefreshToken)
params.Add("grant_type", "refresh_token")
var data struct {
@@ -320,11 +314,14 @@ func (p *GoogleProvider) redeemRefreshToken(ctx context.Context, refreshToken st
Do().
UnmarshalInto(&data)
if err != nil {
return "", "", 0, err
return err
}
token = data.AccessToken
idToken = data.IDToken
expires = time.Duration(data.ExpiresIn) * time.Second
return
s.AccessToken = data.AccessToken
s.IDToken = data.IDToken
s.CreatedAtNow()
s.ExpiresIn(time.Duration(data.ExpiresIn) * time.Second)
return nil
}
-1
View File
@@ -154,7 +154,6 @@ func (p *OIDCProvider) RefreshSession(ctx context.Context, s *sessions.SessionSt
return false, fmt.Errorf("unable to redeem refresh token: %v", err)
}
logger.Printf("refreshed session: %s", s)
return true, nil
}
+6 -1
View File
@@ -135,8 +135,13 @@ func (p *ProviderData) RefreshSession(_ context.Context, s *sessions.SessionStat
return false, nil
}
// Pretend `RefreshSession` occured so `ValidateSession` isn't called
// HACK:
// Pretend `RefreshSession` occurred so `ValidateSession` isn't called
// on every request after any potential set refresh period elapses.
// See `middleware.refreshSession` for detailed logic & explanation.
//
// Intentionally doesn't use `ErrNotImplemented` since all providers will
// call this and we don't want to force them to implement this dummy logic.
return true, nil
}
+14 -6
View File
@@ -14,12 +14,20 @@ import (
func TestRefresh(t *testing.T) {
p := &ProviderData{}
expires := time.Now().Add(time.Duration(-11) * time.Minute)
refreshed, err := p.RefreshSession(context.Background(), &sessions.SessionState{
ExpiresOn: &expires,
})
assert.Equal(t, false, refreshed)
assert.Equal(t, nil, err)
now := time.Unix(1234567890, 10)
expires := time.Unix(1234567890, 0)
ss := &sessions.SessionState{}
ss.Clock.Set(now)
ss.SetExpiresOn(expires)
refreshed, err := p.RefreshSession(context.Background(), ss)
assert.True(t, refreshed)
assert.NoError(t, err)
refreshed, err = p.RefreshSession(context.Background(), nil)
assert.False(t, refreshed)
assert.NoError(t, err)
}
func TestAcrValuesNotConfigured(t *testing.T) {