1
0
mirror of https://github.com/labstack/echo.git synced 2024-12-22 20:06:21 +02:00
echo/middleware/csrf.go
toimtoimtoim 6ef5f77bf2 WIP: logger examples
WIP: make default logger implemented custom writer for jsonlike logs
WIP: improve examples
WIP: defaultErrorHandler use errors.As to unwrap errors. Update readme
WIP: default logger logs json, restore e.Start method
WIP: clean router.Match a bit
WIP: func types/fields have echo.Context has first element
WIP: remove yaml tags as functions etc can not be serialized anyway
WIP: change BindPathParams,BindQueryParams,BindHeaders from methods to functions and reverse arguments to be like DefaultBinder.Bind is
WIP: improved comments, logger now extracts status from error
WIP: go mod tidy
WIP: rebase with 4.5.0
WIP:
* removed todos.
* removed StartAutoTLS and StartH2CServer methods from `StartConfig`
* KeyAuth middleware errorhandler can swallow the error and resume next middleware
WIP: add RouterConfig.UseEscapedPathForMatching to use escaped path for matching request against routes
WIP: FIXMEs
WIP: upgrade golang-jwt/jwt to `v4`
WIP: refactor http methods to return RouteInfo
WIP: refactor static not creating multiple routes
WIP: refactor route and middleware adding functions not to return error directly
WIP: Use 401 for problematic/missing headers for key auth and JWT middleware (#1552, #1402).
> In summary, a 401 Unauthorized response should be used for missing or bad authentication
WIP: replace `HTTPError.SetInternal` with `HTTPError.WithInternal` so we could not mutate global error variables
WIP: add RouteInfo and RouteMatchType into Context what we could know from in middleware what route was matched and/or type of that match (200/404/405)
WIP: make notFoundHandler and methodNotAllowedHandler private. encourage that all errors be handled in Echo.HTTPErrorHandler
WIP: server cleanup ideas
WIP: routable.ForGroup
WIP: note about logger middleware
WIP: bind should not default values on second try. use crypto rand for better randomness
WIP: router add route as interface and returns info as interface
WIP: improve flaky test (remains still flaky)
WIP: add notes about bind default values
WIP: every route can have their own path params names
WIP: routerCreator and different tests
WIP: different things
WIP: remove route implementation
WIP: support custom method types
WIP: extractor tests
WIP: v5.0.x proposal
over v4.4.0
2021-10-02 18:36:42 +03:00

226 lines
6.2 KiB
Go

package middleware
import (
"crypto/subtle"
"errors"
"net/http"
"strings"
"time"
"github.com/labstack/echo/v4"
)
// CSRFConfig defines the config for CSRF middleware.
type CSRFConfig struct {
// Skipper defines a function to skip middleware.
Skipper Skipper
// TokenLength is the length of the generated token.
TokenLength uint8
// Optional. Default value 32.
// Generator defines a function to generate token.
// Optional. Defaults tp randomString(TokenLength).
Generator func() string
// TokenLookup is a string in the form of "<source>:<key>" that is used
// to extract token from the request.
// Optional. Default value "header:X-CSRF-Token".
// Possible values:
// - "header:<name>"
// - "form:<name>"
// - "query:<name>"
TokenLookup string
// Context key to store generated CSRF token into context.
// Optional. Default value "csrf".
ContextKey string
// Name of the CSRF cookie. This cookie will store CSRF token.
// Optional. Default value "csrf".
CookieName string
// Domain of the CSRF cookie.
// Optional. Default value none.
CookieDomain string
// Path of the CSRF cookie.
// Optional. Default value none.
CookiePath string
// Max age (in seconds) of the CSRF cookie.
// Optional. Default value 86400 (24hr).
CookieMaxAge int
// Indicates if CSRF cookie is secure.
// Optional. Default value false.
CookieSecure bool
// Indicates if CSRF cookie is HTTP only.
// Optional. Default value false.
CookieHTTPOnly bool
// Indicates SameSite mode of the CSRF cookie.
// Optional. Default value SameSiteDefaultMode.
CookieSameSite http.SameSite
}
// csrfTokenExtractor defines a function that takes `echo.Context` and returns either a token or an error.
type csrfTokenExtractor func(echo.Context) (string, error)
// DefaultCSRFConfig is the default CSRF middleware config.
var DefaultCSRFConfig = CSRFConfig{
Skipper: DefaultSkipper,
TokenLength: 32,
TokenLookup: "header:" + echo.HeaderXCSRFToken,
ContextKey: "csrf",
CookieName: "_csrf",
CookieMaxAge: 86400,
CookieSameSite: http.SameSiteDefaultMode,
}
// CSRF returns a Cross-Site Request Forgery (CSRF) middleware.
// See: https://en.wikipedia.org/wiki/Cross-site_request_forgery
func CSRF() echo.MiddlewareFunc {
return CSRFWithConfig(DefaultCSRFConfig)
}
// CSRFWithConfig returns a CSRF middleware with config or panics on invalid configuration.
func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc {
return toMiddlewareOrPanic(config)
}
// ToMiddleware converts CSRFConfig to middleware or returns an error for invalid configuration
func (config CSRFConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
// Defaults
if config.Skipper == nil {
config.Skipper = DefaultCSRFConfig.Skipper
}
if config.TokenLength == 0 {
config.TokenLength = DefaultCSRFConfig.TokenLength
}
if config.Generator == nil {
config.Generator = createRandomStringGenerator(config.TokenLength)
}
if config.TokenLookup == "" {
config.TokenLookup = DefaultCSRFConfig.TokenLookup
}
if config.ContextKey == "" {
config.ContextKey = DefaultCSRFConfig.ContextKey
}
if config.CookieName == "" {
config.CookieName = DefaultCSRFConfig.CookieName
}
if config.CookieMaxAge == 0 {
config.CookieMaxAge = DefaultCSRFConfig.CookieMaxAge
}
if config.CookieSameSite == http.SameSiteNoneMode {
config.CookieSecure = true
}
// Initialize
parts := strings.Split(config.TokenLookup, ":")
extractor := csrfTokenFromHeader(parts[1])
switch parts[0] {
case "form":
extractor = csrfTokenFromForm(parts[1])
case "query":
extractor = csrfTokenFromQuery(parts[1])
}
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
if config.Skipper(c) {
return next(c)
}
req := c.Request()
k, err := c.Cookie(config.CookieName)
token := ""
// Generate token
if err != nil {
token = config.Generator()
} else {
// Reuse token
token = k.Value
}
switch req.Method {
case http.MethodGet, http.MethodHead, http.MethodOptions, http.MethodTrace:
default:
// Validate token only for requests which are not defined as 'safe' by RFC7231
clientToken, err := extractor(c)
if err != nil {
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
}
if !validateCSRFToken(token, clientToken) {
return echo.NewHTTPError(http.StatusForbidden, "invalid csrf token")
}
}
// Set CSRF cookie
cookie := new(http.Cookie)
cookie.Name = config.CookieName
cookie.Value = token
if config.CookiePath != "" {
cookie.Path = config.CookiePath
}
if config.CookieDomain != "" {
cookie.Domain = config.CookieDomain
}
if config.CookieSameSite != http.SameSiteDefaultMode {
cookie.SameSite = config.CookieSameSite
}
cookie.Expires = time.Now().Add(time.Duration(config.CookieMaxAge) * time.Second)
cookie.Secure = config.CookieSecure
cookie.HttpOnly = config.CookieHTTPOnly
c.SetCookie(cookie)
// Store token in the context
c.Set(config.ContextKey, token)
// Protect clients from caching the response
c.Response().Header().Add(echo.HeaderVary, echo.HeaderCookie)
return next(c)
}
}, nil
}
// csrfTokenFromForm returns a `csrfTokenExtractor` that extracts token from the
// provided request header.
func csrfTokenFromHeader(header string) csrfTokenExtractor {
return func(c echo.Context) (string, error) {
return c.Request().Header.Get(header), nil
}
}
// csrfTokenFromForm returns a `csrfTokenExtractor` that extracts token from the
// provided form parameter.
func csrfTokenFromForm(param string) csrfTokenExtractor {
return func(c echo.Context) (string, error) {
token := c.FormValue(param)
if token == "" {
return "", errors.New("missing csrf token in the form parameter")
}
return token, nil
}
}
// csrfTokenFromQuery returns a `csrfTokenExtractor` that extracts token from the
// provided query parameter.
func csrfTokenFromQuery(param string) csrfTokenExtractor {
return func(c echo.Context) (string, error) {
token := c.QueryParam(param)
if token == "" {
return "", errors.New("missing csrf token in the query string")
}
return token, nil
}
}
func validateCSRFToken(token, clientToken string) bool {
return subtle.ConstantTimeCompare([]byte(token), []byte(clientToken)) == 1
}