diff --git a/echo.go b/echo.go index 3a031098..5af66b65 100644 --- a/echo.go +++ b/echo.go @@ -174,6 +174,7 @@ const ( HeaderXXSSProtection = "X-XSS-Protection" HeaderXFrameOptions = "X-Frame-Options" HeaderContentSecurityPolicy = "Content-Security-Policy" + HeaderXCSRFToken = "X-CSRF-Token" ) var ( diff --git a/middleware/body_limit.go b/middleware/body_limit.go index f09aac31..3f92c86d 100644 --- a/middleware/body_limit.go +++ b/middleware/body_limit.go @@ -40,7 +40,7 @@ func BodyLimit(limit string) echo.MiddlewareFunc { } // BodyLimitWithConfig returns a body limit middleware from config. -// See `BodyLimit()`. +// See: `BodyLimit()`. func BodyLimitWithConfig(config BodyLimitConfig) echo.MiddlewareFunc { limit, err := bytes.Parse(config.Limit) if err != nil { diff --git a/middleware/compress.go b/middleware/compress.go index 83d0b2f5..07508579 100644 --- a/middleware/compress.go +++ b/middleware/compress.go @@ -40,7 +40,7 @@ func Gzip() echo.MiddlewareFunc { } // GzipWithConfig return gzip middleware from config. -// See `Gzip()`. +// See: `Gzip()`. func GzipWithConfig(config GzipConfig) echo.MiddlewareFunc { // Defaults if config.Level == 0 { diff --git a/middleware/cors.go b/middleware/cors.go index 3cf0d40e..a0aced44 100644 --- a/middleware/cors.go +++ b/middleware/cors.go @@ -53,13 +53,13 @@ var ( ) // CORS returns a Cross-Origin Resource Sharing (CORS) middleware. -// See https://developer.mozilla.org/en/docs/Web/HTTP/Access_control_CORS +// See: https://developer.mozilla.org/en/docs/Web/HTTP/Access_control_CORS func CORS() echo.MiddlewareFunc { return CORSWithConfig(DefaultCORSConfig) } // CORSWithConfig returns a CORS middleware from config. -// See `CORS()`. +// See: `CORS()`. func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc { // Defaults if len(config.AllowOrigins) == 0 { diff --git a/middleware/csrf.go b/middleware/csrf.go new file mode 100644 index 00000000..95d2380d --- /dev/null +++ b/middleware/csrf.go @@ -0,0 +1,146 @@ +package middleware + +import ( + "crypto/hmac" + "crypto/rand" + "crypto/sha1" + "encoding/hex" + "fmt" + "net/http" + "strings" + "time" + + "github.com/labstack/echo" +) + +type ( + // CSRFConfig defines the config for CSRF middleware. + CSRFConfig struct { + // Key to create CSRF token. + Secret []byte + + // Name of the request header to extract CSRF token. + HeaderName string + + // Name of the CSRF cookie. This cookie will store CSRF token. + // Optional. Default value "csrf". + CookieName string + + // Domain of the CSRF cookie. + // Optional. Default value none. + CookieDomain string + + // Paht of the CSRF cookie. + // Optional. Default value none. + CookiePath string + + // Expiriation time of the CSRF cookie. + // Optioanl. Default value 24hrs. + CookieExpires time.Time + + // Indicates if CSRF cookie is secure. + CookieSecure bool + + // Indicates if CSRF cookie is HTTP only. + CookieHTTPOnly bool + } +) + +var ( + // DefaultCSRFConfig is the default CSRF middleware config. + DefaultCSRFConfig = CSRFConfig{ + HeaderName: echo.HeaderXCSRFToken, + CookieName: "csrf", + CookieExpires: time.Now().Add(24 * time.Hour), + } +) + +// CSRF returns a Cross-Site Request Forgery (CSRF) middleware. +// See: https://en.wikipedia.org/wiki/Cross-site_request_forgery +func CSRF(secret []byte) echo.MiddlewareFunc { + c := DefaultCSRFConfig + c.Secret = secret + return CSRFWithConfig(c) +} + +// CSRFWithConfig returns a CSRF middleware from config. +// See `CSRF()`. +func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc { + // Defaults + if config.Secret == nil { + panic("csrf: secret must be provided") + } + if config.HeaderName == "" { + config.HeaderName = DefaultCSRFConfig.HeaderName + } + if config.CookieName == "" { + config.CookieName = DefaultCSRFConfig.CookieName + } + if config.CookieExpires.IsZero() { + config.CookieExpires = DefaultCSRFConfig.CookieExpires + } + + return func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + req := c.Request() + + // Set CSRF token + salt, err := generateSalt(8) + if err != nil { + return err + } + token := generateCSRFToken(config.Secret, salt) + cookie := new(echo.Cookie) + cookie.SetName(config.CookieName) + cookie.SetValue(token) + if config.CookiePath != "" { + cookie.SetPath(config.CookiePath) + } + if config.CookieDomain != "" { + cookie.SetDomain(config.CookieDomain) + } + cookie.SetExpires(config.CookieExpires) + cookie.SetSecure(config.CookieSecure) + cookie.SetHTTPOnly(config.CookieHTTPOnly) + c.SetCookie(cookie) + + switch req.Method() { + case echo.GET, echo.HEAD, echo.OPTIONS, echo.TRACE: + default: + token := req.Header().Get(config.HeaderName) + ok, err := validateCSRFToken(token, config.Secret) + if err != nil { + return err + } + if !ok { + return echo.NewHTTPError(http.StatusForbidden, "csrf: invalid token") + } + } + return next(c) + } + } +} + +func generateCSRFToken(secret, salt []byte) string { + h := hmac.New(sha1.New, secret) + h.Write(salt) + return fmt.Sprintf("%s:%s", hex.EncodeToString(h.Sum(nil)), hex.EncodeToString(salt)) +} + +func validateCSRFToken(token string, secret []byte) (bool, error) { + sep := strings.Index(token, ":") + if sep < 0 { + return false, nil + } + salt, err := hex.DecodeString(token[sep+1:]) + if err != nil { + return false, err + } + return token == generateCSRFToken(secret, salt), nil +} + +func generateSalt(len uint8) (salt []byte, err error) { + salt = make([]byte, len) + _, err = rand.Read(salt) + return +} diff --git a/middleware/csrf_test.go b/middleware/csrf_test.go new file mode 100644 index 00000000..c69d828b --- /dev/null +++ b/middleware/csrf_test.go @@ -0,0 +1,40 @@ +package middleware + +import ( + "net/http" + "testing" + + "github.com/labstack/echo" + "github.com/labstack/echo/test" + "github.com/stretchr/testify/assert" +) + +func TestCSRF(t *testing.T) { + e := echo.New() + req := test.NewRequest(echo.GET, "/", nil) + rec := test.NewResponseRecorder() + c := e.NewContext(req, rec) + csrf := CSRF([]byte("secret")) + h := csrf(func(c echo.Context) error { + return c.String(http.StatusOK, "test") + }) + + // Generate CSRF token + h(c) + assert.Contains(t, rec.Header().Get(echo.HeaderSetCookie), "csrf") + + // Empty/invalid CSRF token + req = test.NewRequest(echo.POST, "/", nil) + rec = test.NewResponseRecorder() + c = e.NewContext(req, rec) + req.Header().Set(echo.HeaderXCSRFToken, "") + he := h(c).(*echo.HTTPError) + assert.Equal(t, http.StatusForbidden, he.Code) + + // Valid CSRF token + salt, _ := generateSalt(8) + token := generateCSRFToken([]byte("secret"), salt) + req.Header().Set(echo.HeaderXCSRFToken, token) + h(c) + assert.Equal(t, http.StatusOK, rec.Status()) +} diff --git a/middleware/jwt.go b/middleware/jwt.go index 2ec9fd1c..935d660a 100644 --- a/middleware/jwt.go +++ b/middleware/jwt.go @@ -57,7 +57,7 @@ var ( // For invalid token, it sends "401 - Unauthorized" response. // For empty or invalid `Authorization` header, it sends "400 - Bad Request". // -// See https://jwt.io/introduction +// See: https://jwt.io/introduction func JWT(key []byte) echo.MiddlewareFunc { c := DefaultJWTConfig c.SigningKey = key @@ -65,7 +65,7 @@ func JWT(key []byte) echo.MiddlewareFunc { } // JWTWithConfig returns a JWT auth middleware from config. -// See `JWT()`. +// See: `JWT()`. func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc { // Defaults if config.SigningKey == nil { diff --git a/middleware/logger.go b/middleware/logger.go index c22e40c1..356f0e34 100644 --- a/middleware/logger.go +++ b/middleware/logger.go @@ -69,7 +69,7 @@ func Logger() echo.MiddlewareFunc { } // LoggerWithConfig returns a logger middleware from config. -// See `Logger()`. +// See: `Logger()`. func LoggerWithConfig(config LoggerConfig) echo.MiddlewareFunc { // Defaults if config.Format == "" { diff --git a/middleware/method_override.go b/middleware/method_override.go index 0b49e5ca..7a0744e9 100644 --- a/middleware/method_override.go +++ b/middleware/method_override.go @@ -33,7 +33,7 @@ func MethodOverride() echo.MiddlewareFunc { } // MethodOverrideWithConfig returns a method override middleware from config. -// See `MethodOverride()`. +// See: `MethodOverride()`. func MethodOverrideWithConfig(config MethodOverrideConfig) echo.MiddlewareFunc { // Defaults if config.Getter == nil { diff --git a/middleware/recover.go b/middleware/recover.go index 0c250e4e..be25def4 100644 --- a/middleware/recover.go +++ b/middleware/recover.go @@ -42,7 +42,7 @@ func Recover() echo.MiddlewareFunc { } // RecoverWithConfig returns a recover middleware from config. -// See `Recover()`. +// See: `Recover()`. func RecoverWithConfig(config RecoverConfig) echo.MiddlewareFunc { // Defaults if config.StackSize == 0 { diff --git a/middleware/secure.go b/middleware/secure.go index 1702595d..3a12a54f 100644 --- a/middleware/secure.go +++ b/middleware/secure.go @@ -71,7 +71,7 @@ func Secure() echo.MiddlewareFunc { } // SecureWithConfig returns a secure middleware from config. -// See `Secure()`. +// See: `Secure()`. func SecureWithConfig(config SecureConfig) echo.MiddlewareFunc { return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error {