diff --git a/oauthproxy.go b/oauthproxy.go index c3a5693d..e2d20ed6 100644 --- a/oauthproxy.go +++ b/oauthproxy.go @@ -870,9 +870,9 @@ func (p *OAuthProxy) Proxy(rw http.ResponseWriter, req *http.Request) { // See https://developers.google.com/web/fundamentals/performance/optimizing-content-efficiency/http-caching?hl=en var noCacheHeaders = map[string]string{ - "Expires": time.Unix(0, 0).Format(time.RFC1123), - "Cache-Control": "no-cache, no-store, must-revalidate, max-age=0", - "X-Accel-Expire": "0", // https://www.nginx.com/resources/wiki/start/topics/examples/x-accel/ + "Expires": time.Unix(0, 0).Format(time.RFC1123), + "Cache-Control": "no-cache, no-store, must-revalidate, max-age=0", + "X-Accel-Expires": "0", // https://www.nginx.com/resources/wiki/start/topics/examples/x-accel/ } // prepareNoCache prepares headers for preventing browser caching. diff --git a/pkg/apis/sessions/session_state_test.go b/pkg/apis/sessions/session_state_test.go index 0121b4b6..e6b9ff39 100644 --- a/pkg/apis/sessions/session_state_test.go +++ b/pkg/apis/sessions/session_state_test.go @@ -16,6 +16,30 @@ func timePtr(t time.Time) *time.Time { return &t } +func TestCreatedAtNow(t *testing.T) { + g := NewWithT(t) + ss := &SessionState{} + + now := time.Unix(1234567890, 0) + ss.Clock.Set(now) + + ss.CreatedAtNow() + g.Expect(*ss.CreatedAt).To(Equal(now)) +} + +func TestExpiresIn(t *testing.T) { + g := NewWithT(t) + ss := &SessionState{} + + now := time.Unix(1234567890, 0) + ss.Clock.Set(now) + + ttl := time.Duration(743) * time.Second + ss.ExpiresIn(ttl) + + g.Expect(*ss.ExpiresOn).To(Equal(ss.CreatedAt.Add(ttl))) +} + func TestString(t *testing.T) { g := NewWithT(t) created, err := time.Parse(time.RFC3339, "2000-01-01T00:00:00Z") diff --git a/pkg/middleware/stored_session.go b/pkg/middleware/stored_session.go index 9f69ba64..85974867 100644 --- a/pkg/middleware/stored_session.go +++ b/pkg/middleware/stored_session.go @@ -108,11 +108,7 @@ func (s *storedSessionLoader) getValidatedSession(rw http.ResponseWriter, req *h // refreshSessionIfNeeded will attempt to refresh a session if the session // is older than the refresh period. -// It is assumed that if the provider refreshes the session, the session is now -// valid. -// If the session requires refreshing but the provider does not refresh it, -// we must validate the session to ensure that the returned session is still -// valid. +// Success or fail, we will then validate the session. func (s *storedSessionLoader) refreshSessionIfNeeded(rw http.ResponseWriter, req *http.Request, session *sessionsapi.SessionState) error { if s.refreshPeriod <= time.Duration(0) || session.Age() < s.refreshPeriod { // Refresh is disabled or the session is not old enough, do nothing @@ -122,10 +118,12 @@ func (s *storedSessionLoader) refreshSessionIfNeeded(rw http.ResponseWriter, req logger.Printf("Refreshing %s old session cookie for %s (refresh after %s)", session.Age(), session, s.refreshPeriod) err := s.refreshSession(rw, req, session) if err != nil { - return err + // If a preemptive refresh fails, we still keep the session + // if validateSession succeeds. + logger.Errorf("Unable to refresh session: %v", err) } - // Validate all sessions after any Redeem/Refresh operation + // Validate all sessions after any Redeem/Refresh operation (fail or success) return s.validateSession(req.Context(), session) } @@ -134,7 +132,7 @@ func (s *storedSessionLoader) refreshSessionIfNeeded(rw http.ResponseWriter, req func (s *storedSessionLoader) refreshSession(rw http.ResponseWriter, req *http.Request, session *sessionsapi.SessionState) error { refreshed, err := s.sessionRefresher(req.Context(), session) if err != nil { - return fmt.Errorf("error refreshing access token: %v", err) + return fmt.Errorf("error refreshing tokens: %v", err) } if !refreshed { @@ -142,6 +140,12 @@ func (s *storedSessionLoader) refreshSession(rw http.ResponseWriter, req *http.R } // If we refreshed, update the `CreatedAt` time to reset the refresh timer + // + // HACK: + // Providers that don't implement `RefreshSession` use the default + // implementation. It always returns `refreshed == true`, so the + // `session.CreatedAt` is updated and doesn't trigger `ValidateSession` + // every subsequent request. session.CreatedAtNow() // Because the session was refreshed, make sure to save it diff --git a/pkg/middleware/stored_session_test.go b/pkg/middleware/stored_session_test.go index 2ec134c9..9c9a4b92 100644 --- a/pkg/middleware/stored_session_test.go +++ b/pkg/middleware/stored_session_test.go @@ -10,6 +10,7 @@ import ( middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" sessionsapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" + "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/clock" . "github.com/onsi/ginkgo" . "github.com/onsi/ginkgo/extensions/table" . "github.com/onsi/gomega" @@ -24,8 +25,9 @@ var _ = Describe("Stored Session Suite", func() { var ctx = context.Background() Context("StoredSessionLoader", func() { - createdPast := time.Now().Add(-5 * time.Minute) - createdFuture := time.Now().Add(5 * time.Minute) + now := time.Now() + createdPast := now.Add(-5 * time.Minute) + createdFuture := now.Add(5 * time.Minute) var defaultRefreshFunc = func(_ context.Context, ss *sessionsapi.SessionState) (bool, error) { switch ss.RefreshToken { @@ -85,6 +87,14 @@ var _ = Describe("Stored Session Suite", func() { }, } + BeforeEach(func() { + clock.Set(now) + }) + + AfterEach(func() { + clock.Reset() + }) + type storedSessionLoaderTableInput struct { requestHeaders http.Header existingSession *sessionsapi.SessionState @@ -208,6 +218,21 @@ var _ = Describe("Stored Session Suite", func() { existingSession: nil, expectedSession: &sessionsapi.SessionState{ RefreshToken: "Refreshed", + CreatedAt: &now, + ExpiresOn: &createdFuture, + }, + store: defaultSessionStore, + refreshPeriod: 1 * time.Minute, + refreshSession: defaultRefreshFunc, + validateSession: defaultValidateFunc, + }), + Entry("when the provider refresh fails but validation succeeds", storedSessionLoaderTableInput{ + requestHeaders: http.Header{ + "Cookie": []string{"_oauth2_proxy=RefreshError"}, + }, + existingSession: nil, + expectedSession: &sessionsapi.SessionState{ + RefreshToken: "RefreshError", CreatedAt: &createdPast, ExpiresOn: &createdFuture, }, @@ -216,7 +241,7 @@ var _ = Describe("Stored Session Suite", func() { refreshSession: defaultRefreshFunc, validateSession: defaultValidateFunc, }), - Entry("when the provider refresh fails", storedSessionLoaderTableInput{ + Entry("when the provider refresh fails and validation fails", storedSessionLoaderTableInput{ requestHeaders: http.Header{ "Cookie": []string{"_oauth2_proxy=RefreshError"}, }, @@ -225,7 +250,7 @@ var _ = Describe("Stored Session Suite", func() { store: defaultSessionStore, refreshPeriod: 1 * time.Minute, refreshSession: defaultRefreshFunc, - validateSession: defaultValidateFunc, + validateSession: func(context.Context, *sessionsapi.SessionState) bool { return false }, }), Entry("when the session is not refreshed and is no longer valid", storedSessionLoaderTableInput{ requestHeaders: http.Header{ @@ -326,7 +351,7 @@ var _ = Describe("Stored Session Suite", func() { }, expectedErr: nil, expectRefreshed: true, - expectValidated: false, + expectValidated: true, }), Entry("when the session is not refreshed by the provider", refreshSessionIfNeededTableInput{ refreshPeriod: 1 * time.Minute, @@ -339,15 +364,15 @@ var _ = Describe("Stored Session Suite", func() { expectRefreshed: true, expectValidated: true, }), - Entry("when the provider refresh fails", refreshSessionIfNeededTableInput{ + Entry("when the provider refresh fails but validation succeeds", refreshSessionIfNeededTableInput{ refreshPeriod: 1 * time.Minute, session: &sessionsapi.SessionState{ RefreshToken: "RefreshError", CreatedAt: &createdPast, }, - expectedErr: errors.New("error refreshing access token: error refreshing session"), + expectedErr: nil, expectRefreshed: true, - expectValidated: false, + expectValidated: true, }), Entry("when the session is not refreshed by the provider and validation fails", refreshSessionIfNeededTableInput{ refreshPeriod: 1 * time.Minute, @@ -366,10 +391,9 @@ var _ = Describe("Stored Session Suite", func() { Context("refreshSession", func() { type refreshSessionWithProviderTableInput struct { - session *sessionsapi.SessionState - expectedErr error - expectRefreshed bool - expectSaved bool + session *sessionsapi.SessionState + expectedErr error + expectSaved bool } now := time.Now() @@ -414,15 +438,15 @@ var _ = Describe("Stored Session Suite", func() { session: &sessionsapi.SessionState{ RefreshToken: noRefresh, }, - expectedErr: nil, - expectSaved: false, + expectedErr: nil, + expectSaved: false, }), Entry("when the provider refreshes the session", refreshSessionWithProviderTableInput{ session: &sessionsapi.SessionState{ RefreshToken: refresh, }, - expectedErr: nil, - expectSaved: true, + expectedErr: nil, + expectSaved: true, }), Entry("when the provider returns an error", refreshSessionWithProviderTableInput{ session: &sessionsapi.SessionState{ @@ -430,16 +454,16 @@ var _ = Describe("Stored Session Suite", func() { CreatedAt: &now, ExpiresOn: &now, }, - expectedErr: errors.New("error refreshing access token: error refreshing session"), - expectSaved: false, + expectedErr: errors.New("error refreshing tokens: error refreshing session"), + expectSaved: false, }), Entry("when the saving the session returns an error", refreshSessionWithProviderTableInput{ session: &sessionsapi.SessionState{ RefreshToken: refresh, AccessToken: "NoSave", }, - expectedErr: errors.New("error saving session: unable to save session"), - expectSaved: true, + expectedErr: errors.New("error saving session: unable to save session"), + expectSaved: true, }), ) }) diff --git a/providers/azure.go b/providers/azure.go index 46d7e302..39beb836 100644 --- a/providers/azure.go +++ b/providers/azure.go @@ -242,21 +242,23 @@ func (p *AzureProvider) RefreshSession(ctx context.Context, s *sessions.SessionS return false, nil } - origExpiration := s.ExpiresOn - err := p.redeemRefreshToken(ctx, s) if err != nil { return false, fmt.Errorf("unable to redeem refresh token: %v", err) } - logger.Printf("refreshed id token %s (expired on %s)\n", s, origExpiration) return true, nil } func (p *AzureProvider) redeemRefreshToken(ctx context.Context, s *sessions.SessionState) error { + clientSecret, err := p.GetClientSecret() + if err != nil { + return err + } + params := url.Values{} params.Add("client_id", p.ClientID) - params.Add("client_secret", p.ClientSecret) + params.Add("client_secret", clientSecret) params.Add("refresh_token", s.RefreshToken) params.Add("grant_type", "refresh_token") @@ -267,7 +269,7 @@ func (p *AzureProvider) redeemRefreshToken(ctx context.Context, s *sessions.Sess IDToken string `json:"id_token"` } - err := requests.New(p.RedeemURL.String()). + err = requests.New(p.RedeemURL.String()). WithContext(ctx). WithMethod("POST"). WithBody(bytes.NewBufferString(params.Encode())). diff --git a/providers/azure_test.go b/providers/azure_test.go index 9d539c63..78592df7 100644 --- a/providers/azure_test.go +++ b/providers/azure_test.go @@ -340,17 +340,7 @@ func TestAzureProviderProtectedResourceConfigured(t *testing.T) { assert.Contains(t, result, "resource="+url.QueryEscape("http://my.resource.test")) } -func TestAzureProviderNotRefreshWhenNotExpired(t *testing.T) { - p := testAzureProvider("") - - expires := time.Now().Add(time.Duration(1) * time.Hour) - session := &sessions.SessionState{AccessToken: "some_access_token", RefreshToken: "some_refresh_token", IDToken: "some_id_token", ExpiresOn: &expires} - refreshNeeded, err := p.RefreshSession(context.Background(), session) - assert.Equal(t, nil, err) - assert.False(t, refreshNeeded) -} - -func TestAzureProviderRefreshWhenExpired(t *testing.T) { +func TestAzureProviderRefresh(t *testing.T) { email := "foo@example.com" idToken := idTokenClaims{Email: email} idTokenString, err := newSignedTestIDToken(idToken) diff --git a/providers/google.go b/providers/google.go index 0cfd3e1c..a467c50f 100644 --- a/providers/google.go +++ b/providers/google.go @@ -271,7 +271,7 @@ func (p *GoogleProvider) RefreshSession(ctx context.Context, s *sessions.Session return false, nil } - newToken, newIDToken, ttl, err := p.redeemRefreshToken(ctx, s.RefreshToken) + err := p.redeemRefreshToken(ctx, s) if err != nil { return false, err } @@ -284,26 +284,20 @@ func (p *GoogleProvider) RefreshSession(ctx context.Context, s *sessions.Session return false, fmt.Errorf("%s is no longer in the group(s)", s.Email) } - s.AccessToken = newToken - s.IDToken = newIDToken - - s.CreatedAtNow() - s.ExpiresIn(ttl) - return true, nil } -func (p *GoogleProvider) redeemRefreshToken(ctx context.Context, refreshToken string) (token string, idToken string, expires time.Duration, err error) { +func (p *GoogleProvider) redeemRefreshToken(ctx context.Context, s *sessions.SessionState) error { // https://developers.google.com/identity/protocols/OAuth2WebServer#refresh clientSecret, err := p.GetClientSecret() if err != nil { - return + return err } params := url.Values{} params.Add("client_id", p.ClientID) params.Add("client_secret", clientSecret) - params.Add("refresh_token", refreshToken) + params.Add("refresh_token", s.RefreshToken) params.Add("grant_type", "refresh_token") var data struct { @@ -320,11 +314,14 @@ func (p *GoogleProvider) redeemRefreshToken(ctx context.Context, refreshToken st Do(). UnmarshalInto(&data) if err != nil { - return "", "", 0, err + return err } - token = data.AccessToken - idToken = data.IDToken - expires = time.Duration(data.ExpiresIn) * time.Second - return + s.AccessToken = data.AccessToken + s.IDToken = data.IDToken + + s.CreatedAtNow() + s.ExpiresIn(time.Duration(data.ExpiresIn) * time.Second) + + return nil } diff --git a/providers/oidc.go b/providers/oidc.go index 2cbbd009..b1711d54 100644 --- a/providers/oidc.go +++ b/providers/oidc.go @@ -154,7 +154,6 @@ func (p *OIDCProvider) RefreshSession(ctx context.Context, s *sessions.SessionSt return false, fmt.Errorf("unable to redeem refresh token: %v", err) } - logger.Printf("refreshed session: %s", s) return true, nil } diff --git a/providers/provider_default.go b/providers/provider_default.go index 0a62c240..d364501b 100644 --- a/providers/provider_default.go +++ b/providers/provider_default.go @@ -135,8 +135,13 @@ func (p *ProviderData) RefreshSession(_ context.Context, s *sessions.SessionStat return false, nil } - // Pretend `RefreshSession` occured so `ValidateSession` isn't called + // HACK: + // Pretend `RefreshSession` occurred so `ValidateSession` isn't called // on every request after any potential set refresh period elapses. + // See `middleware.refreshSession` for detailed logic & explanation. + // + // Intentionally doesn't use `ErrNotImplemented` since all providers will + // call this and we don't want to force them to implement this dummy logic. return true, nil } diff --git a/providers/provider_default_test.go b/providers/provider_default_test.go index 8474baae..2ba72f25 100644 --- a/providers/provider_default_test.go +++ b/providers/provider_default_test.go @@ -14,12 +14,20 @@ import ( func TestRefresh(t *testing.T) { p := &ProviderData{} - expires := time.Now().Add(time.Duration(-11) * time.Minute) - refreshed, err := p.RefreshSession(context.Background(), &sessions.SessionState{ - ExpiresOn: &expires, - }) - assert.Equal(t, false, refreshed) - assert.Equal(t, nil, err) + now := time.Unix(1234567890, 10) + expires := time.Unix(1234567890, 0) + + ss := &sessions.SessionState{} + ss.Clock.Set(now) + ss.SetExpiresOn(expires) + + refreshed, err := p.RefreshSession(context.Background(), ss) + assert.True(t, refreshed) + assert.NoError(t, err) + + refreshed, err = p.RefreshSession(context.Background(), nil) + assert.False(t, refreshed) + assert.NoError(t, err) } func TestAcrValuesNotConfigured(t *testing.T) {