diff --git a/middleware/csrf.go b/middleware/csrf.go index 2fce9d04..318da42e 100644 --- a/middleware/csrf.go +++ b/middleware/csrf.go @@ -1,12 +1,9 @@ package middleware import ( - "crypto/hmac" - "crypto/rand" - "crypto/sha1" - "encoding/hex" + "crypto/subtle" "errors" - "fmt" + "math/rand" "net/http" "strings" "time" @@ -17,8 +14,9 @@ import ( type ( // CSRFConfig defines the config for CSRF middleware. CSRFConfig struct { - // Key to create CSRF token. - Secret []byte `json:"secret"` + // TokenLength is the length of the generated token. + TokenLength uint8 `json:"token_length"` + // Optional. Default value 32. // TokenLookup is a string in the form of ":" that is used // to extract token from the request. @@ -52,6 +50,10 @@ type ( // Indicates if CSRF cookie is secure. // Optional. Default value false. CookieSecure bool `json:"cookie_secure"` + + // Indicates if CSRF cookie is HTTP only. + // Optional. Default value false. + CookieHTTPOnly bool `json:"cookie_http_only"` } // csrfTokenExtractor defines a function that takes `echo.Context` and returns @@ -62,6 +64,7 @@ type ( var ( // DefaultCSRFConfig is the default CSRF middleware config. DefaultCSRFConfig = CSRFConfig{ + TokenLength: 32, TokenLookup: "header:" + echo.HeaderXCSRFToken, ContextKey: "csrf", CookieName: "_csrf", @@ -71,9 +74,8 @@ var ( // CSRF returns a Cross-Site Request Forgery (CSRF) middleware. // See: https://en.wikipedia.org/wiki/Cross-site_request_forgery -func CSRF(secret []byte) echo.MiddlewareFunc { +func CSRF() echo.MiddlewareFunc { c := DefaultCSRFConfig - c.Secret = secret return CSRFWithConfig(c) } @@ -81,8 +83,8 @@ func CSRF(secret []byte) echo.MiddlewareFunc { // See `CSRF()`. func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc { // Defaults - if config.Secret == nil { - panic("csrf secret must be provided") + if config.TokenLength == 0 { + config.TokenLength = DefaultCSRFConfig.TokenLength } if config.TokenLookup == "" { config.TokenLookup = DefaultCSRFConfig.TokenLookup @@ -110,51 +112,51 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc { return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { req := c.Request() - cookie, err := c.Cookie(config.CookieName) + k, err := c.Cookie(config.CookieName) token := "" if err != nil { - // Token expired, generate it - salt, err := generateSalt(8) - if err != nil { - return err - } - token = generateCSRFToken(config.Secret, salt) - cookie := new(echo.Cookie) - cookie.SetName(config.CookieName) - cookie.SetValue(token) - if config.CookiePath != "" { - cookie.SetPath(config.CookiePath) - } - if config.CookieDomain != "" { - cookie.SetDomain(config.CookieDomain) - } - cookie.SetExpires(time.Now().Add(time.Duration(config.CookieMaxAge) * time.Second)) - cookie.SetSecure(config.CookieSecure) - cookie.SetHTTPOnly(true) - c.SetCookie(cookie) + // Generate token + token = generateCSRFToken(config.TokenLength) } else { // Reuse token - token = cookie.Value() + token = k.Value() } - c.Set(config.ContextKey, token) - switch req.Method() { case echo.GET, echo.HEAD, echo.OPTIONS, echo.TRACE: default: + // Validate token only for requests which are not defined as 'safe' by RFC7231 clientToken, err := extractor(c) if err != nil { return err } - ok, err := validateCSRFToken(token, clientToken, config.Secret) - if err != nil { - return err - } - if !ok { - return echo.NewHTTPError(http.StatusForbidden, "invalid csrf token") + if !validateCSRFToken(token, clientToken) { + return echo.NewHTTPError(http.StatusForbidden, "csrf token is invalid") } } + + // Set CSRF cookie + cookie := new(echo.Cookie) + cookie.SetName(config.CookieName) + cookie.SetValue(token) + if config.CookiePath != "" { + cookie.SetPath(config.CookiePath) + } + if config.CookieDomain != "" { + cookie.SetDomain(config.CookieDomain) + } + cookie.SetExpires(time.Now().Add(time.Duration(config.CookieMaxAge) * time.Second)) + cookie.SetSecure(config.CookieSecure) + cookie.SetHTTPOnly(config.CookieHTTPOnly) + c.SetCookie(cookie) + + // Store token in the context + c.Set(config.ContextKey, token) + + // Protect clients from caching the response + c.Response().Header().Add(echo.HeaderVary, echo.HeaderCookie) + return next(c) } } @@ -192,29 +194,16 @@ func csrfTokenFromQuery(param string) csrfTokenExtractor { } } -func generateCSRFToken(secret, salt []byte) string { - h := hmac.New(sha1.New, secret) - h.Write(salt) - return fmt.Sprintf("%s:%s", hex.EncodeToString(h.Sum(nil)), hex.EncodeToString(salt)) +func generateCSRFToken(n uint8) string { + // TODO: From utility library + chars := "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" + b := make([]byte, n) + for i := range b { + b[i] = chars[rand.Int63()%int64(len(chars))] + } + return string(b) } -func validateCSRFToken(serverToken, clientToken string, secret []byte) (bool, error) { - if serverToken != clientToken { - return false, nil - } - sep := strings.Index(clientToken, ":") - if sep < 0 { - return false, nil - } - salt, err := hex.DecodeString(clientToken[sep+1:]) - if err != nil { - return false, err - } - return clientToken == generateCSRFToken(secret, salt), nil -} - -func generateSalt(len uint8) (salt []byte, err error) { - salt = make([]byte, len) - _, err = rand.Read(salt) - return +func validateCSRFToken(token, clientToken string) bool { + return subtle.ConstantTimeCompare([]byte(token), []byte(clientToken)) == 1 } diff --git a/middleware/csrf_test.go b/middleware/csrf_test.go index 73b10253..6b3df9cc 100644 --- a/middleware/csrf_test.go +++ b/middleware/csrf_test.go @@ -17,17 +17,12 @@ func TestCSRF(t *testing.T) { rec := test.NewResponseRecorder() c := e.NewContext(req, rec) csrf := CSRFWithConfig(CSRFConfig{ - Secret: []byte("secret"), + TokenLength: 16, }) h := csrf(func(c echo.Context) error { return c.String(http.StatusOK, "test") }) - // No secret - assert.Panics(t, func() { - CSRF(nil) - }) - // Generate CSRF token h(c) assert.Contains(t, rec.Header().Get(echo.HeaderSetCookie), "_csrf") @@ -46,8 +41,7 @@ func TestCSRF(t *testing.T) { assert.Error(t, h(c)) // Valid CSRF token - salt, _ := generateSalt(8) - token := generateCSRFToken([]byte("secret"), salt) + token := generateCSRFToken(16) req.Header().Set(echo.HeaderCookie, "_csrf="+token) req.Header().Set(echo.HeaderXCSRFToken, token) if assert.NoError(t, h(c)) {