1
0
mirror of https://github.com/labstack/echo.git synced 2024-12-24 20:14:31 +02:00

Added CSRF middleware, #341.

Signed-off-by: Vishal Rana <vr@labstack.com>
This commit is contained in:
Vishal Rana 2016-05-12 17:45:00 -07:00
parent 1afaa6ec0b
commit 98dd8bf9e9
11 changed files with 197 additions and 10 deletions

View File

@ -174,6 +174,7 @@ const (
HeaderXXSSProtection = "X-XSS-Protection" HeaderXXSSProtection = "X-XSS-Protection"
HeaderXFrameOptions = "X-Frame-Options" HeaderXFrameOptions = "X-Frame-Options"
HeaderContentSecurityPolicy = "Content-Security-Policy" HeaderContentSecurityPolicy = "Content-Security-Policy"
HeaderXCSRFToken = "X-CSRF-Token"
) )
var ( var (

View File

@ -40,7 +40,7 @@ func BodyLimit(limit string) echo.MiddlewareFunc {
} }
// BodyLimitWithConfig returns a body limit middleware from config. // BodyLimitWithConfig returns a body limit middleware from config.
// See `BodyLimit()`. // See: `BodyLimit()`.
func BodyLimitWithConfig(config BodyLimitConfig) echo.MiddlewareFunc { func BodyLimitWithConfig(config BodyLimitConfig) echo.MiddlewareFunc {
limit, err := bytes.Parse(config.Limit) limit, err := bytes.Parse(config.Limit)
if err != nil { if err != nil {

View File

@ -40,7 +40,7 @@ func Gzip() echo.MiddlewareFunc {
} }
// GzipWithConfig return gzip middleware from config. // GzipWithConfig return gzip middleware from config.
// See `Gzip()`. // See: `Gzip()`.
func GzipWithConfig(config GzipConfig) echo.MiddlewareFunc { func GzipWithConfig(config GzipConfig) echo.MiddlewareFunc {
// Defaults // Defaults
if config.Level == 0 { if config.Level == 0 {

View File

@ -53,13 +53,13 @@ var (
) )
// CORS returns a Cross-Origin Resource Sharing (CORS) middleware. // CORS returns a Cross-Origin Resource Sharing (CORS) middleware.
// See https://developer.mozilla.org/en/docs/Web/HTTP/Access_control_CORS // See: https://developer.mozilla.org/en/docs/Web/HTTP/Access_control_CORS
func CORS() echo.MiddlewareFunc { func CORS() echo.MiddlewareFunc {
return CORSWithConfig(DefaultCORSConfig) return CORSWithConfig(DefaultCORSConfig)
} }
// CORSWithConfig returns a CORS middleware from config. // CORSWithConfig returns a CORS middleware from config.
// See `CORS()`. // See: `CORS()`.
func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc { func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc {
// Defaults // Defaults
if len(config.AllowOrigins) == 0 { if len(config.AllowOrigins) == 0 {

146
middleware/csrf.go Normal file
View File

@ -0,0 +1,146 @@
package middleware
import (
"crypto/hmac"
"crypto/rand"
"crypto/sha1"
"encoding/hex"
"fmt"
"net/http"
"strings"
"time"
"github.com/labstack/echo"
)
type (
// CSRFConfig defines the config for CSRF middleware.
CSRFConfig struct {
// Key to create CSRF token.
Secret []byte
// Name of the request header to extract CSRF token.
HeaderName string
// Name of the CSRF cookie. This cookie will store CSRF token.
// Optional. Default value "csrf".
CookieName string
// Domain of the CSRF cookie.
// Optional. Default value none.
CookieDomain string
// Paht of the CSRF cookie.
// Optional. Default value none.
CookiePath string
// Expiriation time of the CSRF cookie.
// Optioanl. Default value 24hrs.
CookieExpires time.Time
// Indicates if CSRF cookie is secure.
CookieSecure bool
// Indicates if CSRF cookie is HTTP only.
CookieHTTPOnly bool
}
)
var (
// DefaultCSRFConfig is the default CSRF middleware config.
DefaultCSRFConfig = CSRFConfig{
HeaderName: echo.HeaderXCSRFToken,
CookieName: "csrf",
CookieExpires: time.Now().Add(24 * time.Hour),
}
)
// CSRF returns a Cross-Site Request Forgery (CSRF) middleware.
// See: https://en.wikipedia.org/wiki/Cross-site_request_forgery
func CSRF(secret []byte) echo.MiddlewareFunc {
c := DefaultCSRFConfig
c.Secret = secret
return CSRFWithConfig(c)
}
// CSRFWithConfig returns a CSRF middleware from config.
// See `CSRF()`.
func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc {
// Defaults
if config.Secret == nil {
panic("csrf: secret must be provided")
}
if config.HeaderName == "" {
config.HeaderName = DefaultCSRFConfig.HeaderName
}
if config.CookieName == "" {
config.CookieName = DefaultCSRFConfig.CookieName
}
if config.CookieExpires.IsZero() {
config.CookieExpires = DefaultCSRFConfig.CookieExpires
}
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
req := c.Request()
// Set CSRF token
salt, err := generateSalt(8)
if err != nil {
return err
}
token := generateCSRFToken(config.Secret, salt)
cookie := new(echo.Cookie)
cookie.SetName(config.CookieName)
cookie.SetValue(token)
if config.CookiePath != "" {
cookie.SetPath(config.CookiePath)
}
if config.CookieDomain != "" {
cookie.SetDomain(config.CookieDomain)
}
cookie.SetExpires(config.CookieExpires)
cookie.SetSecure(config.CookieSecure)
cookie.SetHTTPOnly(config.CookieHTTPOnly)
c.SetCookie(cookie)
switch req.Method() {
case echo.GET, echo.HEAD, echo.OPTIONS, echo.TRACE:
default:
token := req.Header().Get(config.HeaderName)
ok, err := validateCSRFToken(token, config.Secret)
if err != nil {
return err
}
if !ok {
return echo.NewHTTPError(http.StatusForbidden, "csrf: invalid token")
}
}
return next(c)
}
}
}
func generateCSRFToken(secret, salt []byte) string {
h := hmac.New(sha1.New, secret)
h.Write(salt)
return fmt.Sprintf("%s:%s", hex.EncodeToString(h.Sum(nil)), hex.EncodeToString(salt))
}
func validateCSRFToken(token string, secret []byte) (bool, error) {
sep := strings.Index(token, ":")
if sep < 0 {
return false, nil
}
salt, err := hex.DecodeString(token[sep+1:])
if err != nil {
return false, err
}
return token == generateCSRFToken(secret, salt), nil
}
func generateSalt(len uint8) (salt []byte, err error) {
salt = make([]byte, len)
_, err = rand.Read(salt)
return
}

40
middleware/csrf_test.go Normal file
View File

@ -0,0 +1,40 @@
package middleware
import (
"net/http"
"testing"
"github.com/labstack/echo"
"github.com/labstack/echo/test"
"github.com/stretchr/testify/assert"
)
func TestCSRF(t *testing.T) {
e := echo.New()
req := test.NewRequest(echo.GET, "/", nil)
rec := test.NewResponseRecorder()
c := e.NewContext(req, rec)
csrf := CSRF([]byte("secret"))
h := csrf(func(c echo.Context) error {
return c.String(http.StatusOK, "test")
})
// Generate CSRF token
h(c)
assert.Contains(t, rec.Header().Get(echo.HeaderSetCookie), "csrf")
// Empty/invalid CSRF token
req = test.NewRequest(echo.POST, "/", nil)
rec = test.NewResponseRecorder()
c = e.NewContext(req, rec)
req.Header().Set(echo.HeaderXCSRFToken, "")
he := h(c).(*echo.HTTPError)
assert.Equal(t, http.StatusForbidden, he.Code)
// Valid CSRF token
salt, _ := generateSalt(8)
token := generateCSRFToken([]byte("secret"), salt)
req.Header().Set(echo.HeaderXCSRFToken, token)
h(c)
assert.Equal(t, http.StatusOK, rec.Status())
}

View File

@ -57,7 +57,7 @@ var (
// For invalid token, it sends "401 - Unauthorized" response. // For invalid token, it sends "401 - Unauthorized" response.
// For empty or invalid `Authorization` header, it sends "400 - Bad Request". // For empty or invalid `Authorization` header, it sends "400 - Bad Request".
// //
// See https://jwt.io/introduction // See: https://jwt.io/introduction
func JWT(key []byte) echo.MiddlewareFunc { func JWT(key []byte) echo.MiddlewareFunc {
c := DefaultJWTConfig c := DefaultJWTConfig
c.SigningKey = key c.SigningKey = key
@ -65,7 +65,7 @@ func JWT(key []byte) echo.MiddlewareFunc {
} }
// JWTWithConfig returns a JWT auth middleware from config. // JWTWithConfig returns a JWT auth middleware from config.
// See `JWT()`. // See: `JWT()`.
func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc { func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc {
// Defaults // Defaults
if config.SigningKey == nil { if config.SigningKey == nil {

View File

@ -69,7 +69,7 @@ func Logger() echo.MiddlewareFunc {
} }
// LoggerWithConfig returns a logger middleware from config. // LoggerWithConfig returns a logger middleware from config.
// See `Logger()`. // See: `Logger()`.
func LoggerWithConfig(config LoggerConfig) echo.MiddlewareFunc { func LoggerWithConfig(config LoggerConfig) echo.MiddlewareFunc {
// Defaults // Defaults
if config.Format == "" { if config.Format == "" {

View File

@ -33,7 +33,7 @@ func MethodOverride() echo.MiddlewareFunc {
} }
// MethodOverrideWithConfig returns a method override middleware from config. // MethodOverrideWithConfig returns a method override middleware from config.
// See `MethodOverride()`. // See: `MethodOverride()`.
func MethodOverrideWithConfig(config MethodOverrideConfig) echo.MiddlewareFunc { func MethodOverrideWithConfig(config MethodOverrideConfig) echo.MiddlewareFunc {
// Defaults // Defaults
if config.Getter == nil { if config.Getter == nil {

View File

@ -42,7 +42,7 @@ func Recover() echo.MiddlewareFunc {
} }
// RecoverWithConfig returns a recover middleware from config. // RecoverWithConfig returns a recover middleware from config.
// See `Recover()`. // See: `Recover()`.
func RecoverWithConfig(config RecoverConfig) echo.MiddlewareFunc { func RecoverWithConfig(config RecoverConfig) echo.MiddlewareFunc {
// Defaults // Defaults
if config.StackSize == 0 { if config.StackSize == 0 {

View File

@ -71,7 +71,7 @@ func Secure() echo.MiddlewareFunc {
} }
// SecureWithConfig returns a secure middleware from config. // SecureWithConfig returns a secure middleware from config.
// See `Secure()`. // See: `Secure()`.
func SecureWithConfig(config SecureConfig) echo.MiddlewareFunc { func SecureWithConfig(config SecureConfig) echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc { return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error { return func(c echo.Context) error {