diff --git a/pkg/cookies/csrf.go b/pkg/cookies/csrf.go index 0af74173..7ad0c2c9 100644 --- a/pkg/cookies/csrf.go +++ b/pkg/cookies/csrf.go @@ -1,12 +1,10 @@ package cookies import ( - "errors" "fmt" "net/http" "time" - "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/clock" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/encryption" @@ -38,12 +36,13 @@ type csrf struct { // is used to mitigate replay attacks. OIDCNonce []byte `msgpack:"n,omitempty"` - cookieOpts *options.Cookie - time clock.Clock + builder Builder + encryptionSecret string + time clock.Clock } // NewCSRF creates a CSRF with random nonces -func NewCSRF(opts *options.Cookie) (CSRF, error) { +func NewCSRF(builder Builder, secret string) (CSRF, error) { state, err := encryption.Nonce() if err != nil { return nil, err @@ -57,18 +56,19 @@ func NewCSRF(opts *options.Cookie) (CSRF, error) { OAuthState: state, OIDCNonce: nonce, - cookieOpts: opts, + builder: builder, + encryptionSecret: secret, }, nil } // LoadCSRFCookie loads a CSRF object from a request's CSRF cookie -func LoadCSRFCookie(req *http.Request, opts *options.Cookie) (CSRF, error) { - cookie, err := req.Cookie(csrfCookieName(opts)) +func LoadCSRFCookie(req *http.Request, builder Builder, secret string) (CSRF, error) { + cookieValue, err := builder.ValidateRequest(req) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to validate CSRF cookie value: %v", err) } - return decodeCSRFCookie(cookie, opts) + return decodeCSRFCookie(cookieValue, builder, secret) } // HashOAuthState returns the hash of the OAuth state nonce @@ -104,29 +104,27 @@ func (c *csrf) SetCookie(rw http.ResponseWriter, req *http.Request) (*http.Cooki return nil, err } - cookie := MakeCookieFromOptions( - req, - c.cookieName(), - encoded, - c.cookieOpts, - c.cookieOpts.Expire, - c.time.Now(), - ) + cookie, err := c.builder. + WithStart(c.time.Now()). + MakeCookie(req, encoded) + if err != nil { + return nil, err + } http.SetCookie(rw, cookie) return cookie, nil } // ClearCookie removes the CSRF cookie -func (c *csrf) ClearCookie(rw http.ResponseWriter, req *http.Request) { - http.SetCookie(rw, MakeCookieFromOptions( - req, - c.cookieName(), - "", - c.cookieOpts, - time.Hour*-1, - c.time.Now(), - )) +func (c *csrf) ClearCookie(rw http.ResponseWriter, req *http.Request) error { + cookie, err := c.builder. + WithExpiration(time.Hour*-1). + WithStart(c.time.Now()). + MakeCookie(req, "") + if err != nil { + return fmt.Errorf("could not create cookie: %v", err) + } + http.SetCookie(rw, cookie) } // encodeCookie MessagePack encodes and encrypts the CSRF and then creates a @@ -142,58 +140,42 @@ func (c *csrf) encodeCookie() (string, error) { return "", err } - return encryption.SignedValue(c.cookieOpts.Secret, c.cookieName(), encrypted, c.time.Now()) + return string(encrypted), nil } // decodeCSRFCookie validates the signature then decrypts and decodes a CSRF // cookie into a CSRF struct -func decodeCSRFCookie(cookie *http.Cookie, opts *options.Cookie) (*csrf, error) { - val, _, ok := encryption.Validate(cookie, opts.Secret, opts.Expire) - if !ok { - return nil, errors.New("CSRF cookie failed validation") - } - - decrypted, err := decrypt(val, opts) +func decodeCSRFCookie(cookieValue string, builder Builder, secret string) (*csrf, error) { + decrypted, err := decrypt([]byte(cookieValue), secret) if err != nil { return nil, err } // Valid cookie, Unmarshal the CSRF - csrf := &csrf{cookieOpts: opts} - err = msgpack.Unmarshal(decrypted, csrf) - if err != nil { + csrf := &csrf{builder: builder, encryptionSecret: secret} + if err := msgpack.Unmarshal(decrypted, csrf); err != nil { return nil, fmt.Errorf("error unmarshalling data to CSRF: %v", err) } return csrf, nil } -// cookieName returns the CSRF cookie's name derived from the base -// session cookie name -func (c *csrf) cookieName() string { - return csrfCookieName(c.cookieOpts) -} - -func csrfCookieName(opts *options.Cookie) string { - return fmt.Sprintf("%v_csrf", opts.Name) -} - -func encrypt(data []byte, opts *options.Cookie) ([]byte, error) { - cipher, err := makeCipher(opts) +func encrypt(data []byte, secret string) ([]byte, error) { + cipher, err := makeCipher(secret) if err != nil { return nil, err } return cipher.Encrypt(data) } -func decrypt(data []byte, opts *options.Cookie) ([]byte, error) { - cipher, err := makeCipher(opts) +func decrypt(data []byte, secret string) ([]byte, error) { + cipher, err := makeCipher(secret) if err != nil { return nil, err } return cipher.Decrypt(data) } -func makeCipher(opts *options.Cookie) (encryption.Cipher, error) { - return encryption.NewCFBCipher(encryption.SecretBytes(opts.Secret)) +func makeCipher(secret string) (encryption.Cipher, error) { + return encryption.NewCFBCipher(encryption.SecretBytes(secret)) }