1
0
mirror of https://github.com/labstack/echo.git synced 2024-12-24 20:14:31 +02:00

Extractor for csrf token

Signed-off-by: Vishal Rana <vr@labstack.com>
This commit is contained in:
Vishal Rana 2016-05-13 08:18:00 -07:00
parent 1aa22ce09b
commit 7d1819e5b1
2 changed files with 65 additions and 10 deletions

View File

@ -5,6 +5,7 @@ import (
"crypto/rand"
"crypto/sha1"
"encoding/hex"
"errors"
"fmt"
"net/http"
"strings"
@ -19,8 +20,13 @@ type (
// Key to create CSRF token.
Secret []byte
// Name of the request header to extract CSRF token.
HeaderName string
// Context key to store generated CSRF token into context.
// Optional. Default value "csrf".
ContextKey string
// Extractor is a function that extracts token from the request.
// Optional. Default value CSRFTokenFromHeader(echo.HeaderXCSRFToken).
Extractor CSRFTokenExtractor
// Name of the CSRF cookie. This cookie will store CSRF token.
// Optional. Default value "csrf".
@ -46,12 +52,17 @@ type (
// Optional. Default value false.
CookieHTTPOnly bool
}
// CSRFTokenExtractor defines a function that takes `echo.Context` and returns
// either a token or an error.
CSRFTokenExtractor func(echo.Context) (string, error)
)
var (
// DefaultCSRFConfig is the default CSRF middleware config.
DefaultCSRFConfig = CSRFConfig{
HeaderName: echo.HeaderXCSRFToken,
ContextKey: "csrf",
Extractor: CSRFTokenFromHeader(echo.HeaderXCSRFToken),
CookieName: "csrf",
CookieExpires: time.Now().Add(24 * time.Hour),
}
@ -70,10 +81,13 @@ func CSRF(secret []byte) echo.MiddlewareFunc {
func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc {
// Defaults
if config.Secret == nil {
panic("csrf: secret must be provided")
panic("csrf secret must be provided")
}
if config.HeaderName == "" {
config.HeaderName = DefaultCSRFConfig.HeaderName
if config.ContextKey == "" {
config.ContextKey = DefaultCSRFConfig.ContextKey
}
if config.Extractor == nil {
config.Extractor = DefaultCSRFConfig.Extractor
}
if config.CookieName == "" {
config.CookieName = DefaultCSRFConfig.CookieName
@ -92,6 +106,7 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc {
return err
}
token := generateCSRFToken(config.Secret, salt)
c.Set(config.ContextKey, token)
cookie := new(echo.Cookie)
cookie.SetName(config.CookieName)
cookie.SetValue(token)
@ -109,13 +124,16 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc {
switch req.Method() {
case echo.GET, echo.HEAD, echo.OPTIONS, echo.TRACE:
default:
token := req.Header().Get(config.HeaderName)
token, err := config.Extractor(c)
if err != nil {
return err
}
ok, err := validateCSRFToken(token, config.Secret)
if err != nil {
return err
}
if !ok {
return echo.NewHTTPError(http.StatusForbidden, "csrf: invalid token")
return echo.NewHTTPError(http.StatusForbidden, "invalid csrf token")
}
}
return next(c)
@ -123,6 +141,38 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc {
}
}
// CSRFTokenFromHeader returns a `CSRFTokenExtractor` that extracts token from the
// provided request header.
func CSRFTokenFromHeader(header string) CSRFTokenExtractor {
return func(c echo.Context) (string, error) {
return c.Request().Header().Get(header), nil
}
}
// CSRFTokenFromForm returns a `CSRFTokenExtractor` that extracts token from the
// provided form parameter.
func CSRFTokenFromForm(param string) CSRFTokenExtractor {
return func(c echo.Context) (string, error) {
token := c.FormValue(param)
if token == "" {
return "", errors.New("empty csrf token in form param")
}
return token, nil
}
}
// CSRFTokenFromQuery returns a `CSRFTokenExtractor` that extracts token from the
// provided query parameter.
func CSRFTokenFromQuery(param string) CSRFTokenExtractor {
return func(c echo.Context) (string, error) {
token := c.QueryParam(param)
if token == "" {
return "", errors.New("empty csrf token in query param")
}
return token, nil
}
}
func generateCSRFToken(secret, salt []byte) string {
h := hmac.New(sha1.New, secret)
h.Write(salt)

View File

@ -1,6 +1,7 @@
package middleware
import (
"errors"
"fmt"
"net/http"
@ -113,13 +114,17 @@ func JWTFromHeader(c echo.Context) (string, error) {
if len(auth) > l+1 && auth[:l] == bearer {
return auth[l+1:], nil
}
return "", echo.NewHTTPError(http.StatusBadRequest, "empty or invalid authorization header="+auth)
return "", errors.New("empty or invalid jwt in authorization header")
}
// JWTFromQuery returns a `JWTExtractor` that extracts token from the provided query
// parameter.
func JWTFromQuery(param string) JWTExtractor {
return func(c echo.Context) (string, error) {
return c.QueryParam(param), nil
token := c.QueryParam(param)
if token == "" {
return "", errors.New("empty jwt in query param")
}
return token, nil
}
}