1
0
mirror of https://github.com/labstack/echo.git synced 2024-12-24 20:14:31 +02:00

Wrappers for handler and middleware

Signed-off-by: Vishal Rana <vr@labstack.com>
This commit is contained in:
Vishal Rana 2016-02-08 22:17:20 -08:00
parent 3cd1d5be65
commit 7b843e66c5
11 changed files with 147 additions and 109 deletions

157
echo.go
View File

@ -61,7 +61,10 @@ type (
MiddlewareFunc func(HandlerFunc) HandlerFunc
// Handler interface{}
Handler interface {
Handle(Context) error
}
HandlerFunc func(Context) error
// HTTPErrorHandler is a centralized HTTP error handler.
@ -181,13 +184,13 @@ var (
// Error handlers
//----------------
notFoundHandler = func(c Context) error {
notFoundHandler = HandlerFunc(func(c Context) error {
return NewHTTPError(http.StatusNotFound)
}
})
methodNotAllowedHandler = func(c Context) error {
methodNotAllowedHandler = HandlerFunc(func(c Context) error {
return NewHTTPError(http.StatusMethodNotAllowed)
}
})
)
// New creates an instance of Echo.
@ -204,22 +207,7 @@ func New() (e *Echo) {
//----------
e.HTTP2(true)
e.defaultHTTPErrorHandler = func(err error, c Context) {
code := http.StatusInternalServerError
msg := http.StatusText(code)
if he, ok := err.(*HTTPError); ok {
code = he.code
msg = he.message
}
if e.debug {
msg = err.Error()
}
if !c.Response().Committed() {
c.String(code, msg)
}
e.logger.Error(err)
}
e.SetHTTPErrorHandler(e.defaultHTTPErrorHandler)
e.SetHTTPErrorHandler(e.DefaultHTTPErrorHandler)
e.SetBinder(&binder{})
// Logger
@ -232,6 +220,10 @@ func (f MiddlewareFunc) Process(h HandlerFunc) HandlerFunc {
return f(h)
}
func (f HandlerFunc) Handle(c Context) error {
return f(c)
}
// Router returns router.
func (e *Echo) Router() *Router {
return e.router
@ -254,7 +246,19 @@ func (e *Echo) HTTP2(on bool) {
// DefaultHTTPErrorHandler invokes the default HTTP error handler.
func (e *Echo) DefaultHTTPErrorHandler(err error, c Context) {
e.defaultHTTPErrorHandler(err, c)
code := http.StatusInternalServerError
msg := http.StatusText(code)
if he, ok := err.(*HTTPError); ok {
code = he.code
msg = he.message
}
if e.debug {
msg = err.Error()
}
if !c.Response().Committed() {
c.String(code, msg)
}
e.logger.Error(err)
}
// SetHTTPErrorHandler registers a custom Echo.HTTPErrorHandler.
@ -295,75 +299,75 @@ func (e *Echo) Hook(h engine.HandlerFunc) {
}
// Use adds handler to the middleware chain.
func (e *Echo) Use(m ...MiddlewareFunc) {
for _, h := range m {
e.middleware = append(e.middleware, h)
func (e *Echo) Use(middleware ...interface{}) {
for _, m := range middleware {
e.middleware = append(e.middleware, wrapMiddleware(m))
}
}
// Connect adds a CONNECT route > handler to the router.
func (e *Echo) Connect(path string, h HandlerFunc) {
e.add(CONNECT, path, h)
func (e *Echo) Connect(path string, handler interface{}) {
e.add(CONNECT, path, handler)
}
// Delete adds a DELETE route > handler to the router.
func (e *Echo) Delete(path string, h HandlerFunc) {
e.add(DELETE, path, h)
func (e *Echo) Delete(path string, handler interface{}) {
e.add(DELETE, path, handler)
}
// Get adds a GET route > handler to the router.
func (e *Echo) Get(path string, h HandlerFunc) {
e.add(GET, path, h)
func (e *Echo) Get(path string, handler interface{}) {
e.add(GET, path, handler)
}
// Head adds a HEAD route > handler to the router.
func (e *Echo) Head(path string, h HandlerFunc) {
e.add(HEAD, path, h)
func (e *Echo) Head(path string, handler interface{}) {
e.add(HEAD, path, handler)
}
// Options adds an OPTIONS route > handler to the router.
func (e *Echo) Options(path string, h HandlerFunc) {
e.add(OPTIONS, path, h)
func (e *Echo) Options(path string, handler interface{}) {
e.add(OPTIONS, path, handler)
}
// Patch adds a PATCH route > handler to the router.
func (e *Echo) Patch(path string, h HandlerFunc) {
e.add(PATCH, path, h)
func (e *Echo) Patch(path string, handler interface{}) {
e.add(PATCH, path, handler)
}
// Post adds a POST route > handler to the router.
func (e *Echo) Post(path string, h HandlerFunc) {
e.add(POST, path, h)
func (e *Echo) Post(path string, handler interface{}) {
e.add(POST, path, handler)
}
// Put adds a PUT route > handler to the router.
func (e *Echo) Put(path string, h HandlerFunc) {
e.add(PUT, path, h)
func (e *Echo) Put(path string, handler interface{}) {
e.add(PUT, path, handler)
}
// Trace adds a TRACE route > handler to the router.
func (e *Echo) Trace(path string, h HandlerFunc) {
e.add(TRACE, path, h)
func (e *Echo) Trace(path string, handler interface{}) {
e.add(TRACE, path, handler)
}
// Any adds a route > handler to the router for all HTTP methods.
func (e *Echo) Any(path string, h HandlerFunc) {
func (e *Echo) Any(path string, handler interface{}) {
for _, m := range methods {
e.add(m, path, h)
e.add(m, path, handler)
}
}
// Match adds a route > handler to the router for multiple HTTP methods provided.
func (e *Echo) Match(methods []string, path string, h HandlerFunc) {
func (e *Echo) Match(methods []string, path string, handler interface{}) {
for _, m := range methods {
e.add(m, path, h)
e.add(m, path, handler)
}
}
// NOTE: v2
func (e *Echo) add(method, path string, h HandlerFunc) {
func (e *Echo) add(method, path string, h interface{}) {
path = e.prefix + path
e.router.Add(method, path, h, e)
e.router.Add(method, path, wrapHandler(h), e)
r := Route{
Method: method,
Path: path,
@ -511,8 +515,7 @@ func (e *Echo) Routes() []Route {
return e.router.routes
}
// ServeHTTP serves HTTP requests.
func (e *Echo) ServeHTTP(req engine.Request, res engine.Response) {
func (e *Echo) handle(req engine.Request, res engine.Response) {
if e.hook != nil {
e.hook(req, res)
}
@ -566,33 +569,11 @@ func (e *Echo) RunTLS(addr, certfile, keyfile string) {
// RunConfig runs a server with engine configuration.
func (e *Echo) RunConfig(config *engine.Config) {
handler := func(req engine.Request, res engine.Response) {
if e.hook != nil {
e.hook(req, res)
}
c := e.pool.Get().(*context)
h, e := e.router.Find(req.Method(), req.URL().Path(), c)
c.reset(req, res, e)
// Chain middleware with handler in the end
for i := len(e.middleware) - 1; i >= 0; i-- {
h = e.middleware[i](h)
}
// Execute chain
if err := h(c); err != nil {
e.httpErrorHandler(err, c)
}
e.pool.Put(c)
}
switch e.engineType {
case engine.FastHTTP:
e.engine = fasthttp.NewServer(config, handler, e.logger)
e.engine = fasthttp.NewServer(config, e.handle, e.logger)
default:
e.engine = standard.NewServer(config, handler, e.logger)
e.engine = standard.NewServer(config, e.handle, e.logger)
}
e.engine.Start()
}
@ -635,3 +616,29 @@ func (binder) Bind(r engine.Request, i interface{}) (err error) {
}
return
}
func wrapMiddleware(m interface{}) MiddlewareFunc {
switch m := m.(type) {
case Middleware:
return m.Process
case MiddlewareFunc:
return m
case func(HandlerFunc) HandlerFunc:
return m
default:
panic("invalid middleware")
}
}
func wrapHandler(h interface{}) HandlerFunc {
switch h := h.(type) {
case Handler:
return h.Handle
case HandlerFunc:
return h
case func(Context) error:
return h
default:
panic("invalid handler")
}
}

View File

@ -308,7 +308,7 @@ func TestEchoNotFound(t *testing.T) {
e := New()
req := test.NewRequest(GET, "/files", nil)
rec := test.NewResponseRecorder()
e.ServeHTTP(req, rec)
e.handle(req, rec)
assert.Equal(t, http.StatusNotFound, rec.Status())
}
@ -319,7 +319,7 @@ func TestEchoMethodNotAllowed(t *testing.T) {
})
req := test.NewRequest(POST, "/", nil)
rec := test.NewResponseRecorder()
e.ServeHTTP(req, rec)
e.handle(req, rec)
assert.Equal(t, http.StatusMethodNotAllowed, rec.Status())
}
@ -350,7 +350,7 @@ func TestEchoHook(t *testing.T) {
})
req := test.NewRequest(GET, "/test/", nil)
rec := test.NewResponseRecorder()
e.ServeHTTP(req, rec)
e.handle(req, rec)
assert.Equal(t, req.URL().Path(), "/test")
}
@ -371,6 +371,6 @@ func testMethod(t *testing.T, method, path string, e *Echo) {
func request(method, path string, e *Echo) (int, string) {
req := test.NewRequest(method, path, nil)
rec := test.NewResponseRecorder()
e.ServeHTTP(req, rec)
e.handle(req, rec)
return rec.Status(), rec.Body.String()
}

View File

@ -4,11 +4,11 @@ import "github.com/valyala/fasthttp"
type (
RequestHeader struct {
fasthttp.RequestHeader
header fasthttp.RequestHeader
}
ResponseHeader struct {
fasthttp.ResponseHeader
header fasthttp.ResponseHeader
}
)
@ -17,23 +17,24 @@ func (h *RequestHeader) Add(key, val string) {
}
func (h *RequestHeader) Del(key string) {
h.RequestHeader.Del(key)
h.header.Del(key)
}
func (h *RequestHeader) Get(key string) string {
return string(h.RequestHeader.Peek(key))
// return h.header.Peek(key)
return ""
}
func (h *RequestHeader) Set(key, val string) {
h.RequestHeader.Set(key, val)
h.header.Set(key, val)
}
func (h *ResponseHeader) Add(key, val string) {
// h.ResponseHeader.Add(key, val)
// h.header.Add(key, val)
}
func (h *ResponseHeader) Del(key string) {
h.ResponseHeader.Del(key)
h.header.Del(key)
}
func (h *ResponseHeader) Get(key string) string {
@ -42,5 +43,5 @@ func (h *ResponseHeader) Get(key string) string {
}
func (h *ResponseHeader) Set(key, val string) {
h.ResponseHeader.Set(key, val)
h.header.Set(key, val)
}

View File

@ -9,18 +9,26 @@ import (
type (
Request struct {
request *fasthttp.RequestCtx
context *fasthttp.RequestCtx
url engine.URL
header engine.Header
}
)
func NewRequest(c *fasthttp.RequestCtx) *Request {
return &Request{
context: c,
url: &URL{url: c.URI()},
header: &RequestHeader{c.Request.Header},
}
}
func (r *Request) Object() interface{} {
return r.request
return r.context
}
func (r *Request) URI() string {
return string(r.request.RequestURI())
return string(r.context.RequestURI())
}
func (r *Request) URL() engine.URL {
@ -32,11 +40,11 @@ func (r *Request) Header() engine.Header {
}
func (r *Request) RemoteAddress() string {
return r.request.RemoteAddr().String()
return r.context.RemoteAddr().String()
}
func (r *Request) Method() string {
return string(r.request.Method())
return string(r.context.Method())
}
func (r *Request) Body() io.ReadCloser {

View File

@ -4,22 +4,33 @@ import (
"io"
"github.com/labstack/echo/engine"
"github.com/labstack/echo/logger"
"github.com/valyala/fasthttp"
)
type (
Response struct {
response *fasthttp.RequestCtx
context *fasthttp.RequestCtx
header engine.Header
status int
size int64
committed bool
writer io.Writer
logger logger.Logger
}
)
func NewResponse(c *fasthttp.RequestCtx, l logger.Logger) *Response {
return &Response{
context: c,
header: &ResponseHeader{c.Response.Header},
writer: c,
logger: l,
}
}
func (r *Response) Object() interface{} {
return r.response
return r.context
}
func (r *Response) Header() engine.Header {
@ -27,11 +38,17 @@ func (r *Response) Header() engine.Header {
}
func (r *Response) WriteHeader(code int) {
r.response.SetStatusCode(code)
if r.committed {
r.logger.Warn("response already committed")
return
}
r.status = code
r.context.SetStatusCode(code)
r.committed = true
}
func (r *Response) Write(b []byte) (int, error) {
return r.response.Write(b)
return r.context.Write(b)
}
func (r *Response) Status() int {

View File

@ -27,15 +27,15 @@ func NewServer(c *engine.Config, h engine.HandlerFunc, l logger.Logger) *Server
}
func (s *Server) Start() {
fasthttp.ListenAndServe(s.config.Address, func(ctx *fasthttp.RequestCtx) {
fasthttp.ListenAndServe(s.config.Address, func(c *fasthttp.RequestCtx) {
req := &Request{
request: ctx,
url: &URL{ctx.URI()},
header: &RequestHeader{ctx.Request.Header},
context: c,
url: &URL{c.URI()},
header: &RequestHeader{c.Request.Header},
}
res := &Response{
response: ctx,
header: &ResponseHeader{ctx.Response.Header},
context: c,
header: &ResponseHeader{c.Response.Header},
}
s.handler(req, res)
})

View File

@ -4,16 +4,16 @@ import "github.com/valyala/fasthttp"
type (
URL struct {
*fasthttp.URI
url *fasthttp.URI
}
)
func (u *URL) Scheme() string {
return string(u.URI.Scheme())
return string(u.url.Scheme())
}
func (u *URL) Host() string {
return string(u.URI.Host())
return string(u.url.Host())
}
func (u *URL) SetPath(path string) {
@ -21,7 +21,7 @@ func (u *URL) SetPath(path string) {
}
func (u *URL) Path() string {
return string(u.URI.Path())
return string(u.url.Path())
}
func (u *URL) QueryValue(name string) string {

View File

@ -65,7 +65,7 @@ func Gzip() echo.MiddlewareFunc {
c.Response().Header().Set(echo.ContentEncoding, scheme)
c.Response().SetWriter(gw)
}
if err := h(c); err != nil {
if err := h.Handle(c); err != nil {
c.Error(err)
}
return nil

View File

@ -25,7 +25,7 @@ func Log() echo.MiddlewareFunc {
}
start := time.Now()
if err := h(c); err != nil {
if err := h.Handle(c); err != nil {
c.Error(err)
}
stop := time.Now()

View File

@ -22,7 +22,7 @@ func Recover() echo.MiddlewareFunc {
err, n, trace[:n]))
}
}()
return h(c)
return h.Handle(c)
}
}
}

View File

@ -19,8 +19,13 @@ type (
)
func NewRequest(method, url string, body io.Reader) engine.Request {
// switch t {
// case engine.Standard:
r, _ := http.NewRequest(method, url, body)
return standard.NewRequest(r)
// default:
// panic("invalid engine")
// }
}
func NewResponseRecorder() *ResponseRecorder {