diff --git a/README.md b/README.md index c0a9fad7..0ff0b62b 100644 --- a/README.md +++ b/README.md @@ -90,14 +90,21 @@ func main() { // Echo instance e := echo.New() + //------------ // Middleware + //------------ + + // Recover + e.Use(mw.Recover()) + + // Logger e.Use(mw.Logger()) // Routes e.Get("/", hello) // Start server - e.Run(":1323) + e.Run(":1323") } ``` diff --git a/context.go b/context.go index 7cd57156..2d51b65d 100644 --- a/context.go +++ b/context.go @@ -7,7 +7,7 @@ import ( type ( // Context represents context for the current request. It holds request and - // response references, path parameters, data and registered handler. + // response objects, path parameters, data and registered handler. Context struct { Request *http.Request Response *Response diff --git a/echo.go b/echo.go index 38bad2d6..7fa9417e 100644 --- a/echo.go +++ b/echo.go @@ -22,12 +22,12 @@ type ( prefix string middleware []MiddlewareFunc maxParam byte - notFoundHandler HandlerFunc httpErrorHandler HTTPErrorHandler binder BindFunc renderer Renderer uris map[Handler]string pool sync.Pool + debug bool } HTTPError struct { Code int @@ -115,8 +115,8 @@ var ( // Errors //-------- - UnsupportedMediaType = errors.New("echo: unsupported media type") - RendererNotRegistered = errors.New("echo: renderer not registered") + UnsupportedMediaType = errors.New("echo ⇒ unsupported media type") + RendererNotRegistered = errors.New("echo ⇒ renderer not registered") ) // New creates an Echo instance. @@ -134,19 +134,14 @@ func New() (e *Echo) { //---------- e.MaxParam(5) - e.NotFoundHandler(func(c *Context) *HTTPError { - http.Error(c.Response, http.StatusText(http.StatusNotFound), http.StatusNotFound) - return nil - }) e.HTTPErrorHandler(func(he *HTTPError, c *Context) { if he.Code == 0 { he.Code = http.StatusInternalServerError } if he.Message == "" { - if he.Error != nil { + he.Message = http.StatusText(he.Code) + if e.debug { he.Message = he.Error.Error() - } else { - he.Message = http.StatusText(he.Code) } } http.Error(c.Response, he.Message, he.Code) @@ -185,12 +180,6 @@ func (e *Echo) MaxParam(n uint8) { e.maxParam = n } -// NotFoundHandler registers a custom NotFound handler used by router in case it -// doesn't find any registered handler for HTTP method and path. -func (e *Echo) NotFoundHandler(h Handler) { - e.notFoundHandler = wrapHandler(h) -} - // HTTPErrorHandler registers an HTTP error handler. func (e *Echo) HTTPErrorHandler(h HTTPErrorHandler) { e.httpErrorHandler = h @@ -207,6 +196,11 @@ func (e *Echo) Renderer(r Renderer) { e.renderer = r } +// Debug runs the application in debug mode. +func (e *Echo) Debug(on bool) { + e.debug = on +} + // Use adds handler to the middleware chain. func (e *Echo) Use(m ...Middleware) { for _, h := range m { @@ -325,21 +319,20 @@ func (e *Echo) ServeHTTP(w http.ResponseWriter, r *http.Request) { if echo != nil { e = echo } - if h == nil { - h = e.notFoundHandler - } c.reset(w, r, e) + if h == nil { + c.Error(&HTTPError{Code: http.StatusNotFound}) + } else { + // Chain middleware with handler in the end + for i := len(e.middleware) - 1; i >= 0; i-- { + h = e.middleware[i](h) + } - // Chain middleware with handler in the end - for i := len(e.middleware) - 1; i >= 0; i-- { - h = e.middleware[i](h) + // Execute chain + if he := h(c); he != nil { + e.httpErrorHandler(he, c) + } } - - // Execute chain - if he := h(c); he != nil { - e.httpErrorHandler(he, c) - } - e.pool.Put(c) } @@ -394,7 +387,7 @@ func wrapMiddleware(m Middleware) MiddlewareFunc { case func(http.ResponseWriter, *http.Request): return wrapHTTPHandlerFuncMW(m) default: - panic("echo: unknown middleware") + panic("echo ⇒ unknown middleware") } } @@ -440,7 +433,7 @@ func wrapHandler(h Handler) HandlerFunc { return nil } default: - panic("echo: unknown handler") + panic("echo ⇒ unknown handler") } } diff --git a/echo_test.go b/echo_test.go index 763d60ad..3119eeb9 100644 --- a/echo_test.go +++ b/echo_test.go @@ -285,16 +285,6 @@ func TestEchoNotFound(t *testing.T) { if w.Code != http.StatusNotFound { t.Errorf("status code should be 404, found %d", w.Code) } - - // Customized NotFound handler - e.NotFoundHandler(func(c *Context) *HTTPError { - return c.String(http.StatusNotFound, "not found") - }) - w = httptest.NewRecorder() - e.ServeHTTP(w, r) - if w.Body.String() != "not found" { - t.Errorf("body should be `not found`") - } } func verifyUser(u2 *user, t *testing.T) { diff --git a/examples/crud/server.go b/examples/crud/server.go index 3c26dbff..79d6f57b 100644 --- a/examples/crud/server.go +++ b/examples/crud/server.go @@ -61,6 +61,7 @@ func main() { e := echo.New() // Middleware + e.Use(mw.Recover()) e.Use(mw.Logger()) // Routes diff --git a/examples/hello/server.go b/examples/hello/server.go index 15410f54..6a6a23cd 100644 --- a/examples/hello/server.go +++ b/examples/hello/server.go @@ -16,7 +16,14 @@ func main() { // Echo instance e := echo.New() + //------------ // Middleware + //------------ + + // Recover + e.Use(mw.Recover()) + + // Logger e.Use(mw.Logger()) // Routes diff --git a/examples/middleware/server.go b/examples/middleware/server.go index af51e095..92e22b05 100644 --- a/examples/middleware/server.go +++ b/examples/middleware/server.go @@ -16,10 +16,16 @@ func main() { // Echo instance e := echo.New() + // Debug mode + e.Debug(true) + //------------ // Middleware //------------ + // Recover + e.Use(mw.Recover()) + // Logger e.Use(mw.Logger()) diff --git a/examples/web/server.go b/examples/web/server.go index 0331d058..5f40dcae 100644 --- a/examples/web/server.go +++ b/examples/web/server.go @@ -65,6 +65,7 @@ func main() { e := echo.New() // Middleware + e.Use(mw.Recover()) e.Use(mw.Logger()) //------------------------ diff --git a/middleware/auth.go b/middleware/auth.go index 3d3f6697..5fae391e 100644 --- a/middleware/auth.go +++ b/middleware/auth.go @@ -14,7 +14,7 @@ const ( Basic = "Basic" ) -// BasicAuth provides HTTP basic authentication. +// BasicAuth returns an HTTP basic authentication middleware. func BasicAuth(fn AuthFunc) echo.HandlerFunc { return func(c *echo.Context) (he *echo.HTTPError) { auth := c.Request.Header.Get(echo.Authorization) diff --git a/middleware/auth_test.go b/middleware/auth_test.go index 4961cf3c..24ceed00 100644 --- a/middleware/auth_test.go +++ b/middleware/auth_test.go @@ -2,15 +2,15 @@ package middleware import ( "encoding/base64" - "github.com/labstack/echo" "net/http" - "net/http/httptest" "testing" + + "github.com/labstack/echo" ) func TestBasicAuth(t *testing.T) { req, _ := http.NewRequest(echo.POST, "/", nil) - res := &echo.Response{Writer: httptest.NewRecorder()} + res := &echo.Response{} c := echo.NewContext(req, res, echo.New()) fn := func(u, p string) bool { if u == "joe" && p == "secret" { @@ -34,7 +34,7 @@ func TestBasicAuth(t *testing.T) { auth = "basic " + base64.StdEncoding.EncodeToString([]byte("joe:secret")) req.Header.Set(echo.Authorization, auth) if ba(c) != nil { - t.Error("expected `pass` with case insensitive header") + t.Error("expected `pass`, with case insensitive header.") } //--------------------- @@ -46,15 +46,22 @@ func TestBasicAuth(t *testing.T) { req.Header.Set(echo.Authorization, auth) ba = BasicAuth(fn) if ba(c) == nil { - t.Error("expected `fail` with incorrect password") + t.Error("expected `fail`, with incorrect password.") } - // Invalid header + // Empty Authorization header + req.Header.Set(echo.Authorization, "") + ba = BasicAuth(fn) + if ba(c) == nil { + t.Error("expected `fail`, with empty Authorization header.") + } + + // Invalid Authorization header auth = base64.StdEncoding.EncodeToString([]byte(" :secret")) req.Header.Set(echo.Authorization, auth) ba = BasicAuth(fn) if ba(c) == nil { - t.Error("expected `fail` with invalid auth header") + t.Error("expected `fail`, with invalid Authorization header.") } // Invalid scheme @@ -62,13 +69,7 @@ func TestBasicAuth(t *testing.T) { req.Header.Set(echo.Authorization, auth) ba = BasicAuth(fn) if ba(c) == nil { - t.Error("expected `fail` with invalid scheme") + t.Error("expected `fail`, with invalid scheme.") } - // Empty auth header - req.Header.Set(echo.Authorization, "") - ba = BasicAuth(fn) - if ba(c) == nil { - t.Error("expected `fail` with empty auth header") - } } diff --git a/middleware/compress.go b/middleware/compress.go index 987e710e..102bd303 100644 --- a/middleware/compress.go +++ b/middleware/compress.go @@ -19,25 +19,21 @@ func (g gzipWriter) Write(b []byte) (int, error) { return g.Writer.Write(b) } -// Gzip compresses HTTP response using gzip compression scheme. +// Gzip returns a middleware which compresses HTTP response using gzip compression +// scheme. func Gzip() echo.MiddlewareFunc { scheme := "gzip" return func(h echo.HandlerFunc) echo.HandlerFunc { return func(c *echo.Context) *echo.HTTPError { - if !strings.Contains(c.Request.Header.Get(echo.AcceptEncoding), scheme) { - return nil + if strings.Contains(c.Request.Header.Get(echo.AcceptEncoding), scheme) { + w := gzip.NewWriter(c.Response.Writer) + defer w.Close() + gw := gzipWriter{Writer: w, Response: c.Response} + c.Response.Header().Set(echo.ContentEncoding, scheme) + c.Response = &echo.Response{Writer: gw} } - - w := gzip.NewWriter(c.Response.Writer) - defer w.Close() - gw := gzipWriter{Writer: w, Response: c.Response} - c.Response.Header().Set(echo.ContentEncoding, scheme) - c.Response = &echo.Response{Writer: gw} - if he := h(c); he != nil { - c.Error(he) - } - return nil + return h(c) } } } diff --git a/middleware/compress_test.go b/middleware/compress_test.go index 867f1ea4..625c2c9a 100644 --- a/middleware/compress_test.go +++ b/middleware/compress_test.go @@ -1,42 +1,52 @@ package middleware import ( + "compress/gzip" + "io/ioutil" "net/http" "net/http/httptest" "testing" - "compress/gzip" "github.com/labstack/echo" - "io/ioutil" ) func TestGzip(t *testing.T) { + // Empty Accept-Encoding header req, _ := http.NewRequest(echo.GET, "/", nil) - req.Header.Set(echo.AcceptEncoding, "gzip") w := httptest.NewRecorder() res := &echo.Response{Writer: w} c := echo.NewContext(req, res, echo.New()) - Gzip()(func(c *echo.Context) *echo.HTTPError { + h := func(c *echo.Context) *echo.HTTPError { return c.String(http.StatusOK, "test") - })(c) - - if w.Header().Get(echo.ContentEncoding) != "gzip" { - t.Errorf("expected Content-Encoding header `gzip`, got %d.", w.Header().Get(echo.ContentEncoding)) + } + Gzip()(h)(c) + s := w.Body.String() + if s != "test" { + t.Errorf("expected `test`, with empty Accept-Encoding header, got %s.", s) } + // Content-Encoding header + req.Header.Set(echo.AcceptEncoding, "gzip") + w = httptest.NewRecorder() + c.Response = &echo.Response{Writer: w} + Gzip()(h)(c) + ce := w.Header().Get(echo.ContentEncoding) + if ce != "gzip" { + t.Errorf("expected Content-Encoding header `gzip`, got %d.", ce) + } + + // Body r, err := gzip.NewReader(w.Body) defer r.Close() if err != nil { t.Error(err) } - b, err := ioutil.ReadAll(r) if err != nil { t.Error(err) } - s := string(b) - + s = string(b) if s != "test" { - t.Errorf("expected `test`, got %s.", s) + t.Errorf("expected body `test`, got %s.", s) } } diff --git a/middleware/recover.go b/middleware/recover.go new file mode 100644 index 00000000..01a5bbbf --- /dev/null +++ b/middleware/recover.go @@ -0,0 +1,30 @@ +package middleware + +import ( + "fmt" + + "runtime" + + "github.com/labstack/echo" +) + +// Recover returns a middleware which recovers from panics anywhere in the chain +// and handles the control to centralized HTTPErrorHandler. +func Recover() echo.MiddlewareFunc { + // TODO: Provide better stack trace `https://github.com/go-errors/errors` `https://github.com/docker/libcontainer/tree/master/stacktrace` + return func(h echo.HandlerFunc) echo.HandlerFunc { + return func(c *echo.Context) *echo.HTTPError { + defer func() { + if err := recover(); err != nil { + trace := make([]byte, 1<<16) + n := runtime.Stack(trace, true) + c.Error(&echo.HTTPError{ + Error: fmt.Errorf("echo ⇒ panic recover\n %v\n stack trace %d bytes\n %s", + err, n, trace[:n]), + }) + } + }() + return h(c) + } + } +} diff --git a/middleware/recover_test.go b/middleware/recover_test.go new file mode 100644 index 00000000..3f16fe14 --- /dev/null +++ b/middleware/recover_test.go @@ -0,0 +1,33 @@ +package middleware + +import ( + "github.com/labstack/echo" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func TestRecover(t *testing.T) { + e := echo.New() + e.Debug(true) + req, _ := http.NewRequest(echo.GET, "/", nil) + w := httptest.NewRecorder() + res := &echo.Response{Writer: w} + c := echo.NewContext(req, res, e) + h := func(c *echo.Context) *echo.HTTPError { + panic("test") + } + + // Status + Recover()(h)(c) + if w.Code != http.StatusInternalServerError { + t.Errorf("expected status `500`, got %d.", w.Code) + } + + // Body + s := w.Body.String() + if !strings.Contains(s, "panic recover") { + t.Error("expected body contains `panice recover`.") + } +} diff --git a/middleware/slash.go b/middleware/slash.go index 70b0eed8..3efef310 100644 --- a/middleware/slash.go +++ b/middleware/slash.go @@ -11,7 +11,8 @@ type ( } ) -// StripTrailingSlash removes trailing slash from request path. +// StripTrailingSlash returns a middleware which removes trailing slash from request +// path. func StripTrailingSlash() echo.HandlerFunc { return func(c *echo.Context) *echo.HTTPError { p := c.Request.URL.Path @@ -23,8 +24,8 @@ func StripTrailingSlash() echo.HandlerFunc { } } -// RedirectToSlash redirects requests without trailing slash path to trailing slash -// path, with . +// RedirectToSlash returns a middleware which redirects requests without trailing +// slash path to trailing slash path. func RedirectToSlash(opts ...RedirectToSlashOptions) echo.HandlerFunc { code := http.StatusMovedPermanently diff --git a/response.go b/response.go index bec7ddd7..756ce58c 100644 --- a/response.go +++ b/response.go @@ -23,7 +23,7 @@ func (r *Response) Header() http.Header { func (r *Response) WriteHeader(code int) { if r.committed { // TODO: Warning - log.Printf("echo: %s", color.Yellow("response already committed")) + log.Printf("echo ⇒ %s", color.Yellow("response already committed")) return } r.status = code diff --git a/router.go b/router.go index 75f10e47..15956464 100644 --- a/router.go +++ b/router.go @@ -308,11 +308,11 @@ func (r *router) Find(method, path string, ctx *Context) (h HandlerFunc, echo *E func (r *router) ServeHTTP(w http.ResponseWriter, req *http.Request) { c := r.echo.pool.Get().(*Context) h, _ := r.Find(req.Method, req.URL.Path, c) - c.reset(w, req, nil) - if h != nil { - h(c) + c.reset(w, req, r.echo) + if h == nil { + c.Error(&HTTPError{Code: http.StatusNotFound}) } else { - r.echo.notFoundHandler(c) + h(c) } r.echo.pool.Put(c) } diff --git a/website/docs/guide.md b/website/docs/guide.md index db92b4ae..8f8f7e2f 100644 --- a/website/docs/guide.md +++ b/website/docs/guide.md @@ -35,15 +35,6 @@ Sets the maximum number of path parameters allowed for the application. Default value is **5**, [good enough](https://github.com/interagent/http-api-design#minimize-path-nesting) for many use cases. Restricting path parameters allows us to use memory efficiently. -### Not found handler - -`echo.NotFoundHandler(h Handler)` - -Registers a custom NotFound handler. This handler is called in case router doesn't -find a matching route for the HTTP request. - -Default handler sends 404 "Not Found" response. - ### HTTP error handler `echo.HTTPErrorHandler(h HTTPErrorHandler)` @@ -53,7 +44,7 @@ Registers a custom centralized HTTP error handler `func(*HTTPError, *Context)`. Default handler sends `HTTPError.Message` HTTP response with `HTTPError.Code` status code. -- If HTTPError.Code is not specified it uses 500 "Internal Server Error". +- If HTTPError.Code is not specified it uses "500 - Internal Server Error". - If HTTPError.Message is not specified it uses HTTPError.Error.Error() or the status code text.