diff --git a/pkg/middleware/stored_session.go b/pkg/middleware/stored_session.go index 6c6c45ed..972a4c6f 100644 --- a/pkg/middleware/stored_session.go +++ b/pkg/middleware/stored_session.go @@ -103,13 +103,7 @@ func (s *storedSessionLoader) getValidatedSession(rw http.ResponseWriter, req *h err = s.refreshSessionIfNeeded(rw, req, session) if err != nil { - logger.Errorf("error refreshing access token for session (%s): %v", session, err) - } - - // Validate all sessions after any Redeem/Refresh operation (fail or success) - err = s.validateSession(req.Context(), session) - if err != nil { - return nil, err + return nil, fmt.Errorf("error refreshing access token for session (%s): %v", session, err) } return session, nil @@ -133,11 +127,22 @@ func (s *storedSessionLoader) refreshSessionIfNeeded(rw http.ResponseWriter, req // it should be updated after lock is released. if wasLocked { logger.Printf("Update session from store instead of refreshing") - return s.updateSessionFromStore(req, session) + err = s.updateSessionFromStore(req, session) + if err != nil { + logger.Errorf("Unable to update session from store: %v", err) + } + } else { + logger.Printf("Refreshing session - User: %s; SessionAge: %s", session.User, session.Age()) + err = s.refreshSession(rw, req, session) + if err != nil { + // If a preemptive refresh fails, we still keep the session + // if validateSession succeeds. + logger.Errorf("Unable to refresh session: %v", err) + } } - logger.Printf("Refreshing session - User: %s; SessionAge: %s", session.User, session.Age()) - return s.refreshSession(rw, req, session) + // Validate all sessions after any Redeem/Refresh operation (fail or success) + return s.validateSession(req.Context(), session) } // refreshSession attempts to refresh the session with the provider diff --git a/pkg/middleware/stored_session_test.go b/pkg/middleware/stored_session_test.go index 6ce816ab..7c74116a 100644 --- a/pkg/middleware/stored_session_test.go +++ b/pkg/middleware/stored_session_test.go @@ -460,6 +460,7 @@ var _ = Describe("Stored Session Suite", func() { session *sessionsapi.SessionState expectedErr error expectRefreshed bool + expectValidated bool expectedLockState TestLock } @@ -469,6 +470,7 @@ var _ = Describe("Stored Session Suite", func() { DescribeTable("with a session", func(in refreshSessionIfNeededTableInput) { refreshed := false + validated := false store := &fakeSessionStore{} if in.sessionStored { @@ -496,6 +498,7 @@ var _ = Describe("Stored Session Suite", func() { } }, sessionValidator: func(_ context.Context, ss *sessionsapi.SessionState) bool { + validated = true return ss.AccessToken != "Invalid" }, } @@ -508,6 +511,7 @@ var _ = Describe("Stored Session Suite", func() { Expect(err).ToNot(HaveOccurred()) } Expect(refreshed).To(Equal(in.expectRefreshed)) + Expect(validated).To(Equal(in.expectValidated)) testLock, ok := in.session.Lock.(*TestLock) Expect(ok).To(Equal(true)) @@ -522,6 +526,7 @@ var _ = Describe("Stored Session Suite", func() { }, expectedErr: nil, expectRefreshed: false, + expectValidated: false, expectedLockState: TestLock{}, }), Entry("when the refresh period is 0, and the session needs refreshing", refreshSessionIfNeededTableInput{ @@ -533,6 +538,7 @@ var _ = Describe("Stored Session Suite", func() { }, expectedErr: nil, expectRefreshed: false, + expectValidated: false, expectedLockState: TestLock{}, }), Entry("when the session does not need refreshing", refreshSessionIfNeededTableInput{ @@ -544,6 +550,7 @@ var _ = Describe("Stored Session Suite", func() { }, expectedErr: nil, expectRefreshed: false, + expectValidated: false, expectedLockState: TestLock{}, }), Entry("when the session is refreshed by the provider", refreshSessionIfNeededTableInput{ @@ -555,6 +562,7 @@ var _ = Describe("Stored Session Suite", func() { }, expectedErr: nil, expectRefreshed: true, + expectValidated: true, expectedLockState: TestLock{ Locked: false, WasObtained: true, @@ -574,6 +582,7 @@ var _ = Describe("Stored Session Suite", func() { sessionStored: true, expectedErr: nil, expectRefreshed: false, + expectValidated: true, expectedLockState: TestLock{ Locked: false, PeekedCount: 2, @@ -590,6 +599,7 @@ var _ = Describe("Stored Session Suite", func() { }, expectedErr: nil, expectRefreshed: false, + expectValidated: true, expectedLockState: TestLock{ PeekedCount: 1, ObtainError: errors.New("not able to obtain lock"), @@ -605,6 +615,7 @@ var _ = Describe("Stored Session Suite", func() { }, expectedErr: nil, expectRefreshed: true, + expectValidated: true, expectedLockState: TestLock{ Locked: false, WasObtained: true, @@ -621,6 +632,7 @@ var _ = Describe("Stored Session Suite", func() { }, expectedErr: nil, expectRefreshed: true, + expectValidated: true, expectedLockState: TestLock{ Locked: false, WasObtained: true, @@ -637,8 +649,9 @@ var _ = Describe("Stored Session Suite", func() { ExpiresOn: &createdFuture, Lock: &TestLock{}, }, - expectedErr: nil, + expectedErr: errors.New("session is invalid"), expectRefreshed: true, + expectValidated: true, expectedLockState: TestLock{ Locked: false, WasObtained: true,