1
0
mirror of https://github.com/oauth2-proxy/oauth2-proxy.git synced 2025-06-15 00:15:00 +02:00

RefreshSessions immediately when called

This commit is contained in:
Nick Meves
2021-03-06 15:33:13 -08:00
parent 5f4ac25b1e
commit 7e80e5596b
13 changed files with 74 additions and 79 deletions

View File

@ -361,10 +361,10 @@ func buildSessionChain(opts *options.Options, sessionStore sessionsapi.SessionSt
} }
chain = chain.Append(middleware.NewStoredSessionLoader(&middleware.StoredSessionLoaderOptions{ chain = chain.Append(middleware.NewStoredSessionLoader(&middleware.StoredSessionLoaderOptions{
SessionStore: sessionStore, SessionStore: sessionStore,
RefreshPeriod: opts.Cookie.Refresh, RefreshPeriod: opts.Cookie.Refresh,
RefreshSessionIfNeeded: opts.GetProvider().RefreshSessionIfNeeded, RefreshSession: opts.GetProvider().RefreshSession,
ValidateSessionState: opts.GetProvider().ValidateSession, ValidateSession: opts.GetProvider().ValidateSession,
})) }))
return chain return chain

View File

@ -24,12 +24,12 @@ type StoredSessionLoaderOptions struct {
RefreshPeriod time.Duration RefreshPeriod time.Duration
// Provider based sesssion refreshing // Provider based sesssion refreshing
RefreshSessionIfNeeded func(context.Context, *sessionsapi.SessionState) (bool, error) RefreshSession func(context.Context, *sessionsapi.SessionState) (bool, error)
// Provider based session validation. // Provider based session validation.
// If the sesssion is older than `RefreshPeriod` but the provider doesn't // If the sesssion is older than `RefreshPeriod` but the provider doesn't
// refresh it, we must re-validate using this validation. // refresh it, we must re-validate using this validation.
ValidateSessionState func(context.Context, *sessionsapi.SessionState) bool ValidateSession func(context.Context, *sessionsapi.SessionState) bool
} }
// NewStoredSessionLoader creates a new storedSessionLoader which loads // NewStoredSessionLoader creates a new storedSessionLoader which loads
@ -38,10 +38,10 @@ type StoredSessionLoaderOptions struct {
// If a session was loader by a previous handler, it will not be replaced. // If a session was loader by a previous handler, it will not be replaced.
func NewStoredSessionLoader(opts *StoredSessionLoaderOptions) alice.Constructor { func NewStoredSessionLoader(opts *StoredSessionLoaderOptions) alice.Constructor {
ss := &storedSessionLoader{ ss := &storedSessionLoader{
store: opts.SessionStore, store: opts.SessionStore,
refreshPeriod: opts.RefreshPeriod, refreshPeriod: opts.RefreshPeriod,
refreshSessionWithProviderIfNeeded: opts.RefreshSessionIfNeeded, sessionRefresher: opts.RefreshSession,
validateSessionState: opts.ValidateSessionState, sessionValidator: opts.ValidateSession,
} }
return ss.loadSession return ss.loadSession
} }
@ -49,10 +49,10 @@ func NewStoredSessionLoader(opts *StoredSessionLoaderOptions) alice.Constructor
// storedSessionLoader is responsible for loading sessions from cookie // storedSessionLoader is responsible for loading sessions from cookie
// identified sessions in the session store. // identified sessions in the session store.
type storedSessionLoader struct { type storedSessionLoader struct {
store sessionsapi.SessionStore store sessionsapi.SessionStore
refreshPeriod time.Duration refreshPeriod time.Duration
refreshSessionWithProviderIfNeeded func(context.Context, *sessionsapi.SessionState) (bool, error) sessionRefresher func(context.Context, *sessionsapi.SessionState) (bool, error)
validateSessionState func(context.Context, *sessionsapi.SessionState) bool sessionValidator func(context.Context, *sessionsapi.SessionState) bool
} }
// loadSession attempts to load a session as identified by the request cookies. // loadSession attempts to load a session as identified by the request cookies.
@ -120,37 +120,38 @@ func (s *storedSessionLoader) refreshSessionIfNeeded(rw http.ResponseWriter, req
} }
logger.Printf("Refreshing %s old session cookie for %s (refresh after %s)", session.Age(), session, s.refreshPeriod) logger.Printf("Refreshing %s old session cookie for %s (refresh after %s)", session.Age(), session, s.refreshPeriod)
refreshed, err := s.refreshSessionWithProvider(rw, req, session) err := s.refreshSession(rw, req, session)
if err != nil { if err != nil {
return err return err
} }
if !refreshed { // Validate all sessions after any Redeem/Refresh operation
// Session wasn't refreshed, so make sure it's still valid return s.validateSession(req.Context(), session)
return s.validateSession(req.Context(), session)
}
return nil
} }
// refreshSessionWithProvider attempts to refresh the sessinon with the provider // refreshSession attempts to refresh the session with the provider
// and will save the session if it was updated. // and will save the session if it was updated.
func (s *storedSessionLoader) refreshSessionWithProvider(rw http.ResponseWriter, req *http.Request, session *sessionsapi.SessionState) (bool, error) { func (s *storedSessionLoader) refreshSession(rw http.ResponseWriter, req *http.Request, session *sessionsapi.SessionState) error {
refreshed, err := s.refreshSessionWithProviderIfNeeded(req.Context(), session) refreshed, err := s.sessionRefresher(req.Context(), session)
if err != nil { if err != nil {
return false, fmt.Errorf("error refreshing access token: %v", err) return fmt.Errorf("error refreshing access token: %v", err)
} }
if !refreshed { if !refreshed {
return false, nil return nil
} }
// If we refreshed, update the `CreatedAt` time to reset the refresh timer
// TODO: Implement
// session.CreatedAtNow()
// Because the session was refreshed, make sure to save it // Because the session was refreshed, make sure to save it
err = s.store.Save(rw, req, session) err = s.store.Save(rw, req, session)
if err != nil { if err != nil {
logger.PrintAuthf(session.Email, req, logger.AuthError, "error saving session: %v", err) logger.PrintAuthf(session.Email, req, logger.AuthError, "error saving session: %v", err)
return false, fmt.Errorf("error saving session: %v", err) return fmt.Errorf("error saving session: %v", err)
} }
return true, nil return nil
} }
// validateSession checks whether the session has expired and performs // validateSession checks whether the session has expired and performs
@ -161,7 +162,7 @@ func (s *storedSessionLoader) validateSession(ctx context.Context, session *sess
return errors.New("session is expired") return errors.New("session is expired")
} }
if !s.validateSessionState(ctx, session) { if !s.sessionValidator(ctx, session) {
return errors.New("session is invalid") return errors.New("session is invalid")
} }

View File

@ -109,10 +109,10 @@ var _ = Describe("Stored Session Suite", func() {
rw := httptest.NewRecorder() rw := httptest.NewRecorder()
opts := &StoredSessionLoaderOptions{ opts := &StoredSessionLoaderOptions{
SessionStore: in.store, SessionStore: in.store,
RefreshPeriod: in.refreshPeriod, RefreshPeriod: in.refreshPeriod,
RefreshSessionIfNeeded: in.refreshSession, RefreshSession: in.refreshSession,
ValidateSessionState: in.validateSession, ValidateSession: in.validateSession,
} }
// Create the handler with a next handler that will capture the session // Create the handler with a next handler that will capture the session
@ -261,7 +261,7 @@ var _ = Describe("Stored Session Suite", func() {
s := &storedSessionLoader{ s := &storedSessionLoader{
refreshPeriod: in.refreshPeriod, refreshPeriod: in.refreshPeriod,
store: &fakeSessionStore{}, store: &fakeSessionStore{},
refreshSessionWithProviderIfNeeded: func(_ context.Context, ss *sessionsapi.SessionState) (bool, error) { sessionRefresher: func(_ context.Context, ss *sessionsapi.SessionState) (bool, error) {
refreshed = true refreshed = true
switch ss.RefreshToken { switch ss.RefreshToken {
case refresh: case refresh:
@ -272,7 +272,7 @@ var _ = Describe("Stored Session Suite", func() {
return false, errors.New("error refreshing session") return false, errors.New("error refreshing session")
} }
}, },
validateSessionState: func(_ context.Context, ss *sessionsapi.SessionState) bool { sessionValidator: func(_ context.Context, ss *sessionsapi.SessionState) bool {
validated = true validated = true
return ss.AccessToken != "Invalid" return ss.AccessToken != "Invalid"
}, },
@ -364,7 +364,7 @@ var _ = Describe("Stored Session Suite", func() {
) )
}) })
Context("refreshSessionWithProvider", func() { Context("refreshSession", func() {
type refreshSessionWithProviderTableInput struct { type refreshSessionWithProviderTableInput struct {
session *sessionsapi.SessionState session *sessionsapi.SessionState
expectedErr error expectedErr error
@ -388,7 +388,7 @@ var _ = Describe("Stored Session Suite", func() {
return nil return nil
}, },
}, },
refreshSessionWithProviderIfNeeded: func(_ context.Context, ss *sessionsapi.SessionState) (bool, error) { sessionRefresher: func(_ context.Context, ss *sessionsapi.SessionState) (bool, error) {
switch ss.RefreshToken { switch ss.RefreshToken {
case refresh: case refresh:
return true, nil return true, nil
@ -402,13 +402,12 @@ var _ = Describe("Stored Session Suite", func() {
req := httptest.NewRequest("", "/", nil) req := httptest.NewRequest("", "/", nil)
req = middlewareapi.AddRequestScope(req, &middlewareapi.RequestScope{}) req = middlewareapi.AddRequestScope(req, &middlewareapi.RequestScope{})
refreshed, err := s.refreshSessionWithProvider(nil, req, in.session) err := s.refreshSession(nil, req, in.session)
if in.expectedErr != nil { if in.expectedErr != nil {
Expect(err).To(MatchError(in.expectedErr)) Expect(err).To(MatchError(in.expectedErr))
} else { } else {
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
} }
Expect(refreshed).To(Equal(in.expectRefreshed))
Expect(saved).To(Equal(in.expectSaved)) Expect(saved).To(Equal(in.expectSaved))
}, },
Entry("when the provider does not refresh the session", refreshSessionWithProviderTableInput{ Entry("when the provider does not refresh the session", refreshSessionWithProviderTableInput{
@ -416,7 +415,6 @@ var _ = Describe("Stored Session Suite", func() {
RefreshToken: noRefresh, RefreshToken: noRefresh,
}, },
expectedErr: nil, expectedErr: nil,
expectRefreshed: false,
expectSaved: false, expectSaved: false,
}), }),
Entry("when the provider refreshes the session", refreshSessionWithProviderTableInput{ Entry("when the provider refreshes the session", refreshSessionWithProviderTableInput{
@ -424,7 +422,6 @@ var _ = Describe("Stored Session Suite", func() {
RefreshToken: refresh, RefreshToken: refresh,
}, },
expectedErr: nil, expectedErr: nil,
expectRefreshed: true,
expectSaved: true, expectSaved: true,
}), }),
Entry("when the provider returns an error", refreshSessionWithProviderTableInput{ Entry("when the provider returns an error", refreshSessionWithProviderTableInput{
@ -434,7 +431,6 @@ var _ = Describe("Stored Session Suite", func() {
ExpiresOn: &now, ExpiresOn: &now,
}, },
expectedErr: errors.New("error refreshing access token: error refreshing session"), expectedErr: errors.New("error refreshing access token: error refreshing session"),
expectRefreshed: false,
expectSaved: false, expectSaved: false,
}), }),
Entry("when the saving the session returns an error", refreshSessionWithProviderTableInput{ Entry("when the saving the session returns an error", refreshSessionWithProviderTableInput{
@ -443,7 +439,6 @@ var _ = Describe("Stored Session Suite", func() {
AccessToken: "NoSave", AccessToken: "NoSave",
}, },
expectedErr: errors.New("error saving session: unable to save session"), expectedErr: errors.New("error saving session: unable to save session"),
expectRefreshed: false,
expectSaved: true, expectSaved: true,
}), }),
) )
@ -454,7 +449,7 @@ var _ = Describe("Stored Session Suite", func() {
BeforeEach(func() { BeforeEach(func() {
s = &storedSessionLoader{ s = &storedSessionLoader{
validateSessionState: func(_ context.Context, ss *sessionsapi.SessionState) bool { sessionValidator: func(_ context.Context, ss *sessionsapi.SessionState) bool {
return ss.AccessToken == "Valid" return ss.AccessToken == "Valid"
}, },
} }

View File

@ -345,7 +345,7 @@ func TestAzureProviderNotRefreshWhenNotExpired(t *testing.T) {
expires := time.Now().Add(time.Duration(1) * time.Hour) expires := time.Now().Add(time.Duration(1) * time.Hour)
session := &sessions.SessionState{AccessToken: "some_access_token", RefreshToken: "some_refresh_token", IDToken: "some_id_token", ExpiresOn: &expires} session := &sessions.SessionState{AccessToken: "some_access_token", RefreshToken: "some_refresh_token", IDToken: "some_id_token", ExpiresOn: &expires}
refreshNeeded, err := p.RefreshSessionIfNeeded(context.Background(), session) refreshNeeded, err := p.RefreshSession(context.Background(), session)
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
assert.False(t, refreshNeeded) assert.False(t, refreshNeeded)
} }
@ -373,9 +373,10 @@ func TestAzureProviderRefreshWhenExpired(t *testing.T) {
expires := time.Now().Add(time.Duration(-1) * time.Hour) expires := time.Now().Add(time.Duration(-1) * time.Hour)
session := &sessions.SessionState{AccessToken: "some_access_token", RefreshToken: "some_refresh_token", IDToken: "some_id_token", ExpiresOn: &expires} session := &sessions.SessionState{AccessToken: "some_access_token", RefreshToken: "some_refresh_token", IDToken: "some_id_token", ExpiresOn: &expires}
refreshNeeded, err := p.RefreshSessionIfNeeded(context.Background(), session)
refreshed, err := p.RefreshSession(context.Background(), session)
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
assert.True(t, refreshNeeded) assert.True(t, refreshed)
assert.NotEqual(t, session, nil) assert.NotEqual(t, session, nil)
assert.Equal(t, "new_some_access_token", session.AccessToken) assert.Equal(t, "new_some_access_token", session.AccessToken)
assert.Equal(t, "new_some_refresh_token", session.RefreshToken) assert.Equal(t, "new_some_refresh_token", session.RefreshToken)

View File

@ -88,7 +88,7 @@ func (p *FacebookProvider) GetEmailAddress(ctx context.Context, s *sessions.Sess
return r.Email, nil return r.Email, nil
} }
// ValidateSessionState validates the AccessToken // ValidateSession validates the AccessToken
func (p *FacebookProvider) ValidateSession(ctx context.Context, s *sessions.SessionState) bool { func (p *FacebookProvider) ValidateSession(ctx context.Context, s *sessions.SessionState) bool {
return validateToken(ctx, p, s.AccessToken, makeOIDCHeader(s.AccessToken)) return validateToken(ctx, p, s.AccessToken, makeOIDCHeader(s.AccessToken))
} }

View File

@ -121,10 +121,9 @@ func (p *GitLabProvider) SetProjectScope() {
} }
} }
// RefreshSessionIfNeeded checks if the session has expired and uses the // RefreshSession uses the RefreshToken to fetch new Access and ID Tokens
// RefreshToken to fetch a new ID token if required func (p *GitLabProvider) RefreshSession(ctx context.Context, s *sessions.SessionState) (bool, error) {
func (p *GitLabProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error) { if s == nil || s.RefreshToken == "" {
if s == nil || (s.ExpiresOn != nil && s.ExpiresOn.After(time.Now())) || s.RefreshToken == "" {
return false, nil return false, nil
} }
@ -139,10 +138,10 @@ func (p *GitLabProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions
return true, nil return true, nil
} }
func (p *GitLabProvider) redeemRefreshToken(ctx context.Context, s *sessions.SessionState) (err error) { func (p *GitLabProvider) redeemRefreshToken(ctx context.Context, s *sessions.SessionState) error {
clientSecret, err := p.GetClientSecret() clientSecret, err := p.GetClientSecret()
if err != nil { if err != nil {
return return err
} }
c := oauth2.Config{ c := oauth2.Config{
@ -164,13 +163,9 @@ func (p *GitLabProvider) redeemRefreshToken(ctx context.Context, s *sessions.Ses
if err != nil { if err != nil {
return fmt.Errorf("unable to update session: %v", err) return fmt.Errorf("unable to update session: %v", err)
} }
s.AccessToken = newSession.AccessToken *s = *newSession
s.IDToken = newSession.IDToken
s.RefreshToken = newSession.RefreshToken return nil
s.CreatedAt = newSession.CreatedAt
s.ExpiresOn = newSession.ExpiresOn
s.Email = newSession.Email
return
} }
type gitlabUserInfo struct { type gitlabUserInfo struct {

View File

@ -266,10 +266,9 @@ func userInGroup(service *admin.Service, group string, email string) bool {
return false return false
} }
// RefreshSessionIfNeeded checks if the session has expired and uses the // RefreshSession uses the RefreshToken to fetch new Access and ID Tokens
// RefreshToken to fetch a new ID token if required func (p *GoogleProvider) RefreshSession(ctx context.Context, s *sessions.SessionState) (bool, error) {
func (p *GoogleProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error) { if s == nil || s.RefreshToken == "" {
if s == nil || (s.ExpiresOn != nil && s.ExpiresOn.After(time.Now())) || s.RefreshToken == "" {
return false, nil return false, nil
} }

View File

@ -93,7 +93,7 @@ func (p *LinkedInProvider) GetEmailAddress(ctx context.Context, s *sessions.Sess
return email, nil return email, nil
} }
// ValidateSessionState validates the AccessToken // ValidateSession validates the AccessToken
func (p *LinkedInProvider) ValidateSession(ctx context.Context, s *sessions.SessionState) bool { func (p *LinkedInProvider) ValidateSession(ctx context.Context, s *sessions.SessionState) bool {
return validateToken(ctx, p, s.AccessToken, makeLinkedInHeader(s.AccessToken)) return validateToken(ctx, p, s.AccessToken, makeLinkedInHeader(s.AccessToken))
} }

View File

@ -143,10 +143,9 @@ func (p *OIDCProvider) ValidateSession(ctx context.Context, s *sessions.SessionS
return true return true
} }
// RefreshSessionIfNeeded checks if the session has expired and uses the // RefreshSession uses the RefreshToken to fetch new Access and ID Tokens
// RefreshToken to fetch a new Access Token (and optional ID token) if required func (p *OIDCProvider) RefreshSession(ctx context.Context, s *sessions.SessionState) (bool, error) {
func (p *OIDCProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error) { if s == nil || s.RefreshToken == "" {
if s == nil || (s.ExpiresOn != nil && s.ExpiresOn.After(time.Now())) || s.RefreshToken == "" {
return false, nil return false, nil
} }

View File

@ -487,7 +487,7 @@ func TestOIDCProviderRefreshSessionIfNeededWithoutIdToken(t *testing.T) {
User: "11223344", User: "11223344",
} }
refreshed, err := provider.RefreshSessionIfNeeded(context.Background(), existingSession) refreshed, err := provider.RefreshSession(context.Background(), existingSession)
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
assert.Equal(t, refreshed, true) assert.Equal(t, refreshed, true)
assert.Equal(t, "janedoe@example.com", existingSession.Email) assert.Equal(t, "janedoe@example.com", existingSession.Email)
@ -520,7 +520,7 @@ func TestOIDCProviderRefreshSessionIfNeededWithIdToken(t *testing.T) {
Email: "changeit", Email: "changeit",
User: "changeit", User: "changeit",
} }
refreshed, err := provider.RefreshSessionIfNeeded(context.Background(), existingSession) refreshed, err := provider.RefreshSession(context.Background(), existingSession)
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
assert.Equal(t, refreshed, true) assert.Equal(t, refreshed, true)
assert.Equal(t, defaultIDToken.Email, existingSession.Email) assert.Equal(t, defaultIDToken.Email, existingSession.Email)

View File

@ -126,10 +126,15 @@ func (p *ProviderData) ValidateSession(ctx context.Context, s *sessions.SessionS
return validateToken(ctx, p, s.AccessToken, nil) return validateToken(ctx, p, s.AccessToken, nil)
} }
// RefreshSessionIfNeeded should refresh the user's session if required and // RefreshSession refreshes the user's session
// do nothing if a refresh is not required func (p *ProviderData) RefreshSession(_ context.Context, s *sessions.SessionState) (bool, error) {
func (p *ProviderData) RefreshSessionIfNeeded(_ context.Context, _ *sessions.SessionState) (bool, error) { if s == nil {
return false, nil return false, nil
}
// Pretend `RefreshSession` occured so `ValidateSession` isn't called
// on every request after any potential set refresh period elapses.
return true, nil
} }
// CreateSessionFromToken converts Bearer IDTokens into sessions // CreateSessionFromToken converts Bearer IDTokens into sessions

View File

@ -15,7 +15,7 @@ func TestRefresh(t *testing.T) {
p := &ProviderData{} p := &ProviderData{}
expires := time.Now().Add(time.Duration(-11) * time.Minute) expires := time.Now().Add(time.Duration(-11) * time.Minute)
refreshed, err := p.RefreshSessionIfNeeded(context.Background(), &sessions.SessionState{ refreshed, err := p.RefreshSession(context.Background(), &sessions.SessionState{
ExpiresOn: &expires, ExpiresOn: &expires,
}) })
assert.Equal(t, false, refreshed) assert.Equal(t, false, refreshed)

View File

@ -9,14 +9,14 @@ import (
// Provider represents an upstream identity provider implementation // Provider represents an upstream identity provider implementation
type Provider interface { type Provider interface {
Data() *ProviderData Data() *ProviderData
GetLoginURL(redirectURI, finalRedirect string, nonce string) string
Redeem(ctx context.Context, redirectURI, code string) (*sessions.SessionState, error)
// Deprecated: Migrate to EnrichSession // Deprecated: Migrate to EnrichSession
GetEmailAddress(ctx context.Context, s *sessions.SessionState) (string, error) GetEmailAddress(ctx context.Context, s *sessions.SessionState) (string, error)
GetLoginURL(redirectURI, state, nonce string) string
Redeem(ctx context.Context, redirectURI, code string) (*sessions.SessionState, error)
EnrichSession(ctx context.Context, s *sessions.SessionState) error EnrichSession(ctx context.Context, s *sessions.SessionState) error
Authorize(ctx context.Context, s *sessions.SessionState) (bool, error) Authorize(ctx context.Context, s *sessions.SessionState) (bool, error)
ValidateSession(ctx context.Context, s *sessions.SessionState) bool ValidateSession(ctx context.Context, s *sessions.SessionState) bool
RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error) RefreshSession(ctx context.Context, s *sessions.SessionState) (bool, error)
CreateSessionFromToken(ctx context.Context, token string) (*sessions.SessionState, error) CreateSessionFromToken(ctx context.Context, token string) (*sessions.SessionState, error)
} }