mirror of
https://github.com/labstack/echo.git
synced 2025-01-24 03:16:14 +02:00
Added panic recover middleware
Signed-off-by: Vishal Rana <vr@labstack.com>
This commit is contained in:
parent
609879bf39
commit
73fa05f826
@ -90,14 +90,21 @@ func main() {
|
||||
// Echo instance
|
||||
e := echo.New()
|
||||
|
||||
//------------
|
||||
// Middleware
|
||||
//------------
|
||||
|
||||
// Recover
|
||||
e.Use(mw.Recover())
|
||||
|
||||
// Logger
|
||||
e.Use(mw.Logger())
|
||||
|
||||
// Routes
|
||||
e.Get("/", hello)
|
||||
|
||||
// Start server
|
||||
e.Run(":1323)
|
||||
e.Run(":1323")
|
||||
}
|
||||
```
|
||||
|
||||
|
@ -7,7 +7,7 @@ import (
|
||||
|
||||
type (
|
||||
// Context represents context for the current request. It holds request and
|
||||
// response references, path parameters, data and registered handler.
|
||||
// response objects, path parameters, data and registered handler.
|
||||
Context struct {
|
||||
Request *http.Request
|
||||
Response *Response
|
||||
|
53
echo.go
53
echo.go
@ -22,12 +22,12 @@ type (
|
||||
prefix string
|
||||
middleware []MiddlewareFunc
|
||||
maxParam byte
|
||||
notFoundHandler HandlerFunc
|
||||
httpErrorHandler HTTPErrorHandler
|
||||
binder BindFunc
|
||||
renderer Renderer
|
||||
uris map[Handler]string
|
||||
pool sync.Pool
|
||||
debug bool
|
||||
}
|
||||
HTTPError struct {
|
||||
Code int
|
||||
@ -115,8 +115,8 @@ var (
|
||||
// Errors
|
||||
//--------
|
||||
|
||||
UnsupportedMediaType = errors.New("echo: unsupported media type")
|
||||
RendererNotRegistered = errors.New("echo: renderer not registered")
|
||||
UnsupportedMediaType = errors.New("echo ⇒ unsupported media type")
|
||||
RendererNotRegistered = errors.New("echo ⇒ renderer not registered")
|
||||
)
|
||||
|
||||
// New creates an Echo instance.
|
||||
@ -134,19 +134,14 @@ func New() (e *Echo) {
|
||||
//----------
|
||||
|
||||
e.MaxParam(5)
|
||||
e.NotFoundHandler(func(c *Context) *HTTPError {
|
||||
http.Error(c.Response, http.StatusText(http.StatusNotFound), http.StatusNotFound)
|
||||
return nil
|
||||
})
|
||||
e.HTTPErrorHandler(func(he *HTTPError, c *Context) {
|
||||
if he.Code == 0 {
|
||||
he.Code = http.StatusInternalServerError
|
||||
}
|
||||
if he.Message == "" {
|
||||
if he.Error != nil {
|
||||
he.Message = http.StatusText(he.Code)
|
||||
if e.debug {
|
||||
he.Message = he.Error.Error()
|
||||
} else {
|
||||
he.Message = http.StatusText(he.Code)
|
||||
}
|
||||
}
|
||||
http.Error(c.Response, he.Message, he.Code)
|
||||
@ -185,12 +180,6 @@ func (e *Echo) MaxParam(n uint8) {
|
||||
e.maxParam = n
|
||||
}
|
||||
|
||||
// NotFoundHandler registers a custom NotFound handler used by router in case it
|
||||
// doesn't find any registered handler for HTTP method and path.
|
||||
func (e *Echo) NotFoundHandler(h Handler) {
|
||||
e.notFoundHandler = wrapHandler(h)
|
||||
}
|
||||
|
||||
// HTTPErrorHandler registers an HTTP error handler.
|
||||
func (e *Echo) HTTPErrorHandler(h HTTPErrorHandler) {
|
||||
e.httpErrorHandler = h
|
||||
@ -207,6 +196,11 @@ func (e *Echo) Renderer(r Renderer) {
|
||||
e.renderer = r
|
||||
}
|
||||
|
||||
// Debug runs the application in debug mode.
|
||||
func (e *Echo) Debug(on bool) {
|
||||
e.debug = on
|
||||
}
|
||||
|
||||
// Use adds handler to the middleware chain.
|
||||
func (e *Echo) Use(m ...Middleware) {
|
||||
for _, h := range m {
|
||||
@ -325,21 +319,20 @@ func (e *Echo) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
if echo != nil {
|
||||
e = echo
|
||||
}
|
||||
if h == nil {
|
||||
h = e.notFoundHandler
|
||||
}
|
||||
c.reset(w, r, e)
|
||||
if h == nil {
|
||||
c.Error(&HTTPError{Code: http.StatusNotFound})
|
||||
} else {
|
||||
// Chain middleware with handler in the end
|
||||
for i := len(e.middleware) - 1; i >= 0; i-- {
|
||||
h = e.middleware[i](h)
|
||||
}
|
||||
|
||||
// Chain middleware with handler in the end
|
||||
for i := len(e.middleware) - 1; i >= 0; i-- {
|
||||
h = e.middleware[i](h)
|
||||
// Execute chain
|
||||
if he := h(c); he != nil {
|
||||
e.httpErrorHandler(he, c)
|
||||
}
|
||||
}
|
||||
|
||||
// Execute chain
|
||||
if he := h(c); he != nil {
|
||||
e.httpErrorHandler(he, c)
|
||||
}
|
||||
|
||||
e.pool.Put(c)
|
||||
}
|
||||
|
||||
@ -394,7 +387,7 @@ func wrapMiddleware(m Middleware) MiddlewareFunc {
|
||||
case func(http.ResponseWriter, *http.Request):
|
||||
return wrapHTTPHandlerFuncMW(m)
|
||||
default:
|
||||
panic("echo: unknown middleware")
|
||||
panic("echo ⇒ unknown middleware")
|
||||
}
|
||||
}
|
||||
|
||||
@ -440,7 +433,7 @@ func wrapHandler(h Handler) HandlerFunc {
|
||||
return nil
|
||||
}
|
||||
default:
|
||||
panic("echo: unknown handler")
|
||||
panic("echo ⇒ unknown handler")
|
||||
}
|
||||
}
|
||||
|
||||
|
10
echo_test.go
10
echo_test.go
@ -285,16 +285,6 @@ func TestEchoNotFound(t *testing.T) {
|
||||
if w.Code != http.StatusNotFound {
|
||||
t.Errorf("status code should be 404, found %d", w.Code)
|
||||
}
|
||||
|
||||
// Customized NotFound handler
|
||||
e.NotFoundHandler(func(c *Context) *HTTPError {
|
||||
return c.String(http.StatusNotFound, "not found")
|
||||
})
|
||||
w = httptest.NewRecorder()
|
||||
e.ServeHTTP(w, r)
|
||||
if w.Body.String() != "not found" {
|
||||
t.Errorf("body should be `not found`")
|
||||
}
|
||||
}
|
||||
|
||||
func verifyUser(u2 *user, t *testing.T) {
|
||||
|
@ -61,6 +61,7 @@ func main() {
|
||||
e := echo.New()
|
||||
|
||||
// Middleware
|
||||
e.Use(mw.Recover())
|
||||
e.Use(mw.Logger())
|
||||
|
||||
// Routes
|
||||
|
@ -16,7 +16,14 @@ func main() {
|
||||
// Echo instance
|
||||
e := echo.New()
|
||||
|
||||
//------------
|
||||
// Middleware
|
||||
//------------
|
||||
|
||||
// Recover
|
||||
e.Use(mw.Recover())
|
||||
|
||||
// Logger
|
||||
e.Use(mw.Logger())
|
||||
|
||||
// Routes
|
||||
|
@ -16,10 +16,16 @@ func main() {
|
||||
// Echo instance
|
||||
e := echo.New()
|
||||
|
||||
// Debug mode
|
||||
e.Debug(true)
|
||||
|
||||
//------------
|
||||
// Middleware
|
||||
//------------
|
||||
|
||||
// Recover
|
||||
e.Use(mw.Recover())
|
||||
|
||||
// Logger
|
||||
e.Use(mw.Logger())
|
||||
|
||||
|
@ -65,6 +65,7 @@ func main() {
|
||||
e := echo.New()
|
||||
|
||||
// Middleware
|
||||
e.Use(mw.Recover())
|
||||
e.Use(mw.Logger())
|
||||
|
||||
//------------------------
|
||||
|
@ -14,7 +14,7 @@ const (
|
||||
Basic = "Basic"
|
||||
)
|
||||
|
||||
// BasicAuth provides HTTP basic authentication.
|
||||
// BasicAuth returns an HTTP basic authentication middleware.
|
||||
func BasicAuth(fn AuthFunc) echo.HandlerFunc {
|
||||
return func(c *echo.Context) (he *echo.HTTPError) {
|
||||
auth := c.Request.Header.Get(echo.Authorization)
|
||||
|
@ -2,15 +2,15 @@ package middleware
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"github.com/labstack/echo"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/labstack/echo"
|
||||
)
|
||||
|
||||
func TestBasicAuth(t *testing.T) {
|
||||
req, _ := http.NewRequest(echo.POST, "/", nil)
|
||||
res := &echo.Response{Writer: httptest.NewRecorder()}
|
||||
res := &echo.Response{}
|
||||
c := echo.NewContext(req, res, echo.New())
|
||||
fn := func(u, p string) bool {
|
||||
if u == "joe" && p == "secret" {
|
||||
@ -34,7 +34,7 @@ func TestBasicAuth(t *testing.T) {
|
||||
auth = "basic " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))
|
||||
req.Header.Set(echo.Authorization, auth)
|
||||
if ba(c) != nil {
|
||||
t.Error("expected `pass` with case insensitive header")
|
||||
t.Error("expected `pass`, with case insensitive header.")
|
||||
}
|
||||
|
||||
//---------------------
|
||||
@ -46,15 +46,22 @@ func TestBasicAuth(t *testing.T) {
|
||||
req.Header.Set(echo.Authorization, auth)
|
||||
ba = BasicAuth(fn)
|
||||
if ba(c) == nil {
|
||||
t.Error("expected `fail` with incorrect password")
|
||||
t.Error("expected `fail`, with incorrect password.")
|
||||
}
|
||||
|
||||
// Invalid header
|
||||
// Empty Authorization header
|
||||
req.Header.Set(echo.Authorization, "")
|
||||
ba = BasicAuth(fn)
|
||||
if ba(c) == nil {
|
||||
t.Error("expected `fail`, with empty Authorization header.")
|
||||
}
|
||||
|
||||
// Invalid Authorization header
|
||||
auth = base64.StdEncoding.EncodeToString([]byte(" :secret"))
|
||||
req.Header.Set(echo.Authorization, auth)
|
||||
ba = BasicAuth(fn)
|
||||
if ba(c) == nil {
|
||||
t.Error("expected `fail` with invalid auth header")
|
||||
t.Error("expected `fail`, with invalid Authorization header.")
|
||||
}
|
||||
|
||||
// Invalid scheme
|
||||
@ -62,13 +69,7 @@ func TestBasicAuth(t *testing.T) {
|
||||
req.Header.Set(echo.Authorization, auth)
|
||||
ba = BasicAuth(fn)
|
||||
if ba(c) == nil {
|
||||
t.Error("expected `fail` with invalid scheme")
|
||||
t.Error("expected `fail`, with invalid scheme.")
|
||||
}
|
||||
|
||||
// Empty auth header
|
||||
req.Header.Set(echo.Authorization, "")
|
||||
ba = BasicAuth(fn)
|
||||
if ba(c) == nil {
|
||||
t.Error("expected `fail` with empty auth header")
|
||||
}
|
||||
}
|
||||
|
@ -19,25 +19,21 @@ func (g gzipWriter) Write(b []byte) (int, error) {
|
||||
return g.Writer.Write(b)
|
||||
}
|
||||
|
||||
// Gzip compresses HTTP response using gzip compression scheme.
|
||||
// Gzip returns a middleware which compresses HTTP response using gzip compression
|
||||
// scheme.
|
||||
func Gzip() echo.MiddlewareFunc {
|
||||
scheme := "gzip"
|
||||
|
||||
return func(h echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c *echo.Context) *echo.HTTPError {
|
||||
if !strings.Contains(c.Request.Header.Get(echo.AcceptEncoding), scheme) {
|
||||
return nil
|
||||
if strings.Contains(c.Request.Header.Get(echo.AcceptEncoding), scheme) {
|
||||
w := gzip.NewWriter(c.Response.Writer)
|
||||
defer w.Close()
|
||||
gw := gzipWriter{Writer: w, Response: c.Response}
|
||||
c.Response.Header().Set(echo.ContentEncoding, scheme)
|
||||
c.Response = &echo.Response{Writer: gw}
|
||||
}
|
||||
|
||||
w := gzip.NewWriter(c.Response.Writer)
|
||||
defer w.Close()
|
||||
gw := gzipWriter{Writer: w, Response: c.Response}
|
||||
c.Response.Header().Set(echo.ContentEncoding, scheme)
|
||||
c.Response = &echo.Response{Writer: gw}
|
||||
if he := h(c); he != nil {
|
||||
c.Error(he)
|
||||
}
|
||||
return nil
|
||||
return h(c)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1,42 +1,52 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"compress/gzip"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"compress/gzip"
|
||||
"github.com/labstack/echo"
|
||||
"io/ioutil"
|
||||
)
|
||||
|
||||
func TestGzip(t *testing.T) {
|
||||
// Empty Accept-Encoding header
|
||||
req, _ := http.NewRequest(echo.GET, "/", nil)
|
||||
req.Header.Set(echo.AcceptEncoding, "gzip")
|
||||
w := httptest.NewRecorder()
|
||||
res := &echo.Response{Writer: w}
|
||||
c := echo.NewContext(req, res, echo.New())
|
||||
Gzip()(func(c *echo.Context) *echo.HTTPError {
|
||||
h := func(c *echo.Context) *echo.HTTPError {
|
||||
return c.String(http.StatusOK, "test")
|
||||
})(c)
|
||||
|
||||
if w.Header().Get(echo.ContentEncoding) != "gzip" {
|
||||
t.Errorf("expected Content-Encoding header `gzip`, got %d.", w.Header().Get(echo.ContentEncoding))
|
||||
}
|
||||
Gzip()(h)(c)
|
||||
s := w.Body.String()
|
||||
if s != "test" {
|
||||
t.Errorf("expected `test`, with empty Accept-Encoding header, got %s.", s)
|
||||
}
|
||||
|
||||
// Content-Encoding header
|
||||
req.Header.Set(echo.AcceptEncoding, "gzip")
|
||||
w = httptest.NewRecorder()
|
||||
c.Response = &echo.Response{Writer: w}
|
||||
Gzip()(h)(c)
|
||||
ce := w.Header().Get(echo.ContentEncoding)
|
||||
if ce != "gzip" {
|
||||
t.Errorf("expected Content-Encoding header `gzip`, got %d.", ce)
|
||||
}
|
||||
|
||||
// Body
|
||||
r, err := gzip.NewReader(w.Body)
|
||||
defer r.Close()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
b, err := ioutil.ReadAll(r)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
s := string(b)
|
||||
|
||||
s = string(b)
|
||||
if s != "test" {
|
||||
t.Errorf("expected `test`, got %s.", s)
|
||||
t.Errorf("expected body `test`, got %s.", s)
|
||||
}
|
||||
}
|
||||
|
30
middleware/recover.go
Normal file
30
middleware/recover.go
Normal file
@ -0,0 +1,30 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"runtime"
|
||||
|
||||
"github.com/labstack/echo"
|
||||
)
|
||||
|
||||
// Recover returns a middleware which recovers from panics anywhere in the chain
|
||||
// and handles the control to centralized HTTPErrorHandler.
|
||||
func Recover() echo.MiddlewareFunc {
|
||||
// TODO: Provide better stack trace `https://github.com/go-errors/errors` `https://github.com/docker/libcontainer/tree/master/stacktrace`
|
||||
return func(h echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c *echo.Context) *echo.HTTPError {
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
trace := make([]byte, 1<<16)
|
||||
n := runtime.Stack(trace, true)
|
||||
c.Error(&echo.HTTPError{
|
||||
Error: fmt.Errorf("echo ⇒ panic recover\n %v\n stack trace %d bytes\n %s",
|
||||
err, n, trace[:n]),
|
||||
})
|
||||
}
|
||||
}()
|
||||
return h(c)
|
||||
}
|
||||
}
|
||||
}
|
33
middleware/recover_test.go
Normal file
33
middleware/recover_test.go
Normal file
@ -0,0 +1,33 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"github.com/labstack/echo"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestRecover(t *testing.T) {
|
||||
e := echo.New()
|
||||
e.Debug(true)
|
||||
req, _ := http.NewRequest(echo.GET, "/", nil)
|
||||
w := httptest.NewRecorder()
|
||||
res := &echo.Response{Writer: w}
|
||||
c := echo.NewContext(req, res, e)
|
||||
h := func(c *echo.Context) *echo.HTTPError {
|
||||
panic("test")
|
||||
}
|
||||
|
||||
// Status
|
||||
Recover()(h)(c)
|
||||
if w.Code != http.StatusInternalServerError {
|
||||
t.Errorf("expected status `500`, got %d.", w.Code)
|
||||
}
|
||||
|
||||
// Body
|
||||
s := w.Body.String()
|
||||
if !strings.Contains(s, "panic recover") {
|
||||
t.Error("expected body contains `panice recover`.")
|
||||
}
|
||||
}
|
@ -11,7 +11,8 @@ type (
|
||||
}
|
||||
)
|
||||
|
||||
// StripTrailingSlash removes trailing slash from request path.
|
||||
// StripTrailingSlash returns a middleware which removes trailing slash from request
|
||||
// path.
|
||||
func StripTrailingSlash() echo.HandlerFunc {
|
||||
return func(c *echo.Context) *echo.HTTPError {
|
||||
p := c.Request.URL.Path
|
||||
@ -23,8 +24,8 @@ func StripTrailingSlash() echo.HandlerFunc {
|
||||
}
|
||||
}
|
||||
|
||||
// RedirectToSlash redirects requests without trailing slash path to trailing slash
|
||||
// path, with .
|
||||
// RedirectToSlash returns a middleware which redirects requests without trailing
|
||||
// slash path to trailing slash path.
|
||||
func RedirectToSlash(opts ...RedirectToSlashOptions) echo.HandlerFunc {
|
||||
code := http.StatusMovedPermanently
|
||||
|
||||
|
@ -23,7 +23,7 @@ func (r *Response) Header() http.Header {
|
||||
func (r *Response) WriteHeader(code int) {
|
||||
if r.committed {
|
||||
// TODO: Warning
|
||||
log.Printf("echo: %s", color.Yellow("response already committed"))
|
||||
log.Printf("echo ⇒ %s", color.Yellow("response already committed"))
|
||||
return
|
||||
}
|
||||
r.status = code
|
||||
|
@ -308,11 +308,11 @@ func (r *router) Find(method, path string, ctx *Context) (h HandlerFunc, echo *E
|
||||
func (r *router) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
||||
c := r.echo.pool.Get().(*Context)
|
||||
h, _ := r.Find(req.Method, req.URL.Path, c)
|
||||
c.reset(w, req, nil)
|
||||
if h != nil {
|
||||
h(c)
|
||||
c.reset(w, req, r.echo)
|
||||
if h == nil {
|
||||
c.Error(&HTTPError{Code: http.StatusNotFound})
|
||||
} else {
|
||||
r.echo.notFoundHandler(c)
|
||||
h(c)
|
||||
}
|
||||
r.echo.pool.Put(c)
|
||||
}
|
||||
|
@ -35,15 +35,6 @@ Sets the maximum number of path parameters allowed for the application.
|
||||
Default value is **5**, [good enough](https://github.com/interagent/http-api-design#minimize-path-nesting)
|
||||
for many use cases. Restricting path parameters allows us to use memory efficiently.
|
||||
|
||||
### Not found handler
|
||||
|
||||
`echo.NotFoundHandler(h Handler)`
|
||||
|
||||
Registers a custom NotFound handler. This handler is called in case router doesn't
|
||||
find a matching route for the HTTP request.
|
||||
|
||||
Default handler sends 404 "Not Found" response.
|
||||
|
||||
### HTTP error handler
|
||||
|
||||
`echo.HTTPErrorHandler(h HTTPErrorHandler)`
|
||||
@ -53,7 +44,7 @@ Registers a custom centralized HTTP error handler `func(*HTTPError, *Context)`.
|
||||
Default handler sends `HTTPError.Message` HTTP response with `HTTPError.Code` status
|
||||
code.
|
||||
|
||||
- If HTTPError.Code is not specified it uses 500 "Internal Server Error".
|
||||
- If HTTPError.Code is not specified it uses "500 - Internal Server Error".
|
||||
- If HTTPError.Message is not specified it uses HTTPError.Error.Error() or the status
|
||||
code text.
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user