1
0
mirror of https://github.com/labstack/echo.git synced 2024-12-24 20:14:31 +02:00

Added ErrorHandler and ErrorHandlerWithContext in CSRF middleware (#2257)

* feat: add error handler to csrf middleware

Co-authored-by: Mojtaba Arezoomand <mojtaba.arezoomand@snapp.cab>
This commit is contained in:
Mojtaba Arezoumand 2022-09-01 12:21:55 +04:30 committed by GitHub
parent 534bbb81e3
commit d77e8c09b2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 38 additions and 2 deletions

View File

@ -61,7 +61,13 @@ type (
// Indicates SameSite mode of the CSRF cookie.
// Optional. Default value SameSiteDefaultMode.
CookieSameSite http.SameSite `yaml:"cookie_same_site"`
// ErrorHandler defines a function which is executed for returning custom errors.
ErrorHandler CSRFErrorHandler
}
// CSRFErrorHandler is a function which is executed for creating custom errors.
CSRFErrorHandler func(err error, c echo.Context) error
)
// ErrCSRFInvalid is returned when CSRF check fails
@ -154,8 +160,9 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc {
lastTokenErr = ErrCSRFInvalid
}
}
var finalErr error
if lastTokenErr != nil {
return lastTokenErr
finalErr = lastTokenErr
} else if lastExtractorErr != nil {
// ugly part to preserve backwards compatible errors. someone could rely on them
if lastExtractorErr == errQueryExtractorValueMissing {
@ -167,7 +174,14 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc {
} else {
lastExtractorErr = echo.NewHTTPError(http.StatusBadRequest, lastExtractorErr.Error())
}
return lastExtractorErr
finalErr = lastExtractorErr
}
if finalErr != nil {
if config.ErrorHandler != nil {
return config.ErrorHandler(finalErr, c)
}
return finalErr
}
}

View File

@ -358,3 +358,25 @@ func TestCSRFConfig_skipper(t *testing.T) {
})
}
}
func TestCSRFErrorHandling(t *testing.T) {
cfg := CSRFConfig{
ErrorHandler: func(err error, c echo.Context) error {
return echo.NewHTTPError(http.StatusTeapot, "error_handler_executed")
},
}
e := echo.New()
e.POST("/", func(c echo.Context) error {
return c.String(http.StatusNotImplemented, "should not end up here")
})
e.Use(CSRFWithConfig(cfg))
req := httptest.NewRequest(http.MethodPost, "/", nil)
res := httptest.NewRecorder()
e.ServeHTTP(res, req)
assert.Equal(t, http.StatusTeapot, res.Code)
assert.Equal(t, "{\"message\":\"error_handler_executed\"}\n", res.Body.String())
}