You've already forked oauth2-proxy
mirror of
https://github.com/oauth2-proxy/oauth2-proxy.git
synced 2025-08-08 22:46:33 +02:00
* feat: add feature support for cookie-secret-file --------- Signed-off-by: Jan Larwig <jan@larwig.com> Co-Authored-By: Sandy Chen <Yuxuan.Chen@morganstanley.com> Co-authored-by: Jan Larwig <jan@larwig.com>
311 lines
8.5 KiB
Go
311 lines
8.5 KiB
Go
package cookies
|
|
|
|
import (
|
|
"errors"
|
|
"fmt"
|
|
"net/http"
|
|
"slices"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options"
|
|
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
|
|
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/clock"
|
|
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/encryption"
|
|
"github.com/vmihailenco/msgpack/v5"
|
|
)
|
|
|
|
// CSRF manages various nonces stored in the CSRF cookie during the initial
|
|
// authentication flows.
|
|
type CSRF interface {
|
|
HashOAuthState() string
|
|
HashOIDCNonce() string
|
|
CheckOAuthState(string) bool
|
|
CheckOIDCNonce(string) bool
|
|
GetCodeVerifier() string
|
|
|
|
SetSessionNonce(s *sessions.SessionState)
|
|
|
|
SetCookie(http.ResponseWriter, *http.Request) (*http.Cookie, error)
|
|
ClearCookie(http.ResponseWriter, *http.Request)
|
|
}
|
|
|
|
type csrf struct {
|
|
// OAuthState holds the OAuth2 state parameter's nonce component set in the
|
|
// initial authentication request and mirrored back in the callback
|
|
// redirect from the IdP for CSRF protection.
|
|
OAuthState []byte `msgpack:"s,omitempty"`
|
|
|
|
// OIDCNonce holds the OIDC nonce parameter used in the initial authentication
|
|
// and then set in all subsequent OIDC ID Tokens as the nonce claim. This
|
|
// is used to mitigate replay attacks.
|
|
OIDCNonce []byte `msgpack:"n,omitempty"`
|
|
|
|
// CodeVerifier holds the unobfuscated PKCE code verification string
|
|
// which is used to compare the code challenge when exchanging the
|
|
// authentication code.
|
|
CodeVerifier string `msgpack:"cv,omitempty"`
|
|
|
|
cookieOpts *options.Cookie
|
|
time clock.Clock
|
|
}
|
|
|
|
// csrtStateTrim will indicate the length of the state trimmed for the name of the csrf cookie
|
|
const csrfStateLength int = 9
|
|
|
|
// NewCSRF creates a CSRF with random nonces
|
|
func NewCSRF(opts *options.Cookie, codeVerifier string) (CSRF, error) {
|
|
state, err := encryption.Nonce(32)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
nonce, err := encryption.Nonce(32)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &csrf{
|
|
OAuthState: state,
|
|
OIDCNonce: nonce,
|
|
CodeVerifier: codeVerifier,
|
|
|
|
cookieOpts: opts,
|
|
}, nil
|
|
}
|
|
|
|
// LoadCSRFCookie loads a CSRF object from a request's CSRF cookie
|
|
func LoadCSRFCookie(req *http.Request, cookieName string, opts *options.Cookie) (CSRF, error) {
|
|
cookies := req.Cookies()
|
|
for _, cookie := range cookies {
|
|
if cookie.Name != cookieName {
|
|
continue
|
|
}
|
|
|
|
csrf, err := decodeCSRFCookie(cookie, opts)
|
|
if err != nil {
|
|
continue
|
|
}
|
|
|
|
return csrf, nil
|
|
}
|
|
|
|
return nil, fmt.Errorf("CSRF cookie with name '%v' was not found", cookieName)
|
|
}
|
|
|
|
// GenerateCookieName in case cookie options state that CSRF cookie has fixed name then set fixed name, otherwise
|
|
// build name based on the state
|
|
func GenerateCookieName(opts *options.Cookie, state string) string {
|
|
stateSubstring := ""
|
|
if opts.CSRFPerRequest {
|
|
// csrfCookieName will include a substring of the state to enable multiple csrf cookies
|
|
// in case of parallel requests
|
|
stateSubstring = ExtractStateSubstring(state)
|
|
}
|
|
return csrfCookieName(opts, stateSubstring)
|
|
}
|
|
|
|
func (c *csrf) GetCodeVerifier() string {
|
|
return c.CodeVerifier
|
|
}
|
|
|
|
// HashOAuthState returns the hash of the OAuth state nonce
|
|
func (c *csrf) HashOAuthState() string {
|
|
return encryption.HashNonce(c.OAuthState)
|
|
}
|
|
|
|
// HashOIDCNonce returns the hash of the OIDC nonce
|
|
func (c *csrf) HashOIDCNonce() string {
|
|
return encryption.HashNonce(c.OIDCNonce)
|
|
}
|
|
|
|
// CheckOAuthState compares the OAuth state nonce against a potential
|
|
// hash of it
|
|
func (c *csrf) CheckOAuthState(hashed string) bool {
|
|
return encryption.CheckNonce(c.OAuthState, hashed)
|
|
}
|
|
|
|
// CheckOIDCNonce compares the OIDC nonce against a potential hash of it
|
|
func (c *csrf) CheckOIDCNonce(hashed string) bool {
|
|
return encryption.CheckNonce(c.OIDCNonce, hashed)
|
|
}
|
|
|
|
// SetSessionNonce sets the OIDCNonce on a SessionState
|
|
func (c *csrf) SetSessionNonce(s *sessions.SessionState) {
|
|
s.Nonce = c.OIDCNonce
|
|
}
|
|
|
|
// SetCookie encodes the CSRF to a signed cookie and sets it on the ResponseWriter
|
|
func (c *csrf) SetCookie(rw http.ResponseWriter, req *http.Request) (*http.Cookie, error) {
|
|
encoded, err := c.encodeCookie()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
cookie := MakeCookieFromOptions(
|
|
req,
|
|
c.cookieName(),
|
|
encoded,
|
|
c.cookieOpts,
|
|
c.cookieOpts.CSRFExpire,
|
|
)
|
|
http.SetCookie(rw, cookie)
|
|
|
|
return cookie, nil
|
|
}
|
|
|
|
// ClearExtraCsrfCookies limits the amount of existing CSRF cookies by deleting
|
|
// an excess of cookies controlled through the option CSRFPerRequestLimit
|
|
func ClearExtraCsrfCookies(opts *options.Cookie, rw http.ResponseWriter, req *http.Request) {
|
|
if !opts.CSRFPerRequest || opts.CSRFPerRequestLimit <= 0 {
|
|
return
|
|
}
|
|
|
|
cookies := req.Cookies()
|
|
existingCsrfCookies := []*http.Cookie{}
|
|
startsWith := fmt.Sprintf("%v_", opts.Name)
|
|
|
|
// determine how many csrf cookies we have
|
|
for _, cookie := range cookies {
|
|
if strings.HasPrefix(cookie.Name, startsWith) && strings.HasSuffix(cookie.Name, "_csrf") {
|
|
existingCsrfCookies = append(existingCsrfCookies, cookie)
|
|
}
|
|
}
|
|
|
|
// short circuit return
|
|
if len(existingCsrfCookies) <= opts.CSRFPerRequestLimit {
|
|
return
|
|
}
|
|
|
|
decodedCookies := []*csrf{}
|
|
for _, cookie := range existingCsrfCookies {
|
|
decodedCookie, err := decodeCSRFCookie(cookie, opts)
|
|
if err != nil {
|
|
continue
|
|
}
|
|
decodedCookies = append(decodedCookies, decodedCookie)
|
|
}
|
|
|
|
// delete the X oldest cookies
|
|
slices.SortStableFunc(decodedCookies, func(a, b *csrf) int {
|
|
return a.time.Now().Compare(b.time.Now())
|
|
})
|
|
|
|
for i := 0; i < len(decodedCookies)-opts.CSRFPerRequestLimit; i++ {
|
|
decodedCookies[i].ClearCookie(rw, req)
|
|
}
|
|
}
|
|
|
|
// ClearCookie removes the CSRF cookie
|
|
func (c *csrf) ClearCookie(rw http.ResponseWriter, req *http.Request) {
|
|
http.SetCookie(rw, MakeCookieFromOptions(
|
|
req,
|
|
c.cookieName(),
|
|
"",
|
|
c.cookieOpts,
|
|
time.Hour*-1,
|
|
))
|
|
}
|
|
|
|
// encodeCookie MessagePack encodes and encrypts the CSRF and then creates a
|
|
// signed cookie value
|
|
func (c *csrf) encodeCookie() (string, error) {
|
|
packed, err := msgpack.Marshal(c)
|
|
if err != nil {
|
|
return "", fmt.Errorf("error marshalling CSRF to msgpack: %v", err)
|
|
}
|
|
|
|
encrypted, err := encrypt(packed, c.cookieOpts)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
secret, err := c.cookieOpts.GetSecret()
|
|
if err != nil {
|
|
return "", fmt.Errorf("error getting cookie secret: %v", err)
|
|
}
|
|
return encryption.SignedValue(secret, c.cookieName(), encrypted, c.time.Now())
|
|
}
|
|
|
|
// decodeCSRFCookie validates the signature then decrypts and decodes a CSRF
|
|
// cookie into a CSRF struct
|
|
func decodeCSRFCookie(cookie *http.Cookie, opts *options.Cookie) (*csrf, error) {
|
|
secret, err := opts.GetSecret()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("error getting cookie secret: %v", err)
|
|
}
|
|
|
|
val, t, ok := encryption.Validate(cookie, secret, opts.Expire)
|
|
if !ok {
|
|
return nil, errors.New("CSRF cookie failed validation")
|
|
}
|
|
|
|
decrypted, err := decrypt(val, opts)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return unmarshalCSRF(decrypted, opts, t)
|
|
}
|
|
|
|
// unmarshalCSRF unmarshals decrypted data into a CSRF struct
|
|
func unmarshalCSRF(decrypted []byte, opts *options.Cookie, csrfTime time.Time) (*csrf, error) {
|
|
clock := clock.Clock{}
|
|
clock.Set(csrfTime)
|
|
|
|
csrf := &csrf{cookieOpts: opts, time: clock}
|
|
if err := msgpack.Unmarshal(decrypted, csrf); err != nil {
|
|
return nil, fmt.Errorf("error unmarshalling data to CSRF: %v", err)
|
|
}
|
|
return csrf, nil
|
|
}
|
|
|
|
// cookieName returns the CSRF cookie's name
|
|
func (c *csrf) cookieName() string {
|
|
stateSubstring := ""
|
|
if c.cookieOpts.CSRFPerRequest {
|
|
stateSubstring = encryption.HashNonce(c.OAuthState)[0 : csrfStateLength-1]
|
|
}
|
|
return csrfCookieName(c.cookieOpts, stateSubstring)
|
|
}
|
|
|
|
func csrfCookieName(opts *options.Cookie, stateSubstring string) string {
|
|
if stateSubstring == "" {
|
|
return fmt.Sprintf("%v_csrf", opts.Name)
|
|
}
|
|
return fmt.Sprintf("%v_%v_csrf", opts.Name, stateSubstring)
|
|
}
|
|
|
|
// ExtractStateSubstring extract the initial state characters, to add it to the CSRF cookie name
|
|
func ExtractStateSubstring(state string) string {
|
|
lastChar := csrfStateLength - 1
|
|
stateSubstring := ""
|
|
if lastChar <= len(state) {
|
|
stateSubstring = state[0:lastChar]
|
|
}
|
|
return stateSubstring
|
|
}
|
|
|
|
func encrypt(data []byte, opts *options.Cookie) ([]byte, error) {
|
|
cipher, err := makeCipher(opts)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return cipher.Encrypt(data)
|
|
}
|
|
|
|
func decrypt(data []byte, opts *options.Cookie) ([]byte, error) {
|
|
cipher, err := makeCipher(opts)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return cipher.Decrypt(data)
|
|
}
|
|
|
|
func makeCipher(opts *options.Cookie) (encryption.Cipher, error) {
|
|
secret, err := opts.GetSecret()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("error getting cookie secret: %v", err)
|
|
}
|
|
return encryption.NewCFBCipher(encryption.SecretBytes(secret))
|
|
}
|