diff --git a/CHANGELOG.md b/CHANGELOG.md index b2446269..55841190 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,10 +4,16 @@ ## Important Notes +- [#1086](https://github.com/oauth2-proxy/oauth2-proxy/pull/1086) The extra validation to protect invalid session + deserialization from v6.0.0 (only) has been removed to improve performance. If you are on v6.0.0, either upgrade + to a version before this first and allow legacy sessions to expire gracefully or change your `cookie-secret` + value and force all sessions to reauthenticate. + ## Breaking Changes ## Changes since v7.1.3 +- [#1086](https://github.com/oauth2-proxy/oauth2-proxy/pull/1086) Refresh sessions before token expiration if configured (@NickMeves) - [#1226](https://github.com/oauth2-proxy/oauth2-proxy/pull/1226) Move app redirection logic to its own package (@JoelSpeed) - [#1128](https://github.com/oauth2-proxy/oauth2-proxy/pull/1128) Use gorilla mux for OAuth Proxy routing (@JoelSpeed) - [#1238](https://github.com/oauth2-proxy/oauth2-proxy/pull/1238) Added ADFS provider (@samirachoadi) diff --git a/oauthproxy.go b/oauthproxy.go index d6479609..e2d20ed6 100644 --- a/oauthproxy.go +++ b/oauthproxy.go @@ -361,10 +361,10 @@ func buildSessionChain(opts *options.Options, sessionStore sessionsapi.SessionSt } chain = chain.Append(middleware.NewStoredSessionLoader(&middleware.StoredSessionLoaderOptions{ - SessionStore: sessionStore, - RefreshPeriod: opts.Cookie.Refresh, - RefreshSessionIfNeeded: opts.GetProvider().RefreshSessionIfNeeded, - ValidateSessionState: opts.GetProvider().ValidateSession, + SessionStore: sessionStore, + RefreshPeriod: opts.Cookie.Refresh, + RefreshSession: opts.GetProvider().RefreshSession, + ValidateSession: opts.GetProvider().ValidateSession, })) return chain @@ -786,6 +786,15 @@ func (p *OAuthProxy) redeemCode(req *http.Request) (*sessionsapi.SessionState, e if err != nil { return nil, err } + + // Force setting these in case the Provider didn't + if s.CreatedAt == nil { + s.CreatedAtNow() + } + if s.ExpiresOn == nil { + s.ExpiresIn(p.CookieOptions.Expire) + } + return s, nil } diff --git a/pkg/apis/sessions/session_state.go b/pkg/apis/sessions/session_state.go index e1ee4a6c..08538dae 100644 --- a/pkg/apis/sessions/session_state.go +++ b/pkg/apis/sessions/session_state.go @@ -3,14 +3,12 @@ package sessions import ( "bytes" "context" - "errors" "fmt" "io" "io/ioutil" - "reflect" "time" - "unicode/utf8" + "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/clock" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/encryption" "github.com/pierrec/lz4" "github.com/vmihailenco/msgpack/v4" @@ -32,7 +30,9 @@ type SessionState struct { Groups []string `msgpack:"g,omitempty"` PreferredUsername string `msgpack:"pu,omitempty"` - Lock Lock `msgpack:"-"` + // Internal helpers, not serialized + Clock clock.Clock `msgpack:"-"` + Lock Lock `msgpack:"-"` } func (s *SessionState) ObtainLock(ctx context.Context, expiration time.Duration) error { @@ -63,9 +63,30 @@ func (s *SessionState) PeekLock(ctx context.Context) (bool, error) { return s.Lock.Peek(ctx) } +// CreatedAtNow sets a SessionState's CreatedAt to now +func (s *SessionState) CreatedAtNow() { + now := s.Clock.Now() + s.CreatedAt = &now +} + +// SetExpiresOn sets an expiration +func (s *SessionState) SetExpiresOn(exp time.Time) { + s.ExpiresOn = &exp +} + +// ExpiresIn sets an expiration a certain duration from CreatedAt. +// CreatedAt will be set to time.Now if it is unset. +func (s *SessionState) ExpiresIn(d time.Duration) { + if s.CreatedAt == nil { + s.CreatedAtNow() + } + exp := s.CreatedAt.Add(d) + s.ExpiresOn = &exp +} + // IsExpired checks whether the session has expired func (s *SessionState) IsExpired() bool { - if s.ExpiresOn != nil && !s.ExpiresOn.IsZero() && s.ExpiresOn.Before(time.Now()) { + if s.ExpiresOn != nil && !s.ExpiresOn.IsZero() && s.ExpiresOn.Before(s.Clock.Now()) { return true } return false @@ -74,7 +95,7 @@ func (s *SessionState) IsExpired() bool { // Age returns the age of a session func (s *SessionState) Age() time.Duration { if s.CreatedAt != nil && !s.CreatedAt.IsZero() { - return time.Now().Truncate(time.Second).Sub(*s.CreatedAt) + return s.Clock.Now().Truncate(time.Second).Sub(*s.CreatedAt) } return 0 } @@ -177,11 +198,6 @@ func DecodeSessionState(data []byte, c encryption.Cipher, compressed bool) (*Ses return nil, fmt.Errorf("error unmarshalling data to session state: %w", err) } - err = ss.validate() - if err != nil { - return nil, err - } - return &ss, nil } @@ -235,35 +251,3 @@ func lz4Decompress(compressed []byte) ([]byte, error) { return payload, nil } - -// validate ensures the decoded session is non-empty and contains valid data -// -// Non-empty check is needed due to ensure the non-authenticated AES-CFB -// decryption doesn't result in garbage data that collides with a valid -// MessagePack header bytes (which MessagePack will unmarshal to an empty -// default SessionState). <1% chance, but observed with random test data. -// -// UTF-8 check ensures the strings are valid and not raw bytes overloaded -// into Latin-1 encoding. The occurs when legacy unencrypted fields are -// decrypted with AES-CFB which results in random bytes. -func (s *SessionState) validate() error { - for _, field := range []string{ - s.User, - s.Email, - s.PreferredUsername, - s.AccessToken, - s.IDToken, - s.RefreshToken, - } { - if !utf8.ValidString(field) { - return errors.New("invalid non-UTF8 field in session") - } - } - - empty := new(SessionState) - if reflect.DeepEqual(*s, *empty) { - return errors.New("invalid empty session unmarshalled") - } - - return nil -} diff --git a/pkg/apis/sessions/session_state_test.go b/pkg/apis/sessions/session_state_test.go index 0121b4b6..e6b9ff39 100644 --- a/pkg/apis/sessions/session_state_test.go +++ b/pkg/apis/sessions/session_state_test.go @@ -16,6 +16,30 @@ func timePtr(t time.Time) *time.Time { return &t } +func TestCreatedAtNow(t *testing.T) { + g := NewWithT(t) + ss := &SessionState{} + + now := time.Unix(1234567890, 0) + ss.Clock.Set(now) + + ss.CreatedAtNow() + g.Expect(*ss.CreatedAt).To(Equal(now)) +} + +func TestExpiresIn(t *testing.T) { + g := NewWithT(t) + ss := &SessionState{} + + now := time.Unix(1234567890, 0) + ss.Clock.Set(now) + + ttl := time.Duration(743) * time.Second + ss.ExpiresIn(ttl) + + g.Expect(*ss.ExpiresOn).To(Equal(ss.CreatedAt.Add(ttl))) +} + func TestString(t *testing.T) { g := NewWithT(t) created, err := time.Parse(time.RFC3339, "2000-01-01T00:00:00Z") diff --git a/pkg/clock/clock.go b/pkg/clock/clock.go index 34b7bf23..887bf0aa 100644 --- a/pkg/clock/clock.go +++ b/pkg/clock/clock.go @@ -63,13 +63,10 @@ func Reset() *clockapi.Mock { // package. type Clock struct { mock *clockapi.Mock - sync.Mutex } // Set sets the Clock to a clock.Mock at the given time.Time func (c *Clock) Set(t time.Time) { - c.Lock() - defer c.Unlock() if c.mock == nil { c.mock = clockapi.NewMock() } @@ -79,8 +76,6 @@ func (c *Clock) Set(t time.Time) { // Add moves clock forward time.Duration if it is mocked. It will error // if the clock is not mocked. func (c *Clock) Add(d time.Duration) error { - c.Lock() - defer c.Unlock() if c.mock == nil { return errors.New("clock not mocked") } @@ -91,8 +86,6 @@ func (c *Clock) Add(d time.Duration) error { // Reset removes local clock.Mock. Returns any existing Mock if set in case // lingering time operations are attached to it. func (c *Clock) Reset() *clockapi.Mock { - c.Lock() - defer c.Unlock() existing := c.mock c.mock = nil return existing diff --git a/pkg/middleware/stored_session.go b/pkg/middleware/stored_session.go index 1bd0a9a4..6748816f 100644 --- a/pkg/middleware/stored_session.go +++ b/pkg/middleware/stored_session.go @@ -11,25 +11,26 @@ import ( middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" sessionsapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" + "github.com/oauth2-proxy/oauth2-proxy/v7/providers" ) -// StoredSessionLoaderOptions cotnains all of the requirements to construct +// StoredSessionLoaderOptions contains all of the requirements to construct // a stored session loader. // All options must be provided. type StoredSessionLoaderOptions struct { - // Session storage basckend + // Session storage backend SessionStore sessionsapi.SessionStore // How often should sessions be refreshed RefreshPeriod time.Duration - // Provider based sesssion refreshing - RefreshSessionIfNeeded func(context.Context, *sessionsapi.SessionState) (bool, error) + // Provider based session refreshing + RefreshSession func(context.Context, *sessionsapi.SessionState) (bool, error) // Provider based session validation. // If the sesssion is older than `RefreshPeriod` but the provider doesn't // refresh it, we must re-validate using this validation. - ValidateSessionState func(context.Context, *sessionsapi.SessionState) bool + ValidateSession func(context.Context, *sessionsapi.SessionState) bool } // NewStoredSessionLoader creates a new storedSessionLoader which loads @@ -38,10 +39,10 @@ type StoredSessionLoaderOptions struct { // If a session was loader by a previous handler, it will not be replaced. func NewStoredSessionLoader(opts *StoredSessionLoaderOptions) alice.Constructor { ss := &storedSessionLoader{ - store: opts.SessionStore, - refreshPeriod: opts.RefreshPeriod, - refreshSessionWithProviderIfNeeded: opts.RefreshSessionIfNeeded, - validateSessionState: opts.ValidateSessionState, + store: opts.SessionStore, + refreshPeriod: opts.RefreshPeriod, + sessionRefresher: opts.RefreshSession, + sessionValidator: opts.ValidateSession, } return ss.loadSession } @@ -49,10 +50,10 @@ func NewStoredSessionLoader(opts *StoredSessionLoaderOptions) alice.Constructor // storedSessionLoader is responsible for loading sessions from cookie // identified sessions in the session store. type storedSessionLoader struct { - store sessionsapi.SessionStore - refreshPeriod time.Duration - refreshSessionWithProviderIfNeeded func(context.Context, *sessionsapi.SessionState) (bool, error) - validateSessionState func(context.Context, *sessionsapi.SessionState) bool + store sessionsapi.SessionStore + refreshPeriod time.Duration + sessionRefresher func(context.Context, *sessionsapi.SessionState) (bool, error) + sessionValidator func(context.Context, *sessionsapi.SessionState) bool } // loadSession attempts to load a session as identified by the request cookies. @@ -108,49 +109,59 @@ func (s *storedSessionLoader) getValidatedSession(rw http.ResponseWriter, req *h // refreshSessionIfNeeded will attempt to refresh a session if the session // is older than the refresh period. -// It is assumed that if the provider refreshes the session, the session is now -// valid. -// If the session requires refreshing but the provider does not refresh it, -// we must validate the session to ensure that the returned session is still -// valid. +// Success or fail, we will then validate the session. func (s *storedSessionLoader) refreshSessionIfNeeded(rw http.ResponseWriter, req *http.Request, session *sessionsapi.SessionState) error { if s.refreshPeriod <= time.Duration(0) || session.Age() < s.refreshPeriod { // Refresh is disabled or the session is not old enough, do nothing return nil } - logger.Printf("Refreshing %s old session cookie for %s (refresh after %s)", session.Age(), session, s.refreshPeriod) - refreshed, err := s.refreshSessionWithProvider(rw, req, session) + logger.Printf("Refreshing session - User: %s; SessionAge: %s", session.User, session.Age()) + err := s.refreshSession(rw, req, session) if err != nil { - return err + // If a preemptive refresh fails, we still keep the session + // if validateSession succeeds. + logger.Errorf("Unable to refresh session: %v", err) } - if !refreshed { - // Session wasn't refreshed, so make sure it's still valid - return s.validateSession(req.Context(), session) - } - return nil + // Validate all sessions after any Redeem/Refresh operation (fail or success) + return s.validateSession(req.Context(), session) } -// refreshSessionWithProvider attempts to refresh the sessinon with the provider +// refreshSession attempts to refresh the session with the provider // and will save the session if it was updated. -func (s *storedSessionLoader) refreshSessionWithProvider(rw http.ResponseWriter, req *http.Request, session *sessionsapi.SessionState) (bool, error) { - refreshed, err := s.refreshSessionWithProviderIfNeeded(req.Context(), session) - if err != nil { - return false, fmt.Errorf("error refreshing access token: %v", err) +func (s *storedSessionLoader) refreshSession(rw http.ResponseWriter, req *http.Request, session *sessionsapi.SessionState) error { + refreshed, err := s.sessionRefresher(req.Context(), session) + if err != nil && !errors.Is(err, providers.ErrNotImplemented) { + return fmt.Errorf("error refreshing tokens: %v", err) } - if !refreshed { - return false, nil + // HACK: + // Providers that don't implement `RefreshSession` use the default + // implementation which returns `ErrNotImplemented`. + // Pretend it refreshed to reset the refresh timer so that `ValidateSession` + // isn't triggered every subsequent request and is only called once during + // this request. + if errors.Is(err, providers.ErrNotImplemented) { + refreshed = true } + // Session not refreshed, nothing to persist. + if !refreshed { + return nil + } + + // If we refreshed, update the `CreatedAt` time to reset the refresh timer + // (In case underlying provider implementations forget) + session.CreatedAtNow() + // Because the session was refreshed, make sure to save it err = s.store.Save(rw, req, session) if err != nil { logger.PrintAuthf(session.Email, req, logger.AuthError, "error saving session: %v", err) - return false, fmt.Errorf("error saving session: %v", err) + return fmt.Errorf("error saving session: %v", err) } - return true, nil + return nil } // validateSession checks whether the session has expired and performs @@ -161,7 +172,7 @@ func (s *storedSessionLoader) validateSession(ctx context.Context, session *sess return errors.New("session is expired") } - if !s.validateSessionState(ctx, session) { + if !s.sessionValidator(ctx, session) { return errors.New("session is invalid") } diff --git a/pkg/middleware/stored_session_test.go b/pkg/middleware/stored_session_test.go index 3d8dd087..782390b6 100644 --- a/pkg/middleware/stored_session_test.go +++ b/pkg/middleware/stored_session_test.go @@ -10,6 +10,8 @@ import ( middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" sessionsapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" + "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/clock" + "github.com/oauth2-proxy/oauth2-proxy/v7/providers" . "github.com/onsi/ginkgo" . "github.com/onsi/ginkgo/extensions/table" . "github.com/onsi/gomega" @@ -17,15 +19,17 @@ import ( var _ = Describe("Stored Session Suite", func() { const ( - refresh = "Refresh" - noRefresh = "NoRefresh" + refresh = "Refresh" + noRefresh = "NoRefresh" + notImplemented = "NotImplemented" ) var ctx = context.Background() Context("StoredSessionLoader", func() { - createdPast := time.Now().Add(-5 * time.Minute) - createdFuture := time.Now().Add(5 * time.Minute) + now := time.Now() + createdPast := now.Add(-5 * time.Minute) + createdFuture := now.Add(5 * time.Minute) var defaultRefreshFunc = func(_ context.Context, ss *sessionsapi.SessionState) (bool, error) { switch ss.RefreshToken { @@ -85,6 +89,14 @@ var _ = Describe("Stored Session Suite", func() { }, } + BeforeEach(func() { + clock.Set(now) + }) + + AfterEach(func() { + clock.Reset() + }) + type storedSessionLoaderTableInput struct { requestHeaders http.Header existingSession *sessionsapi.SessionState @@ -109,10 +121,10 @@ var _ = Describe("Stored Session Suite", func() { rw := httptest.NewRecorder() opts := &StoredSessionLoaderOptions{ - SessionStore: in.store, - RefreshPeriod: in.refreshPeriod, - RefreshSessionIfNeeded: in.refreshSession, - ValidateSessionState: in.validateSession, + SessionStore: in.store, + RefreshPeriod: in.refreshPeriod, + RefreshSession: in.refreshSession, + ValidateSession: in.validateSession, } // Create the handler with a next handler that will capture the session @@ -208,6 +220,21 @@ var _ = Describe("Stored Session Suite", func() { existingSession: nil, expectedSession: &sessionsapi.SessionState{ RefreshToken: "Refreshed", + CreatedAt: &now, + ExpiresOn: &createdFuture, + }, + store: defaultSessionStore, + refreshPeriod: 1 * time.Minute, + refreshSession: defaultRefreshFunc, + validateSession: defaultValidateFunc, + }), + Entry("when the provider refresh fails but validation succeeds", storedSessionLoaderTableInput{ + requestHeaders: http.Header{ + "Cookie": []string{"_oauth2_proxy=RefreshError"}, + }, + existingSession: nil, + expectedSession: &sessionsapi.SessionState{ + RefreshToken: "RefreshError", CreatedAt: &createdPast, ExpiresOn: &createdFuture, }, @@ -216,7 +243,7 @@ var _ = Describe("Stored Session Suite", func() { refreshSession: defaultRefreshFunc, validateSession: defaultValidateFunc, }), - Entry("when the provider refresh fails", storedSessionLoaderTableInput{ + Entry("when the provider refresh fails and validation fails", storedSessionLoaderTableInput{ requestHeaders: http.Header{ "Cookie": []string{"_oauth2_proxy=RefreshError"}, }, @@ -225,7 +252,7 @@ var _ = Describe("Stored Session Suite", func() { store: defaultSessionStore, refreshPeriod: 1 * time.Minute, refreshSession: defaultRefreshFunc, - validateSession: defaultValidateFunc, + validateSession: func(context.Context, *sessionsapi.SessionState) bool { return false }, }), Entry("when the session is not refreshed and is no longer valid", storedSessionLoaderTableInput{ requestHeaders: http.Header{ @@ -261,18 +288,20 @@ var _ = Describe("Stored Session Suite", func() { s := &storedSessionLoader{ refreshPeriod: in.refreshPeriod, store: &fakeSessionStore{}, - refreshSessionWithProviderIfNeeded: func(_ context.Context, ss *sessionsapi.SessionState) (bool, error) { + sessionRefresher: func(_ context.Context, ss *sessionsapi.SessionState) (bool, error) { refreshed = true switch ss.RefreshToken { case refresh: return true, nil case noRefresh: return false, nil + case notImplemented: + return false, providers.ErrNotImplemented default: return false, errors.New("error refreshing session") } }, - validateSessionState: func(_ context.Context, ss *sessionsapi.SessionState) bool { + sessionValidator: func(_ context.Context, ss *sessionsapi.SessionState) bool { validated = true return ss.AccessToken != "Invalid" }, @@ -326,7 +355,7 @@ var _ = Describe("Stored Session Suite", func() { }, expectedErr: nil, expectRefreshed: true, - expectValidated: false, + expectValidated: true, }), Entry("when the session is not refreshed by the provider", refreshSessionIfNeededTableInput{ refreshPeriod: 1 * time.Minute, @@ -339,15 +368,25 @@ var _ = Describe("Stored Session Suite", func() { expectRefreshed: true, expectValidated: true, }), - Entry("when the provider refresh fails", refreshSessionIfNeededTableInput{ + Entry("when the provider doesn't implement refresh but validation succeeds", refreshSessionIfNeededTableInput{ + refreshPeriod: 1 * time.Minute, + session: &sessionsapi.SessionState{ + RefreshToken: notImplemented, + CreatedAt: &createdPast, + }, + expectedErr: nil, + expectRefreshed: true, + expectValidated: true, + }), + Entry("when the provider refresh fails but validation succeeds", refreshSessionIfNeededTableInput{ refreshPeriod: 1 * time.Minute, session: &sessionsapi.SessionState{ RefreshToken: "RefreshError", CreatedAt: &createdPast, }, - expectedErr: errors.New("error refreshing access token: error refreshing session"), + expectedErr: nil, expectRefreshed: true, - expectValidated: false, + expectValidated: true, }), Entry("when the session is not refreshed by the provider and validation fails", refreshSessionIfNeededTableInput{ refreshPeriod: 1 * time.Minute, @@ -364,12 +403,11 @@ var _ = Describe("Stored Session Suite", func() { ) }) - Context("refreshSessionWithProvider", func() { + Context("refreshSession", func() { type refreshSessionWithProviderTableInput struct { - session *sessionsapi.SessionState - expectedErr error - expectRefreshed bool - expectSaved bool + session *sessionsapi.SessionState + expectedErr error + expectSaved bool } now := time.Now() @@ -388,12 +426,14 @@ var _ = Describe("Stored Session Suite", func() { return nil }, }, - refreshSessionWithProviderIfNeeded: func(_ context.Context, ss *sessionsapi.SessionState) (bool, error) { + sessionRefresher: func(_ context.Context, ss *sessionsapi.SessionState) (bool, error) { switch ss.RefreshToken { case refresh: return true, nil case noRefresh: return false, nil + case notImplemented: + return false, providers.ErrNotImplemented default: return false, errors.New("error refreshing session") } @@ -402,30 +442,34 @@ var _ = Describe("Stored Session Suite", func() { req := httptest.NewRequest("", "/", nil) req = middlewareapi.AddRequestScope(req, &middlewareapi.RequestScope{}) - refreshed, err := s.refreshSessionWithProvider(nil, req, in.session) + err := s.refreshSession(nil, req, in.session) if in.expectedErr != nil { Expect(err).To(MatchError(in.expectedErr)) } else { Expect(err).ToNot(HaveOccurred()) } - Expect(refreshed).To(Equal(in.expectRefreshed)) Expect(saved).To(Equal(in.expectSaved)) }, Entry("when the provider does not refresh the session", refreshSessionWithProviderTableInput{ session: &sessionsapi.SessionState{ RefreshToken: noRefresh, }, - expectedErr: nil, - expectRefreshed: false, - expectSaved: false, + expectedErr: nil, + expectSaved: false, }), Entry("when the provider refreshes the session", refreshSessionWithProviderTableInput{ session: &sessionsapi.SessionState{ RefreshToken: refresh, }, - expectedErr: nil, - expectRefreshed: true, - expectSaved: true, + expectedErr: nil, + expectSaved: true, + }), + Entry("when the provider doesn't implement refresh", refreshSessionWithProviderTableInput{ + session: &sessionsapi.SessionState{ + RefreshToken: notImplemented, + }, + expectedErr: nil, + expectSaved: true, }), Entry("when the provider returns an error", refreshSessionWithProviderTableInput{ session: &sessionsapi.SessionState{ @@ -433,18 +477,16 @@ var _ = Describe("Stored Session Suite", func() { CreatedAt: &now, ExpiresOn: &now, }, - expectedErr: errors.New("error refreshing access token: error refreshing session"), - expectRefreshed: false, - expectSaved: false, + expectedErr: errors.New("error refreshing tokens: error refreshing session"), + expectSaved: false, }), Entry("when the saving the session returns an error", refreshSessionWithProviderTableInput{ session: &sessionsapi.SessionState{ RefreshToken: refresh, AccessToken: "NoSave", }, - expectedErr: errors.New("error saving session: unable to save session"), - expectRefreshed: false, - expectSaved: true, + expectedErr: errors.New("error saving session: unable to save session"), + expectSaved: true, }), ) }) @@ -454,7 +496,7 @@ var _ = Describe("Stored Session Suite", func() { BeforeEach(func() { s = &storedSessionLoader{ - validateSessionState: func(_ context.Context, ss *sessionsapi.SessionState) bool { + sessionValidator: func(_ context.Context, ss *sessionsapi.SessionState) bool { return ss.AccessToken == "Valid" }, } diff --git a/pkg/sessions/cookie/session_store.go b/pkg/sessions/cookie/session_store.go index ce51ed07..1b3c12de 100644 --- a/pkg/sessions/cookie/session_store.go +++ b/pkg/sessions/cookie/session_store.go @@ -36,8 +36,7 @@ type SessionStore struct { // within Cookies set on the HTTP response writer func (s *SessionStore) Save(rw http.ResponseWriter, req *http.Request, ss *sessions.SessionState) error { if ss.CreatedAt == nil || ss.CreatedAt.IsZero() { - now := time.Now() - ss.CreatedAt = &now + ss.CreatedAtNow() } value, err := s.cookieForSession(ss) if err != nil { diff --git a/pkg/sessions/persistence/manager.go b/pkg/sessions/persistence/manager.go index 49225171..3215b257 100644 --- a/pkg/sessions/persistence/manager.go +++ b/pkg/sessions/persistence/manager.go @@ -30,8 +30,7 @@ func NewManager(store Store, cookieOpts *options.Cookie) *Manager { // from the persistent data store. func (m *Manager) Save(rw http.ResponseWriter, req *http.Request, s *sessions.SessionState) error { if s.CreatedAt == nil || s.CreatedAt.IsZero() { - now := time.Now() - s.CreatedAt = &now + s.CreatedAtNow() } tckt, err := decodeTicketFromRequest(req, m.Options) diff --git a/providers/azure.go b/providers/azure.go index f66d3764..39beb836 100644 --- a/providers/azure.go +++ b/providers/azure.go @@ -142,16 +142,13 @@ func (p *AzureProvider) Redeem(ctx context.Context, redirectURL, code string) (* return nil, err } - created := time.Now() - expires := time.Unix(jsonResponse.ExpiresOn, 0) - session := &sessions.SessionState{ AccessToken: jsonResponse.AccessToken, IDToken: jsonResponse.IDToken, - CreatedAt: &created, - ExpiresOn: &expires, RefreshToken: jsonResponse.RefreshToken, } + session.CreatedAtNow() + session.SetExpiresOn(time.Unix(jsonResponse.ExpiresOn, 0)) email, err := p.verifyTokenAndExtractEmail(ctx, session.IDToken) @@ -239,28 +236,29 @@ func (p *AzureProvider) verifyTokenAndExtractEmail(ctx context.Context, token st return email, nil } -// RefreshSessionIfNeeded checks if the session has expired and uses the -// RefreshToken to fetch a new ID token if required -func (p *AzureProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error) { - if s == nil || s.ExpiresOn.After(time.Now()) || s.RefreshToken == "" { +// RefreshSession uses the RefreshToken to fetch new Access and ID Tokens +func (p *AzureProvider) RefreshSession(ctx context.Context, s *sessions.SessionState) (bool, error) { + if s == nil || s.RefreshToken == "" { return false, nil } - origExpiration := s.ExpiresOn - err := p.redeemRefreshToken(ctx, s) if err != nil { return false, fmt.Errorf("unable to redeem refresh token: %v", err) } - logger.Printf("refreshed id token %s (expired on %s)\n", s, origExpiration) return true, nil } -func (p *AzureProvider) redeemRefreshToken(ctx context.Context, s *sessions.SessionState) (err error) { +func (p *AzureProvider) redeemRefreshToken(ctx context.Context, s *sessions.SessionState) error { + clientSecret, err := p.GetClientSecret() + if err != nil { + return err + } + params := url.Values{} params.Add("client_id", p.ClientID) - params.Add("client_secret", p.ClientSecret) + params.Add("client_secret", clientSecret) params.Add("refresh_token", s.RefreshToken) params.Add("grant_type", "refresh_token") @@ -278,18 +276,16 @@ func (p *AzureProvider) redeemRefreshToken(ctx context.Context, s *sessions.Sess SetHeader("Content-Type", "application/x-www-form-urlencoded"). Do(). UnmarshalInto(&jsonResponse) - if err != nil { - return + return err } - now := time.Now() - expires := time.Unix(jsonResponse.ExpiresOn, 0) s.AccessToken = jsonResponse.AccessToken s.IDToken = jsonResponse.IDToken s.RefreshToken = jsonResponse.RefreshToken - s.CreatedAt = &now - s.ExpiresOn = &expires + + s.CreatedAtNow() + s.SetExpiresOn(time.Unix(jsonResponse.ExpiresOn, 0)) email, err := p.verifyTokenAndExtractEmail(ctx, s.IDToken) @@ -312,7 +308,7 @@ func (p *AzureProvider) redeemRefreshToken(ctx context.Context, s *sessions.Sess } } - return + return nil } func makeAzureHeader(accessToken string) http.Header { diff --git a/providers/azure_test.go b/providers/azure_test.go index bb44f20a..78592df7 100644 --- a/providers/azure_test.go +++ b/providers/azure_test.go @@ -340,17 +340,7 @@ func TestAzureProviderProtectedResourceConfigured(t *testing.T) { assert.Contains(t, result, "resource="+url.QueryEscape("http://my.resource.test")) } -func TestAzureProviderNotRefreshWhenNotExpired(t *testing.T) { - p := testAzureProvider("") - - expires := time.Now().Add(time.Duration(1) * time.Hour) - session := &sessions.SessionState{AccessToken: "some_access_token", RefreshToken: "some_refresh_token", IDToken: "some_id_token", ExpiresOn: &expires} - refreshNeeded, err := p.RefreshSessionIfNeeded(context.Background(), session) - assert.Equal(t, nil, err) - assert.False(t, refreshNeeded) -} - -func TestAzureProviderRefreshWhenExpired(t *testing.T) { +func TestAzureProviderRefresh(t *testing.T) { email := "foo@example.com" idToken := idTokenClaims{Email: email} idTokenString, err := newSignedTestIDToken(idToken) @@ -373,9 +363,10 @@ func TestAzureProviderRefreshWhenExpired(t *testing.T) { expires := time.Now().Add(time.Duration(-1) * time.Hour) session := &sessions.SessionState{AccessToken: "some_access_token", RefreshToken: "some_refresh_token", IDToken: "some_id_token", ExpiresOn: &expires} - refreshNeeded, err := p.RefreshSessionIfNeeded(context.Background(), session) + + refreshed, err := p.RefreshSession(context.Background(), session) assert.Equal(t, nil, err) - assert.True(t, refreshNeeded) + assert.True(t, refreshed) assert.NotEqual(t, session, nil) assert.Equal(t, "new_some_access_token", session.AccessToken) assert.Equal(t, "new_some_refresh_token", session.RefreshToken) diff --git a/providers/facebook.go b/providers/facebook.go index e3babc0d..6db9c38d 100644 --- a/providers/facebook.go +++ b/providers/facebook.go @@ -88,7 +88,7 @@ func (p *FacebookProvider) GetEmailAddress(ctx context.Context, s *sessions.Sess return r.Email, nil } -// ValidateSessionState validates the AccessToken +// ValidateSession validates the AccessToken func (p *FacebookProvider) ValidateSession(ctx context.Context, s *sessions.SessionState) bool { return validateToken(ctx, p, s.AccessToken, makeOIDCHeader(s.AccessToken)) } diff --git a/providers/gitlab.go b/providers/gitlab.go index 18f77fe7..a2b11df7 100644 --- a/providers/gitlab.go +++ b/providers/gitlab.go @@ -121,10 +121,9 @@ func (p *GitLabProvider) SetProjectScope() { } } -// RefreshSessionIfNeeded checks if the session has expired and uses the -// RefreshToken to fetch a new ID token if required -func (p *GitLabProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error) { - if s == nil || (s.ExpiresOn != nil && s.ExpiresOn.After(time.Now())) || s.RefreshToken == "" { +// RefreshSession uses the RefreshToken to fetch new Access and ID Tokens +func (p *GitLabProvider) RefreshSession(ctx context.Context, s *sessions.SessionState) (bool, error) { + if s == nil || s.RefreshToken == "" { return false, nil } @@ -139,10 +138,10 @@ func (p *GitLabProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions return true, nil } -func (p *GitLabProvider) redeemRefreshToken(ctx context.Context, s *sessions.SessionState) (err error) { +func (p *GitLabProvider) redeemRefreshToken(ctx context.Context, s *sessions.SessionState) error { clientSecret, err := p.GetClientSecret() if err != nil { - return + return err } c := oauth2.Config{ @@ -164,13 +163,9 @@ func (p *GitLabProvider) redeemRefreshToken(ctx context.Context, s *sessions.Ses if err != nil { return fmt.Errorf("unable to update session: %v", err) } - s.AccessToken = newSession.AccessToken - s.IDToken = newSession.IDToken - s.RefreshToken = newSession.RefreshToken - s.CreatedAt = newSession.CreatedAt - s.ExpiresOn = newSession.ExpiresOn - s.Email = newSession.Email - return + *s = *newSession + + return nil } type gitlabUserInfo struct { @@ -264,14 +259,16 @@ func (p *GitLabProvider) createSession(ctx context.Context, token *oauth2.Token) } } - created := time.Now() - return &sessions.SessionState{ + ss := &sessions.SessionState{ AccessToken: token.AccessToken, IDToken: getIDToken(token), RefreshToken: token.RefreshToken, - CreatedAt: &created, - ExpiresOn: &idToken.Expiry, - }, nil + } + + ss.CreatedAtNow() + ss.SetExpiresOn(idToken.Expiry) + + return ss, nil } // ValidateSession checks that the session's IDToken is still valid diff --git a/providers/google.go b/providers/google.go index b669156d..a467c50f 100644 --- a/providers/google.go +++ b/providers/google.go @@ -163,23 +163,22 @@ func (p *GoogleProvider) Redeem(ctx context.Context, redirectURL, code string) ( return nil, err } - created := time.Now() - expires := time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second) - - return &sessions.SessionState{ + ss := &sessions.SessionState{ AccessToken: jsonResponse.AccessToken, IDToken: jsonResponse.IDToken, - CreatedAt: &created, - ExpiresOn: &expires, RefreshToken: jsonResponse.RefreshToken, Email: c.Email, User: c.Subject, - }, nil + } + ss.CreatedAtNow() + ss.ExpiresIn(time.Duration(jsonResponse.ExpiresIn) * time.Second) + + return ss, nil } // EnrichSession checks the listed Google Groups configured and adds any // that the user is a member of to session.Groups. -func (p *GoogleProvider) EnrichSession(ctx context.Context, s *sessions.SessionState) error { +func (p *GoogleProvider) EnrichSession(_ context.Context, s *sessions.SessionState) error { // TODO (@NickMeves) - Move to pure EnrichSession logic and stop // reusing legacy `groupValidator`. // @@ -266,14 +265,13 @@ func userInGroup(service *admin.Service, group string, email string) bool { return false } -// RefreshSessionIfNeeded checks if the session has expired and uses the -// RefreshToken to fetch a new ID token if required -func (p *GoogleProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error) { - if s == nil || (s.ExpiresOn != nil && s.ExpiresOn.After(time.Now())) || s.RefreshToken == "" { +// RefreshSession uses the RefreshToken to fetch new Access and ID Tokens +func (p *GoogleProvider) RefreshSession(ctx context.Context, s *sessions.SessionState) (bool, error) { + if s == nil || s.RefreshToken == "" { return false, nil } - newToken, newIDToken, duration, err := p.redeemRefreshToken(ctx, s.RefreshToken) + err := p.redeemRefreshToken(ctx, s) if err != nil { return false, err } @@ -286,26 +284,20 @@ func (p *GoogleProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions return false, fmt.Errorf("%s is no longer in the group(s)", s.Email) } - origExpiration := s.ExpiresOn - expires := time.Now().Add(duration).Truncate(time.Second) - s.AccessToken = newToken - s.IDToken = newIDToken - s.ExpiresOn = &expires - logger.Printf("refreshed access token %s (expired on %s)", s, origExpiration) return true, nil } -func (p *GoogleProvider) redeemRefreshToken(ctx context.Context, refreshToken string) (token string, idToken string, expires time.Duration, err error) { +func (p *GoogleProvider) redeemRefreshToken(ctx context.Context, s *sessions.SessionState) error { // https://developers.google.com/identity/protocols/OAuth2WebServer#refresh clientSecret, err := p.GetClientSecret() if err != nil { - return + return err } params := url.Values{} params.Add("client_id", p.ClientID) params.Add("client_secret", clientSecret) - params.Add("refresh_token", refreshToken) + params.Add("refresh_token", s.RefreshToken) params.Add("grant_type", "refresh_token") var data struct { @@ -322,11 +314,14 @@ func (p *GoogleProvider) redeemRefreshToken(ctx context.Context, refreshToken st Do(). UnmarshalInto(&data) if err != nil { - return "", "", 0, err + return err } - token = data.AccessToken - idToken = data.IDToken - expires = time.Duration(data.ExpiresIn) * time.Second - return + s.AccessToken = data.AccessToken + s.IDToken = data.IDToken + + s.CreatedAtNow() + s.ExpiresIn(time.Duration(data.ExpiresIn) * time.Second) + + return nil } diff --git a/providers/linkedin.go b/providers/linkedin.go index 58217952..115d4c99 100644 --- a/providers/linkedin.go +++ b/providers/linkedin.go @@ -93,7 +93,7 @@ func (p *LinkedInProvider) GetEmailAddress(ctx context.Context, s *sessions.Sess return email, nil } -// ValidateSessionState validates the AccessToken +// ValidateSession validates the AccessToken func (p *LinkedInProvider) ValidateSession(ctx context.Context, s *sessions.SessionState) bool { return validateToken(ctx, p, s.AccessToken, makeLinkedInHeader(s.AccessToken)) } diff --git a/providers/logingov.go b/providers/logingov.go index 0f625208..43f361f3 100644 --- a/providers/logingov.go +++ b/providers/logingov.go @@ -159,7 +159,7 @@ func emailFromUserInfo(ctx context.Context, accessToken string, userInfoEndpoint } // Redeem exchanges the OAuth2 authentication token for an ID token -func (p *LoginGovProvider) Redeem(ctx context.Context, redirectURL, code string) (*sessions.SessionState, error) { +func (p *LoginGovProvider) Redeem(ctx context.Context, _, code string) (*sessions.SessionState, error) { if code == "" { return nil, ErrMissingCode } @@ -214,17 +214,16 @@ func (p *LoginGovProvider) Redeem(ctx context.Context, redirectURL, code string) return nil, err } - created := time.Now() - expires := time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second) - - // Store the data that we found in the session state - return &sessions.SessionState{ + session := &sessions.SessionState{ AccessToken: jsonResponse.AccessToken, IDToken: jsonResponse.IDToken, - CreatedAt: &created, - ExpiresOn: &expires, Email: email, - }, nil + } + + session.CreatedAtNow() + session.ExpiresIn(time.Duration(jsonResponse.ExpiresIn) * time.Second) + + return session, nil } // GetLoginURL overrides GetLoginURL to add login.gov parameters diff --git a/providers/oidc.go b/providers/oidc.go index 9e7bf56f..b1711d54 100644 --- a/providers/oidc.go +++ b/providers/oidc.go @@ -143,10 +143,9 @@ func (p *OIDCProvider) ValidateSession(ctx context.Context, s *sessions.SessionS return true } -// RefreshSessionIfNeeded checks if the session has expired and uses the -// RefreshToken to fetch a new Access Token (and optional ID token) if required -func (p *OIDCProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error) { - if s == nil || (s.ExpiresOn != nil && s.ExpiresOn.After(time.Now())) || s.RefreshToken == "" { +// RefreshSession uses the RefreshToken to fetch new Access and ID Tokens +func (p *OIDCProvider) RefreshSession(ctx context.Context, s *sessions.SessionState) (bool, error) { + if s == nil || s.RefreshToken == "" { return false, nil } @@ -155,7 +154,6 @@ func (p *OIDCProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions.S return false, fmt.Errorf("unable to redeem refresh token: %v", err) } - logger.Printf("refreshed session: %s", s) return true, nil } @@ -227,7 +225,9 @@ func (p *OIDCProvider) CreateSessionFromToken(ctx context.Context, token string) ss.AccessToken = token ss.IDToken = token ss.RefreshToken = "" - ss.ExpiresOn = &idToken.Expiry + + ss.CreatedAtNow() + ss.SetExpiresOn(idToken.Expiry) return ss, nil } @@ -257,9 +257,8 @@ func (p *OIDCProvider) createSession(ctx context.Context, token *oauth2.Token, r ss.RefreshToken = token.RefreshToken ss.IDToken = getIDToken(token) - created := time.Now() - ss.CreatedAt = &created - ss.ExpiresOn = &token.Expiry + ss.CreatedAtNow() + ss.SetExpiresOn(token.Expiry) return ss, nil } diff --git a/providers/oidc_test.go b/providers/oidc_test.go index 7fae3368..9879678d 100644 --- a/providers/oidc_test.go +++ b/providers/oidc_test.go @@ -487,7 +487,7 @@ func TestOIDCProviderRefreshSessionIfNeededWithoutIdToken(t *testing.T) { User: "11223344", } - refreshed, err := provider.RefreshSessionIfNeeded(context.Background(), existingSession) + refreshed, err := provider.RefreshSession(context.Background(), existingSession) assert.Equal(t, nil, err) assert.Equal(t, refreshed, true) assert.Equal(t, "janedoe@example.com", existingSession.Email) @@ -520,7 +520,7 @@ func TestOIDCProviderRefreshSessionIfNeededWithIdToken(t *testing.T) { Email: "changeit", User: "changeit", } - refreshed, err := provider.RefreshSessionIfNeeded(context.Background(), existingSession) + refreshed, err := provider.RefreshSession(context.Background(), existingSession) assert.Equal(t, nil, err) assert.Equal(t, refreshed, true) assert.Equal(t, defaultIDToken.Email, existingSession.Email) diff --git a/providers/provider_default.go b/providers/provider_default.go index 1bde54b7..7a641b1e 100644 --- a/providers/provider_default.go +++ b/providers/provider_default.go @@ -6,7 +6,6 @@ import ( "errors" "fmt" "net/url" - "time" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" @@ -85,9 +84,13 @@ func (p *ProviderData) Redeem(ctx context.Context, redirectURL, code string) (*s if err != nil { return nil, err } + // TODO (@NickMeves): Uses OAuth `expires_in` to set an expiration if token := values.Get("access_token"); token != "" { - created := time.Now() - return &sessions.SessionState{AccessToken: token, CreatedAt: &created}, nil + ss := &sessions.SessionState{ + AccessToken: token, + } + ss.CreatedAtNow() + return ss, nil } return nil, fmt.Errorf("no access token found %s", result.Body()) @@ -126,10 +129,9 @@ func (p *ProviderData) ValidateSession(ctx context.Context, s *sessions.SessionS return validateToken(ctx, p, s.AccessToken, nil) } -// RefreshSessionIfNeeded should refresh the user's session if required and -// do nothing if a refresh is not required -func (p *ProviderData) RefreshSessionIfNeeded(_ context.Context, _ *sessions.SessionState) (bool, error) { - return false, nil +// RefreshSession refreshes the user's session +func (p *ProviderData) RefreshSession(_ context.Context, _ *sessions.SessionState) (bool, error) { + return false, ErrNotImplemented } // CreateSessionFromToken converts Bearer IDTokens into sessions diff --git a/providers/provider_default_test.go b/providers/provider_default_test.go index 0bd2f4f0..5d4ed1af 100644 --- a/providers/provider_default_test.go +++ b/providers/provider_default_test.go @@ -14,12 +14,20 @@ import ( func TestRefresh(t *testing.T) { p := &ProviderData{} - expires := time.Now().Add(time.Duration(-11) * time.Minute) - refreshed, err := p.RefreshSessionIfNeeded(context.Background(), &sessions.SessionState{ - ExpiresOn: &expires, - }) - assert.Equal(t, false, refreshed) - assert.Equal(t, nil, err) + now := time.Unix(1234567890, 10) + expires := time.Unix(1234567890, 0) + + ss := &sessions.SessionState{} + ss.Clock.Set(now) + ss.SetExpiresOn(expires) + + refreshed, err := p.RefreshSession(context.Background(), ss) + assert.False(t, refreshed) + assert.Equal(t, ErrNotImplemented, err) + + refreshed, err = p.RefreshSession(context.Background(), nil) + assert.False(t, refreshed) + assert.Equal(t, ErrNotImplemented, err) } func TestAcrValuesNotConfigured(t *testing.T) { diff --git a/providers/providers.go b/providers/providers.go index 0340c420..d21409c2 100644 --- a/providers/providers.go +++ b/providers/providers.go @@ -9,14 +9,14 @@ import ( // Provider represents an upstream identity provider implementation type Provider interface { Data() *ProviderData + GetLoginURL(redirectURI, finalRedirect string, nonce string) string + Redeem(ctx context.Context, redirectURI, code string) (*sessions.SessionState, error) // Deprecated: Migrate to EnrichSession GetEmailAddress(ctx context.Context, s *sessions.SessionState) (string, error) - GetLoginURL(redirectURI, state, nonce string) string - Redeem(ctx context.Context, redirectURI, code string) (*sessions.SessionState, error) EnrichSession(ctx context.Context, s *sessions.SessionState) error Authorize(ctx context.Context, s *sessions.SessionState) (bool, error) ValidateSession(ctx context.Context, s *sessions.SessionState) bool - RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error) + RefreshSession(ctx context.Context, s *sessions.SessionState) (bool, error) CreateSessionFromToken(ctx context.Context, token string) (*sessions.SessionState, error) }