diff --git a/middleware/jwt.go b/middleware/jwt.go index cad68396..c8d29f96 100644 --- a/middleware/jwt.go +++ b/middleware/jwt.go @@ -61,7 +61,7 @@ type ( JWTSuccessHandler func(echo.Context) // JWTErrorHandler defines a function which is executed for an invalid token. - JWTErrorHandler func(echo.Context, echo.HandlerFunc) error + JWTErrorHandler func(error, echo.Context, echo.HandlerFunc) error jwtExtractor func(echo.Context) (string, error) ) @@ -74,7 +74,6 @@ const ( // Errors var ( ErrJWTMissing = echo.NewHTTPError(http.StatusBadRequest, "missing or malformed jwt") - ErrJWTInvalid = echo.NewHTTPError(http.StatusUnauthorized, "invalid or expired jwt") ) var ( @@ -158,6 +157,9 @@ func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc { auth, err := extractor(c) if err != nil { + if config.ErrorHandler != nil { + return config.ErrorHandler(err, c, next) + } return err } token := new(jwt.Token) @@ -178,11 +180,11 @@ func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc { return next(c) } if config.ErrorHandler != nil { - return config.ErrorHandler(c, next) + return config.ErrorHandler(err, c, next) } return &echo.HTTPError{ - Code: ErrJWTInvalid.Code, - Message: ErrJWTInvalid.Message, + Code: http.StatusUnauthorized, + Message: "Invalid or expired jwt", Internal: err, } }