From f1cf1ec930e388333798c133629afa18dad00241 Mon Sep 17 00:00:00 2001 From: toimtoimtoim Date: Sat, 12 Nov 2022 18:35:19 +0200 Subject: [PATCH] Fix adding route with host overwrites default host route with same method+path in list of routes. --- echo.go | 44 +++------------- echo_test.go | 90 ++++++++++++++++++++++++-------- router.go | 46 +++++++++++++++++ router_test.go | 136 +++++++++++++++++++++++++++++++++++++++---------- 4 files changed, 233 insertions(+), 83 deletions(-) diff --git a/echo.go b/echo.go index 2f54d771..fc2f556e 100644 --- a/echo.go +++ b/echo.go @@ -37,7 +37,6 @@ Learn more at https://echo.labstack.com package echo import ( - "bytes" stdContext "context" "crypto/tls" "errors" @@ -528,20 +527,13 @@ func (e *Echo) File(path, file string, m ...MiddlewareFunc) *Route { } func (e *Echo) add(host, method, path string, handler HandlerFunc, middleware ...MiddlewareFunc) *Route { - name := handlerName(handler) router := e.findRouter(host) - // FIXME: when handler+middleware are both nil ... make it behave like handler removal - router.Add(method, path, func(c Context) error { + //FIXME: when handler+middleware are both nil ... make it behave like handler removal + name := handlerName(handler) + return router.add(method, path, name, func(c Context) error { h := applyMiddleware(handler, middleware...) return h(c) }) - r := &Route{ - Method: method, - Path: path, - Name: name, - } - e.router.routes[method+path] = r - return r } // Add registers a new route for an HTTP method and path with matching handler @@ -578,35 +570,13 @@ func (e *Echo) URL(h HandlerFunc, params ...interface{}) string { // Reverse generates an URL from route name and provided parameters. func (e *Echo) Reverse(name string, params ...interface{}) string { - uri := new(bytes.Buffer) - ln := len(params) - n := 0 - for _, r := range e.router.routes { - if r.Name == name { - for i, l := 0, len(r.Path); i < l; i++ { - if (r.Path[i] == ':' || r.Path[i] == '*') && n < ln { - for ; i < l && r.Path[i] != '/'; i++ { - } - uri.WriteString(fmt.Sprintf("%v", params[n])) - n++ - } - if i < l { - uri.WriteByte(r.Path[i]) - } - } - break - } - } - return uri.String() + return e.router.Reverse(name, params...) } -// Routes returns the registered routes. +// Routes returns the registered routes for default router. +// In case when Echo serves multiple hosts/domains use `e.Routers()["domain2.site"].Routes()` to get specific host routes. func (e *Echo) Routes() []*Route { - routes := make([]*Route, 0, len(e.router.routes)) - for _, v := range e.router.routes { - routes = append(routes, v) - } - return routes + return e.router.Routes() } // AcquireContext returns an empty `Context` instance from the pool. diff --git a/echo_test.go b/echo_test.go index b0d1ccd2..25039692 100644 --- a/echo_test.go +++ b/echo_test.go @@ -530,9 +530,9 @@ func TestEchoRoutes(t *testing.T) { } } -func TestEchoRoutesHandleHostsProperly(t *testing.T) { +func TestEchoRoutesHandleAdditionalHosts(t *testing.T) { e := New() - h := e.Host("route.com") + domain2Router := e.Host("domain2.router.com") routes := []*Route{ {http.MethodGet, "/users/:user/events", ""}, {http.MethodGet, "/users/:user/events/public", ""}, @@ -540,24 +540,61 @@ func TestEchoRoutesHandleHostsProperly(t *testing.T) { {http.MethodPost, "/repos/:owner/:repo/git/tags", ""}, } for _, r := range routes { - h.Add(r.Method, r.Path, func(c Context) error { + domain2Router.Add(r.Method, r.Path, func(c Context) error { return c.String(http.StatusOK, "OK") }) } + e.Add(http.MethodGet, "/api", func(c Context) error { + return c.String(http.StatusOK, "OK") + }) - if assert.Equal(t, len(routes), len(e.Routes())) { - for _, r := range e.Routes() { - found := false - for _, rr := range routes { - if r.Method == rr.Method && r.Path == rr.Path { - found = true - break - } + domain2Routes := e.Routers()["domain2.router.com"].Routes() + + assert.Len(t, domain2Routes, len(routes)) + for _, r := range domain2Routes { + found := false + for _, rr := range routes { + if r.Method == rr.Method && r.Path == rr.Path { + found = true + break } - if !found { - t.Errorf("Route %s %s not found", r.Method, r.Path) + } + if !found { + t.Errorf("Route %s %s not found", r.Method, r.Path) + } + } +} + +func TestEchoRoutesHandleDefaultHost(t *testing.T) { + e := New() + routes := []*Route{ + {http.MethodGet, "/users/:user/events", ""}, + {http.MethodGet, "/users/:user/events/public", ""}, + {http.MethodPost, "/repos/:owner/:repo/git/refs", ""}, + {http.MethodPost, "/repos/:owner/:repo/git/tags", ""}, + } + for _, r := range routes { + e.Add(r.Method, r.Path, func(c Context) error { + return c.String(http.StatusOK, "OK") + }) + } + e.Host("subdomain.mysite.site").Add(http.MethodGet, "/api", func(c Context) error { + return c.String(http.StatusOK, "OK") + }) + + defaultRouterRoutes := e.Routes() + assert.Len(t, defaultRouterRoutes, len(routes)) + for _, r := range defaultRouterRoutes { + found := false + for _, rr := range routes { + if r.Method == rr.Method && r.Path == rr.Path { + found = true + break } } + if !found { + t.Errorf("Route %s %s not found", r.Method, r.Path) + } } } @@ -1468,14 +1505,27 @@ func TestEchoReverseHandleHostProperly(t *testing.T) { dummyHandler := func(Context) error { return nil } e := New() - h := e.Host("the_host") - h.GET("/static", dummyHandler).Name = "/static" - h.GET("/static/*", dummyHandler).Name = "/static/*" - assert.Equal(t, "/static", e.Reverse("/static")) - assert.Equal(t, "/static", e.Reverse("/static", "missing param")) - assert.Equal(t, "/static/*", e.Reverse("/static/*")) - assert.Equal(t, "/static/foo.txt", e.Reverse("/static/*", "foo.txt")) + // routes added to the default router are different form different hosts + e.GET("/static", dummyHandler).Name = "default-host /static" + e.GET("/static/*", dummyHandler).Name = "xxx" + + // different host + h := e.Host("the_host") + h.GET("/static", dummyHandler).Name = "host2 /static" + h.GET("/static/v2/*", dummyHandler).Name = "xxx" + + assert.Equal(t, "/static", e.Reverse("default-host /static")) + // when actual route does not have params and we provide some to Reverse we should get that route url back + assert.Equal(t, "/static", e.Reverse("default-host /static", "missing param")) + + host2Router := e.Routers()["the_host"] + assert.Equal(t, "/static", host2Router.Reverse("host2 /static")) + assert.Equal(t, "/static", host2Router.Reverse("host2 /static", "missing param")) + + assert.Equal(t, "/static/v2/*", host2Router.Reverse("xxx")) + assert.Equal(t, "/static/v2/foo.txt", host2Router.Reverse("xxx", "foo.txt")) + } func TestEcho_ListenerAddr(t *testing.T) { diff --git a/router.go b/router.go index 23c5bd3b..86a986a2 100644 --- a/router.go +++ b/router.go @@ -2,6 +2,7 @@ package echo import ( "bytes" + "fmt" "net/http" ) @@ -141,6 +142,51 @@ func NewRouter(e *Echo) *Router { } } +// Routes returns the registered routes. +func (r *Router) Routes() []*Route { + routes := make([]*Route, 0, len(r.routes)) + for _, v := range r.routes { + routes = append(routes, v) + } + return routes +} + +// Reverse generates an URL from route name and provided parameters. +func (r *Router) Reverse(name string, params ...interface{}) string { + uri := new(bytes.Buffer) + ln := len(params) + n := 0 + for _, route := range r.routes { + if route.Name == name { + for i, l := 0, len(route.Path); i < l; i++ { + if (route.Path[i] == ':' || route.Path[i] == '*') && n < ln { + for ; i < l && route.Path[i] != '/'; i++ { + } + uri.WriteString(fmt.Sprintf("%v", params[n])) + n++ + } + if i < l { + uri.WriteByte(route.Path[i]) + } + } + break + } + } + return uri.String() +} + +func (r *Router) add(method, path, name string, h HandlerFunc) *Route { + r.Add(method, path, h) + + route := &Route{ + Method: method, + Path: path, + Name: name, + } + r.routes[method+path] = route + return route +} + // Add registers a new route for method and path with matching handler. func (r *Router) Add(method, path string, h HandlerFunc) { // Validate path diff --git a/router_test.go b/router_test.go index a9542101..825170a3 100644 --- a/router_test.go +++ b/router_test.go @@ -914,19 +914,22 @@ func TestRouterParamWithSlash(t *testing.T) { // Searching route for "/a/c/f" should match "/a/*/f" // When route `4) /a/*/f` is not added then request for "/a/c/f" should match "/:e/c/f" // -// +----------+ -// +-----+ "/" root +--------------------+--------------------------+ -// | +----------+ | | -// | | | -// +-------v-------+ +---v---------+ +-------v---+ -// | "a/" (static) +---------------+ | ":" (param) | | "*" (any) | -// +-+----------+--+ | +-----------+-+ +-----------+ -// | | | | +// +----------+ +// +-----+ "/" root +--------------------+--------------------------+ +// | +----------+ | | +// | | | +// +-------v-------+ +---v---------+ +-------v---+ +// | "a/" (static) +---------------+ | ":" (param) | | "*" (any) | +// +-+----------+--+ | +-----------+-+ +-----------+ +// | | | | +// // +---------------v+ +-- ---v------+ +------v----+ +-----v-----------+ // | "c/d" (static) | | ":" (param) | | "*" (any) | | "/c/f" (static) | // +---------+------+ +--------+----+ +----------++ +-----------------+ -// | | | -// | | | +// +// | | | +// | | | +// // +---------v----+ +------v--------+ +------v--------+ // | "f" (static) | | "/c" (static) | | "/f" (static) | // +--------------+ +---------------+ +---------------+ @@ -998,22 +1001,22 @@ func TestRouteMultiLevelBacktracking(t *testing.T) { // // Request for "/a/c/f" should match "/:e/c/f" // -// +-0,7--------+ -// | "/" (root) |----------------------------------+ -// +------------+ | -// | | | -// | | | -// +-1,6-----------+ | | +-8-----------+ +------v----+ -// | "a/" (static) +<--+ +--------->+ ":" (param) | | "*" (any) | -// +---------------+ +-------------+ +-----------+ -// | | | -// +-2--------v-----+ +v-3,5--------+ +-9------v--------+ -// | "c/d" (static) | | ":" (param) | | "/c/f" (static) | -// +----------------+ +-------------+ +-----------------+ -// | -// +-4--v----------+ -// | "/c" (static) | -// +---------------+ +// +-0,7--------+ +// | "/" (root) |----------------------------------+ +// +------------+ | +// | | | +// | | | +// +-1,6-----------+ | | +-8-----------+ +------v----+ +// | "a/" (static) +<--+ +--------->+ ":" (param) | | "*" (any) | +// +---------------+ +-------------+ +-----------+ +// | | | +// +-2--------v-----+ +v-3,5--------+ +-9------v--------+ +// | "c/d" (static) | | ":" (param) | | "/c/f" (static) | +// +----------------+ +-------------+ +-----------------+ +// | +// +-4--v----------+ +// | "/c" (static) | +// +---------------+ func TestRouteMultiLevelBacktracking2(t *testing.T) { e := New() r := e.router @@ -2695,6 +2698,87 @@ func TestRouterHandleMethodOptions(t *testing.T) { } } +func TestRouter_Routes(t *testing.T) { + type rr struct { + method string + path string + name string + } + var testCases = []struct { + name string + givenRoutes []rr + expect []rr + }{ + { + name: "ok, multiple", + givenRoutes: []rr{ + {method: http.MethodGet, path: "/static", name: "/static"}, + {method: http.MethodGet, path: "/static/*", name: "/static/*"}, + }, + expect: []rr{ + {method: http.MethodGet, path: "/static", name: "/static"}, + {method: http.MethodGet, path: "/static/*", name: "/static/*"}, + }, + }, + { + name: "ok, no routes", + givenRoutes: []rr{}, + expect: []rr{}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + dummyHandler := func(Context) error { return nil } + + e := New() + route := e.router + + for _, tmp := range tc.givenRoutes { + route.add(tmp.method, tmp.path, tmp.name, dummyHandler) + } + + // Add does not add route. because of backwards compatibility we can not change this method signature + route.Add("LOCK", "/users", handlerFunc) + + result := route.Routes() + assert.Len(t, result, len(tc.expect)) + for _, r := range result { + for _, tmp := range tc.expect { + if tmp.name == r.Name { + assert.Equal(t, tmp.method, r.Method) + assert.Equal(t, tmp.path, r.Path) + } + } + } + }) + } +} + +func TestRouter_Reverse(t *testing.T) { + e := New() + r := e.router + dummyHandler := func(Context) error { return nil } + + r.add(http.MethodGet, "/static", "/static", dummyHandler) + r.add(http.MethodGet, "/static/*", "/static/*", dummyHandler) + r.add(http.MethodGet, "/params/:foo", "/params/:foo", dummyHandler) + r.add(http.MethodGet, "/params/:foo/bar/:qux", "/params/:foo/bar/:qux", dummyHandler) + r.add(http.MethodGet, "/params/:foo/bar/:qux/*", "/params/:foo/bar/:qux/*", dummyHandler) + + assert.Equal(t, "/static", r.Reverse("/static")) + assert.Equal(t, "/static", r.Reverse("/static", "missing param")) + assert.Equal(t, "/static/*", r.Reverse("/static/*")) + assert.Equal(t, "/static/foo.txt", r.Reverse("/static/*", "foo.txt")) + + assert.Equal(t, "/params/:foo", r.Reverse("/params/:foo")) + assert.Equal(t, "/params/one", r.Reverse("/params/:foo", "one")) + assert.Equal(t, "/params/:foo/bar/:qux", r.Reverse("/params/:foo/bar/:qux")) + assert.Equal(t, "/params/one/bar/:qux", r.Reverse("/params/:foo/bar/:qux", "one")) + assert.Equal(t, "/params/one/bar/two", r.Reverse("/params/:foo/bar/:qux", "one", "two")) + assert.Equal(t, "/params/one/bar/two/three", r.Reverse("/params/:foo/bar/:qux/*", "one", "two", "three")) +} + func TestRouterAllowHeaderForAnyOtherMethodType(t *testing.T) { e := New() r := e.router