1
0
mirror of https://github.com/oauth2-proxy/oauth2-proxy.git synced 2025-08-06 22:42:56 +02:00

Move SessionStore tests to independent package

This commit is contained in:
Joel Speed
2020-05-10 16:59:17 +01:00
parent d9a45a3b47
commit 34137f7305
9 changed files with 576 additions and 490 deletions

View File

@ -5,9 +5,30 @@ import (
"testing"
"time"
"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/options"
sessionsapi "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions"
"github.com/oauth2-proxy/oauth2-proxy/pkg/logger"
"github.com/oauth2-proxy/oauth2-proxy/pkg/sessions/tests"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
"github.com/stretchr/testify/assert"
)
func TestSessionStore(t *testing.T) {
logger.SetOutput(GinkgoWriter)
RegisterFailHandler(Fail)
RunSpecs(t, "Cookie SessionStore")
}
var _ = Describe("Cookie SessionStore Tests", func() {
tests.RunSessionStoreTests(
func(opts *options.SessionOptions, cookieOpts *options.CookieOptions) (sessionsapi.SessionStore, error) {
// Set the connection URL
opts.Type = options.CookieSessionStoreType
return NewCookieSessionStore(opts, cookieOpts)
}, nil)
})
func Test_copyCookie(t *testing.T) {
expire, _ := time.Parse(time.RFC3339, "2020-03-17T00:00:00Z")
c := &http.Cookie{

View File

@ -1,69 +0,0 @@
package redis
import (
"crypto/rand"
"net/http"
"net/http/httptest"
"net/url"
"testing"
"github.com/Bose/minisentinel"
"github.com/alicebob/miniredis/v2"
"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/options"
"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestRedisStore(t *testing.T) {
secret := make([]byte, 32)
_, err := rand.Read(secret)
assert.NoError(t, err)
t.Run("save session on redis standalone", func(t *testing.T) {
redisServer, err := miniredis.Run()
require.NoError(t, err)
defer redisServer.Close()
opts := options.NewOptions()
redisURL := url.URL{
Scheme: "redis",
Host: redisServer.Addr(),
}
opts.Session.Redis.ConnectionURL = redisURL.String()
opts.Cookie.Secret = string(secret)
redisStore, err := NewRedisSessionStore(&opts.Session, &opts.Cookie)
require.NoError(t, err)
err = redisStore.Save(
httptest.NewRecorder(),
httptest.NewRequest(http.MethodGet, "/", nil),
&sessions.SessionState{})
assert.NoError(t, err)
})
t.Run("save session on redis sentinel", func(t *testing.T) {
redisServer, err := miniredis.Run()
require.NoError(t, err)
defer redisServer.Close()
sentinel := minisentinel.NewSentinel(redisServer)
err = sentinel.Start()
require.NoError(t, err)
defer sentinel.Close()
opts := options.NewOptions()
sentinelURL := url.URL{
Scheme: "redis",
Host: sentinel.Addr(),
}
opts.Session.Redis.SentinelConnectionURLs = []string{sentinelURL.String()}
opts.Session.Redis.UseSentinel = true
opts.Session.Redis.SentinelMasterName = sentinel.MasterInfo().Name
opts.Cookie.Secret = string(secret)
redisStore, err := NewRedisSessionStore(&opts.Session, &opts.Cookie)
require.NoError(t, err)
err = redisStore.Save(
httptest.NewRecorder(),
httptest.NewRequest(http.MethodGet, "/", nil),
&sessions.SessionState{})
assert.NoError(t, err)
})
}

View File

@ -0,0 +1,101 @@
package redis
import (
"log"
"os"
"testing"
"time"
"github.com/Bose/minisentinel"
"github.com/alicebob/miniredis/v2"
"github.com/go-redis/redis/v7"
"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/options"
sessionsapi "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions"
"github.com/oauth2-proxy/oauth2-proxy/pkg/logger"
"github.com/oauth2-proxy/oauth2-proxy/pkg/sessions/tests"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
func TestSessionStore(t *testing.T) {
logger.SetOutput(GinkgoWriter)
redisLogger := log.New(os.Stderr, "redis: ", log.LstdFlags|log.Lshortfile)
redisLogger.SetOutput(GinkgoWriter)
redis.SetLogger(redisLogger)
RegisterFailHandler(Fail)
RunSpecs(t, "Redis SessionStore")
}
var _ = Describe("Redis SessionStore Tests", func() {
var mr *miniredis.Miniredis
BeforeEach(func() {
var err error
mr, err = miniredis.Run()
Expect(err).ToNot(HaveOccurred())
})
AfterEach(func() {
mr.Close()
})
tests.RunSessionStoreTests(
func(opts *options.SessionOptions, cookieOpts *options.CookieOptions) (sessionsapi.SessionStore, error) {
// Set the connection URL
opts.Type = options.RedisSessionStoreType
opts.Redis.ConnectionURL = "redis://" + mr.Addr()
return NewRedisSessionStore(opts, cookieOpts)
},
func(d time.Duration) error {
mr.FastForward(d)
return nil
},
)
Context("with sentinel", func() {
var ms *minisentinel.Sentinel
BeforeEach(func() {
ms = minisentinel.NewSentinel(mr)
Expect(ms.Start()).To(Succeed())
})
AfterEach(func() {
go ms.Close()
})
tests.RunSessionStoreTests(
func(opts *options.SessionOptions, cookieOpts *options.CookieOptions) (sessionsapi.SessionStore, error) {
// Set the sentinel connection URL
sentinelAddr := "redis://" + ms.Addr()
opts.Type = options.RedisSessionStoreType
opts.Redis.SentinelConnectionURLs = []string{sentinelAddr}
opts.Redis.UseSentinel = true
opts.Redis.SentinelMasterName = ms.MasterInfo().Name
return NewRedisSessionStore(opts, cookieOpts)
},
func(d time.Duration) error {
mr.FastForward(d)
return nil
},
)
})
Context("with cluster", func() {
tests.RunSessionStoreTests(
func(opts *options.SessionOptions, cookieOpts *options.CookieOptions) (sessionsapi.SessionStore, error) {
clusterAddr := "redis://" + mr.Addr()
opts.Type = options.RedisSessionStoreType
opts.Redis.ClusterConnectionURLs = []string{clusterAddr}
opts.Redis.UseCluster = true
return NewRedisSessionStore(opts, cookieOpts)
},
func(d time.Duration) error {
mr.FastForward(d)
return nil
},
)
})
})

View File

@ -1,20 +1,12 @@
package sessions_test
import (
"crypto/rand"
"encoding/base64"
"net/http"
"net/http/httptest"
"strconv"
"strings"
"math/rand"
"testing"
"time"
miniredis "github.com/alicebob/miniredis/v2"
"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/options"
sessionsapi "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions"
cookiesapi "github.com/oauth2-proxy/oauth2-proxy/pkg/cookies"
"github.com/oauth2-proxy/oauth2-proxy/pkg/encryption"
"github.com/oauth2-proxy/oauth2-proxy/pkg/logger"
"github.com/oauth2-proxy/oauth2-proxy/pkg/sessions"
sessionscookie "github.com/oauth2-proxy/oauth2-proxy/pkg/sessions/cookie"
@ -34,341 +26,7 @@ var _ = Describe("NewSessionStore", func() {
var opts *options.SessionOptions
var cookieOpts *options.CookieOptions
var request *http.Request
var response *httptest.ResponseRecorder
var session *sessionsapi.SessionState
var ss sessionsapi.SessionStore
var mr *miniredis.Miniredis
CheckCookieOptions := func() {
Context("the cookies returned", func() {
var cookies []*http.Cookie
BeforeEach(func() {
cookies = response.Result().Cookies()
})
It("have the correct name set", func() {
if len(cookies) == 1 {
Expect(cookies[0].Name).To(Equal(cookieOpts.Name))
} else {
for _, cookie := range cookies {
Expect(cookie.Name).To(ContainSubstring(cookieOpts.Name))
}
}
})
It("have the correct path set", func() {
for _, cookie := range cookies {
Expect(cookie.Path).To(Equal(cookieOpts.Path))
}
})
It("have the correct domain set", func() {
for _, cookie := range cookies {
specifiedDomain := ""
if len(cookieOpts.Domains) > 0 {
specifiedDomain = cookieOpts.Domains[0]
}
Expect(cookie.Domain).To(Equal(specifiedDomain))
}
})
It("have the correct HTTPOnly set", func() {
for _, cookie := range cookies {
Expect(cookie.HttpOnly).To(Equal(cookieOpts.HTTPOnly))
}
})
It("have the correct secure set", func() {
for _, cookie := range cookies {
Expect(cookie.Secure).To(Equal(cookieOpts.Secure))
}
})
It("have the correct SameSite set", func() {
for _, cookie := range cookies {
Expect(cookie.SameSite).To(Equal(cookiesapi.ParseSameSite(cookieOpts.SameSite)))
}
})
It("have a signature timestamp matching session.CreatedAt", func() {
for _, cookie := range cookies {
if cookie.Value != "" {
parts := strings.Split(cookie.Value, "|")
Expect(parts).To(HaveLen(3))
Expect(parts[1]).To(Equal(strconv.Itoa(int(session.CreatedAt.Unix()))))
}
}
})
})
}
// The following should only be for server stores
PersistentSessionStoreTests := func() {
Context("when Clear is called on a persistent store", func() {
var resultCookies []*http.Cookie
BeforeEach(func() {
req := httptest.NewRequest("GET", "http://example.com/", nil)
saveResp := httptest.NewRecorder()
err := ss.Save(saveResp, req, session)
Expect(err).ToNot(HaveOccurred())
resultCookies = saveResp.Result().Cookies()
for _, c := range resultCookies {
request.AddCookie(c)
}
err = ss.Clear(response, request)
Expect(err).ToNot(HaveOccurred())
})
Context("attempting to Load", func() {
var loadedAfterClear *sessionsapi.SessionState
var loadErr error
BeforeEach(func() {
loadReq := httptest.NewRequest("GET", "http://example.com/", nil)
for _, c := range resultCookies {
loadReq.AddCookie(c)
}
loadedAfterClear, loadErr = ss.Load(loadReq)
})
It("returns an empty session", func() {
Expect(loadedAfterClear).To(BeNil())
})
It("returns an error", func() {
Expect(loadErr).To(HaveOccurred())
})
})
CheckCookieOptions()
})
}
SessionStoreInterfaceTests := func(persistent bool) {
Context("when Save is called", func() {
Context("with no existing session", func() {
BeforeEach(func() {
err := ss.Save(response, request, session)
Expect(err).ToNot(HaveOccurred())
})
It("sets a `set-cookie` header in the response", func() {
Expect(response.Header().Get("set-cookie")).ToNot(BeEmpty())
})
It("Ensures the session CreatedAt is not zero", func() {
Expect(session.CreatedAt.IsZero()).To(BeFalse())
})
})
Context("with a broken session", func() {
BeforeEach(func() {
By("Using a valid cookie with a different providers session encoding")
broken := "BrokenSessionFromADifferentSessionImplementation"
value := encryption.SignedValue(cookieOpts.Secret, cookieOpts.Name, []byte(broken), time.Now())
cookie := cookiesapi.MakeCookieFromOptions(request, cookieOpts.Name, value, cookieOpts, cookieOpts.Expire, time.Now())
request.AddCookie(cookie)
err := ss.Save(response, request, session)
Expect(err).ToNot(HaveOccurred())
})
It("sets a `set-cookie` header in the response", func() {
Expect(response.Header().Get("set-cookie")).ToNot(BeEmpty())
})
It("Ensures the session CreatedAt is not zero", func() {
Expect(session.CreatedAt.IsZero()).To(BeFalse())
})
})
Context("with an expired saved session", func() {
var err error
BeforeEach(func() {
By("saving a session")
req := httptest.NewRequest("GET", "http://example.com/", nil)
saveResp := httptest.NewRecorder()
err = ss.Save(saveResp, req, session)
Expect(err).ToNot(HaveOccurred())
By("and clearing the session")
for _, c := range saveResp.Result().Cookies() {
request.AddCookie(c)
}
clearResp := httptest.NewRecorder()
err = ss.Clear(clearResp, request)
Expect(err).ToNot(HaveOccurred())
By("then saving a request with the cleared session")
err = ss.Save(response, request, session)
})
It("no error should occur", func() {
Expect(err).ToNot(HaveOccurred())
})
})
CheckCookieOptions()
})
Context("when Clear is called", func() {
BeforeEach(func() {
req := httptest.NewRequest("GET", "http://example.com/", nil)
saveResp := httptest.NewRecorder()
err := ss.Save(saveResp, req, session)
Expect(err).ToNot(HaveOccurred())
for _, c := range saveResp.Result().Cookies() {
request.AddCookie(c)
}
err = ss.Clear(response, request)
Expect(err).ToNot(HaveOccurred())
})
It("sets a `set-cookie` header in the response", func() {
Expect(response.Header().Get("Set-Cookie")).ToNot(BeEmpty())
})
CheckCookieOptions()
})
Context("when Load is called", func() {
LoadSessionTests := func() {
var loadedSession *sessionsapi.SessionState
BeforeEach(func() {
var err error
loadedSession, err = ss.Load(request)
Expect(err).ToNot(HaveOccurred())
})
It("loads a session equal to the original session", func() {
if cookieOpts.Secret == "" {
// Only Email and User stored in session when encrypted
Expect(loadedSession.Email).To(Equal(session.Email))
Expect(loadedSession.User).To(Equal(session.User))
} else {
// All fields stored in session if encrypted
// Can't compare time.Time using Equal() so remove ExpiresOn from sessions
l := *loadedSession
l.CreatedAt = nil
l.ExpiresOn = nil
s := *session
s.CreatedAt = nil
s.ExpiresOn = nil
Expect(l).To(Equal(s))
// Compare time.Time separately
Expect(loadedSession.CreatedAt.Equal(*session.CreatedAt)).To(BeTrue())
Expect(loadedSession.ExpiresOn.Equal(*session.ExpiresOn)).To(BeTrue())
}
})
}
BeforeEach(func() {
req := httptest.NewRequest("GET", "http://example.com/", nil)
resp := httptest.NewRecorder()
err := ss.Save(resp, req, session)
Expect(err).ToNot(HaveOccurred())
for _, cookie := range resp.Result().Cookies() {
request.AddCookie(cookie)
}
})
Context("before the refresh period", func() {
LoadSessionTests()
})
// Test TTLs and cleanup of persistent session storage
// For non-persistent we rely on the browser cookie lifecycle
if persistent {
Context("after the refresh period, but before the cookie expire period", func() {
BeforeEach(func() {
switch ss.(type) {
case *redis.SessionStore:
mr.FastForward(cookieOpts.Refresh + time.Minute)
}
})
LoadSessionTests()
})
Context("after the cookie expire period", func() {
var loadedSession *sessionsapi.SessionState
var err error
BeforeEach(func() {
switch ss.(type) {
case *redis.SessionStore:
mr.FastForward(cookieOpts.Expire + time.Minute)
}
loadedSession, err = ss.Load(request)
Expect(err).To(HaveOccurred())
})
It("returns an error loading the session", func() {
Expect(err).To(HaveOccurred())
})
It("returns an empty session", func() {
Expect(loadedSession).To(BeNil())
})
})
}
})
if persistent {
PersistentSessionStoreTests()
}
}
RunSessionTests := func(persistent bool) {
Context("with default options", func() {
BeforeEach(func() {
var err error
ss, err = sessions.NewSessionStore(opts, cookieOpts)
Expect(err).ToNot(HaveOccurred())
})
SessionStoreInterfaceTests(persistent)
})
Context("with non-default options", func() {
BeforeEach(func() {
cookieOpts = &options.CookieOptions{
Name: "_cookie_name",
Path: "/path",
Expire: time.Duration(72) * time.Hour,
Refresh: time.Duration(2) * time.Hour,
Secure: false,
HTTPOnly: false,
Domains: []string{"example.com"},
SameSite: "strict",
}
// A secret is required but not defaulted
secret := make([]byte, 32)
_, err := rand.Read(secret)
Expect(err).ToNot(HaveOccurred())
cookieOpts.Secret = base64.URLEncoding.EncodeToString(secret)
ss, err = sessions.NewSessionStore(opts, cookieOpts)
Expect(err).ToNot(HaveOccurred())
})
SessionStoreInterfaceTests(persistent)
})
}
BeforeEach(func() {
ss = nil
opts = &options.SessionOptions{}
// A secret is required to create a Cipher, validation ensures it is the correct
@ -388,19 +46,6 @@ var _ = Describe("NewSessionStore", func() {
HTTPOnly: true,
SameSite: "",
}
expires := time.Now().Add(1 * time.Hour)
session = &sessionsapi.SessionState{
AccessToken: "AccessToken",
IDToken: "IDToken",
ExpiresOn: &expires,
RefreshToken: "RefreshToken",
Email: "john.doe@example.com",
User: "john.doe",
}
request = httptest.NewRequest("GET", "http://example.com/", nil)
response = httptest.NewRecorder()
})
Context("with type 'cookie'", func() {
@ -413,36 +58,12 @@ var _ = Describe("NewSessionStore", func() {
Expect(err).NotTo(HaveOccurred())
Expect(ss).To(BeAssignableToTypeOf(&sessionscookie.SessionStore{}))
})
Context("the cookie.SessionStore", func() {
RunSessionTests(false)
})
Context("with an invalid cookie secret", func() {
BeforeEach(func() {
cookieOpts.Secret = "invalid"
})
It("returns an error", func() {
ss, err := sessions.NewSessionStore(opts, cookieOpts)
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(Equal("error initialising cipher: crypto/aes: invalid key size 7"))
Expect(ss).To(BeNil())
})
})
})
Context("with type 'redis'", func() {
BeforeEach(func() {
var err error
mr, err = miniredis.Run()
Expect(err).ToNot(HaveOccurred())
opts.Type = options.RedisSessionStoreType
opts.Redis.ConnectionURL = "redis://" + mr.Addr()
})
AfterEach(func() {
mr.Close()
opts.Redis.ConnectionURL = "redis://"
})
It("creates a redis.SessionStore", func() {
@ -450,23 +71,6 @@ var _ = Describe("NewSessionStore", func() {
Expect(err).NotTo(HaveOccurred())
Expect(ss).To(BeAssignableToTypeOf(&redis.SessionStore{}))
})
Context("the redis.SessionStore", func() {
RunSessionTests(true)
})
Context("with an invalid cookie secret", func() {
BeforeEach(func() {
cookieOpts.Secret = "invalid"
})
It("returns an error", func() {
ss, err := sessions.NewSessionStore(opts, cookieOpts)
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(Equal("error initialising cipher: crypto/aes: invalid key size 7"))
Expect(ss).To(BeNil())
})
})
})
Context("with an invalid type", func() {

View File

@ -0,0 +1,435 @@
package tests
import (
"crypto/rand"
"net/http"
"net/http/httptest"
"strconv"
"strings"
"time"
"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/options"
sessionsapi "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions"
cookiesapi "github.com/oauth2-proxy/oauth2-proxy/pkg/cookies"
"github.com/oauth2-proxy/oauth2-proxy/pkg/encryption"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
// testInput is passed to test function as a pointer.
// This allows BeforeEach blocks to initialise and use these values after
// Ginkgo has unpacked the tests.
// Interfaces have to be wrapped in closures otherwise nil pointers are thrown.
type testInput struct {
cookieOpts *options.CookieOptions
ss sessionStoreFunc
session *sessionsapi.SessionState
request *http.Request
response *httptest.ResponseRecorder
persistentFastForward PersistentStoreFastForwardFunc
}
// sessionStoreFunc is used in testInput to wrap the SessionStore interface.
type sessionStoreFunc func() sessionsapi.SessionStore
// PersistentStoreFastForwardFunc is used to adjust the time of the persistent
// store to fast forward expiry of sessions.
type PersistentStoreFastForwardFunc func(time.Duration) error
// NewSessionStoreFunc allows any session store implementation to configure their
// own session store before each test.
type NewSessionStoreFunc func(sessionOpts *options.SessionOptions, cookieOpts *options.CookieOptions) (sessionsapi.SessionStore, error)
func RunSessionStoreTests(newSS NewSessionStoreFunc, persistentFastForward PersistentStoreFastForwardFunc) {
Describe("Session Store Suite", func() {
var opts *options.SessionOptions
var ss sessionsapi.SessionStore
var input testInput
var cookieSecret []byte
getSessionStore := func() sessionsapi.SessionStore {
return ss
}
BeforeEach(func() {
ss = nil
opts = &options.SessionOptions{}
// A secret is required to create a Cipher, validation ensures it is the correct
// length before a session store is initialised.
cookieSecret = make([]byte, 32)
_, err := rand.Read(cookieSecret)
Expect(err).ToNot(HaveOccurred())
// Set default options in CookieOptions
cookieOpts := &options.CookieOptions{
Name: "_oauth2_proxy",
Path: "/",
Expire: time.Duration(168) * time.Hour,
Refresh: time.Duration(1) * time.Hour,
Secure: true,
HTTPOnly: true,
SameSite: "",
Secret: string(cookieSecret),
}
expires := time.Now().Add(1 * time.Hour)
session := &sessionsapi.SessionState{
AccessToken: "AccessToken",
IDToken: "IDToken",
ExpiresOn: &expires,
RefreshToken: "RefreshToken",
Email: "john.doe@example.com",
User: "john.doe",
}
request := httptest.NewRequest("GET", "http://example.com/", nil)
response := httptest.NewRecorder()
input = testInput{
cookieOpts: cookieOpts,
ss: getSessionStore,
session: session,
request: request,
response: response,
persistentFastForward: persistentFastForward,
}
})
Context("with default options", func() {
BeforeEach(func() {
var err error
ss, err = newSS(opts, input.cookieOpts)
Expect(err).ToNot(HaveOccurred())
})
SessionStoreInterfaceTests(&input)
if persistentFastForward != nil {
PersistentSessionStoreInterfaceTests(&input)
}
})
Context("with non-default options", func() {
BeforeEach(func() {
input.cookieOpts = &options.CookieOptions{
Name: "_cookie_name",
Path: "/path",
Expire: time.Duration(72) * time.Hour,
Refresh: time.Duration(2) * time.Hour,
Secure: false,
HTTPOnly: false,
Domains: []string{"example.com"},
SameSite: "strict",
Secret: string(cookieSecret),
}
var err error
ss, err = newSS(opts, input.cookieOpts)
Expect(err).ToNot(HaveOccurred())
})
SessionStoreInterfaceTests(&input)
if persistentFastForward != nil {
PersistentSessionStoreInterfaceTests(&input)
}
})
Context("with an invalid cookie secret", func() {
BeforeEach(func() {
input.cookieOpts.Secret = "invalid"
})
It("returns an error when initialising the session store", func() {
ss, err := newSS(opts, input.cookieOpts)
Expect(err).To(MatchError("error initialising cipher: crypto/aes: invalid key size 7"))
Expect(ss).To(BeNil())
})
})
})
}
func CheckCookieOptions(in *testInput) {
Context("the cookies returned", func() {
var cookies []*http.Cookie
BeforeEach(func() {
cookies = in.response.Result().Cookies()
})
It("have the correct name set", func() {
if len(cookies) == 1 {
Expect(cookies[0].Name).To(Equal(in.cookieOpts.Name))
} else {
for _, cookie := range cookies {
Expect(cookie.Name).To(ContainSubstring(in.cookieOpts.Name))
}
}
})
It("have the correct path set", func() {
for _, cookie := range cookies {
Expect(cookie.Path).To(Equal(in.cookieOpts.Path))
}
})
It("have the correct domain set", func() {
for _, cookie := range cookies {
specifiedDomain := ""
if len(in.cookieOpts.Domains) > 0 {
specifiedDomain = in.cookieOpts.Domains[0]
}
Expect(cookie.Domain).To(Equal(specifiedDomain))
}
})
It("have the correct HTTPOnly set", func() {
for _, cookie := range cookies {
Expect(cookie.HttpOnly).To(Equal(in.cookieOpts.HTTPOnly))
}
})
It("have the correct secure set", func() {
for _, cookie := range cookies {
Expect(cookie.Secure).To(Equal(in.cookieOpts.Secure))
}
})
It("have the correct SameSite set", func() {
for _, cookie := range cookies {
Expect(cookie.SameSite).To(Equal(cookiesapi.ParseSameSite(in.cookieOpts.SameSite)))
}
})
It("have a signature timestamp matching session.CreatedAt", func() {
for _, cookie := range cookies {
if cookie.Value != "" {
parts := strings.Split(cookie.Value, "|")
Expect(parts).To(HaveLen(3))
Expect(parts[1]).To(Equal(strconv.Itoa(int(in.session.CreatedAt.Unix()))))
}
}
})
})
}
func PersistentSessionStoreInterfaceTests(in *testInput) {
// Check that a stale cookie can't load an already cleared session
Context("when Clear is called on a persistent store", func() {
var resultCookies []*http.Cookie
BeforeEach(func() {
req := httptest.NewRequest("GET", "http://example.com/", nil)
saveResp := httptest.NewRecorder()
err := in.ss().Save(saveResp, req, in.session)
Expect(err).ToNot(HaveOccurred())
resultCookies = saveResp.Result().Cookies()
for _, c := range resultCookies {
in.request.AddCookie(c)
}
err = in.ss().Clear(in.response, in.request)
Expect(err).ToNot(HaveOccurred())
})
Context("attempting to Load", func() {
var loadedAfterClear *sessionsapi.SessionState
var loadErr error
BeforeEach(func() {
loadReq := httptest.NewRequest("GET", "http://example.com/", nil)
for _, c := range resultCookies {
loadReq.AddCookie(c)
}
loadedAfterClear, loadErr = in.ss().Load(loadReq)
})
It("returns an empty session", func() {
Expect(loadedAfterClear).To(BeNil())
})
It("returns an error", func() {
Expect(loadErr).To(HaveOccurred())
})
})
CheckCookieOptions(in)
})
// Test TTLs and cleanup of persistent session storage
// For non-persistent we rely on the browser cookie lifecycle
Context("when Load is called on a persistent store", func() {
BeforeEach(func() {
req := httptest.NewRequest("GET", "http://example.com/", nil)
resp := httptest.NewRecorder()
err := in.ss().Save(resp, req, in.session)
Expect(err).ToNot(HaveOccurred())
for _, cookie := range resp.Result().Cookies() {
in.request.AddCookie(cookie)
}
})
Context("after the refresh period, but before the cookie expire period", func() {
BeforeEach(func() {
Expect(in.persistentFastForward(in.cookieOpts.Refresh + time.Minute)).To(Succeed())
})
LoadSessionTests(in)
})
Context("after the cookie expire period", func() {
var loadedSession *sessionsapi.SessionState
var err error
BeforeEach(func() {
Expect(in.persistentFastForward(in.cookieOpts.Expire + time.Minute)).To(Succeed())
loadedSession, err = in.ss().Load(in.request)
Expect(err).To(HaveOccurred())
})
It("returns an error loading the session", func() {
Expect(err).To(HaveOccurred())
})
It("returns an empty session", func() {
Expect(loadedSession).To(BeNil())
})
})
})
}
func SessionStoreInterfaceTests(in *testInput) {
Context("when Save is called", func() {
Context("with no existing session", func() {
BeforeEach(func() {
err := in.ss().Save(in.response, in.request, in.session)
Expect(err).ToNot(HaveOccurred())
})
It("sets a `set-cookie` header in the response", func() {
Expect(in.response.Header().Get("set-cookie")).ToNot(BeEmpty())
})
It("Ensures the session CreatedAt is not zero", func() {
Expect(in.session.CreatedAt.IsZero()).To(BeFalse())
})
CheckCookieOptions(in)
})
Context("with a broken session", func() {
BeforeEach(func() {
By("Using a valid cookie with a different providers session encoding")
broken := "BrokenSessionFromADifferentSessionImplementation"
value := encryption.SignedValue(in.cookieOpts.Secret, in.cookieOpts.Name, []byte(broken), time.Now())
cookie := cookiesapi.MakeCookieFromOptions(in.request, in.cookieOpts.Name, value, in.cookieOpts, in.cookieOpts.Expire, time.Now())
in.request.AddCookie(cookie)
err := in.ss().Save(in.response, in.request, in.session)
Expect(err).ToNot(HaveOccurred())
})
It("sets a `set-cookie` header in the response", func() {
Expect(in.response.Header().Get("set-cookie")).ToNot(BeEmpty())
})
It("Ensures the session CreatedAt is not zero", func() {
Expect(in.session.CreatedAt.IsZero()).To(BeFalse())
})
CheckCookieOptions(in)
})
Context("with an expired saved session", func() {
var err error
BeforeEach(func() {
By("saving a session")
req := httptest.NewRequest("GET", "http://example.com/", nil)
saveResp := httptest.NewRecorder()
err = in.ss().Save(saveResp, req, in.session)
Expect(err).ToNot(HaveOccurred())
By("and clearing the session")
for _, c := range saveResp.Result().Cookies() {
in.request.AddCookie(c)
}
clearResp := httptest.NewRecorder()
err = in.ss().Clear(clearResp, in.request)
Expect(err).ToNot(HaveOccurred())
By("then saving a request with the cleared session")
err = in.ss().Save(in.response, in.request, in.session)
})
It("no error should occur", func() {
Expect(err).ToNot(HaveOccurred())
})
})
})
Context("when Clear is called", func() {
BeforeEach(func() {
req := httptest.NewRequest("GET", "http://example.com/", nil)
saveResp := httptest.NewRecorder()
err := in.ss().Save(saveResp, req, in.session)
Expect(err).ToNot(HaveOccurred())
for _, c := range saveResp.Result().Cookies() {
in.request.AddCookie(c)
}
err = in.ss().Clear(in.response, in.request)
Expect(err).ToNot(HaveOccurred())
})
It("sets a `set-cookie` header in the response", func() {
Expect(in.response.Header().Get("Set-Cookie")).ToNot(BeEmpty())
})
CheckCookieOptions(in)
})
Context("when Load is called", func() {
BeforeEach(func() {
req := httptest.NewRequest("GET", "http://example.com/", nil)
resp := httptest.NewRecorder()
err := in.ss().Save(resp, req, in.session)
Expect(err).ToNot(HaveOccurred())
for _, cookie := range resp.Result().Cookies() {
in.request.AddCookie(cookie)
}
})
Context("before the refresh period", func() {
LoadSessionTests(in)
})
})
}
func LoadSessionTests(in *testInput) {
var loadedSession *sessionsapi.SessionState
BeforeEach(func() {
var err error
loadedSession, err = in.ss().Load(in.request)
Expect(err).ToNot(HaveOccurred())
})
It("loads a session equal to the original session", func() {
// Can't compare time.Time using Equal() so remove ExpiresOn from sessions
l := *loadedSession
l.CreatedAt = nil
l.ExpiresOn = nil
s := *in.session
s.CreatedAt = nil
s.ExpiresOn = nil
Expect(l).To(Equal(s))
// Compare time.Time separately
Expect(loadedSession.CreatedAt.Equal(*in.session.CreatedAt)).To(BeTrue())
Expect(loadedSession.ExpiresOn.Equal(*in.session.ExpiresOn)).To(BeTrue())
})
}