diff --git a/context.go b/context.go index 27057139..5daea043 100644 --- a/context.go +++ b/context.go @@ -3,7 +3,10 @@ package echo import ( "encoding/json" "encoding/xml" + "io" + "mime" "net/http" + "os" "path/filepath" "time" @@ -42,10 +45,11 @@ type ( JSONP(int, string, interface{}) error XML(int, interface{}) error XMLBlob(int, []byte) error - File(string, string, bool) error + Attachment(string) error NoContent(int) error Redirect(int, string) error Error(err error) + Handle(Context) error Logger() logger.Logger Object() *context } @@ -59,12 +63,17 @@ type ( pvalues []string query url.Values store store + handler Handler echo *Echo } store map[string]interface{} ) +const ( + indexPage = "index.html" +) + // NewContext creates a Context object. func NewContext(req engine.Request, res engine.Response, e *Echo) Context { return &context{ @@ -73,9 +82,14 @@ func NewContext(req engine.Request, res engine.Response, e *Echo) Context { echo: e, pvalues: make([]string, *e.maxParam), store: make(store), + handler: notFoundHandler, } } +func (c *context) Handle(ctx Context) error { + return c.handler.Handle(ctx) +} + func (c *context) Deadline() (deadline time.Time, ok bool) { return } @@ -166,7 +180,7 @@ func (c *context) Bind(i interface{}) error { // code. Templates can be registered using `Echo.SetRenderer()`. func (c *context) Render(code int, name string, data interface{}) (err error) { if c.echo.renderer == nil { - return RendererNotRegistered + return ErrRendererNotRegistered } buf := new(bytes.Buffer) if err = c.echo.renderer.Render(buf, name, data); err != nil { @@ -250,17 +264,17 @@ func (c *context) XMLBlob(code int, b []byte) (err error) { return } -// File sends a response with the content of the file. If `attachment` is set -// to true, the client is prompted to save the file with provided `name`, -// name can be empty, in that case name of the file is used. -func (c *context) File(path, name string, attachment bool) (err error) { - dir, file := filepath.Split(path) - if attachment { - c.response.Header().Set(ContentDisposition, "attachment; filename="+name) - } - if err = c.echo.serveFile(dir, file, c); err != nil { - c.response.Header().Del(ContentDisposition) +// Attachment sends specified file as an attachment to the client. +func (c *context) Attachment(file string) (err error) { + f, err := os.Open(file) + if err != nil { + return } + _, name := filepath.Split(file) + c.response.Header().Set(ContentDisposition, "attachment; filename="+name) + c.response.Header().Set(ContentType, c.detectContentType(file)) + c.response.WriteHeader(http.StatusOK) + _, err = io.Copy(c.response, f) return } @@ -273,7 +287,7 @@ func (c *context) NoContent(code int) error { // Redirect redirects the request using http.Redirect with status code. func (c *context) Redirect(code int, url string) error { if code < http.StatusMultipleChoices || code > http.StatusTemporaryRedirect { - return InvalidRedirectCode + return ErrInvalidRedirectCode } // TODO: v2 // http.Redirect(c.response, c.request, url, code) @@ -295,10 +309,16 @@ func (c *context) Object() *context { return c } -func (c *context) reset(req engine.Request, res engine.Response, e *Echo) { +func (c *context) detectContentType(name string) (t string) { + if t = mime.TypeByExtension(filepath.Ext(name)); t == "" { + t = OctetStream + } + return +} + +func (c *context) reset(req engine.Request, res engine.Response) { c.request = req c.response = res c.query = nil c.store = nil - c.echo = e } diff --git a/context_test.go b/context_test.go index 778b7c01..6622ad8b 100644 --- a/context_test.go +++ b/context_test.go @@ -157,22 +157,13 @@ func TestContext(t *testing.T) { assert.Equal(t, "Hello, World!", rec.Body.String()) } - // File + // Attachment rec = test.NewResponseRecorder() c = NewContext(req, rec, e) - err = c.File("_fixture/images/walle.png", "", false) + err = c.Attachment("_fixture/images/walle.png") if assert.NoError(t, err) { assert.Equal(t, http.StatusOK, rec.Status()) - assert.Equal(t, 219885, rec.Body.Len()) - } - - // File as attachment - rec = test.NewResponseRecorder() - c = NewContext(req, rec, e) - err = c.File("_fixture/images/walle.png", "WALLE.PNG", true) - if assert.NoError(t, err) { - assert.Equal(t, http.StatusOK, rec.Status()) - assert.Equal(t, rec.Header().Get(ContentDisposition), "attachment; filename=WALLE.PNG") + assert.Equal(t, rec.Header().Get(ContentDisposition), "attachment; filename=walle.png") assert.Equal(t, 219885, rec.Body.Len()) } @@ -194,7 +185,7 @@ func TestContext(t *testing.T) { assert.Equal(t, http.StatusInternalServerError, c.Response().Status()) // reset - c.Object().reset(req, test.NewResponseRecorder(), e) + c.Object().reset(req, test.NewResponseRecorder()) } func TestContextPath(t *testing.T) { @@ -263,7 +254,7 @@ func testBindError(t *testing.T, c Context, ct string) { } default: if assert.IsType(t, new(HTTPError), err) { - assert.Equal(t, UnsupportedMediaType, err) + assert.Equal(t, ErrUnsupportedMediaType, err) } } diff --git a/echo.go b/echo.go index 8aea62bc..ffea2e32 100644 --- a/echo.go +++ b/echo.go @@ -7,8 +7,6 @@ import ( "fmt" "io" "net/http" - "path" - "path/filepath" "reflect" "runtime" "strings" @@ -24,7 +22,7 @@ import ( type ( Echo struct { prefix string - middleware []MiddlewareFunc + middleware []Middleware http2 bool maxParam *int notFoundHandler HandlerFunc @@ -33,8 +31,6 @@ type ( renderer Renderer pool sync.Pool debug bool - hook engine.HandlerFunc - autoIndex bool router *Router logger logger.Logger } @@ -51,10 +47,11 @@ type ( } Middleware interface { - Process(HandlerFunc) HandlerFunc + Handle(Handler) Handler + Priority() int } - MiddlewareFunc func(HandlerFunc) HandlerFunc + MiddlewareFunc func(Handler) Handler Handler interface { Handle(Context) error @@ -122,6 +119,7 @@ const ( TextPlain = "text/plain" TextPlainCharsetUTF8 = TextPlain + "; " + CharsetUTF8 MultipartForm = "multipart/form-data" + OctetStream = "application/octet-stream" //--------- // Charset @@ -150,8 +148,6 @@ const ( //----------- WebSocket = "websocket" - - indexPage = "index.html" ) var ( @@ -171,9 +167,10 @@ var ( // Errors //-------- - UnsupportedMediaType = NewHTTPError(http.StatusUnsupportedMediaType) - RendererNotRegistered = errors.New("renderer not registered") - InvalidRedirectCode = errors.New("invalid redirect status code") + ErrUnsupportedMediaType = NewHTTPError(http.StatusUnsupportedMediaType) + ErrNotFound = NewHTTPError(http.StatusNotFound) + ErrRendererNotRegistered = errors.New("renderer not registered") + ErrInvalidRedirectCode = errors.New("invalid redirect status code") //---------------- // Error handlers @@ -196,6 +193,7 @@ func New() (e *Echo) { return NewContext(nil, nil, e) } e.router = NewRouter(e) + e.middleware = []Middleware{e.router} //---------- // Defaults @@ -211,10 +209,14 @@ func New() (e *Echo) { return } -func (f MiddlewareFunc) Process(h HandlerFunc) HandlerFunc { +func (f MiddlewareFunc) Handle(h Handler) Handler { return f(h) } +func (f MiddlewareFunc) Priority() int { + return 1 +} + func (f HandlerFunc) Handle(c Context) error { return f(c) } @@ -281,18 +283,6 @@ func (e *Echo) Debug() bool { return e.debug } -// AutoIndex enable/disable automatically creating an index page for the directory. -func (e *Echo) AutoIndex(on bool) { - e.autoIndex = on -} - -// Hook registers a callback which is invoked from `Echo#ServerHTTP` as the first -// statement. Hook is useful if you want to modify response/response objects even -// before it hits the router or any middleware. -func (e *Echo) Hook(h engine.HandlerFunc) { - e.hook = h -} - // Use adds handler to the middleware chain. func (e *Echo) Use(middleware ...interface{}) { for _, m := range middleware { @@ -301,190 +291,99 @@ func (e *Echo) Use(middleware ...interface{}) { } // Connect adds a CONNECT route > handler to the router. -func (e *Echo) Connect(path string, handler interface{}) { - e.add(CONNECT, path, handler) +func (e *Echo) Connect(path string, handler interface{}, middleware ...interface{}) { + e.add(CONNECT, path, handler, middleware...) } // Delete adds a DELETE route > handler to the router. -func (e *Echo) Delete(path string, handler interface{}) { - e.add(DELETE, path, handler) +func (e *Echo) Delete(path string, handler interface{}, middleware ...interface{}) { + e.add(DELETE, path, handler, middleware...) } // Get adds a GET route > handler to the router. -func (e *Echo) Get(path string, handler interface{}) { - e.add(GET, path, handler) +func (e *Echo) Get(path string, handler interface{}, middleware ...interface{}) { + e.add(GET, path, handler, middleware...) } // Head adds a HEAD route > handler to the router. -func (e *Echo) Head(path string, handler interface{}) { - e.add(HEAD, path, handler) +func (e *Echo) Head(path string, handler interface{}, middleware ...interface{}) { + e.add(HEAD, path, handler, middleware...) } // Options adds an OPTIONS route > handler to the router. -func (e *Echo) Options(path string, handler interface{}) { - e.add(OPTIONS, path, handler) +func (e *Echo) Options(path string, handler interface{}, middleware ...interface{}) { + e.add(OPTIONS, path, handler, middleware...) } // Patch adds a PATCH route > handler to the router. -func (e *Echo) Patch(path string, handler interface{}) { - e.add(PATCH, path, handler) +func (e *Echo) Patch(path string, handler interface{}, middleware ...interface{}) { + e.add(PATCH, path, handler, middleware...) } // Post adds a POST route > handler to the router. -func (e *Echo) Post(path string, handler interface{}) { - e.add(POST, path, handler) +func (e *Echo) Post(path string, handler interface{}, middleware ...interface{}) { + e.add(POST, path, handler, middleware...) } // Put adds a PUT route > handler to the router. -func (e *Echo) Put(path string, handler interface{}) { - e.add(PUT, path, handler) +func (e *Echo) Put(path string, handler interface{}, middleware ...interface{}) { + e.add(PUT, path, handler, middleware...) } // Trace adds a TRACE route > handler to the router. -func (e *Echo) Trace(path string, handler interface{}) { - e.add(TRACE, path, handler) +func (e *Echo) Trace(path string, handler interface{}, middleware ...interface{}) { + e.add(TRACE, path, handler, middleware...) } // Any adds a route > handler to the router for all HTTP methods. -func (e *Echo) Any(path string, handler interface{}) { +func (e *Echo) Any(path string, handler interface{}, middleware ...interface{}) { for _, m := range methods { - e.add(m, path, handler) + e.add(m, path, handler, middleware...) } } // Match adds a route > handler to the router for multiple HTTP methods provided. -func (e *Echo) Match(methods []string, path string, handler interface{}) { +func (e *Echo) Match(methods []string, path string, handler interface{}, middleware ...interface{}) { for _, m := range methods { - e.add(m, path, handler) + e.add(m, path, handler, middleware...) } } // NOTE: v2 -func (e *Echo) add(method, path string, h interface{}) { - path = e.prefix + path - e.router.Add(method, path, wrapHandler(h), e) +func (e *Echo) add(method, path string, handler interface{}, middleware ...interface{}) { + h := wrapHandler(handler) + name := handlerName(handler) + e.router.Add(method, path, HandlerFunc(func(c Context) error { + for _, m := range middleware { + h = wrapMiddleware(m).Handle(h) + } + return h.Handle(c) + }), e) r := Route{ Method: method, Path: path, - Handler: runtime.FuncForPC(reflect.ValueOf(h).Pointer()).Name(), + Handler: name, } e.router.routes = append(e.router.routes, r) } -// Index serves index file. -func (e *Echo) Index(file string) { - e.ServeFile("/", file) -} - -// Favicon serves the default favicon - GET /favicon.ico. -func (e *Echo) Favicon(file string) { - e.ServeFile("/favicon.ico", file) -} - -// Static serves static files from a directory. It's an alias for `Echo.ServeDir` -func (e *Echo) Static(path, dir string) { - e.ServeDir(path, dir) -} - -// ServeDir serves files from a directory. -func (e *Echo) ServeDir(path, dir string) { - e.Get(path+"*", func(c Context) error { - return e.serveFile(dir, c.P(0), c) // Param `_*` - }) -} - -// ServeFile serves a file. -func (e *Echo) ServeFile(path, file string) { - e.Get(path, func(c Context) error { - dir, file := filepath.Split(file) - return e.serveFile(dir, file, c) - }) -} - -func (e *Echo) serveFile(dir, file string, c Context) (err error) { - fs := http.Dir(dir) - f, err := fs.Open(file) - if err != nil { - return NewHTTPError(http.StatusNotFound) - } - defer f.Close() - - fi, _ := f.Stat() - if fi.IsDir() { - /* NOTE: - Not checking the Last-Modified header as it caches the response `304` when - changing differnt directories for the same path. - */ - d := f - - // Index file - file = path.Join(file, indexPage) - f, err = fs.Open(file) - if err != nil { - if e.autoIndex { - // Auto index - return listDir(d, c) - } - return NewHTTPError(http.StatusForbidden) - } - fi, _ = f.Stat() // Index file stat - } - c.Response().WriteHeader(http.StatusOK) - io.Copy(c.Response(), f) - // TODO: - // http.ServeContent(c.Response(), c.Request(), fi.Name(), fi.ModTime(), f) +// Group creates a new sub-router with prefix. +func (e *Echo) Group(prefix string, middleware ...interface{}) (g *Group) { + g = &Group{prefix: prefix, echo: e} + g.Use(middleware...) return } -func listDir(d http.File, c Context) (err error) { - dirs, err := d.Readdir(-1) - if err != nil { - return err - } - - // Create directory index - w := c.Response() - w.Header().Set(ContentType, TextHTMLCharsetUTF8) - fmt.Fprintf(w, "
\n") - for _, d := range dirs { - name := d.Name() - color := "#212121" - if d.IsDir() { - color = "#e91e63" - name += "/" - } - fmt.Fprintf(w, "%s\n", name, color, name) - } - fmt.Fprintf(w, "\n") - return -} - -// Group creates a new sub router with prefix. It inherits all properties from -// the parent. Passing middleware overrides parent middleware. -func (e *Echo) Group(prefix string, m ...MiddlewareFunc) *Group { - g := &Group{*e} - g.echo.prefix += prefix - if len(m) == 0 { - mw := make([]MiddlewareFunc, len(g.echo.middleware)) - copy(mw, g.echo.middleware) - g.echo.middleware = mw - } else { - g.echo.middleware = nil - g.Use(m...) - } - return g -} - // URI generates a URI from handler. -func (e *Echo) URI(h HandlerFunc, params ...interface{}) string { +func (e *Echo) URI(handler interface{}, params ...interface{}) string { uri := new(bytes.Buffer) - pl := len(params) + ln := len(params) n := 0 - hn := runtime.FuncForPC(reflect.ValueOf(h).Pointer()).Name() + name := handlerName(handler) for _, r := range e.router.routes { - if r.Handler == hn { + if r.Handler == name { for i, l := 0, len(r.Path); i < l; i++ { - if r.Path[i] == ':' && n < pl { + if r.Path[i] == ':' && n < ln { for ; i < l && r.Path[i] != '/'; i++ { } uri.WriteString(fmt.Sprintf("%v", params[n])) @@ -501,8 +400,8 @@ func (e *Echo) URI(h HandlerFunc, params ...interface{}) string { } // URL is an alias for `URI` function. -func (e *Echo) URL(h HandlerFunc, params ...interface{}) string { - return e.URI(h, params...) +func (e *Echo) URL(handler interface{}, params ...interface{}) string { + return e.URI(handler, params...) } // Routes returns the registered routes. @@ -511,37 +410,23 @@ func (e *Echo) Routes() []Route { } func (e *Echo) ServeHTTP(req engine.Request, res engine.Response) { - if e.hook != nil { - e.hook(req, res) - } - c := e.pool.Get().(*context) - h, e := e.router.Find(req.Method(), req.URL().Path(), c) - c.reset(req, res, e) + c.reset(req, res) + h := Handler(c) // Chain middleware with handler in the end for i := len(e.middleware) - 1; i >= 0; i-- { - h = e.middleware[i](h) + h = e.middleware[i].Handle(h) } // Execute chain - if err := h(c); err != nil { + if err := h.Handle(c); err != nil { e.httpErrorHandler(err, c) } e.pool.Put(c) } -// Server returns the internal *http.Server. -// func (e *Echo) Server(addr string) *http.Server { -// s := &http.Server{Addr: addr, Handler: e} -// // TODO: Remove in Go 1.6+ -// if e.http2 { -// http2.ConfigureServer(s, nil) -// } -// return s -// } - // Run starts the HTTP engine. func (e *Echo) Run(eng engine.Engine) { eng.SetHandler(e.ServeHTTP) @@ -575,7 +460,7 @@ func (e *HTTPError) Error() string { func (binder) Bind(r engine.Request, i interface{}) (err error) { ct := r.Header().Get(ContentType) - err = UnsupportedMediaType + err = ErrUnsupportedMediaType if strings.HasPrefix(ct, ApplicationJSON) { if err = json.NewDecoder(r.Body()).Decode(i); err != nil { err = NewHTTPError(http.StatusBadRequest, err.Error()) @@ -588,28 +473,40 @@ func (binder) Bind(r engine.Request, i interface{}) (err error) { return } -func wrapMiddleware(m interface{}) MiddlewareFunc { +func wrapMiddleware(m interface{}) Middleware { switch m := m.(type) { case Middleware: - return m.Process + return m case MiddlewareFunc: return m - case func(HandlerFunc) HandlerFunc: - return m + case func(Handler) Handler: + return MiddlewareFunc(m) default: panic("invalid middleware") } } -func wrapHandler(h interface{}) HandlerFunc { +func wrapHandler(h interface{}) Handler { switch h := h.(type) { case Handler: - return h.Handle + return h case HandlerFunc: return h case func(Context) error: - return h + return HandlerFunc(h) default: - panic("invalid handler") + panic("echo => invalid handler") + } +} + +func handlerName(h interface{}) string { + switch h := h.(type) { + case Handler: + t := reflect.TypeOf(h) + return fmt.Sprintf("%s » %s", t.PkgPath(), t.Name()) + case HandlerFunc, func(Context) error: + return runtime.FuncForPC(reflect.ValueOf(h).Pointer()).Name() + default: + panic("echo => invalid handler") } } diff --git a/echo_test.go b/echo_test.go index 002a6a2e..e84d053c 100644 --- a/echo_test.go +++ b/echo_test.go @@ -11,7 +11,6 @@ import ( "errors" - "github.com/labstack/echo/engine" "github.com/labstack/echo/test" "github.com/stretchr/testify/assert" ) @@ -41,77 +40,29 @@ func TestEcho(t *testing.T) { assert.Equal(t, http.StatusInternalServerError, rec.Status()) } -func TestEchoIndex(t *testing.T) { - e := New() - e.Index("_fixture/index.html") - c, b := request(GET, "/", e) - assert.Equal(t, http.StatusOK, c) - assert.NotEmpty(t, b) -} - -func TestEchoFavicon(t *testing.T) { - e := New() - e.Favicon("_fixture/favicon.ico") - c, b := request(GET, "/favicon.ico", e) - assert.Equal(t, http.StatusOK, c) - assert.NotEmpty(t, b) -} - -func TestEchoStatic(t *testing.T) { - e := New() - - // OK - e.Static("/images", "_fixture/images") - c, b := request(GET, "/images/walle.png", e) - assert.Equal(t, http.StatusOK, c) - assert.NotEmpty(t, b) - - // No file - e.Static("/images", "_fixture/scripts") - c, _ = request(GET, "/images/bolt.png", e) - assert.Equal(t, http.StatusNotFound, c) - - // Directory - e.Static("/images", "_fixture/images") - c, _ = request(GET, "/images", e) - assert.Equal(t, http.StatusForbidden, c) - - // Directory with index.html - e.Static("/", "_fixture") - c, r := request(GET, "/", e) - assert.Equal(t, http.StatusOK, c) - assert.Equal(t, true, strings.HasPrefix(r, "")) - - // Sub-directory with index.html - c, r = request(GET, "/folder", e) - assert.Equal(t, http.StatusOK, c) - assert.Equal(t, true, strings.HasPrefix(r, "")) - // assert.Equal(t, "sub directory", r) -} - func TestEchoMiddleware(t *testing.T) { e := New() buf := new(bytes.Buffer) - e.Use(func(h HandlerFunc) HandlerFunc { - return func(c Context) error { + e.Use(func(h Handler) Handler { + return HandlerFunc(func(c Context) error { buf.WriteString("a") - return h(c) - } + return h.Handle(c) + }) }) - e.Use(func(h HandlerFunc) HandlerFunc { - return func(c Context) error { + e.Use(func(h Handler) Handler { + return HandlerFunc(func(c Context) error { buf.WriteString("b") - return h(c) - } + return h.Handle(c) + }) }) - e.Use(func(h HandlerFunc) HandlerFunc { - return func(c Context) error { + e.Use(func(h Handler) Handler { + return HandlerFunc(func(c Context) error { buf.WriteString("c") - return h(c) - } + return h.Handle(c) + }) }) // Route @@ -125,10 +76,10 @@ func TestEchoMiddleware(t *testing.T) { assert.Equal(t, "OK", b) // Error - e.Use(func(h HandlerFunc) HandlerFunc { - return func(c Context) error { + e.Use(func(Handler) Handler { + return HandlerFunc(func(c Context) error { return errors.New("error") - } + }) }) c, b = request(GET, "/", e) assert.Equal(t, http.StatusInternalServerError, c) @@ -138,9 +89,9 @@ func TestEchoHandler(t *testing.T) { e := New() // HandlerFunc - e.Get("/ok", HandlerFunc(func(c Context) error { + e.Get("/ok", func(c Context) error { return c.String(http.StatusOK, "OK") - })) + }) c, b := request(GET, "/ok", e) assert.Equal(t, http.StatusOK, c) @@ -208,7 +159,6 @@ func TestEchoMatch(t *testing.T) { // JFC func TestEchoURL(t *testing.T) { e := New() - static := func(Context) error { return nil } getUser := func(Context) error { return nil } getFile := func(Context) error { return nil } @@ -248,11 +198,11 @@ func TestEchoRoutes(t *testing.T) { func TestEchoGroup(t *testing.T) { e := New() buf := new(bytes.Buffer) - e.Use(func(h HandlerFunc) HandlerFunc { - return func(c Context) error { + e.Use(func(h Handler) Handler { + return HandlerFunc(func(c Context) error { buf.WriteString("0") - return h(c) - } + return h.Handle(c) + }) }) h := func(c Context) error { return c.NoContent(http.StatusOK) @@ -266,27 +216,18 @@ func TestEchoGroup(t *testing.T) { // Group g1 := e.Group("/group1") - g1.Use(func(h HandlerFunc) HandlerFunc { - return func(c Context) error { + g1.Use(func(h Handler) Handler { + return HandlerFunc(func(c Context) error { buf.WriteString("1") - return h(c) - } + return h.Handle(c) + }) }) g1.Get("/", h) - // Group with no parent middleware - g2 := e.Group("/group2", func(h HandlerFunc) HandlerFunc { - return func(c Context) error { - buf.WriteString("2") - return h(c) - } - }) - g2.Get("/", h) - // Nested groups - g3 := e.Group("/group3") - g4 := g3.Group("/group4") - g4.Get("/", h) + g2 := e.Group("/group2") + g3 := g2.Group("/group3") + g3.Get("/", h) request(GET, "/users", e) assert.Equal(t, "0", buf.String()) @@ -296,11 +237,7 @@ func TestEchoGroup(t *testing.T) { assert.Equal(t, "01", buf.String()) buf.Reset() - request(GET, "/group2/", e) - assert.Equal(t, "2", buf.String()) - - buf.Reset() - c, _ := request(GET, "/group3/group4/", e) + c, _ := request(GET, "/group2/group3/", e) assert.Equal(t, http.StatusOK, c) } @@ -330,30 +267,6 @@ func TestEchoHTTPError(t *testing.T) { assert.Equal(t, m, he.Error()) } -func TestEchoServer(t *testing.T) { - // e := New() - // s := e.Server(":1323") - // assert.IsType(t, &http.Server{}, s) -} - -func TestEchoHook(t *testing.T) { - e := New() - e.Get("/test", func(c Context) error { - return c.NoContent(http.StatusNoContent) - }) - e.Hook(func(req engine.Request, res engine.Response) { - path := req.URL().Path() - l := len(path) - 1 - if path != "/" && path[l] == '/' { - req.URL().SetPath(path[:l]) - } - }) - req := test.NewRequest(GET, "/test/", nil) - rec := test.NewResponseRecorder() - e.ServeHTTP(req, rec) - assert.Equal(t, req.URL().Path(), "/test") -} - func testMethod(t *testing.T, method, path string, e *Echo) { m := fmt.Sprintf("%c%s", method[0], strings.ToLower(method[1:])) p := reflect.ValueOf(path) diff --git a/group.go b/group.go index c2bc9f1b..5f24b684 100644 --- a/group.go +++ b/group.go @@ -2,64 +2,72 @@ package echo type ( Group struct { - echo Echo + prefix string + middleware []Middleware + echo *Echo } ) -func (g *Group) Use(m ...MiddlewareFunc) { - for _, h := range m { - g.echo.middleware = append(g.echo.middleware, h) +func (g *Group) Use(middleware ...interface{}) { + for _, m := range middleware { + g.middleware = append(g.middleware, wrapMiddleware(m)) } } -func (g *Group) Connect(path string, h HandlerFunc) { - g.echo.Connect(path, h) +func (g *Group) Connect(path string, handler interface{}) { + g.add(CONNECT, path, handler) } -func (g *Group) Delete(path string, h HandlerFunc) { - g.echo.Delete(path, h) +func (g *Group) Delete(path string, handler interface{}) { + g.add(DELETE, path, handler) } -func (g *Group) Get(path string, h HandlerFunc) { - g.echo.Get(path, h) +func (g *Group) Get(path string, handler interface{}) { + g.add(GET, path, handler) } -func (g *Group) Head(path string, h HandlerFunc) { - g.echo.Head(path, h) +func (g *Group) Head(path string, handler interface{}) { + g.add(HEAD, path, handler) } -func (g *Group) Options(path string, h HandlerFunc) { - g.echo.Options(path, h) +func (g *Group) Options(path string, handler interface{}) { + g.add(OPTIONS, path, handler) } -func (g *Group) Patch(path string, h HandlerFunc) { - g.echo.Patch(path, h) +func (g *Group) Patch(path string, handler interface{}) { + g.add(PATCH, path, handler) } -func (g *Group) Post(path string, h HandlerFunc) { - g.echo.Post(path, h) +func (g *Group) Post(path string, handler interface{}) { + g.add(POST, path, handler) } -func (g *Group) Put(path string, h HandlerFunc) { - g.echo.Put(path, h) +func (g *Group) Put(path string, handler interface{}) { + g.add(PUT, path, handler) } -func (g *Group) Trace(path string, h HandlerFunc) { - g.echo.Trace(path, h) +func (g *Group) Trace(path string, handler interface{}) { + g.add(TRACE, path, handler) } -func (g *Group) Static(path, root string) { - g.echo.Static(path, root) +func (g *Group) Group(prefix string, middleware ...interface{}) *Group { + return g.echo.Group(prefix, middleware...) } -func (g *Group) ServeDir(path, root string) { - g.echo.ServeDir(path, root) -} - -func (g *Group) ServeFile(path, file string) { - g.echo.ServeFile(path, file) -} - -func (g *Group) Group(prefix string, m ...MiddlewareFunc) *Group { - return g.echo.Group(prefix, m...) +func (g *Group) add(method, path string, handler interface{}) { + path = g.prefix + path + h := wrapHandler(handler) + name := handlerName(handler) + g.echo.router.Add(method, path, HandlerFunc(func(c Context) error { + for i := len(g.middleware) - 1; i >= 0; i-- { + h = g.middleware[i].Handle(h) + } + return h.Handle(c) + }), g.echo) + r := Route{ + Method: method, + Path: path, + Handler: name, + } + g.echo.router.routes = append(g.echo.router.routes, r) } diff --git a/group_test.go b/group_test.go index 6fb9369f..1d02abc7 100644 --- a/group_test.go +++ b/group_test.go @@ -14,7 +14,4 @@ func TestGroup(t *testing.T) { g.Post("/", h) g.Put("/", h) g.Trace("/", h) - g.Static("/scripts", "scripts") - g.ServeDir("/scripts", "scripts") - g.ServeFile("/scripts/main.js", "scripts/main.js") } diff --git a/handler/static.go b/handler/static.go new file mode 100644 index 00000000..6a0c4a15 --- /dev/null +++ b/handler/static.go @@ -0,0 +1,92 @@ +package handler + +import ( + "fmt" + "io" + "net/http" + "path" + + "github.com/labstack/echo" +) + +type ( + Static struct { + Root string + Browse bool + Index string + } + + FaviconOptions struct { + } +) + +func (s Static) Handle(c echo.Context) error { + fs := http.Dir(s.Root) + file := c.P(0) + f, err := fs.Open(file) + if err != nil { + return echo.ErrNotFound + } + defer f.Close() + + fi, err := f.Stat() + if err != nil { + return err + } + + if fi.IsDir() { + /* NOTE: + Not checking the Last-Modified header as it caches the response `304` when + changing differnt directories for the same path. + */ + d := f + + // Index file + file = path.Join(file, s.Index) + f, err = fs.Open(file) + if err != nil { + if s.Browse { + dirs, err := d.Readdir(-1) + if err != nil { + return err + } + + // Create a directory index + w := c.Response() + w.Header().Set(echo.ContentType, echo.TextHTMLCharsetUTF8) + if _, err = fmt.Fprintf(w, "
\n"); err != nil { + return err + } + for _, d := range dirs { + name := d.Name() + color := "#212121" + if d.IsDir() { + color = "#e91e63" + name += "/" + } + if _, err = fmt.Fprintf(w, "%s\n", name, color, name); err != nil { + return err + } + } + _, err = fmt.Fprintf(w, "\n") + return err + } + return echo.ErrNotFound + } + fi, _ = f.Stat() // Index file stat + } + c.Response().WriteHeader(http.StatusOK) + io.Copy(c.Response(), f) + return nil + // TODO: + // http.ServeContent(c.Response(), c.Request(), fi.Name(), fi.ModTime(), f) +} + +// Favicon serves the default favicon - GET /favicon.ico. +func Favicon(root string, options ...FaviconOptions) echo.MiddlewareFunc { + return func(h echo.Handler) echo.Handler { + return echo.HandlerFunc(func(c echo.Context) error { + return nil + }) + } +} diff --git a/middleware/auth.go b/middleware/auth.go index 7a8a4afc..be63681e 100644 --- a/middleware/auth.go +++ b/middleware/auth.go @@ -20,8 +20,8 @@ const ( // For valid credentials it calls the next handler. // For invalid credentials, it sends "401 - Unauthorized" response. func BasicAuth(fn BasicValidateFunc) echo.MiddlewareFunc { - return func(h echo.HandlerFunc) echo.HandlerFunc { - return func(c echo.Context) error { + return func(h echo.Handler) echo.Handler { + return echo.HandlerFunc(func(c echo.Context) error { // Skip WebSocket if (c.Request().Header().Get(echo.Upgrade)) == echo.WebSocket { return nil @@ -46,6 +46,6 @@ func BasicAuth(fn BasicValidateFunc) echo.MiddlewareFunc { } c.Response().Header().Set(echo.WWWAuthenticate, basic+" realm=Restricted") return echo.NewHTTPError(http.StatusUnauthorized) - } + }) } } diff --git a/middleware/auth_test.go b/middleware/auth_test.go index 61720b1d..b20ab1ac 100644 --- a/middleware/auth_test.go +++ b/middleware/auth_test.go @@ -21,15 +21,14 @@ func TestBasicAuth(t *testing.T) { } return false } - h := func(c echo.Context) error { + h := BasicAuth(fn)(echo.HandlerFunc(func(c echo.Context) error { return c.String(http.StatusOK, "test") - } - mw := BasicAuth(fn)(h) + })) // Valid credentials auth := basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret")) req.Header().Set(echo.Authorization, auth) - assert.NoError(t, mw(c)) + assert.NoError(t, h.Handle(c)) //--------------------- // Invalid credentials @@ -38,24 +37,24 @@ func TestBasicAuth(t *testing.T) { // Incorrect password auth = basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:password")) req.Header().Set(echo.Authorization, auth) - he := mw(c).(*echo.HTTPError) + he := h.Handle(c).(*echo.HTTPError) assert.Equal(t, http.StatusUnauthorized, he.Code()) assert.Equal(t, basic+" realm=Restricted", res.Header().Get(echo.WWWAuthenticate)) // Empty Authorization header req.Header().Set(echo.Authorization, "") - he = mw(c).(*echo.HTTPError) + he = h.Handle(c).(*echo.HTTPError) assert.Equal(t, http.StatusUnauthorized, he.Code()) assert.Equal(t, basic+" realm=Restricted", res.Header().Get(echo.WWWAuthenticate)) // Invalid Authorization header auth = base64.StdEncoding.EncodeToString([]byte("invalid")) req.Header().Set(echo.Authorization, auth) - he = mw(c).(*echo.HTTPError) + he = h.Handle(c).(*echo.HTTPError) assert.Equal(t, http.StatusUnauthorized, he.Code()) assert.Equal(t, basic+" realm=Restricted", res.Header().Get(echo.WWWAuthenticate)) // WebSocket c.Request().Header().Set(echo.Upgrade, echo.WebSocket) - assert.NoError(t, mw(c)) + assert.NoError(t, h.Handle(c)) } diff --git a/middleware/compress.go b/middleware/compress.go index 3d8306b8..99026fd6 100644 --- a/middleware/compress.go +++ b/middleware/compress.go @@ -49,10 +49,9 @@ var writerPool = sync.Pool{ // Gzip returns a middleware which compresses HTTP response using gzip compression // scheme. func Gzip() echo.MiddlewareFunc { - return func(h echo.HandlerFunc) echo.HandlerFunc { + return func(h echo.Handler) echo.Handler { scheme := "gzip" - - return func(c echo.Context) error { + return echo.HandlerFunc(func(c echo.Context) error { c.Response().Header().Add(echo.Vary, echo.AcceptEncoding) if strings.Contains(c.Request().Header().Get(echo.AcceptEncoding), scheme) { w := writerPool.Get().(*gzip.Writer) @@ -69,6 +68,6 @@ func Gzip() echo.MiddlewareFunc { c.Error(err) } return nil - } + }) } } diff --git a/middleware/compress_test.go b/middleware/compress_test.go index 97b2df97..b29f6a1e 100644 --- a/middleware/compress_test.go +++ b/middleware/compress_test.go @@ -35,13 +35,12 @@ func TestGzip(t *testing.T) { req := test.NewRequest(echo.GET, "/", nil) rec := test.NewResponseRecorder() c := echo.NewContext(req, rec, e) - h := func(c echo.Context) error { + // Skip if no Accept-Encoding header + h := Gzip()(echo.HandlerFunc(func(c echo.Context) error { c.Response().Write([]byte("test")) // For Content-Type sniffing return nil - } - - // Skip if no Accept-Encoding header - Gzip()(h)(c) + })) + h.Handle(c) // assert.Equal(t, http.StatusOK, rec.Status()) assert.Equal(t, "test", rec.Body.String()) @@ -51,7 +50,7 @@ func TestGzip(t *testing.T) { c = echo.NewContext(req, rec, e) // Gzip - Gzip()(h)(c) + h.Handle(c) // assert.Equal(t, http.StatusOK, rec.Status()) assert.Equal(t, "gzip", rec.Header().Get(echo.ContentEncoding)) assert.Contains(t, rec.Header().Get(echo.ContentType), echo.TextPlain) diff --git a/middleware/log.go b/middleware/log.go index 2c5f94ce..bea444b4 100644 --- a/middleware/log.go +++ b/middleware/log.go @@ -9,8 +9,8 @@ import ( ) func Log() echo.MiddlewareFunc { - return func(h echo.HandlerFunc) echo.HandlerFunc { - return func(c echo.Context) error { + return func(h echo.Handler) echo.Handler { + return echo.HandlerFunc(func(c echo.Context) error { req := c.Request() res := c.Response() logger := c.Logger() @@ -49,6 +49,6 @@ func Log() echo.MiddlewareFunc { logger.Infof("%s %s %s %s %s %d", remoteAddr, method, path, code, stop.Sub(start), size) return nil - } + }) } } diff --git a/middleware/log_test.go b/middleware/log_test.go index 3853ad2b..1f5b7360 100644 --- a/middleware/log_test.go +++ b/middleware/log_test.go @@ -18,38 +18,37 @@ func TestLog(t *testing.T) { req := test.NewRequest(echo.GET, "/", nil) rec := test.NewResponseRecorder() c := echo.NewContext(req, rec, e) - h := func(c echo.Context) error { + h := Log()(echo.HandlerFunc(func(c echo.Context) error { return c.String(http.StatusOK, "test") - } - mw := Log()(h) + })) // Status 2xx - mw(c) + h.Handle(c) // Status 3xx rec = test.NewResponseRecorder() c = echo.NewContext(req, rec, e) - h = func(c echo.Context) error { + h = Log()(echo.HandlerFunc(func(c echo.Context) error { return c.String(http.StatusTemporaryRedirect, "test") - } - mw(c) + })) + h.Handle(c) // Status 4xx rec = test.NewResponseRecorder() c = echo.NewContext(req, rec, e) - h = func(c echo.Context) error { + h = Log()(echo.HandlerFunc(func(c echo.Context) error { return c.String(http.StatusNotFound, "test") - } - mw(c) + })) + h.Handle(c) // Status 5xx with empty path req = test.NewRequest(echo.GET, "", nil) rec = test.NewResponseRecorder() c = echo.NewContext(req, rec, e) - h = func(c echo.Context) error { + h = Log()(echo.HandlerFunc(func(c echo.Context) error { return errors.New("error") - } - mw(c) + })) + h.Handle(c) } func TestLogIPAddress(t *testing.T) { @@ -60,25 +59,24 @@ func TestLogIPAddress(t *testing.T) { buf := new(bytes.Buffer) e.Logger().(*log.Logger).SetOutput(buf) ip := "127.0.0.1" - h := func(c echo.Context) error { + h := Log()(echo.HandlerFunc(func(c echo.Context) error { return c.String(http.StatusOK, "test") - } - mw := Log()(h) + })) // With X-Real-IP req.Header().Add(echo.XRealIP, ip) - mw(c) + h.Handle(c) assert.Contains(t, buf.String(), ip) // With X-Forwarded-For buf.Reset() req.Header().Del(echo.XRealIP) req.Header().Add(echo.XForwardedFor, ip) - mw(c) + h.Handle(c) assert.Contains(t, buf.String(), ip) // with req.RemoteAddr buf.Reset() - mw(c) + h.Handle(c) assert.Contains(t, buf.String(), ip) } diff --git a/middleware/recover.go b/middleware/recover.go index 5a7bf5a4..b32a127f 100644 --- a/middleware/recover.go +++ b/middleware/recover.go @@ -11,9 +11,9 @@ import ( // Recover returns a middleware which recovers from panics anywhere in the chain // and handles the control to the centralized HTTPErrorHandler. func Recover() echo.MiddlewareFunc { - return func(h echo.HandlerFunc) echo.HandlerFunc { + return func(h echo.Handler) echo.Handler { // TODO: Provide better stack trace `https://github.com/go-errors/errors` `https://github.com/docker/libcontainer/tree/master/stacktrace` - return func(c echo.Context) error { + return echo.HandlerFunc(func(c echo.Context) error { defer func() { if err := recover(); err != nil { trace := make([]byte, 1<<16) @@ -23,6 +23,6 @@ func Recover() echo.MiddlewareFunc { } }() return h.Handle(c) - } + }) } } diff --git a/middleware/recover_test.go b/middleware/recover_test.go index b272f361..e1ee6682 100644 --- a/middleware/recover_test.go +++ b/middleware/recover_test.go @@ -15,10 +15,10 @@ func TestRecover(t *testing.T) { req := test.NewRequest(echo.GET, "/", nil) rec := test.NewResponseRecorder() c := echo.NewContext(req, rec, e) - h := func(c echo.Context) error { + h := Recover()(echo.HandlerFunc(func(c echo.Context) error { panic("test") - } - Recover()(h)(c) + })) + h.Handle(c) assert.Equal(t, http.StatusInternalServerError, rec.Status()) assert.Contains(t, rec.Body.String(), "panic recover") } diff --git a/router.go b/router.go index ae0394e6..385a063f 100644 --- a/router.go +++ b/router.go @@ -15,20 +15,19 @@ type ( ppath string pnames []string methodHandler *methodHandler - echo *Echo } kind uint8 children []*node methodHandler struct { - connect HandlerFunc - delete HandlerFunc - get HandlerFunc - head HandlerFunc - options HandlerFunc - patch HandlerFunc - post HandlerFunc - put HandlerFunc - trace HandlerFunc + connect Handler + delete Handler + get Handler + head Handler + options Handler + patch Handler + post Handler + put Handler + trace Handler } ) @@ -48,7 +47,20 @@ func NewRouter(e *Echo) *Router { } } -func (r *Router) Add(method, path string, h HandlerFunc, e *Echo) { +func (r *Router) Handle(h Handler) Handler { + return HandlerFunc(func(c Context) error { + method := c.Request().Method() + path := c.Request().URL().Path() + r.Find(method, path, c) + return h.Handle(c) + }) +} + +func (r *Router) Priority() int { + return 0 +} + +func (r *Router) Add(method, path string, h Handler, e *Echo) { ppath := path // Pristine path pnames := []string{} // Param names @@ -80,7 +92,7 @@ func (r *Router) Add(method, path string, h HandlerFunc, e *Echo) { r.insert(method, path, h, skind, ppath, pnames, e) } -func (r *Router) insert(method, path string, h HandlerFunc, t kind, ppath string, pnames []string, e *Echo) { +func (r *Router) insert(method, path string, h Handler, t kind, ppath string, pnames []string, e *Echo) { // Adjust max param l := len(pnames) if *e.maxParam < l { @@ -89,7 +101,7 @@ func (r *Router) insert(method, path string, h HandlerFunc, t kind, ppath string cn := r.tree // Current node as root if cn == nil { - panic("echo => invalid method") + panic("echo ⇛ invalid method") } search := path @@ -115,11 +127,10 @@ func (r *Router) insert(method, path string, h HandlerFunc, t kind, ppath string cn.addHandler(method, h) cn.ppath = ppath cn.pnames = pnames - cn.echo = e } } else if l < pl { // Split node - n := newNode(cn.kind, cn.prefix[l:], cn, cn.children, cn.methodHandler, cn.ppath, cn.pnames, cn.echo) + n := newNode(cn.kind, cn.prefix[l:], cn, cn.children, cn.methodHandler, cn.ppath, cn.pnames) // Reset parent node cn.kind = skind @@ -129,7 +140,6 @@ func (r *Router) insert(method, path string, h HandlerFunc, t kind, ppath string cn.methodHandler = new(methodHandler) cn.ppath = "" cn.pnames = nil - cn.echo = nil cn.addChild(n) @@ -139,10 +149,9 @@ func (r *Router) insert(method, path string, h HandlerFunc, t kind, ppath string cn.addHandler(method, h) cn.ppath = ppath cn.pnames = pnames - cn.echo = e } else { // Create child node - n = newNode(t, search[l:], cn, nil, new(methodHandler), ppath, pnames, e) + n = newNode(t, search[l:], cn, nil, new(methodHandler), ppath, pnames) n.addHandler(method, h) cn.addChild(n) } @@ -155,7 +164,7 @@ func (r *Router) insert(method, path string, h HandlerFunc, t kind, ppath string continue } // Create child node - n := newNode(t, search, cn, nil, new(methodHandler), ppath, pnames, e) + n := newNode(t, search, cn, nil, new(methodHandler), ppath, pnames) n.addHandler(method, h) cn.addChild(n) } else { @@ -164,14 +173,13 @@ func (r *Router) insert(method, path string, h HandlerFunc, t kind, ppath string cn.addHandler(method, h) cn.ppath = path cn.pnames = pnames - cn.echo = e } } return } } -func newNode(t kind, pre string, p *node, c children, mh *methodHandler, ppath string, pnames []string, e *Echo) *node { +func newNode(t kind, pre string, p *node, c children, mh *methodHandler, ppath string, pnames []string) *node { return &node{ kind: t, label: pre[0], @@ -181,7 +189,6 @@ func newNode(t kind, pre string, p *node, c children, mh *methodHandler, ppath s ppath: ppath, pnames: pnames, methodHandler: mh, - echo: e, } } @@ -216,7 +223,7 @@ func (n *node) findChildByKind(t kind) *node { return nil } -func (n *node) addHandler(method string, h HandlerFunc) { +func (n *node) addHandler(method string, h Handler) { switch method { case GET: n.methodHandler.get = h @@ -239,7 +246,7 @@ func (n *node) addHandler(method string, h HandlerFunc) { } } -func (n *node) findHandler(method string) HandlerFunc { +func (n *node) findHandler(method string) Handler { switch method { case GET: return n.methodHandler.get @@ -273,10 +280,10 @@ func (n *node) check405() HandlerFunc { return notFoundHandler } -func (r *Router) Find(method, path string, context Context) (h HandlerFunc, e *Echo) { - x := context.Object() - h = notFoundHandler - e = r.echo +func (r *Router) Find(method, path string, context Context) { + ctx := context.Object() + // h = notFoundHandler + // e = r.echo cn := r.tree // Current node as root var ( @@ -357,7 +364,7 @@ func (r *Router) Find(method, path string, context Context) (h HandlerFunc, e *E i, l := 0, len(search) for ; i < l && search[i] != '/'; i++ { } - x.pvalues[n] = search[:i] + ctx.pvalues[n] = search[:i] n++ search = search[i:] continue @@ -370,30 +377,27 @@ func (r *Router) Find(method, path string, context Context) (h HandlerFunc, e *E // Not found return } - x.pvalues[len(cn.pnames)-1] = search + ctx.pvalues[len(cn.pnames)-1] = search goto End } End: - x.path = cn.ppath - x.pnames = cn.pnames - h = cn.findHandler(method) - if cn.echo != nil { - e = cn.echo - } + ctx.path = cn.ppath + ctx.pnames = cn.pnames + ctx.handler = cn.findHandler(method) // NOTE: Slow zone... - if h == nil { - h = cn.check405() + if ctx.handler == nil { + ctx.handler = cn.check405() // Dig further for match-any, might have an empty value for *, e.g. // serving a directory. Issue #207. if cn = cn.findChildByKind(mkind); cn == nil { return } - x.pvalues[len(cn.pnames)-1] = "" - if h = cn.findHandler(method); h == nil { - h = cn.check405() + ctx.pvalues[len(cn.pnames)-1] = "" + if ctx.handler = cn.findHandler(method); ctx.handler == nil { + ctx.handler = cn.check405() } } return diff --git a/router_test.go b/router_test.go index d228ebd6..77d311c3 100644 --- a/router_test.go +++ b/router_test.go @@ -277,44 +277,38 @@ func TestRouterStatic(t *testing.T) { e := New() r := e.router path := "/folders/a/files/echo.gif" - r.Add(GET, path, func(c Context) error { + r.Add(GET, path, HandlerFunc(func(c Context) error { c.Set("path", path) return nil - }, e) + }), e) c := NewContext(nil, nil, e) - h, _ := r.Find(GET, path, c) - if assert.NotNil(t, h) { - h(c) - assert.Equal(t, path, c.Get("path")) - } + r.Find(GET, path, c) + c.Handle(c) + assert.Equal(t, path, c.Get("path")) } func TestRouterParam(t *testing.T) { e := New() r := e.router - r.Add(GET, "/users/:id", func(c Context) error { + r.Add(GET, "/users/:id", HandlerFunc(func(c Context) error { return nil - }, e) + }), e) c := NewContext(nil, nil, e) - h, _ := r.Find(GET, "/users/1", c) - if assert.NotNil(t, h) { - assert.Equal(t, "1", c.P(0)) - } + r.Find(GET, "/users/1", c) + assert.Equal(t, "1", c.P(0)) } func TestRouterTwoParam(t *testing.T) { e := New() r := e.router - r.Add(GET, "/users/:uid/files/:fid", func(Context) error { + r.Add(GET, "/users/:uid/files/:fid", HandlerFunc(func(Context) error { return nil - }, e) + }), e) c := NewContext(nil, nil, e) - h, _ := r.Find(GET, "/users/1/files/1", c) - if assert.NotNil(t, h) { - assert.Equal(t, "1", c.P(0)) - assert.Equal(t, "1", c.P(1)) - } + r.Find(GET, "/users/1/files/1", c) + assert.Equal(t, "1", c.P(0)) + assert.Equal(t, "1", c.P(1)) } func TestRouterMatchAny(t *testing.T) { @@ -322,46 +316,38 @@ func TestRouterMatchAny(t *testing.T) { r := e.router // Routes - r.Add(GET, "/", func(Context) error { + r.Add(GET, "/", HandlerFunc(func(Context) error { return nil - }, e) - r.Add(GET, "/*", func(Context) error { + }), e) + r.Add(GET, "/*", HandlerFunc(func(Context) error { return nil - }, e) - r.Add(GET, "/users/*", func(Context) error { + }), e) + r.Add(GET, "/users/*", HandlerFunc(func(Context) error { return nil - }, e) + }), e) c := NewContext(nil, nil, e) - h, _ := r.Find(GET, "/", c) - if assert.NotNil(t, h) { - assert.Equal(t, "", c.P(0)) - } + r.Find(GET, "/", c) + assert.Equal(t, "", c.P(0)) - h, _ = r.Find(GET, "/download", c) - if assert.NotNil(t, h) { - assert.Equal(t, "download", c.P(0)) - } + r.Find(GET, "/download", c) + assert.Equal(t, "download", c.P(0)) - h, _ = r.Find(GET, "/users/joe", c) - if assert.NotNil(t, h) { - assert.Equal(t, "joe", c.P(0)) - } + r.Find(GET, "/users/joe", c) + assert.Equal(t, "joe", c.P(0)) } func TestRouterMicroParam(t *testing.T) { e := New() r := e.router - r.Add(GET, "/:a/:b/:c", func(c Context) error { + r.Add(GET, "/:a/:b/:c", HandlerFunc(func(c Context) error { return nil - }, e) + }), e) c := NewContext(nil, nil, e) - h, _ := r.Find(GET, "/1/2/3", c) - if assert.NotNil(t, h) { - assert.Equal(t, "1", c.P(0)) - assert.Equal(t, "2", c.P(1)) - assert.Equal(t, "3", c.P(2)) - } + r.Find(GET, "/1/2/3", c) + assert.Equal(t, "1", c.P(0)) + assert.Equal(t, "2", c.P(1)) + assert.Equal(t, "3", c.P(2)) } func TestRouterMixParamMatchAny(t *testing.T) { @@ -369,16 +355,14 @@ func TestRouterMixParamMatchAny(t *testing.T) { r := e.router // Route - r.Add(GET, "/users/:id/*", func(c Context) error { + r.Add(GET, "/users/:id/*", HandlerFunc(func(c Context) error { return nil - }, e) + }), e) c := NewContext(nil, nil, e) - h, _ := r.Find(GET, "/users/joe/comments", c) - if assert.NotNil(t, h) { - h(c) - assert.Equal(t, "joe", c.P(0)) - } + r.Find(GET, "/users/joe/comments", c) + c.Handle(c) + assert.Equal(t, "joe", c.P(0)) } func TestRouterMultiRoute(t *testing.T) { @@ -386,32 +370,29 @@ func TestRouterMultiRoute(t *testing.T) { r := e.router // Routes - r.Add(GET, "/users", func(c Context) error { + r.Add(GET, "/users", HandlerFunc(func(c Context) error { c.Set("path", "/users") return nil - }, e) - r.Add(GET, "/users/:id", func(c Context) error { + }), e) + r.Add(GET, "/users/:id", HandlerFunc(func(c Context) error { return nil - }, e) + }), e) c := NewContext(nil, nil, e) // Route > /users - h, _ := r.Find(GET, "/users", c) - if assert.NotNil(t, h) { - h(c) - assert.Equal(t, "/users", c.Get("path")) - } + r.Find(GET, "/users", c) + c.Handle(c) + assert.Equal(t, "/users", c.Get("path")) // Route > /users/:id - h, _ = r.Find(GET, "/users/1", c) - if assert.NotNil(t, h) { - assert.Equal(t, "1", c.P(0)) - } + r.Find(GET, "/users/1", c) + assert.Equal(t, "1", c.P(0)) // Route > /user - h, _ = r.Find(GET, "/user", c) - if assert.IsType(t, new(HTTPError), h(c)) { - he := h(c).(*HTTPError) + c = NewContext(nil, nil, e) + r.Find(GET, "/user", c) + if assert.IsType(t, new(HTTPError), c.Handle(c)) { + he := c.Handle(c).(*HTTPError) assert.Equal(t, http.StatusNotFound, he.code) } } @@ -421,85 +402,71 @@ func TestRouterPriority(t *testing.T) { r := e.router // Routes - r.Add(GET, "/users", func(c Context) error { + r.Add(GET, "/users", HandlerFunc(func(c Context) error { c.Set("a", 1) return nil - }, e) - r.Add(GET, "/users/new", func(c Context) error { + }), e) + r.Add(GET, "/users/new", HandlerFunc(func(c Context) error { c.Set("b", 2) return nil - }, e) - r.Add(GET, "/users/:id", func(c Context) error { + }), e) + r.Add(GET, "/users/:id", HandlerFunc(func(c Context) error { c.Set("c", 3) return nil - }, e) - r.Add(GET, "/users/dew", func(c Context) error { + }), e) + r.Add(GET, "/users/dew", HandlerFunc(func(c Context) error { c.Set("d", 4) return nil - }, e) - r.Add(GET, "/users/:id/files", func(c Context) error { + }), e) + r.Add(GET, "/users/:id/files", HandlerFunc(func(c Context) error { c.Set("e", 5) return nil - }, e) - r.Add(GET, "/users/newsee", func(c Context) error { + }), e) + r.Add(GET, "/users/newsee", HandlerFunc(func(c Context) error { c.Set("f", 6) return nil - }, e) - r.Add(GET, "/users/*", func(c Context) error { + }), e) + r.Add(GET, "/users/*", HandlerFunc(func(c Context) error { c.Set("g", 7) return nil - }, e) + }), e) c := NewContext(nil, nil, e) // Route > /users - h, _ := r.Find(GET, "/users", c) - if assert.NotNil(t, h) { - h(c) - assert.Equal(t, 1, c.Get("a")) - } + r.Find(GET, "/users", c) + c.Handle(c) + assert.Equal(t, 1, c.Get("a")) // Route > /users/new - h, _ = r.Find(GET, "/users/new", c) - if assert.NotNil(t, h) { - h(c) - assert.Equal(t, 2, c.Get("b")) - } + r.Find(GET, "/users/new", c) + c.Handle(c) + assert.Equal(t, 2, c.Get("b")) // Route > /users/:id - h, _ = r.Find(GET, "/users/1", c) - if assert.NotNil(t, h) { - h(c) - assert.Equal(t, 3, c.Get("c")) - } + r.Find(GET, "/users/1", c) + c.Handle(c) + assert.Equal(t, 3, c.Get("c")) // Route > /users/dew - h, _ = r.Find(GET, "/users/dew", c) - if assert.NotNil(t, h) { - h(c) - assert.Equal(t, 4, c.Get("d")) - } + r.Find(GET, "/users/dew", c) + c.Handle(c) + assert.Equal(t, 4, c.Get("d")) // Route > /users/:id/files - h, _ = r.Find(GET, "/users/1/files", c) - if assert.NotNil(t, h) { - h(c) - assert.Equal(t, 5, c.Get("e")) - } + r.Find(GET, "/users/1/files", c) + c.Handle(c) + assert.Equal(t, 5, c.Get("e")) // Route > /users/:id - h, _ = r.Find(GET, "/users/news", c) - if assert.NotNil(t, h) { - h(c) - assert.Equal(t, 3, c.Get("c")) - } + r.Find(GET, "/users/news", c) + c.Handle(c) + assert.Equal(t, 3, c.Get("c")) // Route > /users/* - h, _ = r.Find(GET, "/users/joe/books", c) - if assert.NotNil(t, h) { - h(c) - assert.Equal(t, 7, c.Get("g")) - assert.Equal(t, "joe/books", c.Param("_*")) - } + r.Find(GET, "/users/joe/books", c) + c.Handle(c) + assert.Equal(t, 7, c.Get("g")) + assert.Equal(t, "joe/books", c.Param("_*")) } func TestRouterParamNames(t *testing.T) { @@ -507,40 +474,34 @@ func TestRouterParamNames(t *testing.T) { r := e.router // Routes - r.Add(GET, "/users", func(c Context) error { + r.Add(GET, "/users", HandlerFunc(func(c Context) error { c.Set("path", "/users") return nil - }, e) - r.Add(GET, "/users/:id", func(c Context) error { + }), e) + r.Add(GET, "/users/:id", HandlerFunc(func(c Context) error { return nil - }, e) - r.Add(GET, "/users/:uid/files/:fid", func(c Context) error { + }), e) + r.Add(GET, "/users/:uid/files/:fid", HandlerFunc(func(c Context) error { return nil - }, e) + }), e) c := NewContext(nil, nil, e) // Route > /users - h, _ := r.Find(GET, "/users", c) - if assert.NotNil(t, h) { - h(c) - assert.Equal(t, "/users", c.Get("path")) - } + r.Find(GET, "/users", c) + c.Handle(c) + assert.Equal(t, "/users", c.Get("path")) // Route > /users/:id - h, _ = r.Find(GET, "/users/1", c) - if assert.NotNil(t, h) { - assert.Equal(t, "id", c.Object().pnames[0]) - assert.Equal(t, "1", c.P(0)) - } + r.Find(GET, "/users/1", c) + assert.Equal(t, "id", c.Object().pnames[0]) + assert.Equal(t, "1", c.P(0)) // Route > /users/:uid/files/:fid - h, _ = r.Find(GET, "/users/1/files/1", c) - if assert.NotNil(t, h) { - assert.Equal(t, "uid", c.Object().pnames[0]) - assert.Equal(t, "1", c.P(0)) - assert.Equal(t, "fid", c.Object().pnames[1]) - assert.Equal(t, "1", c.P(1)) - } + r.Find(GET, "/users/1/files/1", c) + assert.Equal(t, "uid", c.Object().pnames[0]) + assert.Equal(t, "1", c.P(0)) + assert.Equal(t, "fid", c.Object().pnames[1]) + assert.Equal(t, "1", c.P(1)) } func TestRouterAPI(t *testing.T) { @@ -548,21 +509,19 @@ func TestRouterAPI(t *testing.T) { r := e.router for _, route := range api { - r.Add(route.Method, route.Path, func(c Context) error { + r.Add(route.Method, route.Path, HandlerFunc(func(c Context) error { return nil - }, e) + }), e) } c := NewContext(nil, nil, e) for _, route := range api { - h, _ := r.Find(route.Method, route.Path, c) - if assert.NotNil(t, h) { - for i, n := range c.Object().pnames { - if assert.NotEmpty(t, n) { - assert.Equal(t, ":"+n, c.P(i)) - } + r.Find(route.Method, route.Path, c) + for i, n := range c.Object().pnames { + if assert.NotEmpty(t, n) { + assert.Equal(t, ":"+n, c.P(i)) } - h(c) } + c.Handle(c) } }