1
0
mirror of https://github.com/labstack/echo.git synced 2025-07-05 00:58:47 +02:00
Signed-off-by: Vishal Rana <vr@labstack.com>
This commit is contained in:
Vishal Rana
2016-07-16 20:13:27 -07:00
parent c1358eda73
commit 0dab439ea4
2 changed files with 54 additions and 71 deletions

View File

@ -1,12 +1,9 @@
package middleware package middleware
import ( import (
"crypto/hmac" "crypto/subtle"
"crypto/rand"
"crypto/sha1"
"encoding/hex"
"errors" "errors"
"fmt" "math/rand"
"net/http" "net/http"
"strings" "strings"
"time" "time"
@ -17,8 +14,9 @@ import (
type ( type (
// CSRFConfig defines the config for CSRF middleware. // CSRFConfig defines the config for CSRF middleware.
CSRFConfig struct { CSRFConfig struct {
// Key to create CSRF token. // TokenLength is the length of the generated token.
Secret []byte `json:"secret"` TokenLength uint8 `json:"token_length"`
// Optional. Default value 32.
// TokenLookup is a string in the form of "<source>:<key>" that is used // TokenLookup is a string in the form of "<source>:<key>" that is used
// to extract token from the request. // to extract token from the request.
@ -52,6 +50,10 @@ type (
// Indicates if CSRF cookie is secure. // Indicates if CSRF cookie is secure.
// Optional. Default value false. // Optional. Default value false.
CookieSecure bool `json:"cookie_secure"` CookieSecure bool `json:"cookie_secure"`
// Indicates if CSRF cookie is HTTP only.
// Optional. Default value false.
CookieHTTPOnly bool `json:"cookie_http_only"`
} }
// csrfTokenExtractor defines a function that takes `echo.Context` and returns // csrfTokenExtractor defines a function that takes `echo.Context` and returns
@ -62,6 +64,7 @@ type (
var ( var (
// DefaultCSRFConfig is the default CSRF middleware config. // DefaultCSRFConfig is the default CSRF middleware config.
DefaultCSRFConfig = CSRFConfig{ DefaultCSRFConfig = CSRFConfig{
TokenLength: 32,
TokenLookup: "header:" + echo.HeaderXCSRFToken, TokenLookup: "header:" + echo.HeaderXCSRFToken,
ContextKey: "csrf", ContextKey: "csrf",
CookieName: "_csrf", CookieName: "_csrf",
@ -71,9 +74,8 @@ var (
// CSRF returns a Cross-Site Request Forgery (CSRF) middleware. // CSRF returns a Cross-Site Request Forgery (CSRF) middleware.
// See: https://en.wikipedia.org/wiki/Cross-site_request_forgery // See: https://en.wikipedia.org/wiki/Cross-site_request_forgery
func CSRF(secret []byte) echo.MiddlewareFunc { func CSRF() echo.MiddlewareFunc {
c := DefaultCSRFConfig c := DefaultCSRFConfig
c.Secret = secret
return CSRFWithConfig(c) return CSRFWithConfig(c)
} }
@ -81,8 +83,8 @@ func CSRF(secret []byte) echo.MiddlewareFunc {
// See `CSRF()`. // See `CSRF()`.
func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc { func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc {
// Defaults // Defaults
if config.Secret == nil { if config.TokenLength == 0 {
panic("csrf secret must be provided") config.TokenLength = DefaultCSRFConfig.TokenLength
} }
if config.TokenLookup == "" { if config.TokenLookup == "" {
config.TokenLookup = DefaultCSRFConfig.TokenLookup config.TokenLookup = DefaultCSRFConfig.TokenLookup
@ -110,51 +112,51 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc { return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error { return func(c echo.Context) error {
req := c.Request() req := c.Request()
cookie, err := c.Cookie(config.CookieName) k, err := c.Cookie(config.CookieName)
token := "" token := ""
if err != nil { if err != nil {
// Token expired, generate it // Generate token
salt, err := generateSalt(8) token = generateCSRFToken(config.TokenLength)
if err != nil {
return err
}
token = generateCSRFToken(config.Secret, salt)
cookie := new(echo.Cookie)
cookie.SetName(config.CookieName)
cookie.SetValue(token)
if config.CookiePath != "" {
cookie.SetPath(config.CookiePath)
}
if config.CookieDomain != "" {
cookie.SetDomain(config.CookieDomain)
}
cookie.SetExpires(time.Now().Add(time.Duration(config.CookieMaxAge) * time.Second))
cookie.SetSecure(config.CookieSecure)
cookie.SetHTTPOnly(true)
c.SetCookie(cookie)
} else { } else {
// Reuse token // Reuse token
token = cookie.Value() token = k.Value()
} }
c.Set(config.ContextKey, token)
switch req.Method() { switch req.Method() {
case echo.GET, echo.HEAD, echo.OPTIONS, echo.TRACE: case echo.GET, echo.HEAD, echo.OPTIONS, echo.TRACE:
default: default:
// Validate token only for requests which are not defined as 'safe' by RFC7231
clientToken, err := extractor(c) clientToken, err := extractor(c)
if err != nil { if err != nil {
return err return err
} }
ok, err := validateCSRFToken(token, clientToken, config.Secret) if !validateCSRFToken(token, clientToken) {
if err != nil { return echo.NewHTTPError(http.StatusForbidden, "csrf token is invalid")
return err
}
if !ok {
return echo.NewHTTPError(http.StatusForbidden, "invalid csrf token")
} }
} }
// Set CSRF cookie
cookie := new(echo.Cookie)
cookie.SetName(config.CookieName)
cookie.SetValue(token)
if config.CookiePath != "" {
cookie.SetPath(config.CookiePath)
}
if config.CookieDomain != "" {
cookie.SetDomain(config.CookieDomain)
}
cookie.SetExpires(time.Now().Add(time.Duration(config.CookieMaxAge) * time.Second))
cookie.SetSecure(config.CookieSecure)
cookie.SetHTTPOnly(config.CookieHTTPOnly)
c.SetCookie(cookie)
// Store token in the context
c.Set(config.ContextKey, token)
// Protect clients from caching the response
c.Response().Header().Add(echo.HeaderVary, echo.HeaderCookie)
return next(c) return next(c)
} }
} }
@ -192,29 +194,16 @@ func csrfTokenFromQuery(param string) csrfTokenExtractor {
} }
} }
func generateCSRFToken(secret, salt []byte) string { func generateCSRFToken(n uint8) string {
h := hmac.New(sha1.New, secret) // TODO: From utility library
h.Write(salt) chars := "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
return fmt.Sprintf("%s:%s", hex.EncodeToString(h.Sum(nil)), hex.EncodeToString(salt)) b := make([]byte, n)
for i := range b {
b[i] = chars[rand.Int63()%int64(len(chars))]
}
return string(b)
} }
func validateCSRFToken(serverToken, clientToken string, secret []byte) (bool, error) { func validateCSRFToken(token, clientToken string) bool {
if serverToken != clientToken { return subtle.ConstantTimeCompare([]byte(token), []byte(clientToken)) == 1
return false, nil
}
sep := strings.Index(clientToken, ":")
if sep < 0 {
return false, nil
}
salt, err := hex.DecodeString(clientToken[sep+1:])
if err != nil {
return false, err
}
return clientToken == generateCSRFToken(secret, salt), nil
}
func generateSalt(len uint8) (salt []byte, err error) {
salt = make([]byte, len)
_, err = rand.Read(salt)
return
} }

View File

@ -17,17 +17,12 @@ func TestCSRF(t *testing.T) {
rec := test.NewResponseRecorder() rec := test.NewResponseRecorder()
c := e.NewContext(req, rec) c := e.NewContext(req, rec)
csrf := CSRFWithConfig(CSRFConfig{ csrf := CSRFWithConfig(CSRFConfig{
Secret: []byte("secret"), TokenLength: 16,
}) })
h := csrf(func(c echo.Context) error { h := csrf(func(c echo.Context) error {
return c.String(http.StatusOK, "test") return c.String(http.StatusOK, "test")
}) })
// No secret
assert.Panics(t, func() {
CSRF(nil)
})
// Generate CSRF token // Generate CSRF token
h(c) h(c)
assert.Contains(t, rec.Header().Get(echo.HeaderSetCookie), "_csrf") assert.Contains(t, rec.Header().Get(echo.HeaderSetCookie), "_csrf")
@ -46,8 +41,7 @@ func TestCSRF(t *testing.T) {
assert.Error(t, h(c)) assert.Error(t, h(c))
// Valid CSRF token // Valid CSRF token
salt, _ := generateSalt(8) token := generateCSRFToken(16)
token := generateCSRFToken([]byte("secret"), salt)
req.Header().Set(echo.HeaderCookie, "_csrf="+token) req.Header().Set(echo.HeaderCookie, "_csrf="+token)
req.Header().Set(echo.HeaderXCSRFToken, token) req.Header().Set(echo.HeaderXCSRFToken, token)
if assert.NoError(t, h(c)) { if assert.NoError(t, h(c)) {