1
0
mirror of https://github.com/labstack/echo.git synced 2025-01-24 03:16:14 +02:00

Added panic recover middleware

Signed-off-by: Vishal Rana <vr@labstack.com>
This commit is contained in:
Vishal Rana 2015-05-17 22:54:29 -07:00
parent 609879bf39
commit 73fa05f826
18 changed files with 167 additions and 100 deletions

View File

@ -90,14 +90,21 @@ func main() {
// Echo instance
e := echo.New()
//------------
// Middleware
//------------
// Recover
e.Use(mw.Recover())
// Logger
e.Use(mw.Logger())
// Routes
e.Get("/", hello)
// Start server
e.Run(":1323)
e.Run(":1323")
}
```

View File

@ -7,7 +7,7 @@ import (
type (
// Context represents context for the current request. It holds request and
// response references, path parameters, data and registered handler.
// response objects, path parameters, data and registered handler.
Context struct {
Request *http.Request
Response *Response

53
echo.go
View File

@ -22,12 +22,12 @@ type (
prefix string
middleware []MiddlewareFunc
maxParam byte
notFoundHandler HandlerFunc
httpErrorHandler HTTPErrorHandler
binder BindFunc
renderer Renderer
uris map[Handler]string
pool sync.Pool
debug bool
}
HTTPError struct {
Code int
@ -115,8 +115,8 @@ var (
// Errors
//--------
UnsupportedMediaType = errors.New("echo: unsupported media type")
RendererNotRegistered = errors.New("echo: renderer not registered")
UnsupportedMediaType = errors.New("echo unsupported media type")
RendererNotRegistered = errors.New("echo renderer not registered")
)
// New creates an Echo instance.
@ -134,19 +134,14 @@ func New() (e *Echo) {
//----------
e.MaxParam(5)
e.NotFoundHandler(func(c *Context) *HTTPError {
http.Error(c.Response, http.StatusText(http.StatusNotFound), http.StatusNotFound)
return nil
})
e.HTTPErrorHandler(func(he *HTTPError, c *Context) {
if he.Code == 0 {
he.Code = http.StatusInternalServerError
}
if he.Message == "" {
if he.Error != nil {
he.Message = http.StatusText(he.Code)
if e.debug {
he.Message = he.Error.Error()
} else {
he.Message = http.StatusText(he.Code)
}
}
http.Error(c.Response, he.Message, he.Code)
@ -185,12 +180,6 @@ func (e *Echo) MaxParam(n uint8) {
e.maxParam = n
}
// NotFoundHandler registers a custom NotFound handler used by router in case it
// doesn't find any registered handler for HTTP method and path.
func (e *Echo) NotFoundHandler(h Handler) {
e.notFoundHandler = wrapHandler(h)
}
// HTTPErrorHandler registers an HTTP error handler.
func (e *Echo) HTTPErrorHandler(h HTTPErrorHandler) {
e.httpErrorHandler = h
@ -207,6 +196,11 @@ func (e *Echo) Renderer(r Renderer) {
e.renderer = r
}
// Debug runs the application in debug mode.
func (e *Echo) Debug(on bool) {
e.debug = on
}
// Use adds handler to the middleware chain.
func (e *Echo) Use(m ...Middleware) {
for _, h := range m {
@ -325,21 +319,20 @@ func (e *Echo) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if echo != nil {
e = echo
}
if h == nil {
h = e.notFoundHandler
}
c.reset(w, r, e)
if h == nil {
c.Error(&HTTPError{Code: http.StatusNotFound})
} else {
// Chain middleware with handler in the end
for i := len(e.middleware) - 1; i >= 0; i-- {
h = e.middleware[i](h)
}
// Chain middleware with handler in the end
for i := len(e.middleware) - 1; i >= 0; i-- {
h = e.middleware[i](h)
// Execute chain
if he := h(c); he != nil {
e.httpErrorHandler(he, c)
}
}
// Execute chain
if he := h(c); he != nil {
e.httpErrorHandler(he, c)
}
e.pool.Put(c)
}
@ -394,7 +387,7 @@ func wrapMiddleware(m Middleware) MiddlewareFunc {
case func(http.ResponseWriter, *http.Request):
return wrapHTTPHandlerFuncMW(m)
default:
panic("echo: unknown middleware")
panic("echo unknown middleware")
}
}
@ -440,7 +433,7 @@ func wrapHandler(h Handler) HandlerFunc {
return nil
}
default:
panic("echo: unknown handler")
panic("echo unknown handler")
}
}

View File

@ -285,16 +285,6 @@ func TestEchoNotFound(t *testing.T) {
if w.Code != http.StatusNotFound {
t.Errorf("status code should be 404, found %d", w.Code)
}
// Customized NotFound handler
e.NotFoundHandler(func(c *Context) *HTTPError {
return c.String(http.StatusNotFound, "not found")
})
w = httptest.NewRecorder()
e.ServeHTTP(w, r)
if w.Body.String() != "not found" {
t.Errorf("body should be `not found`")
}
}
func verifyUser(u2 *user, t *testing.T) {

View File

@ -61,6 +61,7 @@ func main() {
e := echo.New()
// Middleware
e.Use(mw.Recover())
e.Use(mw.Logger())
// Routes

View File

@ -16,7 +16,14 @@ func main() {
// Echo instance
e := echo.New()
//------------
// Middleware
//------------
// Recover
e.Use(mw.Recover())
// Logger
e.Use(mw.Logger())
// Routes

View File

@ -16,10 +16,16 @@ func main() {
// Echo instance
e := echo.New()
// Debug mode
e.Debug(true)
//------------
// Middleware
//------------
// Recover
e.Use(mw.Recover())
// Logger
e.Use(mw.Logger())

View File

@ -65,6 +65,7 @@ func main() {
e := echo.New()
// Middleware
e.Use(mw.Recover())
e.Use(mw.Logger())
//------------------------

View File

@ -14,7 +14,7 @@ const (
Basic = "Basic"
)
// BasicAuth provides HTTP basic authentication.
// BasicAuth returns an HTTP basic authentication middleware.
func BasicAuth(fn AuthFunc) echo.HandlerFunc {
return func(c *echo.Context) (he *echo.HTTPError) {
auth := c.Request.Header.Get(echo.Authorization)

View File

@ -2,15 +2,15 @@ package middleware
import (
"encoding/base64"
"github.com/labstack/echo"
"net/http"
"net/http/httptest"
"testing"
"github.com/labstack/echo"
)
func TestBasicAuth(t *testing.T) {
req, _ := http.NewRequest(echo.POST, "/", nil)
res := &echo.Response{Writer: httptest.NewRecorder()}
res := &echo.Response{}
c := echo.NewContext(req, res, echo.New())
fn := func(u, p string) bool {
if u == "joe" && p == "secret" {
@ -34,7 +34,7 @@ func TestBasicAuth(t *testing.T) {
auth = "basic " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))
req.Header.Set(echo.Authorization, auth)
if ba(c) != nil {
t.Error("expected `pass` with case insensitive header")
t.Error("expected `pass`, with case insensitive header.")
}
//---------------------
@ -46,15 +46,22 @@ func TestBasicAuth(t *testing.T) {
req.Header.Set(echo.Authorization, auth)
ba = BasicAuth(fn)
if ba(c) == nil {
t.Error("expected `fail` with incorrect password")
t.Error("expected `fail`, with incorrect password.")
}
// Invalid header
// Empty Authorization header
req.Header.Set(echo.Authorization, "")
ba = BasicAuth(fn)
if ba(c) == nil {
t.Error("expected `fail`, with empty Authorization header.")
}
// Invalid Authorization header
auth = base64.StdEncoding.EncodeToString([]byte(" :secret"))
req.Header.Set(echo.Authorization, auth)
ba = BasicAuth(fn)
if ba(c) == nil {
t.Error("expected `fail` with invalid auth header")
t.Error("expected `fail`, with invalid Authorization header.")
}
// Invalid scheme
@ -62,13 +69,7 @@ func TestBasicAuth(t *testing.T) {
req.Header.Set(echo.Authorization, auth)
ba = BasicAuth(fn)
if ba(c) == nil {
t.Error("expected `fail` with invalid scheme")
t.Error("expected `fail`, with invalid scheme.")
}
// Empty auth header
req.Header.Set(echo.Authorization, "")
ba = BasicAuth(fn)
if ba(c) == nil {
t.Error("expected `fail` with empty auth header")
}
}

View File

@ -19,25 +19,21 @@ func (g gzipWriter) Write(b []byte) (int, error) {
return g.Writer.Write(b)
}
// Gzip compresses HTTP response using gzip compression scheme.
// Gzip returns a middleware which compresses HTTP response using gzip compression
// scheme.
func Gzip() echo.MiddlewareFunc {
scheme := "gzip"
return func(h echo.HandlerFunc) echo.HandlerFunc {
return func(c *echo.Context) *echo.HTTPError {
if !strings.Contains(c.Request.Header.Get(echo.AcceptEncoding), scheme) {
return nil
if strings.Contains(c.Request.Header.Get(echo.AcceptEncoding), scheme) {
w := gzip.NewWriter(c.Response.Writer)
defer w.Close()
gw := gzipWriter{Writer: w, Response: c.Response}
c.Response.Header().Set(echo.ContentEncoding, scheme)
c.Response = &echo.Response{Writer: gw}
}
w := gzip.NewWriter(c.Response.Writer)
defer w.Close()
gw := gzipWriter{Writer: w, Response: c.Response}
c.Response.Header().Set(echo.ContentEncoding, scheme)
c.Response = &echo.Response{Writer: gw}
if he := h(c); he != nil {
c.Error(he)
}
return nil
return h(c)
}
}
}

View File

@ -1,42 +1,52 @@
package middleware
import (
"compress/gzip"
"io/ioutil"
"net/http"
"net/http/httptest"
"testing"
"compress/gzip"
"github.com/labstack/echo"
"io/ioutil"
)
func TestGzip(t *testing.T) {
// Empty Accept-Encoding header
req, _ := http.NewRequest(echo.GET, "/", nil)
req.Header.Set(echo.AcceptEncoding, "gzip")
w := httptest.NewRecorder()
res := &echo.Response{Writer: w}
c := echo.NewContext(req, res, echo.New())
Gzip()(func(c *echo.Context) *echo.HTTPError {
h := func(c *echo.Context) *echo.HTTPError {
return c.String(http.StatusOK, "test")
})(c)
if w.Header().Get(echo.ContentEncoding) != "gzip" {
t.Errorf("expected Content-Encoding header `gzip`, got %d.", w.Header().Get(echo.ContentEncoding))
}
Gzip()(h)(c)
s := w.Body.String()
if s != "test" {
t.Errorf("expected `test`, with empty Accept-Encoding header, got %s.", s)
}
// Content-Encoding header
req.Header.Set(echo.AcceptEncoding, "gzip")
w = httptest.NewRecorder()
c.Response = &echo.Response{Writer: w}
Gzip()(h)(c)
ce := w.Header().Get(echo.ContentEncoding)
if ce != "gzip" {
t.Errorf("expected Content-Encoding header `gzip`, got %d.", ce)
}
// Body
r, err := gzip.NewReader(w.Body)
defer r.Close()
if err != nil {
t.Error(err)
}
b, err := ioutil.ReadAll(r)
if err != nil {
t.Error(err)
}
s := string(b)
s = string(b)
if s != "test" {
t.Errorf("expected `test`, got %s.", s)
t.Errorf("expected body `test`, got %s.", s)
}
}

30
middleware/recover.go Normal file
View File

@ -0,0 +1,30 @@
package middleware
import (
"fmt"
"runtime"
"github.com/labstack/echo"
)
// Recover returns a middleware which recovers from panics anywhere in the chain
// and handles the control to centralized HTTPErrorHandler.
func Recover() echo.MiddlewareFunc {
// TODO: Provide better stack trace `https://github.com/go-errors/errors` `https://github.com/docker/libcontainer/tree/master/stacktrace`
return func(h echo.HandlerFunc) echo.HandlerFunc {
return func(c *echo.Context) *echo.HTTPError {
defer func() {
if err := recover(); err != nil {
trace := make([]byte, 1<<16)
n := runtime.Stack(trace, true)
c.Error(&echo.HTTPError{
Error: fmt.Errorf("echo ⇒ panic recover\n %v\n stack trace %d bytes\n %s",
err, n, trace[:n]),
})
}
}()
return h(c)
}
}
}

View File

@ -0,0 +1,33 @@
package middleware
import (
"github.com/labstack/echo"
"net/http"
"net/http/httptest"
"strings"
"testing"
)
func TestRecover(t *testing.T) {
e := echo.New()
e.Debug(true)
req, _ := http.NewRequest(echo.GET, "/", nil)
w := httptest.NewRecorder()
res := &echo.Response{Writer: w}
c := echo.NewContext(req, res, e)
h := func(c *echo.Context) *echo.HTTPError {
panic("test")
}
// Status
Recover()(h)(c)
if w.Code != http.StatusInternalServerError {
t.Errorf("expected status `500`, got %d.", w.Code)
}
// Body
s := w.Body.String()
if !strings.Contains(s, "panic recover") {
t.Error("expected body contains `panice recover`.")
}
}

View File

@ -11,7 +11,8 @@ type (
}
)
// StripTrailingSlash removes trailing slash from request path.
// StripTrailingSlash returns a middleware which removes trailing slash from request
// path.
func StripTrailingSlash() echo.HandlerFunc {
return func(c *echo.Context) *echo.HTTPError {
p := c.Request.URL.Path
@ -23,8 +24,8 @@ func StripTrailingSlash() echo.HandlerFunc {
}
}
// RedirectToSlash redirects requests without trailing slash path to trailing slash
// path, with .
// RedirectToSlash returns a middleware which redirects requests without trailing
// slash path to trailing slash path.
func RedirectToSlash(opts ...RedirectToSlashOptions) echo.HandlerFunc {
code := http.StatusMovedPermanently

View File

@ -23,7 +23,7 @@ func (r *Response) Header() http.Header {
func (r *Response) WriteHeader(code int) {
if r.committed {
// TODO: Warning
log.Printf("echo: %s", color.Yellow("response already committed"))
log.Printf("echo %s", color.Yellow("response already committed"))
return
}
r.status = code

View File

@ -308,11 +308,11 @@ func (r *router) Find(method, path string, ctx *Context) (h HandlerFunc, echo *E
func (r *router) ServeHTTP(w http.ResponseWriter, req *http.Request) {
c := r.echo.pool.Get().(*Context)
h, _ := r.Find(req.Method, req.URL.Path, c)
c.reset(w, req, nil)
if h != nil {
h(c)
c.reset(w, req, r.echo)
if h == nil {
c.Error(&HTTPError{Code: http.StatusNotFound})
} else {
r.echo.notFoundHandler(c)
h(c)
}
r.echo.pool.Put(c)
}

View File

@ -35,15 +35,6 @@ Sets the maximum number of path parameters allowed for the application.
Default value is **5**, [good enough](https://github.com/interagent/http-api-design#minimize-path-nesting)
for many use cases. Restricting path parameters allows us to use memory efficiently.
### Not found handler
`echo.NotFoundHandler(h Handler)`
Registers a custom NotFound handler. This handler is called in case router doesn't
find a matching route for the HTTP request.
Default handler sends 404 "Not Found" response.
### HTTP error handler
`echo.HTTPErrorHandler(h HTTPErrorHandler)`
@ -53,7 +44,7 @@ Registers a custom centralized HTTP error handler `func(*HTTPError, *Context)`.
Default handler sends `HTTPError.Message` HTTP response with `HTTPError.Code` status
code.
- If HTTPError.Code is not specified it uses 500 "Internal Server Error".
- If HTTPError.Code is not specified it uses "500 - Internal Server Error".
- If HTTPError.Message is not specified it uses HTTPError.Error.Error() or the status
code text.