2016-05-13 02:45:00 +02:00
|
|
|
package middleware
|
|
|
|
|
|
|
|
import (
|
|
|
|
"crypto/hmac"
|
|
|
|
"crypto/rand"
|
|
|
|
"crypto/sha1"
|
|
|
|
"encoding/hex"
|
2016-05-13 17:18:00 +02:00
|
|
|
"errors"
|
2016-05-13 02:45:00 +02:00
|
|
|
"fmt"
|
|
|
|
"net/http"
|
|
|
|
"strings"
|
|
|
|
"time"
|
|
|
|
|
|
|
|
"github.com/labstack/echo"
|
|
|
|
)
|
|
|
|
|
|
|
|
type (
|
|
|
|
// CSRFConfig defines the config for CSRF middleware.
|
|
|
|
CSRFConfig struct {
|
|
|
|
// Key to create CSRF token.
|
2016-05-19 03:53:54 +02:00
|
|
|
Secret []byte `json:"secret"`
|
2016-05-13 02:45:00 +02:00
|
|
|
|
2016-05-27 04:23:46 +02:00
|
|
|
// TokenLookup is a string in the form of "<source>:<key>" that is used
|
|
|
|
// to extract token from the request.
|
2016-05-26 23:06:30 +02:00
|
|
|
// Optional. Default value "header:X-CSRF-Token".
|
|
|
|
// Possible values:
|
|
|
|
// - "header:<name>"
|
|
|
|
// - "form:<name>"
|
|
|
|
// - "header:<name>"
|
2016-05-27 04:23:46 +02:00
|
|
|
TokenLookup string `json:"token_lookup"`
|
2016-05-26 23:06:30 +02:00
|
|
|
|
2016-05-13 17:18:00 +02:00
|
|
|
// Context key to store generated CSRF token into context.
|
|
|
|
// Optional. Default value "csrf".
|
2016-05-19 03:53:54 +02:00
|
|
|
ContextKey string `json:"context_key"`
|
2016-05-13 17:18:00 +02:00
|
|
|
|
2016-05-13 02:45:00 +02:00
|
|
|
// Name of the CSRF cookie. This cookie will store CSRF token.
|
|
|
|
// Optional. Default value "csrf".
|
2016-05-19 03:53:54 +02:00
|
|
|
CookieName string `json:"cookie_name"`
|
2016-05-13 02:45:00 +02:00
|
|
|
|
|
|
|
// Domain of the CSRF cookie.
|
|
|
|
// Optional. Default value none.
|
2016-05-19 03:53:54 +02:00
|
|
|
CookieDomain string `json:"cookie_domain"`
|
2016-05-13 02:45:00 +02:00
|
|
|
|
2016-05-13 03:14:00 +02:00
|
|
|
// Path of the CSRF cookie.
|
2016-05-13 02:45:00 +02:00
|
|
|
// Optional. Default value none.
|
2016-05-19 03:53:54 +02:00
|
|
|
CookiePath string `json:"cookie_path"`
|
2016-05-13 02:45:00 +02:00
|
|
|
|
2016-05-13 03:14:00 +02:00
|
|
|
// Expiration time of the CSRF cookie.
|
|
|
|
// Optional. Default value 24H.
|
2016-05-19 03:53:54 +02:00
|
|
|
CookieExpires time.Time `json:"cookie_expires"`
|
2016-05-13 02:45:00 +02:00
|
|
|
|
|
|
|
// Indicates if CSRF cookie is secure.
|
2016-05-19 03:53:54 +02:00
|
|
|
CookieSecure bool `json:"cookie_secure"`
|
2016-05-13 03:14:00 +02:00
|
|
|
// Optional. Default value false.
|
2016-05-13 02:45:00 +02:00
|
|
|
|
|
|
|
// Indicates if CSRF cookie is HTTP only.
|
2016-05-13 03:14:00 +02:00
|
|
|
// Optional. Default value false.
|
2016-05-19 03:53:54 +02:00
|
|
|
CookieHTTPOnly bool `json:"cookie_http_only"`
|
2016-05-13 02:45:00 +02:00
|
|
|
}
|
2016-05-13 17:18:00 +02:00
|
|
|
|
2016-05-26 23:06:30 +02:00
|
|
|
// csrfTokenExtractor defines a function that takes `echo.Context` and returns
|
2016-05-13 17:18:00 +02:00
|
|
|
// either a token or an error.
|
2016-05-26 23:06:30 +02:00
|
|
|
csrfTokenExtractor func(echo.Context) (string, error)
|
2016-05-13 02:45:00 +02:00
|
|
|
)
|
|
|
|
|
|
|
|
var (
|
|
|
|
// DefaultCSRFConfig is the default CSRF middleware config.
|
|
|
|
DefaultCSRFConfig = CSRFConfig{
|
2016-05-27 04:23:46 +02:00
|
|
|
TokenLookup: "header:" + echo.HeaderXCSRFToken,
|
2016-05-13 17:18:00 +02:00
|
|
|
ContextKey: "csrf",
|
2016-05-13 02:45:00 +02:00
|
|
|
CookieName: "csrf",
|
|
|
|
CookieExpires: time.Now().Add(24 * time.Hour),
|
|
|
|
}
|
|
|
|
)
|
|
|
|
|
|
|
|
// CSRF returns a Cross-Site Request Forgery (CSRF) middleware.
|
|
|
|
// See: https://en.wikipedia.org/wiki/Cross-site_request_forgery
|
|
|
|
func CSRF(secret []byte) echo.MiddlewareFunc {
|
|
|
|
c := DefaultCSRFConfig
|
|
|
|
c.Secret = secret
|
|
|
|
return CSRFWithConfig(c)
|
|
|
|
}
|
|
|
|
|
|
|
|
// CSRFWithConfig returns a CSRF middleware from config.
|
|
|
|
// See `CSRF()`.
|
|
|
|
func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc {
|
|
|
|
// Defaults
|
|
|
|
if config.Secret == nil {
|
2016-05-13 17:18:00 +02:00
|
|
|
panic("csrf secret must be provided")
|
2016-05-13 02:45:00 +02:00
|
|
|
}
|
2016-05-27 04:23:46 +02:00
|
|
|
if config.TokenLookup == "" {
|
|
|
|
config.TokenLookup = DefaultCSRFConfig.TokenLookup
|
2016-05-26 23:06:30 +02:00
|
|
|
}
|
2016-05-13 17:18:00 +02:00
|
|
|
if config.ContextKey == "" {
|
|
|
|
config.ContextKey = DefaultCSRFConfig.ContextKey
|
|
|
|
}
|
2016-05-13 02:45:00 +02:00
|
|
|
if config.CookieName == "" {
|
|
|
|
config.CookieName = DefaultCSRFConfig.CookieName
|
|
|
|
}
|
|
|
|
if config.CookieExpires.IsZero() {
|
|
|
|
config.CookieExpires = DefaultCSRFConfig.CookieExpires
|
|
|
|
}
|
|
|
|
|
2016-05-26 23:06:30 +02:00
|
|
|
// Initialize
|
2016-05-27 04:23:46 +02:00
|
|
|
parts := strings.Split(config.TokenLookup, ":")
|
2016-05-26 23:06:30 +02:00
|
|
|
extractor := csrfTokenFromHeader(parts[1])
|
|
|
|
switch parts[0] {
|
|
|
|
case "form":
|
|
|
|
extractor = csrfTokenFromForm(parts[1])
|
|
|
|
case "query":
|
|
|
|
extractor = csrfTokenFromQuery(parts[1])
|
|
|
|
}
|
|
|
|
|
2016-05-13 02:45:00 +02:00
|
|
|
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
|
|
|
return func(c echo.Context) error {
|
|
|
|
req := c.Request()
|
|
|
|
|
|
|
|
// Set CSRF token
|
|
|
|
salt, err := generateSalt(8)
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
token := generateCSRFToken(config.Secret, salt)
|
2016-05-13 17:18:00 +02:00
|
|
|
c.Set(config.ContextKey, token)
|
2016-05-13 02:45:00 +02:00
|
|
|
cookie := new(echo.Cookie)
|
|
|
|
cookie.SetName(config.CookieName)
|
|
|
|
cookie.SetValue(token)
|
|
|
|
if config.CookiePath != "" {
|
|
|
|
cookie.SetPath(config.CookiePath)
|
|
|
|
}
|
|
|
|
if config.CookieDomain != "" {
|
|
|
|
cookie.SetDomain(config.CookieDomain)
|
|
|
|
}
|
|
|
|
cookie.SetExpires(config.CookieExpires)
|
|
|
|
cookie.SetSecure(config.CookieSecure)
|
|
|
|
cookie.SetHTTPOnly(config.CookieHTTPOnly)
|
|
|
|
c.SetCookie(cookie)
|
|
|
|
|
|
|
|
switch req.Method() {
|
|
|
|
case echo.GET, echo.HEAD, echo.OPTIONS, echo.TRACE:
|
|
|
|
default:
|
2016-05-26 23:06:30 +02:00
|
|
|
token, err := extractor(c)
|
2016-05-13 17:18:00 +02:00
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
2016-05-13 02:45:00 +02:00
|
|
|
ok, err := validateCSRFToken(token, config.Secret)
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
if !ok {
|
2016-05-13 17:18:00 +02:00
|
|
|
return echo.NewHTTPError(http.StatusForbidden, "invalid csrf token")
|
2016-05-13 02:45:00 +02:00
|
|
|
}
|
|
|
|
}
|
|
|
|
return next(c)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2016-05-26 23:06:30 +02:00
|
|
|
// csrfTokenFromForm returns a `csrfTokenExtractor` that extracts token from the
|
2016-05-13 17:18:00 +02:00
|
|
|
// provided request header.
|
2016-05-26 23:06:30 +02:00
|
|
|
func csrfTokenFromHeader(header string) csrfTokenExtractor {
|
2016-05-13 17:18:00 +02:00
|
|
|
return func(c echo.Context) (string, error) {
|
|
|
|
return c.Request().Header().Get(header), nil
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2016-05-26 23:06:30 +02:00
|
|
|
// csrfTokenFromForm returns a `csrfTokenExtractor` that extracts token from the
|
2016-05-13 17:18:00 +02:00
|
|
|
// provided form parameter.
|
2016-05-26 23:06:30 +02:00
|
|
|
func csrfTokenFromForm(param string) csrfTokenExtractor {
|
2016-05-13 17:18:00 +02:00
|
|
|
return func(c echo.Context) (string, error) {
|
|
|
|
token := c.FormValue(param)
|
|
|
|
if token == "" {
|
|
|
|
return "", errors.New("empty csrf token in form param")
|
|
|
|
}
|
|
|
|
return token, nil
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2016-05-26 23:06:30 +02:00
|
|
|
// csrfTokenFromQuery returns a `csrfTokenExtractor` that extracts token from the
|
2016-05-13 17:18:00 +02:00
|
|
|
// provided query parameter.
|
2016-05-26 23:06:30 +02:00
|
|
|
func csrfTokenFromQuery(param string) csrfTokenExtractor {
|
2016-05-13 17:18:00 +02:00
|
|
|
return func(c echo.Context) (string, error) {
|
|
|
|
token := c.QueryParam(param)
|
|
|
|
if token == "" {
|
|
|
|
return "", errors.New("empty csrf token in query param")
|
|
|
|
}
|
|
|
|
return token, nil
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2016-05-13 02:45:00 +02:00
|
|
|
func generateCSRFToken(secret, salt []byte) string {
|
|
|
|
h := hmac.New(sha1.New, secret)
|
|
|
|
h.Write(salt)
|
|
|
|
return fmt.Sprintf("%s:%s", hex.EncodeToString(h.Sum(nil)), hex.EncodeToString(salt))
|
|
|
|
}
|
|
|
|
|
|
|
|
func validateCSRFToken(token string, secret []byte) (bool, error) {
|
|
|
|
sep := strings.Index(token, ":")
|
|
|
|
if sep < 0 {
|
|
|
|
return false, nil
|
|
|
|
}
|
|
|
|
salt, err := hex.DecodeString(token[sep+1:])
|
|
|
|
if err != nil {
|
|
|
|
return false, err
|
|
|
|
}
|
|
|
|
return token == generateCSRFToken(secret, salt), nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func generateSalt(len uint8) (salt []byte, err error) {
|
|
|
|
salt = make([]byte, len)
|
|
|
|
_, err = rand.Read(salt)
|
|
|
|
return
|
|
|
|
}
|