1
0
mirror of https://github.com/oauth2-proxy/oauth2-proxy.git synced 2025-01-24 05:26:55 +02:00
oauth2-proxy/pkg/cookies/csrf_test.go

191 lines
4.9 KiB
Go
Raw Normal View History

package cookies
import (
"fmt"
"net/http"
"net/http/httptest"
"net/url"
"time"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/encryption"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("CSRF Cookie Tests", func() {
var (
cookieOpts *options.Cookie
publicCSRF CSRF
privateCSRF *csrf
)
BeforeEach(func() {
cookieOpts = &options.Cookie{
Name: cookieName,
Secret: cookieSecret,
Domains: []string{cookieDomain},
Path: cookiePath,
Expire: time.Hour,
Secure: true,
HTTPOnly: true,
}
var err error
publicCSRF, err = NewCSRF(cookieOpts)
Expect(err).ToNot(HaveOccurred())
privateCSRF = publicCSRF.(*csrf)
})
Context("NewCSRF", func() {
It("makes unique nonces for OAuth and OIDC", func() {
Expect(privateCSRF.OAuthState).ToNot(BeEmpty())
Expect(privateCSRF.OIDCNonce).ToNot(BeEmpty())
Expect(privateCSRF.OAuthState).ToNot(Equal(privateCSRF.OIDCNonce))
})
It("makes unique nonces between multiple CSRFs", func() {
other, err := NewCSRF(cookieOpts)
Expect(err).ToNot(HaveOccurred())
Expect(privateCSRF.OAuthState).ToNot(Equal(other.(*csrf).OAuthState))
Expect(privateCSRF.OIDCNonce).ToNot(Equal(other.(*csrf).OIDCNonce))
})
})
Context("CheckOAuthState and CheckOIDCNonce", func() {
It("checks that hashed versions match", func() {
privateCSRF.OAuthState = []byte(csrfState)
privateCSRF.OIDCNonce = []byte(csrfNonce)
stateHashed := encryption.HashNonce([]byte(csrfState))
nonceHashed := encryption.HashNonce([]byte(csrfNonce))
Expect(publicCSRF.CheckOAuthState(stateHashed)).To(BeTrue())
Expect(publicCSRF.CheckOIDCNonce(nonceHashed)).To(BeTrue())
Expect(publicCSRF.CheckOAuthState(csrfNonce)).To(BeFalse())
Expect(publicCSRF.CheckOIDCNonce(csrfState)).To(BeFalse())
Expect(publicCSRF.CheckOAuthState(csrfState + csrfNonce)).To(BeFalse())
Expect(publicCSRF.CheckOIDCNonce(csrfNonce + csrfState)).To(BeFalse())
Expect(publicCSRF.CheckOAuthState("")).To(BeFalse())
Expect(publicCSRF.CheckOIDCNonce("")).To(BeFalse())
})
})
Context("SetSessionNonce", func() {
It("sets the session.Nonce", func() {
session := &sessions.SessionState{}
publicCSRF.SetSessionNonce(session)
Expect(session.Nonce).To(Equal(privateCSRF.OIDCNonce))
})
})
Context("encodeCookie and decodeCSRFCookie", func() {
It("encodes and decodes to the same nonces", func() {
privateCSRF.OAuthState = []byte(csrfState)
privateCSRF.OIDCNonce = []byte(csrfNonce)
encoded, err := privateCSRF.encodeCookie()
Expect(err).ToNot(HaveOccurred())
cookie := &http.Cookie{
Name: privateCSRF.cookieName(),
Value: encoded,
}
decoded, err := decodeCSRFCookie(cookie, cookieOpts)
Expect(err).ToNot(HaveOccurred())
Expect(decoded).ToNot(BeNil())
Expect(decoded.OAuthState).To(Equal([]byte(csrfState)))
Expect(decoded.OIDCNonce).To(Equal([]byte(csrfNonce)))
})
It("signs the encoded cookie value", func() {
encoded, err := privateCSRF.encodeCookie()
Expect(err).ToNot(HaveOccurred())
cookie := &http.Cookie{
Name: privateCSRF.cookieName(),
Value: encoded,
}
_, _, valid := encryption.Validate(cookie, cookieOpts.Secret, cookieOpts.Expire)
Expect(valid).To(BeTrue())
})
})
Context("Cookie Management", func() {
var req *http.Request
testNow := time.Unix(nowEpoch, 0)
BeforeEach(func() {
privateCSRF.time.Set(testNow)
req = &http.Request{
Method: http.MethodGet,
Proto: "HTTP/1.1",
Host: cookieDomain,
URL: &url.URL{
Scheme: "https",
Host: cookieDomain,
Path: cookiePath,
},
}
})
AfterEach(func() {
privateCSRF.time.Reset()
})
Context("SetCookie", func() {
It("adds the encoded CSRF cookie to a ResponseWriter", func() {
rw := httptest.NewRecorder()
_, err := publicCSRF.SetCookie(rw, req)
Expect(err).ToNot(HaveOccurred())
Expect(rw.Header().Get("Set-Cookie")).To(ContainSubstring(
fmt.Sprintf("%s=", privateCSRF.cookieName()),
))
Expect(rw.Header().Get("Set-Cookie")).To(ContainSubstring(
fmt.Sprintf(
"; Path=%s; Domain=%s; Expires=%s; HttpOnly; Secure",
cookiePath,
cookieDomain,
testCookieExpires(testNow.Add(cookieOpts.Expire)),
),
))
})
})
Context("ClearCookie", func() {
It("sets a cookie with an empty value in the past", func() {
rw := httptest.NewRecorder()
publicCSRF.ClearCookie(rw, req)
Expect(rw.Header().Get("Set-Cookie")).To(Equal(
fmt.Sprintf(
"%s=; Path=%s; Domain=%s; Expires=%s; HttpOnly; Secure",
privateCSRF.cookieName(),
cookiePath,
cookieDomain,
testCookieExpires(testNow.Add(time.Hour*-1)),
),
))
})
})
Context("cookieName", func() {
It("has the cookie options name as a base", func() {
Expect(privateCSRF.cookieName()).To(ContainSubstring(cookieName))
})
})
})
})