1
0
mirror of https://github.com/labstack/echo.git synced 2024-12-24 20:14:31 +02:00

Added panic recover middleware

Signed-off-by: Vishal Rana <vr@labstack.com>
This commit is contained in:
Vishal Rana 2015-05-17 22:54:29 -07:00
parent 609879bf39
commit 73fa05f826
18 changed files with 167 additions and 100 deletions

View File

@ -90,14 +90,21 @@ func main() {
// Echo instance // Echo instance
e := echo.New() e := echo.New()
//------------
// Middleware // Middleware
//------------
// Recover
e.Use(mw.Recover())
// Logger
e.Use(mw.Logger()) e.Use(mw.Logger())
// Routes // Routes
e.Get("/", hello) e.Get("/", hello)
// Start server // Start server
e.Run(":1323) e.Run(":1323")
} }
``` ```

View File

@ -7,7 +7,7 @@ import (
type ( type (
// Context represents context for the current request. It holds request and // Context represents context for the current request. It holds request and
// response references, path parameters, data and registered handler. // response objects, path parameters, data and registered handler.
Context struct { Context struct {
Request *http.Request Request *http.Request
Response *Response Response *Response

53
echo.go
View File

@ -22,12 +22,12 @@ type (
prefix string prefix string
middleware []MiddlewareFunc middleware []MiddlewareFunc
maxParam byte maxParam byte
notFoundHandler HandlerFunc
httpErrorHandler HTTPErrorHandler httpErrorHandler HTTPErrorHandler
binder BindFunc binder BindFunc
renderer Renderer renderer Renderer
uris map[Handler]string uris map[Handler]string
pool sync.Pool pool sync.Pool
debug bool
} }
HTTPError struct { HTTPError struct {
Code int Code int
@ -115,8 +115,8 @@ var (
// Errors // Errors
//-------- //--------
UnsupportedMediaType = errors.New("echo: unsupported media type") UnsupportedMediaType = errors.New("echo unsupported media type")
RendererNotRegistered = errors.New("echo: renderer not registered") RendererNotRegistered = errors.New("echo renderer not registered")
) )
// New creates an Echo instance. // New creates an Echo instance.
@ -134,19 +134,14 @@ func New() (e *Echo) {
//---------- //----------
e.MaxParam(5) e.MaxParam(5)
e.NotFoundHandler(func(c *Context) *HTTPError {
http.Error(c.Response, http.StatusText(http.StatusNotFound), http.StatusNotFound)
return nil
})
e.HTTPErrorHandler(func(he *HTTPError, c *Context) { e.HTTPErrorHandler(func(he *HTTPError, c *Context) {
if he.Code == 0 { if he.Code == 0 {
he.Code = http.StatusInternalServerError he.Code = http.StatusInternalServerError
} }
if he.Message == "" { if he.Message == "" {
if he.Error != nil { he.Message = http.StatusText(he.Code)
if e.debug {
he.Message = he.Error.Error() he.Message = he.Error.Error()
} else {
he.Message = http.StatusText(he.Code)
} }
} }
http.Error(c.Response, he.Message, he.Code) http.Error(c.Response, he.Message, he.Code)
@ -185,12 +180,6 @@ func (e *Echo) MaxParam(n uint8) {
e.maxParam = n e.maxParam = n
} }
// NotFoundHandler registers a custom NotFound handler used by router in case it
// doesn't find any registered handler for HTTP method and path.
func (e *Echo) NotFoundHandler(h Handler) {
e.notFoundHandler = wrapHandler(h)
}
// HTTPErrorHandler registers an HTTP error handler. // HTTPErrorHandler registers an HTTP error handler.
func (e *Echo) HTTPErrorHandler(h HTTPErrorHandler) { func (e *Echo) HTTPErrorHandler(h HTTPErrorHandler) {
e.httpErrorHandler = h e.httpErrorHandler = h
@ -207,6 +196,11 @@ func (e *Echo) Renderer(r Renderer) {
e.renderer = r e.renderer = r
} }
// Debug runs the application in debug mode.
func (e *Echo) Debug(on bool) {
e.debug = on
}
// Use adds handler to the middleware chain. // Use adds handler to the middleware chain.
func (e *Echo) Use(m ...Middleware) { func (e *Echo) Use(m ...Middleware) {
for _, h := range m { for _, h := range m {
@ -325,21 +319,20 @@ func (e *Echo) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if echo != nil { if echo != nil {
e = echo e = echo
} }
if h == nil {
h = e.notFoundHandler
}
c.reset(w, r, e) c.reset(w, r, e)
if h == nil {
c.Error(&HTTPError{Code: http.StatusNotFound})
} else {
// Chain middleware with handler in the end
for i := len(e.middleware) - 1; i >= 0; i-- {
h = e.middleware[i](h)
}
// Chain middleware with handler in the end // Execute chain
for i := len(e.middleware) - 1; i >= 0; i-- { if he := h(c); he != nil {
h = e.middleware[i](h) e.httpErrorHandler(he, c)
}
} }
// Execute chain
if he := h(c); he != nil {
e.httpErrorHandler(he, c)
}
e.pool.Put(c) e.pool.Put(c)
} }
@ -394,7 +387,7 @@ func wrapMiddleware(m Middleware) MiddlewareFunc {
case func(http.ResponseWriter, *http.Request): case func(http.ResponseWriter, *http.Request):
return wrapHTTPHandlerFuncMW(m) return wrapHTTPHandlerFuncMW(m)
default: default:
panic("echo: unknown middleware") panic("echo unknown middleware")
} }
} }
@ -440,7 +433,7 @@ func wrapHandler(h Handler) HandlerFunc {
return nil return nil
} }
default: default:
panic("echo: unknown handler") panic("echo unknown handler")
} }
} }

View File

@ -285,16 +285,6 @@ func TestEchoNotFound(t *testing.T) {
if w.Code != http.StatusNotFound { if w.Code != http.StatusNotFound {
t.Errorf("status code should be 404, found %d", w.Code) t.Errorf("status code should be 404, found %d", w.Code)
} }
// Customized NotFound handler
e.NotFoundHandler(func(c *Context) *HTTPError {
return c.String(http.StatusNotFound, "not found")
})
w = httptest.NewRecorder()
e.ServeHTTP(w, r)
if w.Body.String() != "not found" {
t.Errorf("body should be `not found`")
}
} }
func verifyUser(u2 *user, t *testing.T) { func verifyUser(u2 *user, t *testing.T) {

View File

@ -61,6 +61,7 @@ func main() {
e := echo.New() e := echo.New()
// Middleware // Middleware
e.Use(mw.Recover())
e.Use(mw.Logger()) e.Use(mw.Logger())
// Routes // Routes

View File

@ -16,7 +16,14 @@ func main() {
// Echo instance // Echo instance
e := echo.New() e := echo.New()
//------------
// Middleware // Middleware
//------------
// Recover
e.Use(mw.Recover())
// Logger
e.Use(mw.Logger()) e.Use(mw.Logger())
// Routes // Routes

View File

@ -16,10 +16,16 @@ func main() {
// Echo instance // Echo instance
e := echo.New() e := echo.New()
// Debug mode
e.Debug(true)
//------------ //------------
// Middleware // Middleware
//------------ //------------
// Recover
e.Use(mw.Recover())
// Logger // Logger
e.Use(mw.Logger()) e.Use(mw.Logger())

View File

@ -65,6 +65,7 @@ func main() {
e := echo.New() e := echo.New()
// Middleware // Middleware
e.Use(mw.Recover())
e.Use(mw.Logger()) e.Use(mw.Logger())
//------------------------ //------------------------

View File

@ -14,7 +14,7 @@ const (
Basic = "Basic" Basic = "Basic"
) )
// BasicAuth provides HTTP basic authentication. // BasicAuth returns an HTTP basic authentication middleware.
func BasicAuth(fn AuthFunc) echo.HandlerFunc { func BasicAuth(fn AuthFunc) echo.HandlerFunc {
return func(c *echo.Context) (he *echo.HTTPError) { return func(c *echo.Context) (he *echo.HTTPError) {
auth := c.Request.Header.Get(echo.Authorization) auth := c.Request.Header.Get(echo.Authorization)

View File

@ -2,15 +2,15 @@ package middleware
import ( import (
"encoding/base64" "encoding/base64"
"github.com/labstack/echo"
"net/http" "net/http"
"net/http/httptest"
"testing" "testing"
"github.com/labstack/echo"
) )
func TestBasicAuth(t *testing.T) { func TestBasicAuth(t *testing.T) {
req, _ := http.NewRequest(echo.POST, "/", nil) req, _ := http.NewRequest(echo.POST, "/", nil)
res := &echo.Response{Writer: httptest.NewRecorder()} res := &echo.Response{}
c := echo.NewContext(req, res, echo.New()) c := echo.NewContext(req, res, echo.New())
fn := func(u, p string) bool { fn := func(u, p string) bool {
if u == "joe" && p == "secret" { if u == "joe" && p == "secret" {
@ -34,7 +34,7 @@ func TestBasicAuth(t *testing.T) {
auth = "basic " + base64.StdEncoding.EncodeToString([]byte("joe:secret")) auth = "basic " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))
req.Header.Set(echo.Authorization, auth) req.Header.Set(echo.Authorization, auth)
if ba(c) != nil { if ba(c) != nil {
t.Error("expected `pass` with case insensitive header") t.Error("expected `pass`, with case insensitive header.")
} }
//--------------------- //---------------------
@ -46,15 +46,22 @@ func TestBasicAuth(t *testing.T) {
req.Header.Set(echo.Authorization, auth) req.Header.Set(echo.Authorization, auth)
ba = BasicAuth(fn) ba = BasicAuth(fn)
if ba(c) == nil { if ba(c) == nil {
t.Error("expected `fail` with incorrect password") t.Error("expected `fail`, with incorrect password.")
} }
// Invalid header // Empty Authorization header
req.Header.Set(echo.Authorization, "")
ba = BasicAuth(fn)
if ba(c) == nil {
t.Error("expected `fail`, with empty Authorization header.")
}
// Invalid Authorization header
auth = base64.StdEncoding.EncodeToString([]byte(" :secret")) auth = base64.StdEncoding.EncodeToString([]byte(" :secret"))
req.Header.Set(echo.Authorization, auth) req.Header.Set(echo.Authorization, auth)
ba = BasicAuth(fn) ba = BasicAuth(fn)
if ba(c) == nil { if ba(c) == nil {
t.Error("expected `fail` with invalid auth header") t.Error("expected `fail`, with invalid Authorization header.")
} }
// Invalid scheme // Invalid scheme
@ -62,13 +69,7 @@ func TestBasicAuth(t *testing.T) {
req.Header.Set(echo.Authorization, auth) req.Header.Set(echo.Authorization, auth)
ba = BasicAuth(fn) ba = BasicAuth(fn)
if ba(c) == nil { if ba(c) == nil {
t.Error("expected `fail` with invalid scheme") t.Error("expected `fail`, with invalid scheme.")
} }
// Empty auth header
req.Header.Set(echo.Authorization, "")
ba = BasicAuth(fn)
if ba(c) == nil {
t.Error("expected `fail` with empty auth header")
}
} }

View File

@ -19,25 +19,21 @@ func (g gzipWriter) Write(b []byte) (int, error) {
return g.Writer.Write(b) return g.Writer.Write(b)
} }
// Gzip compresses HTTP response using gzip compression scheme. // Gzip returns a middleware which compresses HTTP response using gzip compression
// scheme.
func Gzip() echo.MiddlewareFunc { func Gzip() echo.MiddlewareFunc {
scheme := "gzip" scheme := "gzip"
return func(h echo.HandlerFunc) echo.HandlerFunc { return func(h echo.HandlerFunc) echo.HandlerFunc {
return func(c *echo.Context) *echo.HTTPError { return func(c *echo.Context) *echo.HTTPError {
if !strings.Contains(c.Request.Header.Get(echo.AcceptEncoding), scheme) { if strings.Contains(c.Request.Header.Get(echo.AcceptEncoding), scheme) {
return nil w := gzip.NewWriter(c.Response.Writer)
defer w.Close()
gw := gzipWriter{Writer: w, Response: c.Response}
c.Response.Header().Set(echo.ContentEncoding, scheme)
c.Response = &echo.Response{Writer: gw}
} }
return h(c)
w := gzip.NewWriter(c.Response.Writer)
defer w.Close()
gw := gzipWriter{Writer: w, Response: c.Response}
c.Response.Header().Set(echo.ContentEncoding, scheme)
c.Response = &echo.Response{Writer: gw}
if he := h(c); he != nil {
c.Error(he)
}
return nil
} }
} }
} }

View File

@ -1,42 +1,52 @@
package middleware package middleware
import ( import (
"compress/gzip"
"io/ioutil"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"testing" "testing"
"compress/gzip"
"github.com/labstack/echo" "github.com/labstack/echo"
"io/ioutil"
) )
func TestGzip(t *testing.T) { func TestGzip(t *testing.T) {
// Empty Accept-Encoding header
req, _ := http.NewRequest(echo.GET, "/", nil) req, _ := http.NewRequest(echo.GET, "/", nil)
req.Header.Set(echo.AcceptEncoding, "gzip")
w := httptest.NewRecorder() w := httptest.NewRecorder()
res := &echo.Response{Writer: w} res := &echo.Response{Writer: w}
c := echo.NewContext(req, res, echo.New()) c := echo.NewContext(req, res, echo.New())
Gzip()(func(c *echo.Context) *echo.HTTPError { h := func(c *echo.Context) *echo.HTTPError {
return c.String(http.StatusOK, "test") return c.String(http.StatusOK, "test")
})(c) }
Gzip()(h)(c)
if w.Header().Get(echo.ContentEncoding) != "gzip" { s := w.Body.String()
t.Errorf("expected Content-Encoding header `gzip`, got %d.", w.Header().Get(echo.ContentEncoding)) if s != "test" {
t.Errorf("expected `test`, with empty Accept-Encoding header, got %s.", s)
} }
// Content-Encoding header
req.Header.Set(echo.AcceptEncoding, "gzip")
w = httptest.NewRecorder()
c.Response = &echo.Response{Writer: w}
Gzip()(h)(c)
ce := w.Header().Get(echo.ContentEncoding)
if ce != "gzip" {
t.Errorf("expected Content-Encoding header `gzip`, got %d.", ce)
}
// Body
r, err := gzip.NewReader(w.Body) r, err := gzip.NewReader(w.Body)
defer r.Close() defer r.Close()
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
b, err := ioutil.ReadAll(r) b, err := ioutil.ReadAll(r)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
s := string(b) s = string(b)
if s != "test" { if s != "test" {
t.Errorf("expected `test`, got %s.", s) t.Errorf("expected body `test`, got %s.", s)
} }
} }

30
middleware/recover.go Normal file
View File

@ -0,0 +1,30 @@
package middleware
import (
"fmt"
"runtime"
"github.com/labstack/echo"
)
// Recover returns a middleware which recovers from panics anywhere in the chain
// and handles the control to centralized HTTPErrorHandler.
func Recover() echo.MiddlewareFunc {
// TODO: Provide better stack trace `https://github.com/go-errors/errors` `https://github.com/docker/libcontainer/tree/master/stacktrace`
return func(h echo.HandlerFunc) echo.HandlerFunc {
return func(c *echo.Context) *echo.HTTPError {
defer func() {
if err := recover(); err != nil {
trace := make([]byte, 1<<16)
n := runtime.Stack(trace, true)
c.Error(&echo.HTTPError{
Error: fmt.Errorf("echo ⇒ panic recover\n %v\n stack trace %d bytes\n %s",
err, n, trace[:n]),
})
}
}()
return h(c)
}
}
}

View File

@ -0,0 +1,33 @@
package middleware
import (
"github.com/labstack/echo"
"net/http"
"net/http/httptest"
"strings"
"testing"
)
func TestRecover(t *testing.T) {
e := echo.New()
e.Debug(true)
req, _ := http.NewRequest(echo.GET, "/", nil)
w := httptest.NewRecorder()
res := &echo.Response{Writer: w}
c := echo.NewContext(req, res, e)
h := func(c *echo.Context) *echo.HTTPError {
panic("test")
}
// Status
Recover()(h)(c)
if w.Code != http.StatusInternalServerError {
t.Errorf("expected status `500`, got %d.", w.Code)
}
// Body
s := w.Body.String()
if !strings.Contains(s, "panic recover") {
t.Error("expected body contains `panice recover`.")
}
}

View File

@ -11,7 +11,8 @@ type (
} }
) )
// StripTrailingSlash removes trailing slash from request path. // StripTrailingSlash returns a middleware which removes trailing slash from request
// path.
func StripTrailingSlash() echo.HandlerFunc { func StripTrailingSlash() echo.HandlerFunc {
return func(c *echo.Context) *echo.HTTPError { return func(c *echo.Context) *echo.HTTPError {
p := c.Request.URL.Path p := c.Request.URL.Path
@ -23,8 +24,8 @@ func StripTrailingSlash() echo.HandlerFunc {
} }
} }
// RedirectToSlash redirects requests without trailing slash path to trailing slash // RedirectToSlash returns a middleware which redirects requests without trailing
// path, with . // slash path to trailing slash path.
func RedirectToSlash(opts ...RedirectToSlashOptions) echo.HandlerFunc { func RedirectToSlash(opts ...RedirectToSlashOptions) echo.HandlerFunc {
code := http.StatusMovedPermanently code := http.StatusMovedPermanently

View File

@ -23,7 +23,7 @@ func (r *Response) Header() http.Header {
func (r *Response) WriteHeader(code int) { func (r *Response) WriteHeader(code int) {
if r.committed { if r.committed {
// TODO: Warning // TODO: Warning
log.Printf("echo: %s", color.Yellow("response already committed")) log.Printf("echo %s", color.Yellow("response already committed"))
return return
} }
r.status = code r.status = code

View File

@ -308,11 +308,11 @@ func (r *router) Find(method, path string, ctx *Context) (h HandlerFunc, echo *E
func (r *router) ServeHTTP(w http.ResponseWriter, req *http.Request) { func (r *router) ServeHTTP(w http.ResponseWriter, req *http.Request) {
c := r.echo.pool.Get().(*Context) c := r.echo.pool.Get().(*Context)
h, _ := r.Find(req.Method, req.URL.Path, c) h, _ := r.Find(req.Method, req.URL.Path, c)
c.reset(w, req, nil) c.reset(w, req, r.echo)
if h != nil { if h == nil {
h(c) c.Error(&HTTPError{Code: http.StatusNotFound})
} else { } else {
r.echo.notFoundHandler(c) h(c)
} }
r.echo.pool.Put(c) r.echo.pool.Put(c)
} }

View File

@ -35,15 +35,6 @@ Sets the maximum number of path parameters allowed for the application.
Default value is **5**, [good enough](https://github.com/interagent/http-api-design#minimize-path-nesting) Default value is **5**, [good enough](https://github.com/interagent/http-api-design#minimize-path-nesting)
for many use cases. Restricting path parameters allows us to use memory efficiently. for many use cases. Restricting path parameters allows us to use memory efficiently.
### Not found handler
`echo.NotFoundHandler(h Handler)`
Registers a custom NotFound handler. This handler is called in case router doesn't
find a matching route for the HTTP request.
Default handler sends 404 "Not Found" response.
### HTTP error handler ### HTTP error handler
`echo.HTTPErrorHandler(h HTTPErrorHandler)` `echo.HTTPErrorHandler(h HTTPErrorHandler)`
@ -53,7 +44,7 @@ Registers a custom centralized HTTP error handler `func(*HTTPError, *Context)`.
Default handler sends `HTTPError.Message` HTTP response with `HTTPError.Code` status Default handler sends `HTTPError.Message` HTTP response with `HTTPError.Code` status
code. code.
- If HTTPError.Code is not specified it uses 500 "Internal Server Error". - If HTTPError.Code is not specified it uses "500 - Internal Server Error".
- If HTTPError.Message is not specified it uses HTTPError.Error.Error() or the status - If HTTPError.Message is not specified it uses HTTPError.Error.Error() or the status
code text. code text.