1
0
mirror of https://github.com/oauth2-proxy/oauth2-proxy.git synced 2025-03-11 14:59:46 +02:00
Nuno Miguel Micaelo Borges a1ff878fdc
Add flags to define CSRF cookie expiration time and to allow CSRF cookies per request ()
* Add start of state to CSRF cookie name

* Update CHANGELOG.md

* Update CHANGELOG.md

* Support optional flags

* Update CHANGELOG.md

* Update CHANGELOG.md

* Update CHANGELOG.md

* Update overview.md

Add new CSRF flags

* Update overview.md

Describe new CSRF flags

Co-authored-by: Nuno Borges <Nuno.Borges@ctw.bmwgroup.com>
2022-08-31 23:27:56 +01:00

250 lines
6.6 KiB
Go

package cookies
import (
"errors"
"fmt"
"net/http"
"time"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/clock"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/encryption"
"github.com/vmihailenco/msgpack/v4"
)
// CSRF manages various nonces stored in the CSRF cookie during the initial
// authentication flows.
type CSRF interface {
HashOAuthState() string
HashOIDCNonce() string
CheckOAuthState(string) bool
CheckOIDCNonce(string) bool
GetCodeVerifier() string
SetSessionNonce(s *sessions.SessionState)
SetCookie(http.ResponseWriter, *http.Request) (*http.Cookie, error)
ClearCookie(http.ResponseWriter, *http.Request)
}
type csrf struct {
// OAuthState holds the OAuth2 state parameter's nonce component set in the
// initial authentication request and mirrored back in the callback
// redirect from the IdP for CSRF protection.
OAuthState []byte `msgpack:"s,omitempty"`
// OIDCNonce holds the OIDC nonce parameter used in the initial authentication
// and then set in all subsequent OIDC ID Tokens as the nonce claim. This
// is used to mitigate replay attacks.
OIDCNonce []byte `msgpack:"n,omitempty"`
// CodeVerifier holds the unobfuscated PKCE code verification string
// which is used to compare the code challenge when exchanging the
// authentication code.
CodeVerifier string `msgpack:"cv,omitempty"`
cookieOpts *options.Cookie
time clock.Clock
}
// csrtStateTrim will indicate the length of the state trimmed for the name of the csrf cookie
const csrfStateLength int = 9
// NewCSRF creates a CSRF with random nonces
func NewCSRF(opts *options.Cookie, codeVerifier string) (CSRF, error) {
state, err := encryption.Nonce(32)
if err != nil {
return nil, err
}
nonce, err := encryption.Nonce(32)
if err != nil {
return nil, err
}
return &csrf{
OAuthState: state,
OIDCNonce: nonce,
CodeVerifier: codeVerifier,
cookieOpts: opts,
}, nil
}
// LoadCSRFCookie loads a CSRF object from a request's CSRF cookie
func LoadCSRFCookie(req *http.Request, opts *options.Cookie) (CSRF, error) {
cookieName := GenerateCookieName(req, opts)
cookie, err := req.Cookie(cookieName)
if err != nil {
return nil, err
}
return decodeCSRFCookie(cookie, opts)
}
// GenerateCookieName in case cookie options state that CSRF cookie has fixed name then set fixed name, otherwise
// build name based on the state
func GenerateCookieName(req *http.Request, opts *options.Cookie) string {
stateSubstring := ""
if opts.CSRFPerRequest {
// csrfCookieName will include a substring of the state to enable multiple csrf cookies
// in case of parallel requests
stateSubstring = ExtractStateSubstring(req)
}
return csrfCookieName(opts, stateSubstring)
}
func (c *csrf) GetCodeVerifier() string {
return c.CodeVerifier
}
// HashOAuthState returns the hash of the OAuth state nonce
func (c *csrf) HashOAuthState() string {
return encryption.HashNonce(c.OAuthState)
}
// HashOIDCNonce returns the hash of the OIDC nonce
func (c *csrf) HashOIDCNonce() string {
return encryption.HashNonce(c.OIDCNonce)
}
// CheckOAuthState compares the OAuth state nonce against a potential
// hash of it
func (c *csrf) CheckOAuthState(hashed string) bool {
return encryption.CheckNonce(c.OAuthState, hashed)
}
// CheckOIDCNonce compares the OIDC nonce against a potential hash of it
func (c *csrf) CheckOIDCNonce(hashed string) bool {
return encryption.CheckNonce(c.OIDCNonce, hashed)
}
// SetSessionNonce sets the OIDCNonce on a SessionState
func (c *csrf) SetSessionNonce(s *sessions.SessionState) {
s.Nonce = c.OIDCNonce
}
// SetCookie encodes the CSRF to a signed cookie and sets it on the ResponseWriter
func (c *csrf) SetCookie(rw http.ResponseWriter, req *http.Request) (*http.Cookie, error) {
encoded, err := c.encodeCookie()
if err != nil {
return nil, err
}
cookie := MakeCookieFromOptions(
req,
c.cookieName(),
encoded,
c.cookieOpts,
c.cookieOpts.CSRFExpire,
c.time.Now(),
)
http.SetCookie(rw, cookie)
return cookie, nil
}
// ClearCookie removes the CSRF cookie
func (c *csrf) ClearCookie(rw http.ResponseWriter, req *http.Request) {
http.SetCookie(rw, MakeCookieFromOptions(
req,
c.cookieName(),
"",
c.cookieOpts,
time.Hour*-1,
c.time.Now(),
))
}
// encodeCookie MessagePack encodes and encrypts the CSRF and then creates a
// signed cookie value
func (c *csrf) encodeCookie() (string, error) {
packed, err := msgpack.Marshal(c)
if err != nil {
return "", fmt.Errorf("error marshalling CSRF to msgpack: %v", err)
}
encrypted, err := encrypt(packed, c.cookieOpts)
if err != nil {
return "", err
}
return encryption.SignedValue(c.cookieOpts.Secret, c.cookieName(), encrypted, c.time.Now())
}
// decodeCSRFCookie validates the signature then decrypts and decodes a CSRF
// cookie into a CSRF struct
func decodeCSRFCookie(cookie *http.Cookie, opts *options.Cookie) (*csrf, error) {
val, _, ok := encryption.Validate(cookie, opts.Secret, opts.Expire)
if !ok {
return nil, errors.New("CSRF cookie failed validation")
}
decrypted, err := decrypt(val, opts)
if err != nil {
return nil, err
}
// Valid cookie, Unmarshal the CSRF
csrf := &csrf{cookieOpts: opts}
err = msgpack.Unmarshal(decrypted, csrf)
if err != nil {
return nil, fmt.Errorf("error unmarshalling data to CSRF: %v", err)
}
return csrf, nil
}
// cookieName returns the CSRF cookie's name
func (c *csrf) cookieName() string {
stateSubstring := ""
if c.cookieOpts.CSRFPerRequest {
stateSubstring = encryption.HashNonce(c.OAuthState)[0 : csrfStateLength-1]
}
return csrfCookieName(c.cookieOpts, stateSubstring)
}
func csrfCookieName(opts *options.Cookie, stateSubstring string) string {
if stateSubstring == "" {
return fmt.Sprintf("%v_csrf", opts.Name)
}
return fmt.Sprintf("%v_csrf_%v", opts.Name, stateSubstring)
}
// ExtractStateSubstring extract the initial state characters, to add it to the CSRF cookie name
func ExtractStateSubstring(req *http.Request) string {
lastChar := csrfStateLength - 1
stateSubstring := ""
state := req.URL.Query()["state"]
if state[0] != "" {
state := state[0]
if lastChar <= len(state) {
stateSubstring = state[0:lastChar]
}
}
return stateSubstring
}
func encrypt(data []byte, opts *options.Cookie) ([]byte, error) {
cipher, err := makeCipher(opts)
if err != nil {
return nil, err
}
return cipher.Encrypt(data)
}
func decrypt(data []byte, opts *options.Cookie) ([]byte, error) {
cipher, err := makeCipher(opts)
if err != nil {
return nil, err
}
return cipher.Decrypt(data)
}
func makeCipher(opts *options.Cookie) (encryption.Cipher, error) {
return encryption.NewCFBCipher(encryption.SecretBytes(opts.Secret))
}