You've already forked oauth2-proxy
mirror of
https://github.com/oauth2-proxy/oauth2-proxy.git
synced 2026-05-22 10:15:21 +02:00
RefreshSessions immediately when called
This commit is contained in:
@@ -345,7 +345,7 @@ func TestAzureProviderNotRefreshWhenNotExpired(t *testing.T) {
|
||||
|
||||
expires := time.Now().Add(time.Duration(1) * time.Hour)
|
||||
session := &sessions.SessionState{AccessToken: "some_access_token", RefreshToken: "some_refresh_token", IDToken: "some_id_token", ExpiresOn: &expires}
|
||||
refreshNeeded, err := p.RefreshSessionIfNeeded(context.Background(), session)
|
||||
refreshNeeded, err := p.RefreshSession(context.Background(), session)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.False(t, refreshNeeded)
|
||||
}
|
||||
@@ -373,9 +373,10 @@ func TestAzureProviderRefreshWhenExpired(t *testing.T) {
|
||||
|
||||
expires := time.Now().Add(time.Duration(-1) * time.Hour)
|
||||
session := &sessions.SessionState{AccessToken: "some_access_token", RefreshToken: "some_refresh_token", IDToken: "some_id_token", ExpiresOn: &expires}
|
||||
refreshNeeded, err := p.RefreshSessionIfNeeded(context.Background(), session)
|
||||
|
||||
refreshed, err := p.RefreshSession(context.Background(), session)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.True(t, refreshNeeded)
|
||||
assert.True(t, refreshed)
|
||||
assert.NotEqual(t, session, nil)
|
||||
assert.Equal(t, "new_some_access_token", session.AccessToken)
|
||||
assert.Equal(t, "new_some_refresh_token", session.RefreshToken)
|
||||
|
||||
@@ -88,7 +88,7 @@ func (p *FacebookProvider) GetEmailAddress(ctx context.Context, s *sessions.Sess
|
||||
return r.Email, nil
|
||||
}
|
||||
|
||||
// ValidateSessionState validates the AccessToken
|
||||
// ValidateSession validates the AccessToken
|
||||
func (p *FacebookProvider) ValidateSession(ctx context.Context, s *sessions.SessionState) bool {
|
||||
return validateToken(ctx, p, s.AccessToken, makeOIDCHeader(s.AccessToken))
|
||||
}
|
||||
|
||||
+8
-13
@@ -121,10 +121,9 @@ func (p *GitLabProvider) SetProjectScope() {
|
||||
}
|
||||
}
|
||||
|
||||
// RefreshSessionIfNeeded checks if the session has expired and uses the
|
||||
// RefreshToken to fetch a new ID token if required
|
||||
func (p *GitLabProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error) {
|
||||
if s == nil || (s.ExpiresOn != nil && s.ExpiresOn.After(time.Now())) || s.RefreshToken == "" {
|
||||
// RefreshSession uses the RefreshToken to fetch new Access and ID Tokens
|
||||
func (p *GitLabProvider) RefreshSession(ctx context.Context, s *sessions.SessionState) (bool, error) {
|
||||
if s == nil || s.RefreshToken == "" {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
@@ -139,10 +138,10 @@ func (p *GitLabProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (p *GitLabProvider) redeemRefreshToken(ctx context.Context, s *sessions.SessionState) (err error) {
|
||||
func (p *GitLabProvider) redeemRefreshToken(ctx context.Context, s *sessions.SessionState) error {
|
||||
clientSecret, err := p.GetClientSecret()
|
||||
if err != nil {
|
||||
return
|
||||
return err
|
||||
}
|
||||
|
||||
c := oauth2.Config{
|
||||
@@ -164,13 +163,9 @@ func (p *GitLabProvider) redeemRefreshToken(ctx context.Context, s *sessions.Ses
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to update session: %v", err)
|
||||
}
|
||||
s.AccessToken = newSession.AccessToken
|
||||
s.IDToken = newSession.IDToken
|
||||
s.RefreshToken = newSession.RefreshToken
|
||||
s.CreatedAt = newSession.CreatedAt
|
||||
s.ExpiresOn = newSession.ExpiresOn
|
||||
s.Email = newSession.Email
|
||||
return
|
||||
*s = *newSession
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type gitlabUserInfo struct {
|
||||
|
||||
+3
-4
@@ -266,10 +266,9 @@ func userInGroup(service *admin.Service, group string, email string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// RefreshSessionIfNeeded checks if the session has expired and uses the
|
||||
// RefreshToken to fetch a new ID token if required
|
||||
func (p *GoogleProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error) {
|
||||
if s == nil || (s.ExpiresOn != nil && s.ExpiresOn.After(time.Now())) || s.RefreshToken == "" {
|
||||
// RefreshSession uses the RefreshToken to fetch new Access and ID Tokens
|
||||
func (p *GoogleProvider) RefreshSession(ctx context.Context, s *sessions.SessionState) (bool, error) {
|
||||
if s == nil || s.RefreshToken == "" {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -93,7 +93,7 @@ func (p *LinkedInProvider) GetEmailAddress(ctx context.Context, s *sessions.Sess
|
||||
return email, nil
|
||||
}
|
||||
|
||||
// ValidateSessionState validates the AccessToken
|
||||
// ValidateSession validates the AccessToken
|
||||
func (p *LinkedInProvider) ValidateSession(ctx context.Context, s *sessions.SessionState) bool {
|
||||
return validateToken(ctx, p, s.AccessToken, makeLinkedInHeader(s.AccessToken))
|
||||
}
|
||||
|
||||
+3
-4
@@ -143,10 +143,9 @@ func (p *OIDCProvider) ValidateSession(ctx context.Context, s *sessions.SessionS
|
||||
return true
|
||||
}
|
||||
|
||||
// RefreshSessionIfNeeded checks if the session has expired and uses the
|
||||
// RefreshToken to fetch a new Access Token (and optional ID token) if required
|
||||
func (p *OIDCProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error) {
|
||||
if s == nil || (s.ExpiresOn != nil && s.ExpiresOn.After(time.Now())) || s.RefreshToken == "" {
|
||||
// RefreshSession uses the RefreshToken to fetch new Access and ID Tokens
|
||||
func (p *OIDCProvider) RefreshSession(ctx context.Context, s *sessions.SessionState) (bool, error) {
|
||||
if s == nil || s.RefreshToken == "" {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -487,7 +487,7 @@ func TestOIDCProviderRefreshSessionIfNeededWithoutIdToken(t *testing.T) {
|
||||
User: "11223344",
|
||||
}
|
||||
|
||||
refreshed, err := provider.RefreshSessionIfNeeded(context.Background(), existingSession)
|
||||
refreshed, err := provider.RefreshSession(context.Background(), existingSession)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.Equal(t, refreshed, true)
|
||||
assert.Equal(t, "janedoe@example.com", existingSession.Email)
|
||||
@@ -520,7 +520,7 @@ func TestOIDCProviderRefreshSessionIfNeededWithIdToken(t *testing.T) {
|
||||
Email: "changeit",
|
||||
User: "changeit",
|
||||
}
|
||||
refreshed, err := provider.RefreshSessionIfNeeded(context.Background(), existingSession)
|
||||
refreshed, err := provider.RefreshSession(context.Background(), existingSession)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.Equal(t, refreshed, true)
|
||||
assert.Equal(t, defaultIDToken.Email, existingSession.Email)
|
||||
|
||||
@@ -126,10 +126,15 @@ func (p *ProviderData) ValidateSession(ctx context.Context, s *sessions.SessionS
|
||||
return validateToken(ctx, p, s.AccessToken, nil)
|
||||
}
|
||||
|
||||
// RefreshSessionIfNeeded should refresh the user's session if required and
|
||||
// do nothing if a refresh is not required
|
||||
func (p *ProviderData) RefreshSessionIfNeeded(_ context.Context, _ *sessions.SessionState) (bool, error) {
|
||||
return false, nil
|
||||
// RefreshSession refreshes the user's session
|
||||
func (p *ProviderData) RefreshSession(_ context.Context, s *sessions.SessionState) (bool, error) {
|
||||
if s == nil {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// Pretend `RefreshSession` occured so `ValidateSession` isn't called
|
||||
// on every request after any potential set refresh period elapses.
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// CreateSessionFromToken converts Bearer IDTokens into sessions
|
||||
|
||||
@@ -15,7 +15,7 @@ func TestRefresh(t *testing.T) {
|
||||
p := &ProviderData{}
|
||||
|
||||
expires := time.Now().Add(time.Duration(-11) * time.Minute)
|
||||
refreshed, err := p.RefreshSessionIfNeeded(context.Background(), &sessions.SessionState{
|
||||
refreshed, err := p.RefreshSession(context.Background(), &sessions.SessionState{
|
||||
ExpiresOn: &expires,
|
||||
})
|
||||
assert.Equal(t, false, refreshed)
|
||||
|
||||
@@ -9,14 +9,14 @@ import (
|
||||
// Provider represents an upstream identity provider implementation
|
||||
type Provider interface {
|
||||
Data() *ProviderData
|
||||
GetLoginURL(redirectURI, finalRedirect string, nonce string) string
|
||||
Redeem(ctx context.Context, redirectURI, code string) (*sessions.SessionState, error)
|
||||
// Deprecated: Migrate to EnrichSession
|
||||
GetEmailAddress(ctx context.Context, s *sessions.SessionState) (string, error)
|
||||
GetLoginURL(redirectURI, state, nonce string) string
|
||||
Redeem(ctx context.Context, redirectURI, code string) (*sessions.SessionState, error)
|
||||
EnrichSession(ctx context.Context, s *sessions.SessionState) error
|
||||
Authorize(ctx context.Context, s *sessions.SessionState) (bool, error)
|
||||
ValidateSession(ctx context.Context, s *sessions.SessionState) bool
|
||||
RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error)
|
||||
RefreshSession(ctx context.Context, s *sessions.SessionState) (bool, error)
|
||||
CreateSessionFromToken(ctx context.Context, token string) (*sessions.SessionState, error)
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user