diff --git a/pkg/sessions/redis/redis_store.go b/pkg/sessions/redis/redis_store.go index b32d5c2f..6a5ffbbd 100644 --- a/pkg/sessions/redis/redis_store.go +++ b/pkg/sessions/redis/redis_store.go @@ -82,7 +82,7 @@ func (store *SessionStore) Save(rw http.ResponseWriter, req *http.Request, s *se if err != nil { return err } - ticketString, err := store.storeValue(value, s.ExpiresOn, requestCookie) + ticketString, err := store.storeValue(value, store.CookieOptions.CookieExpire, requestCookie) if err != nil { return err } @@ -191,7 +191,7 @@ func (store *SessionStore) makeCookie(req *http.Request, value string, expires t ) } -func (store *SessionStore) storeValue(value string, expiresOn time.Time, requestCookie *http.Cookie) (string, error) { +func (store *SessionStore) storeValue(value string, expiration time.Duration, requestCookie *http.Cookie) (string, error) { var ticket *TicketData if requestCookie != nil { var err error @@ -225,7 +225,6 @@ func (store *SessionStore) storeValue(value string, expiresOn time.Time, request stream.XORKeyStream(ciphertext, []byte(value)) handle := ticket.asHandle(store.CookieOptions.CookieName) - expiration := expiresOn.Sub(time.Now()) err = store.Client.Set(handle, ciphertext, expiration).Err() if err != nil { return "", err diff --git a/pkg/sessions/session_store_test.go b/pkg/sessions/session_store_test.go index 85e63b3a..2ffc0bdf 100644 --- a/pkg/sessions/session_store_test.go +++ b/pkg/sessions/session_store_test.go @@ -35,6 +35,7 @@ var _ = Describe("NewSessionStore", func() { var response *httptest.ResponseRecorder var session *sessionsapi.SessionState var ss sessionsapi.SessionStore + var mr *miniredis.Miniredis CheckCookieOptions := func() { Context("the cookies returned", func() { @@ -203,7 +204,38 @@ var _ = Describe("NewSessionStore", func() { }) Context("when Load is called", func() { - var loadedSession *sessionsapi.SessionState + LoadSessionTests := func() { + var loadedSession *sessionsapi.SessionState + BeforeEach(func() { + var err error + loadedSession, err = ss.Load(request) + Expect(err).ToNot(HaveOccurred()) + }) + + It("loads a session equal to the original session", func() { + if cookieOpts.CookieSecret == "" { + // Only Email and User stored in session when encrypted + Expect(loadedSession.Email).To(Equal(session.Email)) + Expect(loadedSession.User).To(Equal(session.User)) + } else { + // All fields stored in session if encrypted + + // Can't compare time.Time using Equal() so remove ExpiresOn from sessions + l := *loadedSession + l.CreatedAt = time.Time{} + l.ExpiresOn = time.Time{} + s := *session + s.CreatedAt = time.Time{} + s.ExpiresOn = time.Time{} + Expect(l).To(Equal(s)) + + // Compare time.Time separately + Expect(loadedSession.CreatedAt.Equal(session.CreatedAt)).To(BeTrue()) + Expect(loadedSession.ExpiresOn.Equal(session.ExpiresOn)).To(BeTrue()) + } + }) + } + BeforeEach(func() { req := httptest.NewRequest("GET", "http://example.com/", nil) resp := httptest.NewRecorder() @@ -213,32 +245,49 @@ var _ = Describe("NewSessionStore", func() { for _, cookie := range resp.Result().Cookies() { request.AddCookie(cookie) } - loadedSession, err = ss.Load(request) - Expect(err).ToNot(HaveOccurred()) }) - It("loads a session equal to the original session", func() { - if cookieOpts.CookieSecret == "" { - // Only Email and User stored in session when encrypted - Expect(loadedSession.Email).To(Equal(session.Email)) - Expect(loadedSession.User).To(Equal(session.User)) - } else { - // All fields stored in session if encrypted - - // Can't compare time.Time using Equal() so remove ExpiresOn from sessions - l := *loadedSession - l.CreatedAt = time.Time{} - l.ExpiresOn = time.Time{} - s := *session - s.CreatedAt = time.Time{} - s.ExpiresOn = time.Time{} - Expect(l).To(Equal(s)) - - // Compare time.Time separately - Expect(loadedSession.CreatedAt.Equal(session.CreatedAt)).To(BeTrue()) - Expect(loadedSession.ExpiresOn.Equal(session.ExpiresOn)).To(BeTrue()) - } + Context("before the refresh period", func() { + LoadSessionTests() }) + + // Test TTLs and cleanup of persistent session storage + // For non-persistent we rely on the browser cookie lifecycle + if persistent { + Context("after the refresh period, but before the cookie expire period", func() { + BeforeEach(func() { + switch ss.(type) { + case *redis.SessionStore: + mr.FastForward(cookieOpts.CookieRefresh + time.Minute) + } + }) + + LoadSessionTests() + }) + + Context("after the cookie expire period", func() { + var loadedSession *sessionsapi.SessionState + var err error + + BeforeEach(func() { + switch ss.(type) { + case *redis.SessionStore: + mr.FastForward(cookieOpts.CookieExpire + time.Minute) + } + + loadedSession, err = ss.Load(request) + Expect(err).To(HaveOccurred()) + }) + + It("returns an error loading the session", func() { + Expect(err).To(HaveOccurred()) + }) + + It("returns an empty session", func() { + Expect(loadedSession).To(BeNil()) + }) + }) + } }) if persistent { @@ -263,7 +312,7 @@ var _ = Describe("NewSessionStore", func() { CookieName: "_cookie_name", CookiePath: "/path", CookieExpire: time.Duration(72) * time.Hour, - CookieRefresh: time.Duration(3600), + CookieRefresh: time.Duration(2) * time.Hour, CookieSecure: false, CookieHTTPOnly: false, CookieDomain: "example.com", @@ -305,7 +354,7 @@ var _ = Describe("NewSessionStore", func() { CookieName: "_oauth2_proxy", CookiePath: "/", CookieExpire: time.Duration(168) * time.Hour, - CookieRefresh: time.Duration(0), + CookieRefresh: time.Duration(1) * time.Hour, CookieSecure: true, CookieHTTPOnly: true, } @@ -340,7 +389,6 @@ var _ = Describe("NewSessionStore", func() { }) Context("with type 'redis'", func() { - var mr *miniredis.Miniredis BeforeEach(func() { var err error mr, err = miniredis.Run()