diff --git a/engine/fasthttp/server.go b/engine/fasthttp/server.go index b165724e..b1b9ccee 100644 --- a/engine/fasthttp/server.go +++ b/engine/fasthttp/server.go @@ -123,11 +123,24 @@ func (s *Server) Start() { } } -// WrapHandler wraps `fasthttp.RequestHandler` into `echo.Handler`. -func WrapHandler(h fasthttp.RequestHandler) echo.Handler { - return echo.HandlerFunc(func(c echo.Context) error { +// WrapHandler wraps `fasthttp.RequestHandler` into `echo.HandlerFunc`. +func WrapHandler(h fasthttp.RequestHandler) echo.HandlerFunc { + return func(c echo.Context) error { ctx := c.Request().Object().(*fasthttp.RequestCtx) h(ctx) return nil - }) + } +} + +// WrapMiddleware wraps `fasthttp.RequestHandler` into `echo.MiddlewareFunc` +func WrapMiddleware(m fasthttp.RequestHandler) echo.MiddlewareFunc { + return func(next echo.Handler) echo.Handler { + return echo.HandlerFunc(func(c echo.Context) error { + ctx := c.Request().Object().(*fasthttp.RequestCtx) + if !c.Response().Committed() { + m(ctx) + } + return next.Handle(c) + }) + } } diff --git a/engine/standard/server.go b/engine/standard/server.go index 0e5033f4..6480b448 100644 --- a/engine/standard/server.go +++ b/engine/standard/server.go @@ -118,12 +118,26 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { s.pool.header.Put(resHdr) } -// WrapHandler wraps `http.Handler` into `echo.Handler`. -func WrapHandler(h http.Handler) echo.Handler { - return echo.HandlerFunc(func(c echo.Context) error { +// WrapHandler wraps `http.Handler` into `echo.HandlerFunc`. +func WrapHandler(h http.Handler) echo.HandlerFunc { + return func(c echo.Context) error { w := c.Response().Object().(http.ResponseWriter) r := c.Request().Object().(*http.Request) h.ServeHTTP(w, r) return nil - }) + } +} + +// WrapMiddleware wraps `http.Handler` into `echo.MiddlewareFunc` +func WrapMiddleware(m http.Handler) echo.MiddlewareFunc { + return func(next echo.Handler) echo.Handler { + return echo.HandlerFunc(func(c echo.Context) error { + w := c.Response().Object().(http.ResponseWriter) + r := c.Request().Object().(*http.Request) + if !c.Response().Committed() { + m.ServeHTTP(w, r) + } + return next.Handle(c) + }) + } }