1
0
mirror of https://github.com/labstack/echo.git synced 2025-11-27 22:38:25 +02:00

Wrappers for handler and middleware

Signed-off-by: Vishal Rana <vr@labstack.com>
This commit is contained in:
Vishal Rana
2016-02-08 22:17:20 -08:00
parent 3cd1d5be65
commit 7b843e66c5
11 changed files with 147 additions and 109 deletions

157
echo.go
View File

@@ -61,7 +61,10 @@ type (
MiddlewareFunc func(HandlerFunc) HandlerFunc MiddlewareFunc func(HandlerFunc) HandlerFunc
// Handler interface{} Handler interface {
Handle(Context) error
}
HandlerFunc func(Context) error HandlerFunc func(Context) error
// HTTPErrorHandler is a centralized HTTP error handler. // HTTPErrorHandler is a centralized HTTP error handler.
@@ -181,13 +184,13 @@ var (
// Error handlers // Error handlers
//---------------- //----------------
notFoundHandler = func(c Context) error { notFoundHandler = HandlerFunc(func(c Context) error {
return NewHTTPError(http.StatusNotFound) return NewHTTPError(http.StatusNotFound)
} })
methodNotAllowedHandler = func(c Context) error { methodNotAllowedHandler = HandlerFunc(func(c Context) error {
return NewHTTPError(http.StatusMethodNotAllowed) return NewHTTPError(http.StatusMethodNotAllowed)
} })
) )
// New creates an instance of Echo. // New creates an instance of Echo.
@@ -204,22 +207,7 @@ func New() (e *Echo) {
//---------- //----------
e.HTTP2(true) e.HTTP2(true)
e.defaultHTTPErrorHandler = func(err error, c Context) { e.SetHTTPErrorHandler(e.DefaultHTTPErrorHandler)
code := http.StatusInternalServerError
msg := http.StatusText(code)
if he, ok := err.(*HTTPError); ok {
code = he.code
msg = he.message
}
if e.debug {
msg = err.Error()
}
if !c.Response().Committed() {
c.String(code, msg)
}
e.logger.Error(err)
}
e.SetHTTPErrorHandler(e.defaultHTTPErrorHandler)
e.SetBinder(&binder{}) e.SetBinder(&binder{})
// Logger // Logger
@@ -232,6 +220,10 @@ func (f MiddlewareFunc) Process(h HandlerFunc) HandlerFunc {
return f(h) return f(h)
} }
func (f HandlerFunc) Handle(c Context) error {
return f(c)
}
// Router returns router. // Router returns router.
func (e *Echo) Router() *Router { func (e *Echo) Router() *Router {
return e.router return e.router
@@ -254,7 +246,19 @@ func (e *Echo) HTTP2(on bool) {
// DefaultHTTPErrorHandler invokes the default HTTP error handler. // DefaultHTTPErrorHandler invokes the default HTTP error handler.
func (e *Echo) DefaultHTTPErrorHandler(err error, c Context) { func (e *Echo) DefaultHTTPErrorHandler(err error, c Context) {
e.defaultHTTPErrorHandler(err, c) code := http.StatusInternalServerError
msg := http.StatusText(code)
if he, ok := err.(*HTTPError); ok {
code = he.code
msg = he.message
}
if e.debug {
msg = err.Error()
}
if !c.Response().Committed() {
c.String(code, msg)
}
e.logger.Error(err)
} }
// SetHTTPErrorHandler registers a custom Echo.HTTPErrorHandler. // SetHTTPErrorHandler registers a custom Echo.HTTPErrorHandler.
@@ -295,75 +299,75 @@ func (e *Echo) Hook(h engine.HandlerFunc) {
} }
// Use adds handler to the middleware chain. // Use adds handler to the middleware chain.
func (e *Echo) Use(m ...MiddlewareFunc) { func (e *Echo) Use(middleware ...interface{}) {
for _, h := range m { for _, m := range middleware {
e.middleware = append(e.middleware, h) e.middleware = append(e.middleware, wrapMiddleware(m))
} }
} }
// Connect adds a CONNECT route > handler to the router. // Connect adds a CONNECT route > handler to the router.
func (e *Echo) Connect(path string, h HandlerFunc) { func (e *Echo) Connect(path string, handler interface{}) {
e.add(CONNECT, path, h) e.add(CONNECT, path, handler)
} }
// Delete adds a DELETE route > handler to the router. // Delete adds a DELETE route > handler to the router.
func (e *Echo) Delete(path string, h HandlerFunc) { func (e *Echo) Delete(path string, handler interface{}) {
e.add(DELETE, path, h) e.add(DELETE, path, handler)
} }
// Get adds a GET route > handler to the router. // Get adds a GET route > handler to the router.
func (e *Echo) Get(path string, h HandlerFunc) { func (e *Echo) Get(path string, handler interface{}) {
e.add(GET, path, h) e.add(GET, path, handler)
} }
// Head adds a HEAD route > handler to the router. // Head adds a HEAD route > handler to the router.
func (e *Echo) Head(path string, h HandlerFunc) { func (e *Echo) Head(path string, handler interface{}) {
e.add(HEAD, path, h) e.add(HEAD, path, handler)
} }
// Options adds an OPTIONS route > handler to the router. // Options adds an OPTIONS route > handler to the router.
func (e *Echo) Options(path string, h HandlerFunc) { func (e *Echo) Options(path string, handler interface{}) {
e.add(OPTIONS, path, h) e.add(OPTIONS, path, handler)
} }
// Patch adds a PATCH route > handler to the router. // Patch adds a PATCH route > handler to the router.
func (e *Echo) Patch(path string, h HandlerFunc) { func (e *Echo) Patch(path string, handler interface{}) {
e.add(PATCH, path, h) e.add(PATCH, path, handler)
} }
// Post adds a POST route > handler to the router. // Post adds a POST route > handler to the router.
func (e *Echo) Post(path string, h HandlerFunc) { func (e *Echo) Post(path string, handler interface{}) {
e.add(POST, path, h) e.add(POST, path, handler)
} }
// Put adds a PUT route > handler to the router. // Put adds a PUT route > handler to the router.
func (e *Echo) Put(path string, h HandlerFunc) { func (e *Echo) Put(path string, handler interface{}) {
e.add(PUT, path, h) e.add(PUT, path, handler)
} }
// Trace adds a TRACE route > handler to the router. // Trace adds a TRACE route > handler to the router.
func (e *Echo) Trace(path string, h HandlerFunc) { func (e *Echo) Trace(path string, handler interface{}) {
e.add(TRACE, path, h) e.add(TRACE, path, handler)
} }
// Any adds a route > handler to the router for all HTTP methods. // Any adds a route > handler to the router for all HTTP methods.
func (e *Echo) Any(path string, h HandlerFunc) { func (e *Echo) Any(path string, handler interface{}) {
for _, m := range methods { for _, m := range methods {
e.add(m, path, h) e.add(m, path, handler)
} }
} }
// Match adds a route > handler to the router for multiple HTTP methods provided. // Match adds a route > handler to the router for multiple HTTP methods provided.
func (e *Echo) Match(methods []string, path string, h HandlerFunc) { func (e *Echo) Match(methods []string, path string, handler interface{}) {
for _, m := range methods { for _, m := range methods {
e.add(m, path, h) e.add(m, path, handler)
} }
} }
// NOTE: v2 // NOTE: v2
func (e *Echo) add(method, path string, h HandlerFunc) { func (e *Echo) add(method, path string, h interface{}) {
path = e.prefix + path path = e.prefix + path
e.router.Add(method, path, h, e) e.router.Add(method, path, wrapHandler(h), e)
r := Route{ r := Route{
Method: method, Method: method,
Path: path, Path: path,
@@ -511,8 +515,7 @@ func (e *Echo) Routes() []Route {
return e.router.routes return e.router.routes
} }
// ServeHTTP serves HTTP requests. func (e *Echo) handle(req engine.Request, res engine.Response) {
func (e *Echo) ServeHTTP(req engine.Request, res engine.Response) {
if e.hook != nil { if e.hook != nil {
e.hook(req, res) e.hook(req, res)
} }
@@ -566,33 +569,11 @@ func (e *Echo) RunTLS(addr, certfile, keyfile string) {
// RunConfig runs a server with engine configuration. // RunConfig runs a server with engine configuration.
func (e *Echo) RunConfig(config *engine.Config) { func (e *Echo) RunConfig(config *engine.Config) {
handler := func(req engine.Request, res engine.Response) {
if e.hook != nil {
e.hook(req, res)
}
c := e.pool.Get().(*context)
h, e := e.router.Find(req.Method(), req.URL().Path(), c)
c.reset(req, res, e)
// Chain middleware with handler in the end
for i := len(e.middleware) - 1; i >= 0; i-- {
h = e.middleware[i](h)
}
// Execute chain
if err := h(c); err != nil {
e.httpErrorHandler(err, c)
}
e.pool.Put(c)
}
switch e.engineType { switch e.engineType {
case engine.FastHTTP: case engine.FastHTTP:
e.engine = fasthttp.NewServer(config, handler, e.logger) e.engine = fasthttp.NewServer(config, e.handle, e.logger)
default: default:
e.engine = standard.NewServer(config, handler, e.logger) e.engine = standard.NewServer(config, e.handle, e.logger)
} }
e.engine.Start() e.engine.Start()
} }
@@ -635,3 +616,29 @@ func (binder) Bind(r engine.Request, i interface{}) (err error) {
} }
return return
} }
func wrapMiddleware(m interface{}) MiddlewareFunc {
switch m := m.(type) {
case Middleware:
return m.Process
case MiddlewareFunc:
return m
case func(HandlerFunc) HandlerFunc:
return m
default:
panic("invalid middleware")
}
}
func wrapHandler(h interface{}) HandlerFunc {
switch h := h.(type) {
case Handler:
return h.Handle
case HandlerFunc:
return h
case func(Context) error:
return h
default:
panic("invalid handler")
}
}

View File

@@ -308,7 +308,7 @@ func TestEchoNotFound(t *testing.T) {
e := New() e := New()
req := test.NewRequest(GET, "/files", nil) req := test.NewRequest(GET, "/files", nil)
rec := test.NewResponseRecorder() rec := test.NewResponseRecorder()
e.ServeHTTP(req, rec) e.handle(req, rec)
assert.Equal(t, http.StatusNotFound, rec.Status()) assert.Equal(t, http.StatusNotFound, rec.Status())
} }
@@ -319,7 +319,7 @@ func TestEchoMethodNotAllowed(t *testing.T) {
}) })
req := test.NewRequest(POST, "/", nil) req := test.NewRequest(POST, "/", nil)
rec := test.NewResponseRecorder() rec := test.NewResponseRecorder()
e.ServeHTTP(req, rec) e.handle(req, rec)
assert.Equal(t, http.StatusMethodNotAllowed, rec.Status()) assert.Equal(t, http.StatusMethodNotAllowed, rec.Status())
} }
@@ -350,7 +350,7 @@ func TestEchoHook(t *testing.T) {
}) })
req := test.NewRequest(GET, "/test/", nil) req := test.NewRequest(GET, "/test/", nil)
rec := test.NewResponseRecorder() rec := test.NewResponseRecorder()
e.ServeHTTP(req, rec) e.handle(req, rec)
assert.Equal(t, req.URL().Path(), "/test") assert.Equal(t, req.URL().Path(), "/test")
} }
@@ -371,6 +371,6 @@ func testMethod(t *testing.T, method, path string, e *Echo) {
func request(method, path string, e *Echo) (int, string) { func request(method, path string, e *Echo) (int, string) {
req := test.NewRequest(method, path, nil) req := test.NewRequest(method, path, nil)
rec := test.NewResponseRecorder() rec := test.NewResponseRecorder()
e.ServeHTTP(req, rec) e.handle(req, rec)
return rec.Status(), rec.Body.String() return rec.Status(), rec.Body.String()
} }

View File

@@ -4,11 +4,11 @@ import "github.com/valyala/fasthttp"
type ( type (
RequestHeader struct { RequestHeader struct {
fasthttp.RequestHeader header fasthttp.RequestHeader
} }
ResponseHeader struct { ResponseHeader struct {
fasthttp.ResponseHeader header fasthttp.ResponseHeader
} }
) )
@@ -17,23 +17,24 @@ func (h *RequestHeader) Add(key, val string) {
} }
func (h *RequestHeader) Del(key string) { func (h *RequestHeader) Del(key string) {
h.RequestHeader.Del(key) h.header.Del(key)
} }
func (h *RequestHeader) Get(key string) string { func (h *RequestHeader) Get(key string) string {
return string(h.RequestHeader.Peek(key)) // return h.header.Peek(key)
return ""
} }
func (h *RequestHeader) Set(key, val string) { func (h *RequestHeader) Set(key, val string) {
h.RequestHeader.Set(key, val) h.header.Set(key, val)
} }
func (h *ResponseHeader) Add(key, val string) { func (h *ResponseHeader) Add(key, val string) {
// h.ResponseHeader.Add(key, val) // h.header.Add(key, val)
} }
func (h *ResponseHeader) Del(key string) { func (h *ResponseHeader) Del(key string) {
h.ResponseHeader.Del(key) h.header.Del(key)
} }
func (h *ResponseHeader) Get(key string) string { func (h *ResponseHeader) Get(key string) string {
@@ -42,5 +43,5 @@ func (h *ResponseHeader) Get(key string) string {
} }
func (h *ResponseHeader) Set(key, val string) { func (h *ResponseHeader) Set(key, val string) {
h.ResponseHeader.Set(key, val) h.header.Set(key, val)
} }

View File

@@ -9,18 +9,26 @@ import (
type ( type (
Request struct { Request struct {
request *fasthttp.RequestCtx context *fasthttp.RequestCtx
url engine.URL url engine.URL
header engine.Header header engine.Header
} }
) )
func NewRequest(c *fasthttp.RequestCtx) *Request {
return &Request{
context: c,
url: &URL{url: c.URI()},
header: &RequestHeader{c.Request.Header},
}
}
func (r *Request) Object() interface{} { func (r *Request) Object() interface{} {
return r.request return r.context
} }
func (r *Request) URI() string { func (r *Request) URI() string {
return string(r.request.RequestURI()) return string(r.context.RequestURI())
} }
func (r *Request) URL() engine.URL { func (r *Request) URL() engine.URL {
@@ -32,11 +40,11 @@ func (r *Request) Header() engine.Header {
} }
func (r *Request) RemoteAddress() string { func (r *Request) RemoteAddress() string {
return r.request.RemoteAddr().String() return r.context.RemoteAddr().String()
} }
func (r *Request) Method() string { func (r *Request) Method() string {
return string(r.request.Method()) return string(r.context.Method())
} }
func (r *Request) Body() io.ReadCloser { func (r *Request) Body() io.ReadCloser {

View File

@@ -4,22 +4,33 @@ import (
"io" "io"
"github.com/labstack/echo/engine" "github.com/labstack/echo/engine"
"github.com/labstack/echo/logger"
"github.com/valyala/fasthttp" "github.com/valyala/fasthttp"
) )
type ( type (
Response struct { Response struct {
response *fasthttp.RequestCtx context *fasthttp.RequestCtx
header engine.Header header engine.Header
status int status int
size int64 size int64
committed bool committed bool
writer io.Writer writer io.Writer
logger logger.Logger
} }
) )
func NewResponse(c *fasthttp.RequestCtx, l logger.Logger) *Response {
return &Response{
context: c,
header: &ResponseHeader{c.Response.Header},
writer: c,
logger: l,
}
}
func (r *Response) Object() interface{} { func (r *Response) Object() interface{} {
return r.response return r.context
} }
func (r *Response) Header() engine.Header { func (r *Response) Header() engine.Header {
@@ -27,11 +38,17 @@ func (r *Response) Header() engine.Header {
} }
func (r *Response) WriteHeader(code int) { func (r *Response) WriteHeader(code int) {
r.response.SetStatusCode(code) if r.committed {
r.logger.Warn("response already committed")
return
}
r.status = code
r.context.SetStatusCode(code)
r.committed = true
} }
func (r *Response) Write(b []byte) (int, error) { func (r *Response) Write(b []byte) (int, error) {
return r.response.Write(b) return r.context.Write(b)
} }
func (r *Response) Status() int { func (r *Response) Status() int {

View File

@@ -27,15 +27,15 @@ func NewServer(c *engine.Config, h engine.HandlerFunc, l logger.Logger) *Server
} }
func (s *Server) Start() { func (s *Server) Start() {
fasthttp.ListenAndServe(s.config.Address, func(ctx *fasthttp.RequestCtx) { fasthttp.ListenAndServe(s.config.Address, func(c *fasthttp.RequestCtx) {
req := &Request{ req := &Request{
request: ctx, context: c,
url: &URL{ctx.URI()}, url: &URL{c.URI()},
header: &RequestHeader{ctx.Request.Header}, header: &RequestHeader{c.Request.Header},
} }
res := &Response{ res := &Response{
response: ctx, context: c,
header: &ResponseHeader{ctx.Response.Header}, header: &ResponseHeader{c.Response.Header},
} }
s.handler(req, res) s.handler(req, res)
}) })

View File

@@ -4,16 +4,16 @@ import "github.com/valyala/fasthttp"
type ( type (
URL struct { URL struct {
*fasthttp.URI url *fasthttp.URI
} }
) )
func (u *URL) Scheme() string { func (u *URL) Scheme() string {
return string(u.URI.Scheme()) return string(u.url.Scheme())
} }
func (u *URL) Host() string { func (u *URL) Host() string {
return string(u.URI.Host()) return string(u.url.Host())
} }
func (u *URL) SetPath(path string) { func (u *URL) SetPath(path string) {
@@ -21,7 +21,7 @@ func (u *URL) SetPath(path string) {
} }
func (u *URL) Path() string { func (u *URL) Path() string {
return string(u.URI.Path()) return string(u.url.Path())
} }
func (u *URL) QueryValue(name string) string { func (u *URL) QueryValue(name string) string {

View File

@@ -65,7 +65,7 @@ func Gzip() echo.MiddlewareFunc {
c.Response().Header().Set(echo.ContentEncoding, scheme) c.Response().Header().Set(echo.ContentEncoding, scheme)
c.Response().SetWriter(gw) c.Response().SetWriter(gw)
} }
if err := h(c); err != nil { if err := h.Handle(c); err != nil {
c.Error(err) c.Error(err)
} }
return nil return nil

View File

@@ -25,7 +25,7 @@ func Log() echo.MiddlewareFunc {
} }
start := time.Now() start := time.Now()
if err := h(c); err != nil { if err := h.Handle(c); err != nil {
c.Error(err) c.Error(err)
} }
stop := time.Now() stop := time.Now()

View File

@@ -22,7 +22,7 @@ func Recover() echo.MiddlewareFunc {
err, n, trace[:n])) err, n, trace[:n]))
} }
}() }()
return h(c) return h.Handle(c)
} }
} }
} }

View File

@@ -19,8 +19,13 @@ type (
) )
func NewRequest(method, url string, body io.Reader) engine.Request { func NewRequest(method, url string, body io.Reader) engine.Request {
// switch t {
// case engine.Standard:
r, _ := http.NewRequest(method, url, body) r, _ := http.NewRequest(method, url, body)
return standard.NewRequest(r) return standard.NewRequest(r)
// default:
// panic("invalid engine")
// }
} }
func NewResponseRecorder() *ResponseRecorder { func NewResponseRecorder() *ResponseRecorder {