You've already forked oauth2-proxy
mirror of
https://github.com/oauth2-proxy/oauth2-proxy.git
synced 2025-08-10 22:51:31 +02:00
Move SessionStore tests to independent package
This commit is contained in:
435
pkg/sessions/tests/session_store_tests.go
Normal file
435
pkg/sessions/tests/session_store_tests.go
Normal file
@@ -0,0 +1,435 @@
|
||||
package tests
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/options"
|
||||
sessionsapi "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions"
|
||||
cookiesapi "github.com/oauth2-proxy/oauth2-proxy/pkg/cookies"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/pkg/encryption"
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
// testInput is passed to test function as a pointer.
|
||||
// This allows BeforeEach blocks to initialise and use these values after
|
||||
// Ginkgo has unpacked the tests.
|
||||
// Interfaces have to be wrapped in closures otherwise nil pointers are thrown.
|
||||
type testInput struct {
|
||||
cookieOpts *options.CookieOptions
|
||||
ss sessionStoreFunc
|
||||
session *sessionsapi.SessionState
|
||||
request *http.Request
|
||||
response *httptest.ResponseRecorder
|
||||
persistentFastForward PersistentStoreFastForwardFunc
|
||||
}
|
||||
|
||||
// sessionStoreFunc is used in testInput to wrap the SessionStore interface.
|
||||
type sessionStoreFunc func() sessionsapi.SessionStore
|
||||
|
||||
// PersistentStoreFastForwardFunc is used to adjust the time of the persistent
|
||||
// store to fast forward expiry of sessions.
|
||||
type PersistentStoreFastForwardFunc func(time.Duration) error
|
||||
|
||||
// NewSessionStoreFunc allows any session store implementation to configure their
|
||||
// own session store before each test.
|
||||
type NewSessionStoreFunc func(sessionOpts *options.SessionOptions, cookieOpts *options.CookieOptions) (sessionsapi.SessionStore, error)
|
||||
|
||||
func RunSessionStoreTests(newSS NewSessionStoreFunc, persistentFastForward PersistentStoreFastForwardFunc) {
|
||||
Describe("Session Store Suite", func() {
|
||||
var opts *options.SessionOptions
|
||||
var ss sessionsapi.SessionStore
|
||||
var input testInput
|
||||
var cookieSecret []byte
|
||||
|
||||
getSessionStore := func() sessionsapi.SessionStore {
|
||||
return ss
|
||||
}
|
||||
|
||||
BeforeEach(func() {
|
||||
ss = nil
|
||||
opts = &options.SessionOptions{}
|
||||
|
||||
// A secret is required to create a Cipher, validation ensures it is the correct
|
||||
// length before a session store is initialised.
|
||||
cookieSecret = make([]byte, 32)
|
||||
_, err := rand.Read(cookieSecret)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
// Set default options in CookieOptions
|
||||
cookieOpts := &options.CookieOptions{
|
||||
Name: "_oauth2_proxy",
|
||||
Path: "/",
|
||||
Expire: time.Duration(168) * time.Hour,
|
||||
Refresh: time.Duration(1) * time.Hour,
|
||||
Secure: true,
|
||||
HTTPOnly: true,
|
||||
SameSite: "",
|
||||
Secret: string(cookieSecret),
|
||||
}
|
||||
|
||||
expires := time.Now().Add(1 * time.Hour)
|
||||
session := &sessionsapi.SessionState{
|
||||
AccessToken: "AccessToken",
|
||||
IDToken: "IDToken",
|
||||
ExpiresOn: &expires,
|
||||
RefreshToken: "RefreshToken",
|
||||
Email: "john.doe@example.com",
|
||||
User: "john.doe",
|
||||
}
|
||||
|
||||
request := httptest.NewRequest("GET", "http://example.com/", nil)
|
||||
response := httptest.NewRecorder()
|
||||
|
||||
input = testInput{
|
||||
cookieOpts: cookieOpts,
|
||||
ss: getSessionStore,
|
||||
session: session,
|
||||
request: request,
|
||||
response: response,
|
||||
persistentFastForward: persistentFastForward,
|
||||
}
|
||||
})
|
||||
|
||||
Context("with default options", func() {
|
||||
BeforeEach(func() {
|
||||
var err error
|
||||
ss, err = newSS(opts, input.cookieOpts)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
SessionStoreInterfaceTests(&input)
|
||||
if persistentFastForward != nil {
|
||||
PersistentSessionStoreInterfaceTests(&input)
|
||||
}
|
||||
})
|
||||
|
||||
Context("with non-default options", func() {
|
||||
BeforeEach(func() {
|
||||
input.cookieOpts = &options.CookieOptions{
|
||||
Name: "_cookie_name",
|
||||
Path: "/path",
|
||||
Expire: time.Duration(72) * time.Hour,
|
||||
Refresh: time.Duration(2) * time.Hour,
|
||||
Secure: false,
|
||||
HTTPOnly: false,
|
||||
Domains: []string{"example.com"},
|
||||
SameSite: "strict",
|
||||
Secret: string(cookieSecret),
|
||||
}
|
||||
|
||||
var err error
|
||||
ss, err = newSS(opts, input.cookieOpts)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
SessionStoreInterfaceTests(&input)
|
||||
if persistentFastForward != nil {
|
||||
PersistentSessionStoreInterfaceTests(&input)
|
||||
}
|
||||
})
|
||||
|
||||
Context("with an invalid cookie secret", func() {
|
||||
BeforeEach(func() {
|
||||
input.cookieOpts.Secret = "invalid"
|
||||
})
|
||||
|
||||
It("returns an error when initialising the session store", func() {
|
||||
ss, err := newSS(opts, input.cookieOpts)
|
||||
Expect(err).To(MatchError("error initialising cipher: crypto/aes: invalid key size 7"))
|
||||
Expect(ss).To(BeNil())
|
||||
})
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func CheckCookieOptions(in *testInput) {
|
||||
Context("the cookies returned", func() {
|
||||
var cookies []*http.Cookie
|
||||
BeforeEach(func() {
|
||||
cookies = in.response.Result().Cookies()
|
||||
})
|
||||
|
||||
It("have the correct name set", func() {
|
||||
if len(cookies) == 1 {
|
||||
Expect(cookies[0].Name).To(Equal(in.cookieOpts.Name))
|
||||
} else {
|
||||
for _, cookie := range cookies {
|
||||
Expect(cookie.Name).To(ContainSubstring(in.cookieOpts.Name))
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
It("have the correct path set", func() {
|
||||
for _, cookie := range cookies {
|
||||
Expect(cookie.Path).To(Equal(in.cookieOpts.Path))
|
||||
}
|
||||
})
|
||||
|
||||
It("have the correct domain set", func() {
|
||||
for _, cookie := range cookies {
|
||||
specifiedDomain := ""
|
||||
if len(in.cookieOpts.Domains) > 0 {
|
||||
specifiedDomain = in.cookieOpts.Domains[0]
|
||||
}
|
||||
Expect(cookie.Domain).To(Equal(specifiedDomain))
|
||||
}
|
||||
})
|
||||
|
||||
It("have the correct HTTPOnly set", func() {
|
||||
for _, cookie := range cookies {
|
||||
Expect(cookie.HttpOnly).To(Equal(in.cookieOpts.HTTPOnly))
|
||||
}
|
||||
})
|
||||
|
||||
It("have the correct secure set", func() {
|
||||
for _, cookie := range cookies {
|
||||
Expect(cookie.Secure).To(Equal(in.cookieOpts.Secure))
|
||||
}
|
||||
})
|
||||
|
||||
It("have the correct SameSite set", func() {
|
||||
for _, cookie := range cookies {
|
||||
Expect(cookie.SameSite).To(Equal(cookiesapi.ParseSameSite(in.cookieOpts.SameSite)))
|
||||
}
|
||||
})
|
||||
|
||||
It("have a signature timestamp matching session.CreatedAt", func() {
|
||||
for _, cookie := range cookies {
|
||||
if cookie.Value != "" {
|
||||
parts := strings.Split(cookie.Value, "|")
|
||||
Expect(parts).To(HaveLen(3))
|
||||
Expect(parts[1]).To(Equal(strconv.Itoa(int(in.session.CreatedAt.Unix()))))
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
})
|
||||
}
|
||||
|
||||
func PersistentSessionStoreInterfaceTests(in *testInput) {
|
||||
// Check that a stale cookie can't load an already cleared session
|
||||
Context("when Clear is called on a persistent store", func() {
|
||||
var resultCookies []*http.Cookie
|
||||
|
||||
BeforeEach(func() {
|
||||
req := httptest.NewRequest("GET", "http://example.com/", nil)
|
||||
saveResp := httptest.NewRecorder()
|
||||
err := in.ss().Save(saveResp, req, in.session)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
resultCookies = saveResp.Result().Cookies()
|
||||
for _, c := range resultCookies {
|
||||
in.request.AddCookie(c)
|
||||
}
|
||||
err = in.ss().Clear(in.response, in.request)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
Context("attempting to Load", func() {
|
||||
var loadedAfterClear *sessionsapi.SessionState
|
||||
var loadErr error
|
||||
|
||||
BeforeEach(func() {
|
||||
loadReq := httptest.NewRequest("GET", "http://example.com/", nil)
|
||||
for _, c := range resultCookies {
|
||||
loadReq.AddCookie(c)
|
||||
}
|
||||
|
||||
loadedAfterClear, loadErr = in.ss().Load(loadReq)
|
||||
})
|
||||
|
||||
It("returns an empty session", func() {
|
||||
Expect(loadedAfterClear).To(BeNil())
|
||||
})
|
||||
|
||||
It("returns an error", func() {
|
||||
Expect(loadErr).To(HaveOccurred())
|
||||
})
|
||||
})
|
||||
|
||||
CheckCookieOptions(in)
|
||||
})
|
||||
|
||||
// Test TTLs and cleanup of persistent session storage
|
||||
// For non-persistent we rely on the browser cookie lifecycle
|
||||
Context("when Load is called on a persistent store", func() {
|
||||
BeforeEach(func() {
|
||||
req := httptest.NewRequest("GET", "http://example.com/", nil)
|
||||
resp := httptest.NewRecorder()
|
||||
err := in.ss().Save(resp, req, in.session)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
for _, cookie := range resp.Result().Cookies() {
|
||||
in.request.AddCookie(cookie)
|
||||
}
|
||||
})
|
||||
|
||||
Context("after the refresh period, but before the cookie expire period", func() {
|
||||
BeforeEach(func() {
|
||||
Expect(in.persistentFastForward(in.cookieOpts.Refresh + time.Minute)).To(Succeed())
|
||||
})
|
||||
|
||||
LoadSessionTests(in)
|
||||
})
|
||||
|
||||
Context("after the cookie expire period", func() {
|
||||
var loadedSession *sessionsapi.SessionState
|
||||
var err error
|
||||
|
||||
BeforeEach(func() {
|
||||
Expect(in.persistentFastForward(in.cookieOpts.Expire + time.Minute)).To(Succeed())
|
||||
|
||||
loadedSession, err = in.ss().Load(in.request)
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
|
||||
It("returns an error loading the session", func() {
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
|
||||
It("returns an empty session", func() {
|
||||
Expect(loadedSession).To(BeNil())
|
||||
})
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func SessionStoreInterfaceTests(in *testInput) {
|
||||
Context("when Save is called", func() {
|
||||
Context("with no existing session", func() {
|
||||
BeforeEach(func() {
|
||||
err := in.ss().Save(in.response, in.request, in.session)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
It("sets a `set-cookie` header in the response", func() {
|
||||
Expect(in.response.Header().Get("set-cookie")).ToNot(BeEmpty())
|
||||
})
|
||||
|
||||
It("Ensures the session CreatedAt is not zero", func() {
|
||||
Expect(in.session.CreatedAt.IsZero()).To(BeFalse())
|
||||
})
|
||||
|
||||
CheckCookieOptions(in)
|
||||
})
|
||||
|
||||
Context("with a broken session", func() {
|
||||
BeforeEach(func() {
|
||||
By("Using a valid cookie with a different providers session encoding")
|
||||
broken := "BrokenSessionFromADifferentSessionImplementation"
|
||||
value := encryption.SignedValue(in.cookieOpts.Secret, in.cookieOpts.Name, []byte(broken), time.Now())
|
||||
cookie := cookiesapi.MakeCookieFromOptions(in.request, in.cookieOpts.Name, value, in.cookieOpts, in.cookieOpts.Expire, time.Now())
|
||||
in.request.AddCookie(cookie)
|
||||
|
||||
err := in.ss().Save(in.response, in.request, in.session)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
It("sets a `set-cookie` header in the response", func() {
|
||||
Expect(in.response.Header().Get("set-cookie")).ToNot(BeEmpty())
|
||||
})
|
||||
|
||||
It("Ensures the session CreatedAt is not zero", func() {
|
||||
Expect(in.session.CreatedAt.IsZero()).To(BeFalse())
|
||||
})
|
||||
|
||||
CheckCookieOptions(in)
|
||||
})
|
||||
|
||||
Context("with an expired saved session", func() {
|
||||
var err error
|
||||
BeforeEach(func() {
|
||||
By("saving a session")
|
||||
req := httptest.NewRequest("GET", "http://example.com/", nil)
|
||||
saveResp := httptest.NewRecorder()
|
||||
err = in.ss().Save(saveResp, req, in.session)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
By("and clearing the session")
|
||||
for _, c := range saveResp.Result().Cookies() {
|
||||
in.request.AddCookie(c)
|
||||
}
|
||||
clearResp := httptest.NewRecorder()
|
||||
err = in.ss().Clear(clearResp, in.request)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
By("then saving a request with the cleared session")
|
||||
err = in.ss().Save(in.response, in.request, in.session)
|
||||
})
|
||||
|
||||
It("no error should occur", func() {
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
Context("when Clear is called", func() {
|
||||
BeforeEach(func() {
|
||||
req := httptest.NewRequest("GET", "http://example.com/", nil)
|
||||
saveResp := httptest.NewRecorder()
|
||||
err := in.ss().Save(saveResp, req, in.session)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
for _, c := range saveResp.Result().Cookies() {
|
||||
in.request.AddCookie(c)
|
||||
}
|
||||
err = in.ss().Clear(in.response, in.request)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
It("sets a `set-cookie` header in the response", func() {
|
||||
Expect(in.response.Header().Get("Set-Cookie")).ToNot(BeEmpty())
|
||||
})
|
||||
|
||||
CheckCookieOptions(in)
|
||||
})
|
||||
|
||||
Context("when Load is called", func() {
|
||||
BeforeEach(func() {
|
||||
req := httptest.NewRequest("GET", "http://example.com/", nil)
|
||||
resp := httptest.NewRecorder()
|
||||
err := in.ss().Save(resp, req, in.session)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
for _, cookie := range resp.Result().Cookies() {
|
||||
in.request.AddCookie(cookie)
|
||||
}
|
||||
})
|
||||
|
||||
Context("before the refresh period", func() {
|
||||
LoadSessionTests(in)
|
||||
})
|
||||
|
||||
})
|
||||
}
|
||||
|
||||
func LoadSessionTests(in *testInput) {
|
||||
var loadedSession *sessionsapi.SessionState
|
||||
BeforeEach(func() {
|
||||
var err error
|
||||
loadedSession, err = in.ss().Load(in.request)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
It("loads a session equal to the original session", func() {
|
||||
// Can't compare time.Time using Equal() so remove ExpiresOn from sessions
|
||||
l := *loadedSession
|
||||
l.CreatedAt = nil
|
||||
l.ExpiresOn = nil
|
||||
s := *in.session
|
||||
s.CreatedAt = nil
|
||||
s.ExpiresOn = nil
|
||||
Expect(l).To(Equal(s))
|
||||
|
||||
// Compare time.Time separately
|
||||
Expect(loadedSession.CreatedAt.Equal(*in.session.CreatedAt)).To(BeTrue())
|
||||
Expect(loadedSession.ExpiresOn.Equal(*in.session.ExpiresOn)).To(BeTrue())
|
||||
|
||||
})
|
||||
}
|
Reference in New Issue
Block a user