mirror of
https://github.com/labstack/echo.git
synced 2024-12-24 20:14:31 +02:00
Added CSRF middleware, #341.
Signed-off-by: Vishal Rana <vr@labstack.com>
This commit is contained in:
parent
1afaa6ec0b
commit
98dd8bf9e9
1
echo.go
1
echo.go
@ -174,6 +174,7 @@ const (
|
|||||||
HeaderXXSSProtection = "X-XSS-Protection"
|
HeaderXXSSProtection = "X-XSS-Protection"
|
||||||
HeaderXFrameOptions = "X-Frame-Options"
|
HeaderXFrameOptions = "X-Frame-Options"
|
||||||
HeaderContentSecurityPolicy = "Content-Security-Policy"
|
HeaderContentSecurityPolicy = "Content-Security-Policy"
|
||||||
|
HeaderXCSRFToken = "X-CSRF-Token"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
@ -40,7 +40,7 @@ func BodyLimit(limit string) echo.MiddlewareFunc {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// BodyLimitWithConfig returns a body limit middleware from config.
|
// BodyLimitWithConfig returns a body limit middleware from config.
|
||||||
// See `BodyLimit()`.
|
// See: `BodyLimit()`.
|
||||||
func BodyLimitWithConfig(config BodyLimitConfig) echo.MiddlewareFunc {
|
func BodyLimitWithConfig(config BodyLimitConfig) echo.MiddlewareFunc {
|
||||||
limit, err := bytes.Parse(config.Limit)
|
limit, err := bytes.Parse(config.Limit)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -40,7 +40,7 @@ func Gzip() echo.MiddlewareFunc {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GzipWithConfig return gzip middleware from config.
|
// GzipWithConfig return gzip middleware from config.
|
||||||
// See `Gzip()`.
|
// See: `Gzip()`.
|
||||||
func GzipWithConfig(config GzipConfig) echo.MiddlewareFunc {
|
func GzipWithConfig(config GzipConfig) echo.MiddlewareFunc {
|
||||||
// Defaults
|
// Defaults
|
||||||
if config.Level == 0 {
|
if config.Level == 0 {
|
||||||
|
@ -53,13 +53,13 @@ var (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// CORS returns a Cross-Origin Resource Sharing (CORS) middleware.
|
// CORS returns a Cross-Origin Resource Sharing (CORS) middleware.
|
||||||
// See https://developer.mozilla.org/en/docs/Web/HTTP/Access_control_CORS
|
// See: https://developer.mozilla.org/en/docs/Web/HTTP/Access_control_CORS
|
||||||
func CORS() echo.MiddlewareFunc {
|
func CORS() echo.MiddlewareFunc {
|
||||||
return CORSWithConfig(DefaultCORSConfig)
|
return CORSWithConfig(DefaultCORSConfig)
|
||||||
}
|
}
|
||||||
|
|
||||||
// CORSWithConfig returns a CORS middleware from config.
|
// CORSWithConfig returns a CORS middleware from config.
|
||||||
// See `CORS()`.
|
// See: `CORS()`.
|
||||||
func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc {
|
func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc {
|
||||||
// Defaults
|
// Defaults
|
||||||
if len(config.AllowOrigins) == 0 {
|
if len(config.AllowOrigins) == 0 {
|
||||||
|
146
middleware/csrf.go
Normal file
146
middleware/csrf.go
Normal file
@ -0,0 +1,146 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/hmac"
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/sha1"
|
||||||
|
"encoding/hex"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/labstack/echo"
|
||||||
|
)
|
||||||
|
|
||||||
|
type (
|
||||||
|
// CSRFConfig defines the config for CSRF middleware.
|
||||||
|
CSRFConfig struct {
|
||||||
|
// Key to create CSRF token.
|
||||||
|
Secret []byte
|
||||||
|
|
||||||
|
// Name of the request header to extract CSRF token.
|
||||||
|
HeaderName string
|
||||||
|
|
||||||
|
// Name of the CSRF cookie. This cookie will store CSRF token.
|
||||||
|
// Optional. Default value "csrf".
|
||||||
|
CookieName string
|
||||||
|
|
||||||
|
// Domain of the CSRF cookie.
|
||||||
|
// Optional. Default value none.
|
||||||
|
CookieDomain string
|
||||||
|
|
||||||
|
// Paht of the CSRF cookie.
|
||||||
|
// Optional. Default value none.
|
||||||
|
CookiePath string
|
||||||
|
|
||||||
|
// Expiriation time of the CSRF cookie.
|
||||||
|
// Optioanl. Default value 24hrs.
|
||||||
|
CookieExpires time.Time
|
||||||
|
|
||||||
|
// Indicates if CSRF cookie is secure.
|
||||||
|
CookieSecure bool
|
||||||
|
|
||||||
|
// Indicates if CSRF cookie is HTTP only.
|
||||||
|
CookieHTTPOnly bool
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
// DefaultCSRFConfig is the default CSRF middleware config.
|
||||||
|
DefaultCSRFConfig = CSRFConfig{
|
||||||
|
HeaderName: echo.HeaderXCSRFToken,
|
||||||
|
CookieName: "csrf",
|
||||||
|
CookieExpires: time.Now().Add(24 * time.Hour),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
// CSRF returns a Cross-Site Request Forgery (CSRF) middleware.
|
||||||
|
// See: https://en.wikipedia.org/wiki/Cross-site_request_forgery
|
||||||
|
func CSRF(secret []byte) echo.MiddlewareFunc {
|
||||||
|
c := DefaultCSRFConfig
|
||||||
|
c.Secret = secret
|
||||||
|
return CSRFWithConfig(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CSRFWithConfig returns a CSRF middleware from config.
|
||||||
|
// See `CSRF()`.
|
||||||
|
func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc {
|
||||||
|
// Defaults
|
||||||
|
if config.Secret == nil {
|
||||||
|
panic("csrf: secret must be provided")
|
||||||
|
}
|
||||||
|
if config.HeaderName == "" {
|
||||||
|
config.HeaderName = DefaultCSRFConfig.HeaderName
|
||||||
|
}
|
||||||
|
if config.CookieName == "" {
|
||||||
|
config.CookieName = DefaultCSRFConfig.CookieName
|
||||||
|
}
|
||||||
|
if config.CookieExpires.IsZero() {
|
||||||
|
config.CookieExpires = DefaultCSRFConfig.CookieExpires
|
||||||
|
}
|
||||||
|
|
||||||
|
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||||
|
return func(c echo.Context) error {
|
||||||
|
req := c.Request()
|
||||||
|
|
||||||
|
// Set CSRF token
|
||||||
|
salt, err := generateSalt(8)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
token := generateCSRFToken(config.Secret, salt)
|
||||||
|
cookie := new(echo.Cookie)
|
||||||
|
cookie.SetName(config.CookieName)
|
||||||
|
cookie.SetValue(token)
|
||||||
|
if config.CookiePath != "" {
|
||||||
|
cookie.SetPath(config.CookiePath)
|
||||||
|
}
|
||||||
|
if config.CookieDomain != "" {
|
||||||
|
cookie.SetDomain(config.CookieDomain)
|
||||||
|
}
|
||||||
|
cookie.SetExpires(config.CookieExpires)
|
||||||
|
cookie.SetSecure(config.CookieSecure)
|
||||||
|
cookie.SetHTTPOnly(config.CookieHTTPOnly)
|
||||||
|
c.SetCookie(cookie)
|
||||||
|
|
||||||
|
switch req.Method() {
|
||||||
|
case echo.GET, echo.HEAD, echo.OPTIONS, echo.TRACE:
|
||||||
|
default:
|
||||||
|
token := req.Header().Get(config.HeaderName)
|
||||||
|
ok, err := validateCSRFToken(token, config.Secret)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
return echo.NewHTTPError(http.StatusForbidden, "csrf: invalid token")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return next(c)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func generateCSRFToken(secret, salt []byte) string {
|
||||||
|
h := hmac.New(sha1.New, secret)
|
||||||
|
h.Write(salt)
|
||||||
|
return fmt.Sprintf("%s:%s", hex.EncodeToString(h.Sum(nil)), hex.EncodeToString(salt))
|
||||||
|
}
|
||||||
|
|
||||||
|
func validateCSRFToken(token string, secret []byte) (bool, error) {
|
||||||
|
sep := strings.Index(token, ":")
|
||||||
|
if sep < 0 {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
salt, err := hex.DecodeString(token[sep+1:])
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
return token == generateCSRFToken(secret, salt), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func generateSalt(len uint8) (salt []byte, err error) {
|
||||||
|
salt = make([]byte, len)
|
||||||
|
_, err = rand.Read(salt)
|
||||||
|
return
|
||||||
|
}
|
40
middleware/csrf_test.go
Normal file
40
middleware/csrf_test.go
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/labstack/echo"
|
||||||
|
"github.com/labstack/echo/test"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestCSRF(t *testing.T) {
|
||||||
|
e := echo.New()
|
||||||
|
req := test.NewRequest(echo.GET, "/", nil)
|
||||||
|
rec := test.NewResponseRecorder()
|
||||||
|
c := e.NewContext(req, rec)
|
||||||
|
csrf := CSRF([]byte("secret"))
|
||||||
|
h := csrf(func(c echo.Context) error {
|
||||||
|
return c.String(http.StatusOK, "test")
|
||||||
|
})
|
||||||
|
|
||||||
|
// Generate CSRF token
|
||||||
|
h(c)
|
||||||
|
assert.Contains(t, rec.Header().Get(echo.HeaderSetCookie), "csrf")
|
||||||
|
|
||||||
|
// Empty/invalid CSRF token
|
||||||
|
req = test.NewRequest(echo.POST, "/", nil)
|
||||||
|
rec = test.NewResponseRecorder()
|
||||||
|
c = e.NewContext(req, rec)
|
||||||
|
req.Header().Set(echo.HeaderXCSRFToken, "")
|
||||||
|
he := h(c).(*echo.HTTPError)
|
||||||
|
assert.Equal(t, http.StatusForbidden, he.Code)
|
||||||
|
|
||||||
|
// Valid CSRF token
|
||||||
|
salt, _ := generateSalt(8)
|
||||||
|
token := generateCSRFToken([]byte("secret"), salt)
|
||||||
|
req.Header().Set(echo.HeaderXCSRFToken, token)
|
||||||
|
h(c)
|
||||||
|
assert.Equal(t, http.StatusOK, rec.Status())
|
||||||
|
}
|
@ -57,7 +57,7 @@ var (
|
|||||||
// For invalid token, it sends "401 - Unauthorized" response.
|
// For invalid token, it sends "401 - Unauthorized" response.
|
||||||
// For empty or invalid `Authorization` header, it sends "400 - Bad Request".
|
// For empty or invalid `Authorization` header, it sends "400 - Bad Request".
|
||||||
//
|
//
|
||||||
// See https://jwt.io/introduction
|
// See: https://jwt.io/introduction
|
||||||
func JWT(key []byte) echo.MiddlewareFunc {
|
func JWT(key []byte) echo.MiddlewareFunc {
|
||||||
c := DefaultJWTConfig
|
c := DefaultJWTConfig
|
||||||
c.SigningKey = key
|
c.SigningKey = key
|
||||||
@ -65,7 +65,7 @@ func JWT(key []byte) echo.MiddlewareFunc {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// JWTWithConfig returns a JWT auth middleware from config.
|
// JWTWithConfig returns a JWT auth middleware from config.
|
||||||
// See `JWT()`.
|
// See: `JWT()`.
|
||||||
func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc {
|
func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc {
|
||||||
// Defaults
|
// Defaults
|
||||||
if config.SigningKey == nil {
|
if config.SigningKey == nil {
|
||||||
|
@ -69,7 +69,7 @@ func Logger() echo.MiddlewareFunc {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// LoggerWithConfig returns a logger middleware from config.
|
// LoggerWithConfig returns a logger middleware from config.
|
||||||
// See `Logger()`.
|
// See: `Logger()`.
|
||||||
func LoggerWithConfig(config LoggerConfig) echo.MiddlewareFunc {
|
func LoggerWithConfig(config LoggerConfig) echo.MiddlewareFunc {
|
||||||
// Defaults
|
// Defaults
|
||||||
if config.Format == "" {
|
if config.Format == "" {
|
||||||
|
@ -33,7 +33,7 @@ func MethodOverride() echo.MiddlewareFunc {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// MethodOverrideWithConfig returns a method override middleware from config.
|
// MethodOverrideWithConfig returns a method override middleware from config.
|
||||||
// See `MethodOverride()`.
|
// See: `MethodOverride()`.
|
||||||
func MethodOverrideWithConfig(config MethodOverrideConfig) echo.MiddlewareFunc {
|
func MethodOverrideWithConfig(config MethodOverrideConfig) echo.MiddlewareFunc {
|
||||||
// Defaults
|
// Defaults
|
||||||
if config.Getter == nil {
|
if config.Getter == nil {
|
||||||
|
@ -42,7 +42,7 @@ func Recover() echo.MiddlewareFunc {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// RecoverWithConfig returns a recover middleware from config.
|
// RecoverWithConfig returns a recover middleware from config.
|
||||||
// See `Recover()`.
|
// See: `Recover()`.
|
||||||
func RecoverWithConfig(config RecoverConfig) echo.MiddlewareFunc {
|
func RecoverWithConfig(config RecoverConfig) echo.MiddlewareFunc {
|
||||||
// Defaults
|
// Defaults
|
||||||
if config.StackSize == 0 {
|
if config.StackSize == 0 {
|
||||||
|
@ -71,7 +71,7 @@ func Secure() echo.MiddlewareFunc {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// SecureWithConfig returns a secure middleware from config.
|
// SecureWithConfig returns a secure middleware from config.
|
||||||
// See `Secure()`.
|
// See: `Secure()`.
|
||||||
func SecureWithConfig(config SecureConfig) echo.MiddlewareFunc {
|
func SecureWithConfig(config SecureConfig) echo.MiddlewareFunc {
|
||||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||||
return func(c echo.Context) error {
|
return func(c echo.Context) error {
|
||||||
|
Loading…
Reference in New Issue
Block a user