You've already forked oauth2-proxy
mirror of
https://github.com/oauth2-proxy/oauth2-proxy.git
synced 2025-06-15 00:15:00 +02:00
RefreshSessions immediately when called
This commit is contained in:
@ -361,10 +361,10 @@ func buildSessionChain(opts *options.Options, sessionStore sessionsapi.SessionSt
|
|||||||
}
|
}
|
||||||
|
|
||||||
chain = chain.Append(middleware.NewStoredSessionLoader(&middleware.StoredSessionLoaderOptions{
|
chain = chain.Append(middleware.NewStoredSessionLoader(&middleware.StoredSessionLoaderOptions{
|
||||||
SessionStore: sessionStore,
|
SessionStore: sessionStore,
|
||||||
RefreshPeriod: opts.Cookie.Refresh,
|
RefreshPeriod: opts.Cookie.Refresh,
|
||||||
RefreshSessionIfNeeded: opts.GetProvider().RefreshSessionIfNeeded,
|
RefreshSession: opts.GetProvider().RefreshSession,
|
||||||
ValidateSessionState: opts.GetProvider().ValidateSession,
|
ValidateSession: opts.GetProvider().ValidateSession,
|
||||||
}))
|
}))
|
||||||
|
|
||||||
return chain
|
return chain
|
||||||
|
@ -24,12 +24,12 @@ type StoredSessionLoaderOptions struct {
|
|||||||
RefreshPeriod time.Duration
|
RefreshPeriod time.Duration
|
||||||
|
|
||||||
// Provider based sesssion refreshing
|
// Provider based sesssion refreshing
|
||||||
RefreshSessionIfNeeded func(context.Context, *sessionsapi.SessionState) (bool, error)
|
RefreshSession func(context.Context, *sessionsapi.SessionState) (bool, error)
|
||||||
|
|
||||||
// Provider based session validation.
|
// Provider based session validation.
|
||||||
// If the sesssion is older than `RefreshPeriod` but the provider doesn't
|
// If the sesssion is older than `RefreshPeriod` but the provider doesn't
|
||||||
// refresh it, we must re-validate using this validation.
|
// refresh it, we must re-validate using this validation.
|
||||||
ValidateSessionState func(context.Context, *sessionsapi.SessionState) bool
|
ValidateSession func(context.Context, *sessionsapi.SessionState) bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewStoredSessionLoader creates a new storedSessionLoader which loads
|
// NewStoredSessionLoader creates a new storedSessionLoader which loads
|
||||||
@ -38,10 +38,10 @@ type StoredSessionLoaderOptions struct {
|
|||||||
// If a session was loader by a previous handler, it will not be replaced.
|
// If a session was loader by a previous handler, it will not be replaced.
|
||||||
func NewStoredSessionLoader(opts *StoredSessionLoaderOptions) alice.Constructor {
|
func NewStoredSessionLoader(opts *StoredSessionLoaderOptions) alice.Constructor {
|
||||||
ss := &storedSessionLoader{
|
ss := &storedSessionLoader{
|
||||||
store: opts.SessionStore,
|
store: opts.SessionStore,
|
||||||
refreshPeriod: opts.RefreshPeriod,
|
refreshPeriod: opts.RefreshPeriod,
|
||||||
refreshSessionWithProviderIfNeeded: opts.RefreshSessionIfNeeded,
|
sessionRefresher: opts.RefreshSession,
|
||||||
validateSessionState: opts.ValidateSessionState,
|
sessionValidator: opts.ValidateSession,
|
||||||
}
|
}
|
||||||
return ss.loadSession
|
return ss.loadSession
|
||||||
}
|
}
|
||||||
@ -49,10 +49,10 @@ func NewStoredSessionLoader(opts *StoredSessionLoaderOptions) alice.Constructor
|
|||||||
// storedSessionLoader is responsible for loading sessions from cookie
|
// storedSessionLoader is responsible for loading sessions from cookie
|
||||||
// identified sessions in the session store.
|
// identified sessions in the session store.
|
||||||
type storedSessionLoader struct {
|
type storedSessionLoader struct {
|
||||||
store sessionsapi.SessionStore
|
store sessionsapi.SessionStore
|
||||||
refreshPeriod time.Duration
|
refreshPeriod time.Duration
|
||||||
refreshSessionWithProviderIfNeeded func(context.Context, *sessionsapi.SessionState) (bool, error)
|
sessionRefresher func(context.Context, *sessionsapi.SessionState) (bool, error)
|
||||||
validateSessionState func(context.Context, *sessionsapi.SessionState) bool
|
sessionValidator func(context.Context, *sessionsapi.SessionState) bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// loadSession attempts to load a session as identified by the request cookies.
|
// loadSession attempts to load a session as identified by the request cookies.
|
||||||
@ -120,37 +120,38 @@ func (s *storedSessionLoader) refreshSessionIfNeeded(rw http.ResponseWriter, req
|
|||||||
}
|
}
|
||||||
|
|
||||||
logger.Printf("Refreshing %s old session cookie for %s (refresh after %s)", session.Age(), session, s.refreshPeriod)
|
logger.Printf("Refreshing %s old session cookie for %s (refresh after %s)", session.Age(), session, s.refreshPeriod)
|
||||||
refreshed, err := s.refreshSessionWithProvider(rw, req, session)
|
err := s.refreshSession(rw, req, session)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if !refreshed {
|
// Validate all sessions after any Redeem/Refresh operation
|
||||||
// Session wasn't refreshed, so make sure it's still valid
|
return s.validateSession(req.Context(), session)
|
||||||
return s.validateSession(req.Context(), session)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// refreshSessionWithProvider attempts to refresh the sessinon with the provider
|
// refreshSession attempts to refresh the session with the provider
|
||||||
// and will save the session if it was updated.
|
// and will save the session if it was updated.
|
||||||
func (s *storedSessionLoader) refreshSessionWithProvider(rw http.ResponseWriter, req *http.Request, session *sessionsapi.SessionState) (bool, error) {
|
func (s *storedSessionLoader) refreshSession(rw http.ResponseWriter, req *http.Request, session *sessionsapi.SessionState) error {
|
||||||
refreshed, err := s.refreshSessionWithProviderIfNeeded(req.Context(), session)
|
refreshed, err := s.sessionRefresher(req.Context(), session)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, fmt.Errorf("error refreshing access token: %v", err)
|
return fmt.Errorf("error refreshing access token: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !refreshed {
|
if !refreshed {
|
||||||
return false, nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If we refreshed, update the `CreatedAt` time to reset the refresh timer
|
||||||
|
// TODO: Implement
|
||||||
|
// session.CreatedAtNow()
|
||||||
|
|
||||||
// Because the session was refreshed, make sure to save it
|
// Because the session was refreshed, make sure to save it
|
||||||
err = s.store.Save(rw, req, session)
|
err = s.store.Save(rw, req, session)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.PrintAuthf(session.Email, req, logger.AuthError, "error saving session: %v", err)
|
logger.PrintAuthf(session.Email, req, logger.AuthError, "error saving session: %v", err)
|
||||||
return false, fmt.Errorf("error saving session: %v", err)
|
return fmt.Errorf("error saving session: %v", err)
|
||||||
}
|
}
|
||||||
return true, nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// validateSession checks whether the session has expired and performs
|
// validateSession checks whether the session has expired and performs
|
||||||
@ -161,7 +162,7 @@ func (s *storedSessionLoader) validateSession(ctx context.Context, session *sess
|
|||||||
return errors.New("session is expired")
|
return errors.New("session is expired")
|
||||||
}
|
}
|
||||||
|
|
||||||
if !s.validateSessionState(ctx, session) {
|
if !s.sessionValidator(ctx, session) {
|
||||||
return errors.New("session is invalid")
|
return errors.New("session is invalid")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -109,10 +109,10 @@ var _ = Describe("Stored Session Suite", func() {
|
|||||||
rw := httptest.NewRecorder()
|
rw := httptest.NewRecorder()
|
||||||
|
|
||||||
opts := &StoredSessionLoaderOptions{
|
opts := &StoredSessionLoaderOptions{
|
||||||
SessionStore: in.store,
|
SessionStore: in.store,
|
||||||
RefreshPeriod: in.refreshPeriod,
|
RefreshPeriod: in.refreshPeriod,
|
||||||
RefreshSessionIfNeeded: in.refreshSession,
|
RefreshSession: in.refreshSession,
|
||||||
ValidateSessionState: in.validateSession,
|
ValidateSession: in.validateSession,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create the handler with a next handler that will capture the session
|
// Create the handler with a next handler that will capture the session
|
||||||
@ -261,7 +261,7 @@ var _ = Describe("Stored Session Suite", func() {
|
|||||||
s := &storedSessionLoader{
|
s := &storedSessionLoader{
|
||||||
refreshPeriod: in.refreshPeriod,
|
refreshPeriod: in.refreshPeriod,
|
||||||
store: &fakeSessionStore{},
|
store: &fakeSessionStore{},
|
||||||
refreshSessionWithProviderIfNeeded: func(_ context.Context, ss *sessionsapi.SessionState) (bool, error) {
|
sessionRefresher: func(_ context.Context, ss *sessionsapi.SessionState) (bool, error) {
|
||||||
refreshed = true
|
refreshed = true
|
||||||
switch ss.RefreshToken {
|
switch ss.RefreshToken {
|
||||||
case refresh:
|
case refresh:
|
||||||
@ -272,7 +272,7 @@ var _ = Describe("Stored Session Suite", func() {
|
|||||||
return false, errors.New("error refreshing session")
|
return false, errors.New("error refreshing session")
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
validateSessionState: func(_ context.Context, ss *sessionsapi.SessionState) bool {
|
sessionValidator: func(_ context.Context, ss *sessionsapi.SessionState) bool {
|
||||||
validated = true
|
validated = true
|
||||||
return ss.AccessToken != "Invalid"
|
return ss.AccessToken != "Invalid"
|
||||||
},
|
},
|
||||||
@ -364,7 +364,7 @@ var _ = Describe("Stored Session Suite", func() {
|
|||||||
)
|
)
|
||||||
})
|
})
|
||||||
|
|
||||||
Context("refreshSessionWithProvider", func() {
|
Context("refreshSession", func() {
|
||||||
type refreshSessionWithProviderTableInput struct {
|
type refreshSessionWithProviderTableInput struct {
|
||||||
session *sessionsapi.SessionState
|
session *sessionsapi.SessionState
|
||||||
expectedErr error
|
expectedErr error
|
||||||
@ -388,7 +388,7 @@ var _ = Describe("Stored Session Suite", func() {
|
|||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
refreshSessionWithProviderIfNeeded: func(_ context.Context, ss *sessionsapi.SessionState) (bool, error) {
|
sessionRefresher: func(_ context.Context, ss *sessionsapi.SessionState) (bool, error) {
|
||||||
switch ss.RefreshToken {
|
switch ss.RefreshToken {
|
||||||
case refresh:
|
case refresh:
|
||||||
return true, nil
|
return true, nil
|
||||||
@ -402,13 +402,12 @@ var _ = Describe("Stored Session Suite", func() {
|
|||||||
|
|
||||||
req := httptest.NewRequest("", "/", nil)
|
req := httptest.NewRequest("", "/", nil)
|
||||||
req = middlewareapi.AddRequestScope(req, &middlewareapi.RequestScope{})
|
req = middlewareapi.AddRequestScope(req, &middlewareapi.RequestScope{})
|
||||||
refreshed, err := s.refreshSessionWithProvider(nil, req, in.session)
|
err := s.refreshSession(nil, req, in.session)
|
||||||
if in.expectedErr != nil {
|
if in.expectedErr != nil {
|
||||||
Expect(err).To(MatchError(in.expectedErr))
|
Expect(err).To(MatchError(in.expectedErr))
|
||||||
} else {
|
} else {
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
}
|
}
|
||||||
Expect(refreshed).To(Equal(in.expectRefreshed))
|
|
||||||
Expect(saved).To(Equal(in.expectSaved))
|
Expect(saved).To(Equal(in.expectSaved))
|
||||||
},
|
},
|
||||||
Entry("when the provider does not refresh the session", refreshSessionWithProviderTableInput{
|
Entry("when the provider does not refresh the session", refreshSessionWithProviderTableInput{
|
||||||
@ -416,7 +415,6 @@ var _ = Describe("Stored Session Suite", func() {
|
|||||||
RefreshToken: noRefresh,
|
RefreshToken: noRefresh,
|
||||||
},
|
},
|
||||||
expectedErr: nil,
|
expectedErr: nil,
|
||||||
expectRefreshed: false,
|
|
||||||
expectSaved: false,
|
expectSaved: false,
|
||||||
}),
|
}),
|
||||||
Entry("when the provider refreshes the session", refreshSessionWithProviderTableInput{
|
Entry("when the provider refreshes the session", refreshSessionWithProviderTableInput{
|
||||||
@ -424,7 +422,6 @@ var _ = Describe("Stored Session Suite", func() {
|
|||||||
RefreshToken: refresh,
|
RefreshToken: refresh,
|
||||||
},
|
},
|
||||||
expectedErr: nil,
|
expectedErr: nil,
|
||||||
expectRefreshed: true,
|
|
||||||
expectSaved: true,
|
expectSaved: true,
|
||||||
}),
|
}),
|
||||||
Entry("when the provider returns an error", refreshSessionWithProviderTableInput{
|
Entry("when the provider returns an error", refreshSessionWithProviderTableInput{
|
||||||
@ -434,7 +431,6 @@ var _ = Describe("Stored Session Suite", func() {
|
|||||||
ExpiresOn: &now,
|
ExpiresOn: &now,
|
||||||
},
|
},
|
||||||
expectedErr: errors.New("error refreshing access token: error refreshing session"),
|
expectedErr: errors.New("error refreshing access token: error refreshing session"),
|
||||||
expectRefreshed: false,
|
|
||||||
expectSaved: false,
|
expectSaved: false,
|
||||||
}),
|
}),
|
||||||
Entry("when the saving the session returns an error", refreshSessionWithProviderTableInput{
|
Entry("when the saving the session returns an error", refreshSessionWithProviderTableInput{
|
||||||
@ -443,7 +439,6 @@ var _ = Describe("Stored Session Suite", func() {
|
|||||||
AccessToken: "NoSave",
|
AccessToken: "NoSave",
|
||||||
},
|
},
|
||||||
expectedErr: errors.New("error saving session: unable to save session"),
|
expectedErr: errors.New("error saving session: unable to save session"),
|
||||||
expectRefreshed: false,
|
|
||||||
expectSaved: true,
|
expectSaved: true,
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
@ -454,7 +449,7 @@ var _ = Describe("Stored Session Suite", func() {
|
|||||||
|
|
||||||
BeforeEach(func() {
|
BeforeEach(func() {
|
||||||
s = &storedSessionLoader{
|
s = &storedSessionLoader{
|
||||||
validateSessionState: func(_ context.Context, ss *sessionsapi.SessionState) bool {
|
sessionValidator: func(_ context.Context, ss *sessionsapi.SessionState) bool {
|
||||||
return ss.AccessToken == "Valid"
|
return ss.AccessToken == "Valid"
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
@ -345,7 +345,7 @@ func TestAzureProviderNotRefreshWhenNotExpired(t *testing.T) {
|
|||||||
|
|
||||||
expires := time.Now().Add(time.Duration(1) * time.Hour)
|
expires := time.Now().Add(time.Duration(1) * time.Hour)
|
||||||
session := &sessions.SessionState{AccessToken: "some_access_token", RefreshToken: "some_refresh_token", IDToken: "some_id_token", ExpiresOn: &expires}
|
session := &sessions.SessionState{AccessToken: "some_access_token", RefreshToken: "some_refresh_token", IDToken: "some_id_token", ExpiresOn: &expires}
|
||||||
refreshNeeded, err := p.RefreshSessionIfNeeded(context.Background(), session)
|
refreshNeeded, err := p.RefreshSession(context.Background(), session)
|
||||||
assert.Equal(t, nil, err)
|
assert.Equal(t, nil, err)
|
||||||
assert.False(t, refreshNeeded)
|
assert.False(t, refreshNeeded)
|
||||||
}
|
}
|
||||||
@ -373,9 +373,10 @@ func TestAzureProviderRefreshWhenExpired(t *testing.T) {
|
|||||||
|
|
||||||
expires := time.Now().Add(time.Duration(-1) * time.Hour)
|
expires := time.Now().Add(time.Duration(-1) * time.Hour)
|
||||||
session := &sessions.SessionState{AccessToken: "some_access_token", RefreshToken: "some_refresh_token", IDToken: "some_id_token", ExpiresOn: &expires}
|
session := &sessions.SessionState{AccessToken: "some_access_token", RefreshToken: "some_refresh_token", IDToken: "some_id_token", ExpiresOn: &expires}
|
||||||
refreshNeeded, err := p.RefreshSessionIfNeeded(context.Background(), session)
|
|
||||||
|
refreshed, err := p.RefreshSession(context.Background(), session)
|
||||||
assert.Equal(t, nil, err)
|
assert.Equal(t, nil, err)
|
||||||
assert.True(t, refreshNeeded)
|
assert.True(t, refreshed)
|
||||||
assert.NotEqual(t, session, nil)
|
assert.NotEqual(t, session, nil)
|
||||||
assert.Equal(t, "new_some_access_token", session.AccessToken)
|
assert.Equal(t, "new_some_access_token", session.AccessToken)
|
||||||
assert.Equal(t, "new_some_refresh_token", session.RefreshToken)
|
assert.Equal(t, "new_some_refresh_token", session.RefreshToken)
|
||||||
|
@ -88,7 +88,7 @@ func (p *FacebookProvider) GetEmailAddress(ctx context.Context, s *sessions.Sess
|
|||||||
return r.Email, nil
|
return r.Email, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ValidateSessionState validates the AccessToken
|
// ValidateSession validates the AccessToken
|
||||||
func (p *FacebookProvider) ValidateSession(ctx context.Context, s *sessions.SessionState) bool {
|
func (p *FacebookProvider) ValidateSession(ctx context.Context, s *sessions.SessionState) bool {
|
||||||
return validateToken(ctx, p, s.AccessToken, makeOIDCHeader(s.AccessToken))
|
return validateToken(ctx, p, s.AccessToken, makeOIDCHeader(s.AccessToken))
|
||||||
}
|
}
|
||||||
|
@ -121,10 +121,9 @@ func (p *GitLabProvider) SetProjectScope() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// RefreshSessionIfNeeded checks if the session has expired and uses the
|
// RefreshSession uses the RefreshToken to fetch new Access and ID Tokens
|
||||||
// RefreshToken to fetch a new ID token if required
|
func (p *GitLabProvider) RefreshSession(ctx context.Context, s *sessions.SessionState) (bool, error) {
|
||||||
func (p *GitLabProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error) {
|
if s == nil || s.RefreshToken == "" {
|
||||||
if s == nil || (s.ExpiresOn != nil && s.ExpiresOn.After(time.Now())) || s.RefreshToken == "" {
|
|
||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -139,10 +138,10 @@ func (p *GitLabProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions
|
|||||||
return true, nil
|
return true, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *GitLabProvider) redeemRefreshToken(ctx context.Context, s *sessions.SessionState) (err error) {
|
func (p *GitLabProvider) redeemRefreshToken(ctx context.Context, s *sessions.SessionState) error {
|
||||||
clientSecret, err := p.GetClientSecret()
|
clientSecret, err := p.GetClientSecret()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
c := oauth2.Config{
|
c := oauth2.Config{
|
||||||
@ -164,13 +163,9 @@ func (p *GitLabProvider) redeemRefreshToken(ctx context.Context, s *sessions.Ses
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("unable to update session: %v", err)
|
return fmt.Errorf("unable to update session: %v", err)
|
||||||
}
|
}
|
||||||
s.AccessToken = newSession.AccessToken
|
*s = *newSession
|
||||||
s.IDToken = newSession.IDToken
|
|
||||||
s.RefreshToken = newSession.RefreshToken
|
return nil
|
||||||
s.CreatedAt = newSession.CreatedAt
|
|
||||||
s.ExpiresOn = newSession.ExpiresOn
|
|
||||||
s.Email = newSession.Email
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type gitlabUserInfo struct {
|
type gitlabUserInfo struct {
|
||||||
|
@ -266,10 +266,9 @@ func userInGroup(service *admin.Service, group string, email string) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// RefreshSessionIfNeeded checks if the session has expired and uses the
|
// RefreshSession uses the RefreshToken to fetch new Access and ID Tokens
|
||||||
// RefreshToken to fetch a new ID token if required
|
func (p *GoogleProvider) RefreshSession(ctx context.Context, s *sessions.SessionState) (bool, error) {
|
||||||
func (p *GoogleProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error) {
|
if s == nil || s.RefreshToken == "" {
|
||||||
if s == nil || (s.ExpiresOn != nil && s.ExpiresOn.After(time.Now())) || s.RefreshToken == "" {
|
|
||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -93,7 +93,7 @@ func (p *LinkedInProvider) GetEmailAddress(ctx context.Context, s *sessions.Sess
|
|||||||
return email, nil
|
return email, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ValidateSessionState validates the AccessToken
|
// ValidateSession validates the AccessToken
|
||||||
func (p *LinkedInProvider) ValidateSession(ctx context.Context, s *sessions.SessionState) bool {
|
func (p *LinkedInProvider) ValidateSession(ctx context.Context, s *sessions.SessionState) bool {
|
||||||
return validateToken(ctx, p, s.AccessToken, makeLinkedInHeader(s.AccessToken))
|
return validateToken(ctx, p, s.AccessToken, makeLinkedInHeader(s.AccessToken))
|
||||||
}
|
}
|
||||||
|
@ -143,10 +143,9 @@ func (p *OIDCProvider) ValidateSession(ctx context.Context, s *sessions.SessionS
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
// RefreshSessionIfNeeded checks if the session has expired and uses the
|
// RefreshSession uses the RefreshToken to fetch new Access and ID Tokens
|
||||||
// RefreshToken to fetch a new Access Token (and optional ID token) if required
|
func (p *OIDCProvider) RefreshSession(ctx context.Context, s *sessions.SessionState) (bool, error) {
|
||||||
func (p *OIDCProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error) {
|
if s == nil || s.RefreshToken == "" {
|
||||||
if s == nil || (s.ExpiresOn != nil && s.ExpiresOn.After(time.Now())) || s.RefreshToken == "" {
|
|
||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -487,7 +487,7 @@ func TestOIDCProviderRefreshSessionIfNeededWithoutIdToken(t *testing.T) {
|
|||||||
User: "11223344",
|
User: "11223344",
|
||||||
}
|
}
|
||||||
|
|
||||||
refreshed, err := provider.RefreshSessionIfNeeded(context.Background(), existingSession)
|
refreshed, err := provider.RefreshSession(context.Background(), existingSession)
|
||||||
assert.Equal(t, nil, err)
|
assert.Equal(t, nil, err)
|
||||||
assert.Equal(t, refreshed, true)
|
assert.Equal(t, refreshed, true)
|
||||||
assert.Equal(t, "janedoe@example.com", existingSession.Email)
|
assert.Equal(t, "janedoe@example.com", existingSession.Email)
|
||||||
@ -520,7 +520,7 @@ func TestOIDCProviderRefreshSessionIfNeededWithIdToken(t *testing.T) {
|
|||||||
Email: "changeit",
|
Email: "changeit",
|
||||||
User: "changeit",
|
User: "changeit",
|
||||||
}
|
}
|
||||||
refreshed, err := provider.RefreshSessionIfNeeded(context.Background(), existingSession)
|
refreshed, err := provider.RefreshSession(context.Background(), existingSession)
|
||||||
assert.Equal(t, nil, err)
|
assert.Equal(t, nil, err)
|
||||||
assert.Equal(t, refreshed, true)
|
assert.Equal(t, refreshed, true)
|
||||||
assert.Equal(t, defaultIDToken.Email, existingSession.Email)
|
assert.Equal(t, defaultIDToken.Email, existingSession.Email)
|
||||||
|
@ -126,10 +126,15 @@ func (p *ProviderData) ValidateSession(ctx context.Context, s *sessions.SessionS
|
|||||||
return validateToken(ctx, p, s.AccessToken, nil)
|
return validateToken(ctx, p, s.AccessToken, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
// RefreshSessionIfNeeded should refresh the user's session if required and
|
// RefreshSession refreshes the user's session
|
||||||
// do nothing if a refresh is not required
|
func (p *ProviderData) RefreshSession(_ context.Context, s *sessions.SessionState) (bool, error) {
|
||||||
func (p *ProviderData) RefreshSessionIfNeeded(_ context.Context, _ *sessions.SessionState) (bool, error) {
|
if s == nil {
|
||||||
return false, nil
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pretend `RefreshSession` occured so `ValidateSession` isn't called
|
||||||
|
// on every request after any potential set refresh period elapses.
|
||||||
|
return true, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateSessionFromToken converts Bearer IDTokens into sessions
|
// CreateSessionFromToken converts Bearer IDTokens into sessions
|
||||||
|
@ -15,7 +15,7 @@ func TestRefresh(t *testing.T) {
|
|||||||
p := &ProviderData{}
|
p := &ProviderData{}
|
||||||
|
|
||||||
expires := time.Now().Add(time.Duration(-11) * time.Minute)
|
expires := time.Now().Add(time.Duration(-11) * time.Minute)
|
||||||
refreshed, err := p.RefreshSessionIfNeeded(context.Background(), &sessions.SessionState{
|
refreshed, err := p.RefreshSession(context.Background(), &sessions.SessionState{
|
||||||
ExpiresOn: &expires,
|
ExpiresOn: &expires,
|
||||||
})
|
})
|
||||||
assert.Equal(t, false, refreshed)
|
assert.Equal(t, false, refreshed)
|
||||||
|
@ -9,14 +9,14 @@ import (
|
|||||||
// Provider represents an upstream identity provider implementation
|
// Provider represents an upstream identity provider implementation
|
||||||
type Provider interface {
|
type Provider interface {
|
||||||
Data() *ProviderData
|
Data() *ProviderData
|
||||||
|
GetLoginURL(redirectURI, finalRedirect string, nonce string) string
|
||||||
|
Redeem(ctx context.Context, redirectURI, code string) (*sessions.SessionState, error)
|
||||||
// Deprecated: Migrate to EnrichSession
|
// Deprecated: Migrate to EnrichSession
|
||||||
GetEmailAddress(ctx context.Context, s *sessions.SessionState) (string, error)
|
GetEmailAddress(ctx context.Context, s *sessions.SessionState) (string, error)
|
||||||
GetLoginURL(redirectURI, state, nonce string) string
|
|
||||||
Redeem(ctx context.Context, redirectURI, code string) (*sessions.SessionState, error)
|
|
||||||
EnrichSession(ctx context.Context, s *sessions.SessionState) error
|
EnrichSession(ctx context.Context, s *sessions.SessionState) error
|
||||||
Authorize(ctx context.Context, s *sessions.SessionState) (bool, error)
|
Authorize(ctx context.Context, s *sessions.SessionState) (bool, error)
|
||||||
ValidateSession(ctx context.Context, s *sessions.SessionState) bool
|
ValidateSession(ctx context.Context, s *sessions.SessionState) bool
|
||||||
RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error)
|
RefreshSession(ctx context.Context, s *sessions.SessionState) (bool, error)
|
||||||
CreateSessionFromToken(ctx context.Context, token string) (*sessions.SessionState, error)
|
CreateSessionFromToken(ctx context.Context, token string) (*sessions.SessionState, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user