mirror of
https://github.com/labstack/echo.git
synced 2024-12-24 20:14:31 +02:00
Extractor for csrf token
Signed-off-by: Vishal Rana <vr@labstack.com>
This commit is contained in:
parent
1aa22ce09b
commit
7d1819e5b1
@ -5,6 +5,7 @@ import (
|
||||
"crypto/rand"
|
||||
"crypto/sha1"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
@ -19,8 +20,13 @@ type (
|
||||
// Key to create CSRF token.
|
||||
Secret []byte
|
||||
|
||||
// Name of the request header to extract CSRF token.
|
||||
HeaderName string
|
||||
// Context key to store generated CSRF token into context.
|
||||
// Optional. Default value "csrf".
|
||||
ContextKey string
|
||||
|
||||
// Extractor is a function that extracts token from the request.
|
||||
// Optional. Default value CSRFTokenFromHeader(echo.HeaderXCSRFToken).
|
||||
Extractor CSRFTokenExtractor
|
||||
|
||||
// Name of the CSRF cookie. This cookie will store CSRF token.
|
||||
// Optional. Default value "csrf".
|
||||
@ -46,12 +52,17 @@ type (
|
||||
// Optional. Default value false.
|
||||
CookieHTTPOnly bool
|
||||
}
|
||||
|
||||
// CSRFTokenExtractor defines a function that takes `echo.Context` and returns
|
||||
// either a token or an error.
|
||||
CSRFTokenExtractor func(echo.Context) (string, error)
|
||||
)
|
||||
|
||||
var (
|
||||
// DefaultCSRFConfig is the default CSRF middleware config.
|
||||
DefaultCSRFConfig = CSRFConfig{
|
||||
HeaderName: echo.HeaderXCSRFToken,
|
||||
ContextKey: "csrf",
|
||||
Extractor: CSRFTokenFromHeader(echo.HeaderXCSRFToken),
|
||||
CookieName: "csrf",
|
||||
CookieExpires: time.Now().Add(24 * time.Hour),
|
||||
}
|
||||
@ -70,10 +81,13 @@ func CSRF(secret []byte) echo.MiddlewareFunc {
|
||||
func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc {
|
||||
// Defaults
|
||||
if config.Secret == nil {
|
||||
panic("csrf: secret must be provided")
|
||||
panic("csrf secret must be provided")
|
||||
}
|
||||
if config.HeaderName == "" {
|
||||
config.HeaderName = DefaultCSRFConfig.HeaderName
|
||||
if config.ContextKey == "" {
|
||||
config.ContextKey = DefaultCSRFConfig.ContextKey
|
||||
}
|
||||
if config.Extractor == nil {
|
||||
config.Extractor = DefaultCSRFConfig.Extractor
|
||||
}
|
||||
if config.CookieName == "" {
|
||||
config.CookieName = DefaultCSRFConfig.CookieName
|
||||
@ -92,6 +106,7 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc {
|
||||
return err
|
||||
}
|
||||
token := generateCSRFToken(config.Secret, salt)
|
||||
c.Set(config.ContextKey, token)
|
||||
cookie := new(echo.Cookie)
|
||||
cookie.SetName(config.CookieName)
|
||||
cookie.SetValue(token)
|
||||
@ -109,13 +124,16 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc {
|
||||
switch req.Method() {
|
||||
case echo.GET, echo.HEAD, echo.OPTIONS, echo.TRACE:
|
||||
default:
|
||||
token := req.Header().Get(config.HeaderName)
|
||||
token, err := config.Extractor(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ok, err := validateCSRFToken(token, config.Secret)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !ok {
|
||||
return echo.NewHTTPError(http.StatusForbidden, "csrf: invalid token")
|
||||
return echo.NewHTTPError(http.StatusForbidden, "invalid csrf token")
|
||||
}
|
||||
}
|
||||
return next(c)
|
||||
@ -123,6 +141,38 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc {
|
||||
}
|
||||
}
|
||||
|
||||
// CSRFTokenFromHeader returns a `CSRFTokenExtractor` that extracts token from the
|
||||
// provided request header.
|
||||
func CSRFTokenFromHeader(header string) CSRFTokenExtractor {
|
||||
return func(c echo.Context) (string, error) {
|
||||
return c.Request().Header().Get(header), nil
|
||||
}
|
||||
}
|
||||
|
||||
// CSRFTokenFromForm returns a `CSRFTokenExtractor` that extracts token from the
|
||||
// provided form parameter.
|
||||
func CSRFTokenFromForm(param string) CSRFTokenExtractor {
|
||||
return func(c echo.Context) (string, error) {
|
||||
token := c.FormValue(param)
|
||||
if token == "" {
|
||||
return "", errors.New("empty csrf token in form param")
|
||||
}
|
||||
return token, nil
|
||||
}
|
||||
}
|
||||
|
||||
// CSRFTokenFromQuery returns a `CSRFTokenExtractor` that extracts token from the
|
||||
// provided query parameter.
|
||||
func CSRFTokenFromQuery(param string) CSRFTokenExtractor {
|
||||
return func(c echo.Context) (string, error) {
|
||||
token := c.QueryParam(param)
|
||||
if token == "" {
|
||||
return "", errors.New("empty csrf token in query param")
|
||||
}
|
||||
return token, nil
|
||||
}
|
||||
}
|
||||
|
||||
func generateCSRFToken(secret, salt []byte) string {
|
||||
h := hmac.New(sha1.New, secret)
|
||||
h.Write(salt)
|
||||
|
@ -1,6 +1,7 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
@ -113,13 +114,17 @@ func JWTFromHeader(c echo.Context) (string, error) {
|
||||
if len(auth) > l+1 && auth[:l] == bearer {
|
||||
return auth[l+1:], nil
|
||||
}
|
||||
return "", echo.NewHTTPError(http.StatusBadRequest, "empty or invalid authorization header="+auth)
|
||||
return "", errors.New("empty or invalid jwt in authorization header")
|
||||
}
|
||||
|
||||
// JWTFromQuery returns a `JWTExtractor` that extracts token from the provided query
|
||||
// parameter.
|
||||
func JWTFromQuery(param string) JWTExtractor {
|
||||
return func(c echo.Context) (string, error) {
|
||||
return c.QueryParam(param), nil
|
||||
token := c.QueryParam(param)
|
||||
if token == "" {
|
||||
return "", errors.New("empty jwt in query param")
|
||||
}
|
||||
return token, nil
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user