mirror of
https://github.com/labstack/echo.git
synced 2025-07-05 00:58:47 +02:00
@ -1,12 +1,9 @@
|
|||||||
package middleware
|
package middleware
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/hmac"
|
"crypto/subtle"
|
||||||
"crypto/rand"
|
|
||||||
"crypto/sha1"
|
|
||||||
"encoding/hex"
|
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"math/rand"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
@ -17,8 +14,9 @@ import (
|
|||||||
type (
|
type (
|
||||||
// CSRFConfig defines the config for CSRF middleware.
|
// CSRFConfig defines the config for CSRF middleware.
|
||||||
CSRFConfig struct {
|
CSRFConfig struct {
|
||||||
// Key to create CSRF token.
|
// TokenLength is the length of the generated token.
|
||||||
Secret []byte `json:"secret"`
|
TokenLength uint8 `json:"token_length"`
|
||||||
|
// Optional. Default value 32.
|
||||||
|
|
||||||
// TokenLookup is a string in the form of "<source>:<key>" that is used
|
// TokenLookup is a string in the form of "<source>:<key>" that is used
|
||||||
// to extract token from the request.
|
// to extract token from the request.
|
||||||
@ -52,6 +50,10 @@ type (
|
|||||||
// Indicates if CSRF cookie is secure.
|
// Indicates if CSRF cookie is secure.
|
||||||
// Optional. Default value false.
|
// Optional. Default value false.
|
||||||
CookieSecure bool `json:"cookie_secure"`
|
CookieSecure bool `json:"cookie_secure"`
|
||||||
|
|
||||||
|
// Indicates if CSRF cookie is HTTP only.
|
||||||
|
// Optional. Default value false.
|
||||||
|
CookieHTTPOnly bool `json:"cookie_http_only"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// csrfTokenExtractor defines a function that takes `echo.Context` and returns
|
// csrfTokenExtractor defines a function that takes `echo.Context` and returns
|
||||||
@ -62,6 +64,7 @@ type (
|
|||||||
var (
|
var (
|
||||||
// DefaultCSRFConfig is the default CSRF middleware config.
|
// DefaultCSRFConfig is the default CSRF middleware config.
|
||||||
DefaultCSRFConfig = CSRFConfig{
|
DefaultCSRFConfig = CSRFConfig{
|
||||||
|
TokenLength: 32,
|
||||||
TokenLookup: "header:" + echo.HeaderXCSRFToken,
|
TokenLookup: "header:" + echo.HeaderXCSRFToken,
|
||||||
ContextKey: "csrf",
|
ContextKey: "csrf",
|
||||||
CookieName: "_csrf",
|
CookieName: "_csrf",
|
||||||
@ -71,9 +74,8 @@ var (
|
|||||||
|
|
||||||
// CSRF returns a Cross-Site Request Forgery (CSRF) middleware.
|
// CSRF returns a Cross-Site Request Forgery (CSRF) middleware.
|
||||||
// See: https://en.wikipedia.org/wiki/Cross-site_request_forgery
|
// See: https://en.wikipedia.org/wiki/Cross-site_request_forgery
|
||||||
func CSRF(secret []byte) echo.MiddlewareFunc {
|
func CSRF() echo.MiddlewareFunc {
|
||||||
c := DefaultCSRFConfig
|
c := DefaultCSRFConfig
|
||||||
c.Secret = secret
|
|
||||||
return CSRFWithConfig(c)
|
return CSRFWithConfig(c)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -81,8 +83,8 @@ func CSRF(secret []byte) echo.MiddlewareFunc {
|
|||||||
// See `CSRF()`.
|
// See `CSRF()`.
|
||||||
func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc {
|
func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc {
|
||||||
// Defaults
|
// Defaults
|
||||||
if config.Secret == nil {
|
if config.TokenLength == 0 {
|
||||||
panic("csrf secret must be provided")
|
config.TokenLength = DefaultCSRFConfig.TokenLength
|
||||||
}
|
}
|
||||||
if config.TokenLookup == "" {
|
if config.TokenLookup == "" {
|
||||||
config.TokenLookup = DefaultCSRFConfig.TokenLookup
|
config.TokenLookup = DefaultCSRFConfig.TokenLookup
|
||||||
@ -110,51 +112,51 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc {
|
|||||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||||
return func(c echo.Context) error {
|
return func(c echo.Context) error {
|
||||||
req := c.Request()
|
req := c.Request()
|
||||||
cookie, err := c.Cookie(config.CookieName)
|
k, err := c.Cookie(config.CookieName)
|
||||||
token := ""
|
token := ""
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Token expired, generate it
|
// Generate token
|
||||||
salt, err := generateSalt(8)
|
token = generateCSRFToken(config.TokenLength)
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
token = generateCSRFToken(config.Secret, salt)
|
|
||||||
cookie := new(echo.Cookie)
|
|
||||||
cookie.SetName(config.CookieName)
|
|
||||||
cookie.SetValue(token)
|
|
||||||
if config.CookiePath != "" {
|
|
||||||
cookie.SetPath(config.CookiePath)
|
|
||||||
}
|
|
||||||
if config.CookieDomain != "" {
|
|
||||||
cookie.SetDomain(config.CookieDomain)
|
|
||||||
}
|
|
||||||
cookie.SetExpires(time.Now().Add(time.Duration(config.CookieMaxAge) * time.Second))
|
|
||||||
cookie.SetSecure(config.CookieSecure)
|
|
||||||
cookie.SetHTTPOnly(true)
|
|
||||||
c.SetCookie(cookie)
|
|
||||||
} else {
|
} else {
|
||||||
// Reuse token
|
// Reuse token
|
||||||
token = cookie.Value()
|
token = k.Value()
|
||||||
}
|
}
|
||||||
|
|
||||||
c.Set(config.ContextKey, token)
|
|
||||||
|
|
||||||
switch req.Method() {
|
switch req.Method() {
|
||||||
case echo.GET, echo.HEAD, echo.OPTIONS, echo.TRACE:
|
case echo.GET, echo.HEAD, echo.OPTIONS, echo.TRACE:
|
||||||
default:
|
default:
|
||||||
|
// Validate token only for requests which are not defined as 'safe' by RFC7231
|
||||||
clientToken, err := extractor(c)
|
clientToken, err := extractor(c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
ok, err := validateCSRFToken(token, clientToken, config.Secret)
|
if !validateCSRFToken(token, clientToken) {
|
||||||
if err != nil {
|
return echo.NewHTTPError(http.StatusForbidden, "csrf token is invalid")
|
||||||
return err
|
|
||||||
}
|
|
||||||
if !ok {
|
|
||||||
return echo.NewHTTPError(http.StatusForbidden, "invalid csrf token")
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Set CSRF cookie
|
||||||
|
cookie := new(echo.Cookie)
|
||||||
|
cookie.SetName(config.CookieName)
|
||||||
|
cookie.SetValue(token)
|
||||||
|
if config.CookiePath != "" {
|
||||||
|
cookie.SetPath(config.CookiePath)
|
||||||
|
}
|
||||||
|
if config.CookieDomain != "" {
|
||||||
|
cookie.SetDomain(config.CookieDomain)
|
||||||
|
}
|
||||||
|
cookie.SetExpires(time.Now().Add(time.Duration(config.CookieMaxAge) * time.Second))
|
||||||
|
cookie.SetSecure(config.CookieSecure)
|
||||||
|
cookie.SetHTTPOnly(config.CookieHTTPOnly)
|
||||||
|
c.SetCookie(cookie)
|
||||||
|
|
||||||
|
// Store token in the context
|
||||||
|
c.Set(config.ContextKey, token)
|
||||||
|
|
||||||
|
// Protect clients from caching the response
|
||||||
|
c.Response().Header().Add(echo.HeaderVary, echo.HeaderCookie)
|
||||||
|
|
||||||
return next(c)
|
return next(c)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -192,29 +194,16 @@ func csrfTokenFromQuery(param string) csrfTokenExtractor {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func generateCSRFToken(secret, salt []byte) string {
|
func generateCSRFToken(n uint8) string {
|
||||||
h := hmac.New(sha1.New, secret)
|
// TODO: From utility library
|
||||||
h.Write(salt)
|
chars := "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
|
||||||
return fmt.Sprintf("%s:%s", hex.EncodeToString(h.Sum(nil)), hex.EncodeToString(salt))
|
b := make([]byte, n)
|
||||||
|
for i := range b {
|
||||||
|
b[i] = chars[rand.Int63()%int64(len(chars))]
|
||||||
|
}
|
||||||
|
return string(b)
|
||||||
}
|
}
|
||||||
|
|
||||||
func validateCSRFToken(serverToken, clientToken string, secret []byte) (bool, error) {
|
func validateCSRFToken(token, clientToken string) bool {
|
||||||
if serverToken != clientToken {
|
return subtle.ConstantTimeCompare([]byte(token), []byte(clientToken)) == 1
|
||||||
return false, nil
|
|
||||||
}
|
|
||||||
sep := strings.Index(clientToken, ":")
|
|
||||||
if sep < 0 {
|
|
||||||
return false, nil
|
|
||||||
}
|
|
||||||
salt, err := hex.DecodeString(clientToken[sep+1:])
|
|
||||||
if err != nil {
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
return clientToken == generateCSRFToken(secret, salt), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func generateSalt(len uint8) (salt []byte, err error) {
|
|
||||||
salt = make([]byte, len)
|
|
||||||
_, err = rand.Read(salt)
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
@ -17,17 +17,12 @@ func TestCSRF(t *testing.T) {
|
|||||||
rec := test.NewResponseRecorder()
|
rec := test.NewResponseRecorder()
|
||||||
c := e.NewContext(req, rec)
|
c := e.NewContext(req, rec)
|
||||||
csrf := CSRFWithConfig(CSRFConfig{
|
csrf := CSRFWithConfig(CSRFConfig{
|
||||||
Secret: []byte("secret"),
|
TokenLength: 16,
|
||||||
})
|
})
|
||||||
h := csrf(func(c echo.Context) error {
|
h := csrf(func(c echo.Context) error {
|
||||||
return c.String(http.StatusOK, "test")
|
return c.String(http.StatusOK, "test")
|
||||||
})
|
})
|
||||||
|
|
||||||
// No secret
|
|
||||||
assert.Panics(t, func() {
|
|
||||||
CSRF(nil)
|
|
||||||
})
|
|
||||||
|
|
||||||
// Generate CSRF token
|
// Generate CSRF token
|
||||||
h(c)
|
h(c)
|
||||||
assert.Contains(t, rec.Header().Get(echo.HeaderSetCookie), "_csrf")
|
assert.Contains(t, rec.Header().Get(echo.HeaderSetCookie), "_csrf")
|
||||||
@ -46,8 +41,7 @@ func TestCSRF(t *testing.T) {
|
|||||||
assert.Error(t, h(c))
|
assert.Error(t, h(c))
|
||||||
|
|
||||||
// Valid CSRF token
|
// Valid CSRF token
|
||||||
salt, _ := generateSalt(8)
|
token := generateCSRFToken(16)
|
||||||
token := generateCSRFToken([]byte("secret"), salt)
|
|
||||||
req.Header().Set(echo.HeaderCookie, "_csrf="+token)
|
req.Header().Set(echo.HeaderCookie, "_csrf="+token)
|
||||||
req.Header().Set(echo.HeaderXCSRFToken, token)
|
req.Header().Set(echo.HeaderXCSRFToken, token)
|
||||||
if assert.NoError(t, h(c)) {
|
if assert.NoError(t, h(c)) {
|
||||||
|
Reference in New Issue
Block a user