1
0
mirror of https://github.com/labstack/echo.git synced 2025-01-12 01:22:21 +02:00

Fix adding route with host overwrites default host route with same method+path in list of routes.

This commit is contained in:
toimtoimtoim 2022-11-12 18:35:19 +02:00 committed by Martti T
parent 895121d178
commit f1cf1ec930
4 changed files with 233 additions and 83 deletions

42
echo.go
View File

@ -37,7 +37,6 @@ Learn more at https://echo.labstack.com
package echo package echo
import ( import (
"bytes"
stdContext "context" stdContext "context"
"crypto/tls" "crypto/tls"
"errors" "errors"
@ -528,20 +527,13 @@ func (e *Echo) File(path, file string, m ...MiddlewareFunc) *Route {
} }
func (e *Echo) add(host, method, path string, handler HandlerFunc, middleware ...MiddlewareFunc) *Route { func (e *Echo) add(host, method, path string, handler HandlerFunc, middleware ...MiddlewareFunc) *Route {
name := handlerName(handler)
router := e.findRouter(host) router := e.findRouter(host)
//FIXME: when handler+middleware are both nil ... make it behave like handler removal //FIXME: when handler+middleware are both nil ... make it behave like handler removal
router.Add(method, path, func(c Context) error { name := handlerName(handler)
return router.add(method, path, name, func(c Context) error {
h := applyMiddleware(handler, middleware...) h := applyMiddleware(handler, middleware...)
return h(c) return h(c)
}) })
r := &Route{
Method: method,
Path: path,
Name: name,
}
e.router.routes[method+path] = r
return r
} }
// Add registers a new route for an HTTP method and path with matching handler // Add registers a new route for an HTTP method and path with matching handler
@ -578,35 +570,13 @@ func (e *Echo) URL(h HandlerFunc, params ...interface{}) string {
// Reverse generates an URL from route name and provided parameters. // Reverse generates an URL from route name and provided parameters.
func (e *Echo) Reverse(name string, params ...interface{}) string { func (e *Echo) Reverse(name string, params ...interface{}) string {
uri := new(bytes.Buffer) return e.router.Reverse(name, params...)
ln := len(params)
n := 0
for _, r := range e.router.routes {
if r.Name == name {
for i, l := 0, len(r.Path); i < l; i++ {
if (r.Path[i] == ':' || r.Path[i] == '*') && n < ln {
for ; i < l && r.Path[i] != '/'; i++ {
}
uri.WriteString(fmt.Sprintf("%v", params[n]))
n++
}
if i < l {
uri.WriteByte(r.Path[i])
}
}
break
}
}
return uri.String()
} }
// Routes returns the registered routes. // Routes returns the registered routes for default router.
// In case when Echo serves multiple hosts/domains use `e.Routers()["domain2.site"].Routes()` to get specific host routes.
func (e *Echo) Routes() []*Route { func (e *Echo) Routes() []*Route {
routes := make([]*Route, 0, len(e.router.routes)) return e.router.Routes()
for _, v := range e.router.routes {
routes = append(routes, v)
}
return routes
} }
// AcquireContext returns an empty `Context` instance from the pool. // AcquireContext returns an empty `Context` instance from the pool.

View File

@ -530,9 +530,9 @@ func TestEchoRoutes(t *testing.T) {
} }
} }
func TestEchoRoutesHandleHostsProperly(t *testing.T) { func TestEchoRoutesHandleAdditionalHosts(t *testing.T) {
e := New() e := New()
h := e.Host("route.com") domain2Router := e.Host("domain2.router.com")
routes := []*Route{ routes := []*Route{
{http.MethodGet, "/users/:user/events", ""}, {http.MethodGet, "/users/:user/events", ""},
{http.MethodGet, "/users/:user/events/public", ""}, {http.MethodGet, "/users/:user/events/public", ""},
@ -540,13 +540,18 @@ func TestEchoRoutesHandleHostsProperly(t *testing.T) {
{http.MethodPost, "/repos/:owner/:repo/git/tags", ""}, {http.MethodPost, "/repos/:owner/:repo/git/tags", ""},
} }
for _, r := range routes { for _, r := range routes {
h.Add(r.Method, r.Path, func(c Context) error { domain2Router.Add(r.Method, r.Path, func(c Context) error {
return c.String(http.StatusOK, "OK") return c.String(http.StatusOK, "OK")
}) })
} }
e.Add(http.MethodGet, "/api", func(c Context) error {
return c.String(http.StatusOK, "OK")
})
if assert.Equal(t, len(routes), len(e.Routes())) { domain2Routes := e.Routers()["domain2.router.com"].Routes()
for _, r := range e.Routes() {
assert.Len(t, domain2Routes, len(routes))
for _, r := range domain2Routes {
found := false found := false
for _, rr := range routes { for _, rr := range routes {
if r.Method == rr.Method && r.Path == rr.Path { if r.Method == rr.Method && r.Path == rr.Path {
@ -559,6 +564,38 @@ func TestEchoRoutesHandleHostsProperly(t *testing.T) {
} }
} }
} }
func TestEchoRoutesHandleDefaultHost(t *testing.T) {
e := New()
routes := []*Route{
{http.MethodGet, "/users/:user/events", ""},
{http.MethodGet, "/users/:user/events/public", ""},
{http.MethodPost, "/repos/:owner/:repo/git/refs", ""},
{http.MethodPost, "/repos/:owner/:repo/git/tags", ""},
}
for _, r := range routes {
e.Add(r.Method, r.Path, func(c Context) error {
return c.String(http.StatusOK, "OK")
})
}
e.Host("subdomain.mysite.site").Add(http.MethodGet, "/api", func(c Context) error {
return c.String(http.StatusOK, "OK")
})
defaultRouterRoutes := e.Routes()
assert.Len(t, defaultRouterRoutes, len(routes))
for _, r := range defaultRouterRoutes {
found := false
for _, rr := range routes {
if r.Method == rr.Method && r.Path == rr.Path {
found = true
break
}
}
if !found {
t.Errorf("Route %s %s not found", r.Method, r.Path)
}
}
} }
func TestEchoServeHTTPPathEncoding(t *testing.T) { func TestEchoServeHTTPPathEncoding(t *testing.T) {
@ -1468,14 +1505,27 @@ func TestEchoReverseHandleHostProperly(t *testing.T) {
dummyHandler := func(Context) error { return nil } dummyHandler := func(Context) error { return nil }
e := New() e := New()
h := e.Host("the_host")
h.GET("/static", dummyHandler).Name = "/static"
h.GET("/static/*", dummyHandler).Name = "/static/*"
assert.Equal(t, "/static", e.Reverse("/static")) // routes added to the default router are different form different hosts
assert.Equal(t, "/static", e.Reverse("/static", "missing param")) e.GET("/static", dummyHandler).Name = "default-host /static"
assert.Equal(t, "/static/*", e.Reverse("/static/*")) e.GET("/static/*", dummyHandler).Name = "xxx"
assert.Equal(t, "/static/foo.txt", e.Reverse("/static/*", "foo.txt"))
// different host
h := e.Host("the_host")
h.GET("/static", dummyHandler).Name = "host2 /static"
h.GET("/static/v2/*", dummyHandler).Name = "xxx"
assert.Equal(t, "/static", e.Reverse("default-host /static"))
// when actual route does not have params and we provide some to Reverse we should get that route url back
assert.Equal(t, "/static", e.Reverse("default-host /static", "missing param"))
host2Router := e.Routers()["the_host"]
assert.Equal(t, "/static", host2Router.Reverse("host2 /static"))
assert.Equal(t, "/static", host2Router.Reverse("host2 /static", "missing param"))
assert.Equal(t, "/static/v2/*", host2Router.Reverse("xxx"))
assert.Equal(t, "/static/v2/foo.txt", host2Router.Reverse("xxx", "foo.txt"))
} }
func TestEcho_ListenerAddr(t *testing.T) { func TestEcho_ListenerAddr(t *testing.T) {

View File

@ -2,6 +2,7 @@ package echo
import ( import (
"bytes" "bytes"
"fmt"
"net/http" "net/http"
) )
@ -141,6 +142,51 @@ func NewRouter(e *Echo) *Router {
} }
} }
// Routes returns the registered routes.
func (r *Router) Routes() []*Route {
routes := make([]*Route, 0, len(r.routes))
for _, v := range r.routes {
routes = append(routes, v)
}
return routes
}
// Reverse generates an URL from route name and provided parameters.
func (r *Router) Reverse(name string, params ...interface{}) string {
uri := new(bytes.Buffer)
ln := len(params)
n := 0
for _, route := range r.routes {
if route.Name == name {
for i, l := 0, len(route.Path); i < l; i++ {
if (route.Path[i] == ':' || route.Path[i] == '*') && n < ln {
for ; i < l && route.Path[i] != '/'; i++ {
}
uri.WriteString(fmt.Sprintf("%v", params[n]))
n++
}
if i < l {
uri.WriteByte(route.Path[i])
}
}
break
}
}
return uri.String()
}
func (r *Router) add(method, path, name string, h HandlerFunc) *Route {
r.Add(method, path, h)
route := &Route{
Method: method,
Path: path,
Name: name,
}
r.routes[method+path] = route
return route
}
// Add registers a new route for method and path with matching handler. // Add registers a new route for method and path with matching handler.
func (r *Router) Add(method, path string, h HandlerFunc) { func (r *Router) Add(method, path string, h HandlerFunc) {
// Validate path // Validate path

View File

@ -922,11 +922,14 @@ func TestRouterParamWithSlash(t *testing.T) {
// | "a/" (static) +---------------+ | ":" (param) | | "*" (any) | // | "a/" (static) +---------------+ | ":" (param) | | "*" (any) |
// +-+----------+--+ | +-----------+-+ +-----------+ // +-+----------+--+ | +-----------+-+ +-----------+
// | | | | // | | | |
//
// +---------------v+ +-- ---v------+ +------v----+ +-----v-----------+ // +---------------v+ +-- ---v------+ +------v----+ +-----v-----------+
// | "c/d" (static) | | ":" (param) | | "*" (any) | | "/c/f" (static) | // | "c/d" (static) | | ":" (param) | | "*" (any) | | "/c/f" (static) |
// +---------+------+ +--------+----+ +----------++ +-----------------+ // +---------+------+ +--------+----+ +----------++ +-----------------+
//
// | | | // | | |
// | | | // | | |
//
// +---------v----+ +------v--------+ +------v--------+ // +---------v----+ +------v--------+ +------v--------+
// | "f" (static) | | "/c" (static) | | "/f" (static) | // | "f" (static) | | "/c" (static) | | "/f" (static) |
// +--------------+ +---------------+ +---------------+ // +--------------+ +---------------+ +---------------+
@ -2695,6 +2698,87 @@ func TestRouterHandleMethodOptions(t *testing.T) {
} }
} }
func TestRouter_Routes(t *testing.T) {
type rr struct {
method string
path string
name string
}
var testCases = []struct {
name string
givenRoutes []rr
expect []rr
}{
{
name: "ok, multiple",
givenRoutes: []rr{
{method: http.MethodGet, path: "/static", name: "/static"},
{method: http.MethodGet, path: "/static/*", name: "/static/*"},
},
expect: []rr{
{method: http.MethodGet, path: "/static", name: "/static"},
{method: http.MethodGet, path: "/static/*", name: "/static/*"},
},
},
{
name: "ok, no routes",
givenRoutes: []rr{},
expect: []rr{},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
dummyHandler := func(Context) error { return nil }
e := New()
route := e.router
for _, tmp := range tc.givenRoutes {
route.add(tmp.method, tmp.path, tmp.name, dummyHandler)
}
// Add does not add route. because of backwards compatibility we can not change this method signature
route.Add("LOCK", "/users", handlerFunc)
result := route.Routes()
assert.Len(t, result, len(tc.expect))
for _, r := range result {
for _, tmp := range tc.expect {
if tmp.name == r.Name {
assert.Equal(t, tmp.method, r.Method)
assert.Equal(t, tmp.path, r.Path)
}
}
}
})
}
}
func TestRouter_Reverse(t *testing.T) {
e := New()
r := e.router
dummyHandler := func(Context) error { return nil }
r.add(http.MethodGet, "/static", "/static", dummyHandler)
r.add(http.MethodGet, "/static/*", "/static/*", dummyHandler)
r.add(http.MethodGet, "/params/:foo", "/params/:foo", dummyHandler)
r.add(http.MethodGet, "/params/:foo/bar/:qux", "/params/:foo/bar/:qux", dummyHandler)
r.add(http.MethodGet, "/params/:foo/bar/:qux/*", "/params/:foo/bar/:qux/*", dummyHandler)
assert.Equal(t, "/static", r.Reverse("/static"))
assert.Equal(t, "/static", r.Reverse("/static", "missing param"))
assert.Equal(t, "/static/*", r.Reverse("/static/*"))
assert.Equal(t, "/static/foo.txt", r.Reverse("/static/*", "foo.txt"))
assert.Equal(t, "/params/:foo", r.Reverse("/params/:foo"))
assert.Equal(t, "/params/one", r.Reverse("/params/:foo", "one"))
assert.Equal(t, "/params/:foo/bar/:qux", r.Reverse("/params/:foo/bar/:qux"))
assert.Equal(t, "/params/one/bar/:qux", r.Reverse("/params/:foo/bar/:qux", "one"))
assert.Equal(t, "/params/one/bar/two", r.Reverse("/params/:foo/bar/:qux", "one", "two"))
assert.Equal(t, "/params/one/bar/two/three", r.Reverse("/params/:foo/bar/:qux/*", "one", "two", "three"))
}
func TestRouterAllowHeaderForAnyOtherMethodType(t *testing.T) { func TestRouterAllowHeaderForAnyOtherMethodType(t *testing.T) {
e := New() e := New()
r := e.router r := e.router