From 7b843e66c502c5b13809ad6e87665edd6794a05d Mon Sep 17 00:00:00 2001 From: Vishal Rana Date: Mon, 8 Feb 2016 22:17:20 -0800 Subject: [PATCH] Wrappers for handler and middleware Signed-off-by: Vishal Rana --- echo.go | 157 +++++++++++++++++++----------------- echo_test.go | 8 +- engine/fasthttp/header.go | 17 ++-- engine/fasthttp/request.go | 18 +++-- engine/fasthttp/response.go | 25 +++++- engine/fasthttp/server.go | 12 +-- engine/fasthttp/url.go | 8 +- middleware/compress.go | 2 +- middleware/log.go | 2 +- middleware/recover.go | 2 +- test/http.go | 5 ++ 11 files changed, 147 insertions(+), 109 deletions(-) diff --git a/echo.go b/echo.go index 5fd1d7d2..0e21ce38 100644 --- a/echo.go +++ b/echo.go @@ -61,7 +61,10 @@ type ( MiddlewareFunc func(HandlerFunc) HandlerFunc - // Handler interface{} + Handler interface { + Handle(Context) error + } + HandlerFunc func(Context) error // HTTPErrorHandler is a centralized HTTP error handler. @@ -181,13 +184,13 @@ var ( // Error handlers //---------------- - notFoundHandler = func(c Context) error { + notFoundHandler = HandlerFunc(func(c Context) error { return NewHTTPError(http.StatusNotFound) - } + }) - methodNotAllowedHandler = func(c Context) error { + methodNotAllowedHandler = HandlerFunc(func(c Context) error { return NewHTTPError(http.StatusMethodNotAllowed) - } + }) ) // New creates an instance of Echo. @@ -204,22 +207,7 @@ func New() (e *Echo) { //---------- e.HTTP2(true) - e.defaultHTTPErrorHandler = func(err error, c Context) { - code := http.StatusInternalServerError - msg := http.StatusText(code) - if he, ok := err.(*HTTPError); ok { - code = he.code - msg = he.message - } - if e.debug { - msg = err.Error() - } - if !c.Response().Committed() { - c.String(code, msg) - } - e.logger.Error(err) - } - e.SetHTTPErrorHandler(e.defaultHTTPErrorHandler) + e.SetHTTPErrorHandler(e.DefaultHTTPErrorHandler) e.SetBinder(&binder{}) // Logger @@ -232,6 +220,10 @@ func (f MiddlewareFunc) Process(h HandlerFunc) HandlerFunc { return f(h) } +func (f HandlerFunc) Handle(c Context) error { + return f(c) +} + // Router returns router. func (e *Echo) Router() *Router { return e.router @@ -254,7 +246,19 @@ func (e *Echo) HTTP2(on bool) { // DefaultHTTPErrorHandler invokes the default HTTP error handler. func (e *Echo) DefaultHTTPErrorHandler(err error, c Context) { - e.defaultHTTPErrorHandler(err, c) + code := http.StatusInternalServerError + msg := http.StatusText(code) + if he, ok := err.(*HTTPError); ok { + code = he.code + msg = he.message + } + if e.debug { + msg = err.Error() + } + if !c.Response().Committed() { + c.String(code, msg) + } + e.logger.Error(err) } // SetHTTPErrorHandler registers a custom Echo.HTTPErrorHandler. @@ -295,75 +299,75 @@ func (e *Echo) Hook(h engine.HandlerFunc) { } // Use adds handler to the middleware chain. -func (e *Echo) Use(m ...MiddlewareFunc) { - for _, h := range m { - e.middleware = append(e.middleware, h) +func (e *Echo) Use(middleware ...interface{}) { + for _, m := range middleware { + e.middleware = append(e.middleware, wrapMiddleware(m)) } } // Connect adds a CONNECT route > handler to the router. -func (e *Echo) Connect(path string, h HandlerFunc) { - e.add(CONNECT, path, h) +func (e *Echo) Connect(path string, handler interface{}) { + e.add(CONNECT, path, handler) } // Delete adds a DELETE route > handler to the router. -func (e *Echo) Delete(path string, h HandlerFunc) { - e.add(DELETE, path, h) +func (e *Echo) Delete(path string, handler interface{}) { + e.add(DELETE, path, handler) } // Get adds a GET route > handler to the router. -func (e *Echo) Get(path string, h HandlerFunc) { - e.add(GET, path, h) +func (e *Echo) Get(path string, handler interface{}) { + e.add(GET, path, handler) } // Head adds a HEAD route > handler to the router. -func (e *Echo) Head(path string, h HandlerFunc) { - e.add(HEAD, path, h) +func (e *Echo) Head(path string, handler interface{}) { + e.add(HEAD, path, handler) } // Options adds an OPTIONS route > handler to the router. -func (e *Echo) Options(path string, h HandlerFunc) { - e.add(OPTIONS, path, h) +func (e *Echo) Options(path string, handler interface{}) { + e.add(OPTIONS, path, handler) } // Patch adds a PATCH route > handler to the router. -func (e *Echo) Patch(path string, h HandlerFunc) { - e.add(PATCH, path, h) +func (e *Echo) Patch(path string, handler interface{}) { + e.add(PATCH, path, handler) } // Post adds a POST route > handler to the router. -func (e *Echo) Post(path string, h HandlerFunc) { - e.add(POST, path, h) +func (e *Echo) Post(path string, handler interface{}) { + e.add(POST, path, handler) } // Put adds a PUT route > handler to the router. -func (e *Echo) Put(path string, h HandlerFunc) { - e.add(PUT, path, h) +func (e *Echo) Put(path string, handler interface{}) { + e.add(PUT, path, handler) } // Trace adds a TRACE route > handler to the router. -func (e *Echo) Trace(path string, h HandlerFunc) { - e.add(TRACE, path, h) +func (e *Echo) Trace(path string, handler interface{}) { + e.add(TRACE, path, handler) } // Any adds a route > handler to the router for all HTTP methods. -func (e *Echo) Any(path string, h HandlerFunc) { +func (e *Echo) Any(path string, handler interface{}) { for _, m := range methods { - e.add(m, path, h) + e.add(m, path, handler) } } // Match adds a route > handler to the router for multiple HTTP methods provided. -func (e *Echo) Match(methods []string, path string, h HandlerFunc) { +func (e *Echo) Match(methods []string, path string, handler interface{}) { for _, m := range methods { - e.add(m, path, h) + e.add(m, path, handler) } } // NOTE: v2 -func (e *Echo) add(method, path string, h HandlerFunc) { +func (e *Echo) add(method, path string, h interface{}) { path = e.prefix + path - e.router.Add(method, path, h, e) + e.router.Add(method, path, wrapHandler(h), e) r := Route{ Method: method, Path: path, @@ -511,8 +515,7 @@ func (e *Echo) Routes() []Route { return e.router.routes } -// ServeHTTP serves HTTP requests. -func (e *Echo) ServeHTTP(req engine.Request, res engine.Response) { +func (e *Echo) handle(req engine.Request, res engine.Response) { if e.hook != nil { e.hook(req, res) } @@ -566,33 +569,11 @@ func (e *Echo) RunTLS(addr, certfile, keyfile string) { // RunConfig runs a server with engine configuration. func (e *Echo) RunConfig(config *engine.Config) { - handler := func(req engine.Request, res engine.Response) { - if e.hook != nil { - e.hook(req, res) - } - - c := e.pool.Get().(*context) - h, e := e.router.Find(req.Method(), req.URL().Path(), c) - c.reset(req, res, e) - - // Chain middleware with handler in the end - for i := len(e.middleware) - 1; i >= 0; i-- { - h = e.middleware[i](h) - } - - // Execute chain - if err := h(c); err != nil { - e.httpErrorHandler(err, c) - } - - e.pool.Put(c) - } - switch e.engineType { case engine.FastHTTP: - e.engine = fasthttp.NewServer(config, handler, e.logger) + e.engine = fasthttp.NewServer(config, e.handle, e.logger) default: - e.engine = standard.NewServer(config, handler, e.logger) + e.engine = standard.NewServer(config, e.handle, e.logger) } e.engine.Start() } @@ -635,3 +616,29 @@ func (binder) Bind(r engine.Request, i interface{}) (err error) { } return } + +func wrapMiddleware(m interface{}) MiddlewareFunc { + switch m := m.(type) { + case Middleware: + return m.Process + case MiddlewareFunc: + return m + case func(HandlerFunc) HandlerFunc: + return m + default: + panic("invalid middleware") + } +} + +func wrapHandler(h interface{}) HandlerFunc { + switch h := h.(type) { + case Handler: + return h.Handle + case HandlerFunc: + return h + case func(Context) error: + return h + default: + panic("invalid handler") + } +} diff --git a/echo_test.go b/echo_test.go index 002a6a2e..ea6a33f9 100644 --- a/echo_test.go +++ b/echo_test.go @@ -308,7 +308,7 @@ func TestEchoNotFound(t *testing.T) { e := New() req := test.NewRequest(GET, "/files", nil) rec := test.NewResponseRecorder() - e.ServeHTTP(req, rec) + e.handle(req, rec) assert.Equal(t, http.StatusNotFound, rec.Status()) } @@ -319,7 +319,7 @@ func TestEchoMethodNotAllowed(t *testing.T) { }) req := test.NewRequest(POST, "/", nil) rec := test.NewResponseRecorder() - e.ServeHTTP(req, rec) + e.handle(req, rec) assert.Equal(t, http.StatusMethodNotAllowed, rec.Status()) } @@ -350,7 +350,7 @@ func TestEchoHook(t *testing.T) { }) req := test.NewRequest(GET, "/test/", nil) rec := test.NewResponseRecorder() - e.ServeHTTP(req, rec) + e.handle(req, rec) assert.Equal(t, req.URL().Path(), "/test") } @@ -371,6 +371,6 @@ func testMethod(t *testing.T, method, path string, e *Echo) { func request(method, path string, e *Echo) (int, string) { req := test.NewRequest(method, path, nil) rec := test.NewResponseRecorder() - e.ServeHTTP(req, rec) + e.handle(req, rec) return rec.Status(), rec.Body.String() } diff --git a/engine/fasthttp/header.go b/engine/fasthttp/header.go index e5d6a0be..9ae592cb 100644 --- a/engine/fasthttp/header.go +++ b/engine/fasthttp/header.go @@ -4,11 +4,11 @@ import "github.com/valyala/fasthttp" type ( RequestHeader struct { - fasthttp.RequestHeader + header fasthttp.RequestHeader } ResponseHeader struct { - fasthttp.ResponseHeader + header fasthttp.ResponseHeader } ) @@ -17,23 +17,24 @@ func (h *RequestHeader) Add(key, val string) { } func (h *RequestHeader) Del(key string) { - h.RequestHeader.Del(key) + h.header.Del(key) } func (h *RequestHeader) Get(key string) string { - return string(h.RequestHeader.Peek(key)) + // return h.header.Peek(key) + return "" } func (h *RequestHeader) Set(key, val string) { - h.RequestHeader.Set(key, val) + h.header.Set(key, val) } func (h *ResponseHeader) Add(key, val string) { - // h.ResponseHeader.Add(key, val) + // h.header.Add(key, val) } func (h *ResponseHeader) Del(key string) { - h.ResponseHeader.Del(key) + h.header.Del(key) } func (h *ResponseHeader) Get(key string) string { @@ -42,5 +43,5 @@ func (h *ResponseHeader) Get(key string) string { } func (h *ResponseHeader) Set(key, val string) { - h.ResponseHeader.Set(key, val) + h.header.Set(key, val) } diff --git a/engine/fasthttp/request.go b/engine/fasthttp/request.go index 60c62cb1..cfcac8a6 100644 --- a/engine/fasthttp/request.go +++ b/engine/fasthttp/request.go @@ -9,18 +9,26 @@ import ( type ( Request struct { - request *fasthttp.RequestCtx + context *fasthttp.RequestCtx url engine.URL header engine.Header } ) +func NewRequest(c *fasthttp.RequestCtx) *Request { + return &Request{ + context: c, + url: &URL{url: c.URI()}, + header: &RequestHeader{c.Request.Header}, + } +} + func (r *Request) Object() interface{} { - return r.request + return r.context } func (r *Request) URI() string { - return string(r.request.RequestURI()) + return string(r.context.RequestURI()) } func (r *Request) URL() engine.URL { @@ -32,11 +40,11 @@ func (r *Request) Header() engine.Header { } func (r *Request) RemoteAddress() string { - return r.request.RemoteAddr().String() + return r.context.RemoteAddr().String() } func (r *Request) Method() string { - return string(r.request.Method()) + return string(r.context.Method()) } func (r *Request) Body() io.ReadCloser { diff --git a/engine/fasthttp/response.go b/engine/fasthttp/response.go index 9f4a22c3..08cae03f 100644 --- a/engine/fasthttp/response.go +++ b/engine/fasthttp/response.go @@ -4,22 +4,33 @@ import ( "io" "github.com/labstack/echo/engine" + "github.com/labstack/echo/logger" "github.com/valyala/fasthttp" ) type ( Response struct { - response *fasthttp.RequestCtx + context *fasthttp.RequestCtx header engine.Header status int size int64 committed bool writer io.Writer + logger logger.Logger } ) +func NewResponse(c *fasthttp.RequestCtx, l logger.Logger) *Response { + return &Response{ + context: c, + header: &ResponseHeader{c.Response.Header}, + writer: c, + logger: l, + } +} + func (r *Response) Object() interface{} { - return r.response + return r.context } func (r *Response) Header() engine.Header { @@ -27,11 +38,17 @@ func (r *Response) Header() engine.Header { } func (r *Response) WriteHeader(code int) { - r.response.SetStatusCode(code) + if r.committed { + r.logger.Warn("response already committed") + return + } + r.status = code + r.context.SetStatusCode(code) + r.committed = true } func (r *Response) Write(b []byte) (int, error) { - return r.response.Write(b) + return r.context.Write(b) } func (r *Response) Status() int { diff --git a/engine/fasthttp/server.go b/engine/fasthttp/server.go index 222d6fcd..960d5a70 100644 --- a/engine/fasthttp/server.go +++ b/engine/fasthttp/server.go @@ -27,15 +27,15 @@ func NewServer(c *engine.Config, h engine.HandlerFunc, l logger.Logger) *Server } func (s *Server) Start() { - fasthttp.ListenAndServe(s.config.Address, func(ctx *fasthttp.RequestCtx) { + fasthttp.ListenAndServe(s.config.Address, func(c *fasthttp.RequestCtx) { req := &Request{ - request: ctx, - url: &URL{ctx.URI()}, - header: &RequestHeader{ctx.Request.Header}, + context: c, + url: &URL{c.URI()}, + header: &RequestHeader{c.Request.Header}, } res := &Response{ - response: ctx, - header: &ResponseHeader{ctx.Response.Header}, + context: c, + header: &ResponseHeader{c.Response.Header}, } s.handler(req, res) }) diff --git a/engine/fasthttp/url.go b/engine/fasthttp/url.go index 0fa5ae99..f66090df 100644 --- a/engine/fasthttp/url.go +++ b/engine/fasthttp/url.go @@ -4,16 +4,16 @@ import "github.com/valyala/fasthttp" type ( URL struct { - *fasthttp.URI + url *fasthttp.URI } ) func (u *URL) Scheme() string { - return string(u.URI.Scheme()) + return string(u.url.Scheme()) } func (u *URL) Host() string { - return string(u.URI.Host()) + return string(u.url.Host()) } func (u *URL) SetPath(path string) { @@ -21,7 +21,7 @@ func (u *URL) SetPath(path string) { } func (u *URL) Path() string { - return string(u.URI.Path()) + return string(u.url.Path()) } func (u *URL) QueryValue(name string) string { diff --git a/middleware/compress.go b/middleware/compress.go index 42d1fc66..3d8306b8 100644 --- a/middleware/compress.go +++ b/middleware/compress.go @@ -65,7 +65,7 @@ func Gzip() echo.MiddlewareFunc { c.Response().Header().Set(echo.ContentEncoding, scheme) c.Response().SetWriter(gw) } - if err := h(c); err != nil { + if err := h.Handle(c); err != nil { c.Error(err) } return nil diff --git a/middleware/log.go b/middleware/log.go index 37dee998..2c5f94ce 100644 --- a/middleware/log.go +++ b/middleware/log.go @@ -25,7 +25,7 @@ func Log() echo.MiddlewareFunc { } start := time.Now() - if err := h(c); err != nil { + if err := h.Handle(c); err != nil { c.Error(err) } stop := time.Now() diff --git a/middleware/recover.go b/middleware/recover.go index f641f276..5a7bf5a4 100644 --- a/middleware/recover.go +++ b/middleware/recover.go @@ -22,7 +22,7 @@ func Recover() echo.MiddlewareFunc { err, n, trace[:n])) } }() - return h(c) + return h.Handle(c) } } } diff --git a/test/http.go b/test/http.go index 2053abfc..aa70c904 100644 --- a/test/http.go +++ b/test/http.go @@ -19,8 +19,13 @@ type ( ) func NewRequest(method, url string, body io.Reader) engine.Request { + // switch t { + // case engine.Standard: r, _ := http.NewRequest(method, url, body) return standard.NewRequest(r) + // default: + // panic("invalid engine") + // } } func NewResponseRecorder() *ResponseRecorder {