diff --git a/echo.go b/echo.go index ffea2e32..c61d9688 100644 --- a/echo.go +++ b/echo.go @@ -23,6 +23,7 @@ type ( Echo struct { prefix string middleware []Middleware + head Handler http2 bool maxParam *int notFoundHandler HandlerFunc @@ -193,7 +194,7 @@ func New() (e *Echo) { return NewContext(nil, nil, e) } e.router = NewRouter(e) - e.middleware = []Middleware{e.router} + e.head = e.router.Handle(nil) //---------- // Defaults @@ -284,80 +285,85 @@ func (e *Echo) Debug() bool { } // Use adds handler to the middleware chain. -func (e *Echo) Use(middleware ...interface{}) { - for _, m := range middleware { - e.middleware = append(e.middleware, wrapMiddleware(m)) +func (e *Echo) Use(middleware ...Middleware) { + e.middleware = append(e.middleware, middleware...) + m := append(e.middleware, e.router) + + // Chain middleware + for i := len(m) - 1; i >= 0; i-- { + e.head = m[i].Handle(e.head) } } // Connect adds a CONNECT route > handler to the router. -func (e *Echo) Connect(path string, handler interface{}, middleware ...interface{}) { - e.add(CONNECT, path, handler, middleware...) +func (e *Echo) Connect(path string, h Handler, m ...Middleware) { + e.add(CONNECT, path, h, m...) } // Delete adds a DELETE route > handler to the router. -func (e *Echo) Delete(path string, handler interface{}, middleware ...interface{}) { - e.add(DELETE, path, handler, middleware...) +func (e *Echo) Delete(path string, h Handler, m ...Middleware) { + e.add(DELETE, path, h, m...) } // Get adds a GET route > handler to the router. -func (e *Echo) Get(path string, handler interface{}, middleware ...interface{}) { - e.add(GET, path, handler, middleware...) +func (e *Echo) Get(path string, h Handler, m ...Middleware) { + e.add(GET, path, h, m...) } // Head adds a HEAD route > handler to the router. -func (e *Echo) Head(path string, handler interface{}, middleware ...interface{}) { - e.add(HEAD, path, handler, middleware...) +func (e *Echo) Head(path string, h Handler, m ...Middleware) { + e.add(HEAD, path, h, m...) } // Options adds an OPTIONS route > handler to the router. -func (e *Echo) Options(path string, handler interface{}, middleware ...interface{}) { - e.add(OPTIONS, path, handler, middleware...) +func (e *Echo) Options(path string, h Handler, m ...Middleware) { + e.add(OPTIONS, path, h, m...) } // Patch adds a PATCH route > handler to the router. -func (e *Echo) Patch(path string, handler interface{}, middleware ...interface{}) { - e.add(PATCH, path, handler, middleware...) +func (e *Echo) Patch(path string, h Handler, m ...Middleware) { + e.add(PATCH, path, h, m...) } // Post adds a POST route > handler to the router. -func (e *Echo) Post(path string, handler interface{}, middleware ...interface{}) { - e.add(POST, path, handler, middleware...) +func (e *Echo) Post(path string, h Handler, m ...Middleware) { + e.add(POST, path, h, m...) } // Put adds a PUT route > handler to the router. -func (e *Echo) Put(path string, handler interface{}, middleware ...interface{}) { - e.add(PUT, path, handler, middleware...) +func (e *Echo) Put(path string, h Handler, m ...Middleware) { + e.add(PUT, path, h, m...) } // Trace adds a TRACE route > handler to the router. -func (e *Echo) Trace(path string, handler interface{}, middleware ...interface{}) { - e.add(TRACE, path, handler, middleware...) +func (e *Echo) Trace(path string, h Handler, m ...Middleware) { + e.add(TRACE, path, h, m...) } // Any adds a route > handler to the router for all HTTP methods. -func (e *Echo) Any(path string, handler interface{}, middleware ...interface{}) { +func (e *Echo) Any(path string, handler Handler, middleware ...Middleware) { for _, m := range methods { e.add(m, path, handler, middleware...) } } // Match adds a route > handler to the router for multiple HTTP methods provided. -func (e *Echo) Match(methods []string, path string, handler interface{}, middleware ...interface{}) { +func (e *Echo) Match(methods []string, path string, handler Handler, middleware ...Middleware) { for _, m := range methods { e.add(m, path, handler, middleware...) } } -// NOTE: v2 -func (e *Echo) add(method, path string, handler interface{}, middleware ...interface{}) { - h := wrapHandler(handler) +func (e *Echo) add(method, path string, handler Handler, middleware ...Middleware) { name := handlerName(handler) + // middleware = append(e.middleware, middleware...) + // e.router.Add(method, path, handler, e) + e.router.Add(method, path, HandlerFunc(func(c Context) error { for _, m := range middleware { - h = wrapMiddleware(m).Handle(h) + handler = m.Handle(handler) } - return h.Handle(c) + return handler.Handle(c) }), e) r := Route{ Method: method, @@ -368,14 +374,14 @@ func (e *Echo) add(method, path string, handler interface{}, middleware ...inter } // Group creates a new sub-router with prefix. -func (e *Echo) Group(prefix string, middleware ...interface{}) (g *Group) { +func (e *Echo) Group(prefix string, m ...Middleware) (g *Group) { g = &Group{prefix: prefix, echo: e} - g.Use(middleware...) + g.Use(m...) return } // URI generates a URI from handler. -func (e *Echo) URI(handler interface{}, params ...interface{}) string { +func (e *Echo) URI(handler Handler, params ...interface{}) string { uri := new(bytes.Buffer) ln := len(params) n := 0 @@ -400,8 +406,8 @@ func (e *Echo) URI(handler interface{}, params ...interface{}) string { } // URL is an alias for `URI` function. -func (e *Echo) URL(handler interface{}, params ...interface{}) string { - return e.URI(handler, params...) +func (e *Echo) URL(h Handler, params ...interface{}) string { + return e.URI(h, params...) } // Routes returns the registered routes. @@ -412,15 +418,9 @@ func (e *Echo) Routes() []Route { func (e *Echo) ServeHTTP(req engine.Request, res engine.Response) { c := e.pool.Get().(*context) c.reset(req, res) - h := Handler(c) - - // Chain middleware with handler in the end - for i := len(e.middleware) - 1; i >= 0; i-- { - h = e.middleware[i].Handle(h) - } // Execute chain - if err := h.Handle(c); err != nil { + if err := e.head.Handle(c); err != nil { e.httpErrorHandler(err, c) } @@ -473,40 +473,6 @@ func (binder) Bind(r engine.Request, i interface{}) (err error) { return } -func wrapMiddleware(m interface{}) Middleware { - switch m := m.(type) { - case Middleware: - return m - case MiddlewareFunc: - return m - case func(Handler) Handler: - return MiddlewareFunc(m) - default: - panic("invalid middleware") - } -} - -func wrapHandler(h interface{}) Handler { - switch h := h.(type) { - case Handler: - return h - case HandlerFunc: - return h - case func(Context) error: - return HandlerFunc(h) - default: - panic("echo => invalid handler") - } -} - -func handlerName(h interface{}) string { - switch h := h.(type) { - case Handler: - t := reflect.TypeOf(h) - return fmt.Sprintf("%s ยป %s", t.PkgPath(), t.Name()) - case HandlerFunc, func(Context) error: - return runtime.FuncForPC(reflect.ValueOf(h).Pointer()).Name() - default: - panic("echo => invalid handler") - } +func handlerName(h Handler) string { + return runtime.FuncForPC(reflect.ValueOf(h).Pointer()).Name() } diff --git a/echo_test.go b/echo_test.go index e84d053c..47da96e1 100644 --- a/echo_test.go +++ b/echo_test.go @@ -44,31 +44,31 @@ func TestEchoMiddleware(t *testing.T) { e := New() buf := new(bytes.Buffer) - e.Use(func(h Handler) Handler { + e.Use(MiddlewareFunc(func(h Handler) Handler { return HandlerFunc(func(c Context) error { buf.WriteString("a") return h.Handle(c) }) - }) + })) - e.Use(func(h Handler) Handler { + e.Use(MiddlewareFunc(func(h Handler) Handler { return HandlerFunc(func(c Context) error { buf.WriteString("b") return h.Handle(c) }) - }) + })) - e.Use(func(h Handler) Handler { + e.Use(MiddlewareFunc(func(h Handler) Handler { return HandlerFunc(func(c Context) error { buf.WriteString("c") return h.Handle(c) }) - }) + })) // Route - e.Get("/", func(c Context) error { + e.Get("/", HandlerFunc(func(c Context) error { return c.String(http.StatusOK, "OK") - }) + })) c, b := request(GET, "/", e) assert.Equal(t, "abc", buf.String()) @@ -76,11 +76,11 @@ func TestEchoMiddleware(t *testing.T) { assert.Equal(t, "OK", b) // Error - e.Use(func(Handler) Handler { + e.Use(MiddlewareFunc(func(Handler) Handler { return HandlerFunc(func(c Context) error { return errors.New("error") }) - }) + })) c, b = request(GET, "/", e) assert.Equal(t, http.StatusInternalServerError, c) } @@ -89,9 +89,9 @@ func TestEchoHandler(t *testing.T) { e := New() // HandlerFunc - e.Get("/ok", func(c Context) error { + e.Get("/ok", HandlerFunc(func(c Context) error { return c.String(http.StatusOK, "OK") - }) + })) c, b := request(GET, "/ok", e) assert.Equal(t, http.StatusOK, c) @@ -145,23 +145,23 @@ func TestEchoTrace(t *testing.T) { func TestEchoAny(t *testing.T) { // JFC e := New() - e.Any("/", func(c Context) error { + e.Any("/", HandlerFunc(func(c Context) error { return c.String(http.StatusOK, "Any") - }) + })) } func TestEchoMatch(t *testing.T) { // JFC e := New() - e.Match([]string{GET, POST}, "/", func(c Context) error { + e.Match([]string{GET, POST}, "/", HandlerFunc(func(c Context) error { return c.String(http.StatusOK, "Match") - }) + })) } func TestEchoURL(t *testing.T) { e := New() - static := func(Context) error { return nil } - getUser := func(Context) error { return nil } - getFile := func(Context) error { return nil } + static := HandlerFunc(func(Context) error { return nil }) + getUser := HandlerFunc(func(Context) error { return nil }) + getFile := HandlerFunc(func(Context) error { return nil }) e.Get("/static/file", static) e.Get("/users/:id", getUser) @@ -184,9 +184,9 @@ func TestEchoRoutes(t *testing.T) { {POST, "/repos/:owner/:repo/git/tags", ""}, } for _, r := range routes { - e.add(r.Method, r.Path, func(c Context) error { + e.add(r.Method, r.Path, HandlerFunc(func(c Context) error { return c.String(http.StatusOK, "OK") - }) + })) } for i, r := range e.Routes() { @@ -198,15 +198,15 @@ func TestEchoRoutes(t *testing.T) { func TestEchoGroup(t *testing.T) { e := New() buf := new(bytes.Buffer) - e.Use(func(h Handler) Handler { + e.Use(MiddlewareFunc(func(h Handler) Handler { return HandlerFunc(func(c Context) error { buf.WriteString("0") return h.Handle(c) }) - }) - h := func(c Context) error { + })) + h := HandlerFunc(func(c Context) error { return c.NoContent(http.StatusOK) - } + }) //-------- // Routes @@ -216,12 +216,12 @@ func TestEchoGroup(t *testing.T) { // Group g1 := e.Group("/group1") - g1.Use(func(h Handler) Handler { + g1.Use(MiddlewareFunc(func(h Handler) Handler { return HandlerFunc(func(c Context) error { buf.WriteString("1") return h.Handle(c) }) - }) + })) g1.Get("/", h) // Nested groups @@ -251,9 +251,9 @@ func TestEchoNotFound(t *testing.T) { func TestEchoMethodNotAllowed(t *testing.T) { e := New() - e.Get("/", func(c Context) error { + e.Get("/", HandlerFunc(func(c Context) error { return c.String(http.StatusOK, "Echo!") - }) + })) req := test.NewRequest(POST, "/", nil) rec := test.NewResponseRecorder() e.ServeHTTP(req, rec) @@ -270,9 +270,9 @@ func TestEchoHTTPError(t *testing.T) { func testMethod(t *testing.T, method, path string, e *Echo) { m := fmt.Sprintf("%c%s", method[0], strings.ToLower(method[1:])) p := reflect.ValueOf(path) - h := reflect.ValueOf(func(c Context) error { + h := reflect.ValueOf(HandlerFunc(func(c Context) error { return c.String(http.StatusOK, method) - }) + })) i := interface{}(e) reflect.ValueOf(i).MethodByName(m).Call([]reflect.Value{p, h}) _, body := request(method, path, e) diff --git a/group.go b/group.go index 5f24b684..48da42ac 100644 --- a/group.go +++ b/group.go @@ -8,56 +8,53 @@ type ( } ) -func (g *Group) Use(middleware ...interface{}) { - for _, m := range middleware { - g.middleware = append(g.middleware, wrapMiddleware(m)) - } +func (g *Group) Use(m ...Middleware) { + g.middleware = append(g.middleware, m...) } -func (g *Group) Connect(path string, handler interface{}) { - g.add(CONNECT, path, handler) +func (g *Group) Connect(path string, h Handler) { + g.add(CONNECT, path, h) } -func (g *Group) Delete(path string, handler interface{}) { - g.add(DELETE, path, handler) +func (g *Group) Delete(path string, h Handler) { + g.add(DELETE, path, h) } -func (g *Group) Get(path string, handler interface{}) { - g.add(GET, path, handler) +func (g *Group) Get(path string, h Handler) { + g.add(GET, path, h) } -func (g *Group) Head(path string, handler interface{}) { - g.add(HEAD, path, handler) +func (g *Group) Head(path string, h Handler) { + g.add(HEAD, path, h) } -func (g *Group) Options(path string, handler interface{}) { - g.add(OPTIONS, path, handler) +func (g *Group) Options(path string, h Handler) { + g.add(OPTIONS, path, h) } -func (g *Group) Patch(path string, handler interface{}) { - g.add(PATCH, path, handler) +func (g *Group) Patch(path string, h Handler) { + g.add(PATCH, path, h) } -func (g *Group) Post(path string, handler interface{}) { - g.add(POST, path, handler) +func (g *Group) Post(path string, h Handler) { + g.add(POST, path, h) } -func (g *Group) Put(path string, handler interface{}) { - g.add(PUT, path, handler) +func (g *Group) Put(path string, h Handler) { + g.add(PUT, path, h) } -func (g *Group) Trace(path string, handler interface{}) { - g.add(TRACE, path, handler) +func (g *Group) Trace(path string, h Handler) { + g.add(TRACE, path, h) } -func (g *Group) Group(prefix string, middleware ...interface{}) *Group { - return g.echo.Group(prefix, middleware...) +func (g *Group) Group(prefix string, m ...Middleware) *Group { + return g.echo.Group(prefix, m...) } -func (g *Group) add(method, path string, handler interface{}) { +func (g *Group) add(method, path string, h Handler) { path = g.prefix + path - h := wrapHandler(handler) - name := handlerName(handler) + name := handlerName(h) g.echo.router.Add(method, path, HandlerFunc(func(c Context) error { for i := len(g.middleware) - 1; i >= 0; i-- { h = g.middleware[i].Handle(h) diff --git a/group_test.go b/group_test.go index 1d02abc7..6c33e309 100644 --- a/group_test.go +++ b/group_test.go @@ -4,7 +4,7 @@ import "testing" func TestGroup(t *testing.T) { g := New().Group("/group") - h := func(Context) error { return nil } + h := HandlerFunc(func(Context) error { return nil }) g.Connect("/", h) g.Delete("/", h) g.Get("/", h) diff --git a/router.go b/router.go index 385a063f..0be1c3bb 100644 --- a/router.go +++ b/router.go @@ -52,7 +52,7 @@ func (r *Router) Handle(h Handler) Handler { method := c.Request().Method() path := c.Request().URL().Path() r.Find(method, path, c) - return h.Handle(c) + return c.Handle(c) }) }