mirror of
https://github.com/labstack/echo.git
synced 2024-12-24 20:14:31 +02:00
Added ErrorHandler and ErrorHandlerWithContext in CSRF middleware (#2257)
* feat: add error handler to csrf middleware Co-authored-by: Mojtaba Arezoomand <mojtaba.arezoomand@snapp.cab>
This commit is contained in:
parent
534bbb81e3
commit
d77e8c09b2
@ -61,7 +61,13 @@ type (
|
||||
// Indicates SameSite mode of the CSRF cookie.
|
||||
// Optional. Default value SameSiteDefaultMode.
|
||||
CookieSameSite http.SameSite `yaml:"cookie_same_site"`
|
||||
|
||||
// ErrorHandler defines a function which is executed for returning custom errors.
|
||||
ErrorHandler CSRFErrorHandler
|
||||
}
|
||||
|
||||
// CSRFErrorHandler is a function which is executed for creating custom errors.
|
||||
CSRFErrorHandler func(err error, c echo.Context) error
|
||||
)
|
||||
|
||||
// ErrCSRFInvalid is returned when CSRF check fails
|
||||
@ -154,8 +160,9 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc {
|
||||
lastTokenErr = ErrCSRFInvalid
|
||||
}
|
||||
}
|
||||
var finalErr error
|
||||
if lastTokenErr != nil {
|
||||
return lastTokenErr
|
||||
finalErr = lastTokenErr
|
||||
} else if lastExtractorErr != nil {
|
||||
// ugly part to preserve backwards compatible errors. someone could rely on them
|
||||
if lastExtractorErr == errQueryExtractorValueMissing {
|
||||
@ -167,7 +174,14 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc {
|
||||
} else {
|
||||
lastExtractorErr = echo.NewHTTPError(http.StatusBadRequest, lastExtractorErr.Error())
|
||||
}
|
||||
return lastExtractorErr
|
||||
finalErr = lastExtractorErr
|
||||
}
|
||||
|
||||
if finalErr != nil {
|
||||
if config.ErrorHandler != nil {
|
||||
return config.ErrorHandler(finalErr, c)
|
||||
}
|
||||
return finalErr
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -358,3 +358,25 @@ func TestCSRFConfig_skipper(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCSRFErrorHandling(t *testing.T) {
|
||||
cfg := CSRFConfig{
|
||||
ErrorHandler: func(err error, c echo.Context) error {
|
||||
return echo.NewHTTPError(http.StatusTeapot, "error_handler_executed")
|
||||
},
|
||||
}
|
||||
|
||||
e := echo.New()
|
||||
e.POST("/", func(c echo.Context) error {
|
||||
return c.String(http.StatusNotImplemented, "should not end up here")
|
||||
})
|
||||
|
||||
e.Use(CSRFWithConfig(cfg))
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/", nil)
|
||||
res := httptest.NewRecorder()
|
||||
e.ServeHTTP(res, req)
|
||||
|
||||
assert.Equal(t, http.StatusTeapot, res.Code)
|
||||
assert.Equal(t, "{\"message\":\"error_handler_executed\"}\n", res.Body.String())
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user