mirror of
https://github.com/labstack/echo.git
synced 2024-12-24 20:14:31 +02:00
Fix adding route with host overwrites default host route with same method+path in list of routes.
This commit is contained in:
parent
895121d178
commit
f1cf1ec930
44
echo.go
44
echo.go
@ -37,7 +37,6 @@ Learn more at https://echo.labstack.com
|
||||
package echo
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
stdContext "context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
@ -528,20 +527,13 @@ func (e *Echo) File(path, file string, m ...MiddlewareFunc) *Route {
|
||||
}
|
||||
|
||||
func (e *Echo) add(host, method, path string, handler HandlerFunc, middleware ...MiddlewareFunc) *Route {
|
||||
name := handlerName(handler)
|
||||
router := e.findRouter(host)
|
||||
// FIXME: when handler+middleware are both nil ... make it behave like handler removal
|
||||
router.Add(method, path, func(c Context) error {
|
||||
//FIXME: when handler+middleware are both nil ... make it behave like handler removal
|
||||
name := handlerName(handler)
|
||||
return router.add(method, path, name, func(c Context) error {
|
||||
h := applyMiddleware(handler, middleware...)
|
||||
return h(c)
|
||||
})
|
||||
r := &Route{
|
||||
Method: method,
|
||||
Path: path,
|
||||
Name: name,
|
||||
}
|
||||
e.router.routes[method+path] = r
|
||||
return r
|
||||
}
|
||||
|
||||
// Add registers a new route for an HTTP method and path with matching handler
|
||||
@ -578,35 +570,13 @@ func (e *Echo) URL(h HandlerFunc, params ...interface{}) string {
|
||||
|
||||
// Reverse generates an URL from route name and provided parameters.
|
||||
func (e *Echo) Reverse(name string, params ...interface{}) string {
|
||||
uri := new(bytes.Buffer)
|
||||
ln := len(params)
|
||||
n := 0
|
||||
for _, r := range e.router.routes {
|
||||
if r.Name == name {
|
||||
for i, l := 0, len(r.Path); i < l; i++ {
|
||||
if (r.Path[i] == ':' || r.Path[i] == '*') && n < ln {
|
||||
for ; i < l && r.Path[i] != '/'; i++ {
|
||||
}
|
||||
uri.WriteString(fmt.Sprintf("%v", params[n]))
|
||||
n++
|
||||
}
|
||||
if i < l {
|
||||
uri.WriteByte(r.Path[i])
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
return uri.String()
|
||||
return e.router.Reverse(name, params...)
|
||||
}
|
||||
|
||||
// Routes returns the registered routes.
|
||||
// Routes returns the registered routes for default router.
|
||||
// In case when Echo serves multiple hosts/domains use `e.Routers()["domain2.site"].Routes()` to get specific host routes.
|
||||
func (e *Echo) Routes() []*Route {
|
||||
routes := make([]*Route, 0, len(e.router.routes))
|
||||
for _, v := range e.router.routes {
|
||||
routes = append(routes, v)
|
||||
}
|
||||
return routes
|
||||
return e.router.Routes()
|
||||
}
|
||||
|
||||
// AcquireContext returns an empty `Context` instance from the pool.
|
||||
|
90
echo_test.go
90
echo_test.go
@ -530,9 +530,9 @@ func TestEchoRoutes(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestEchoRoutesHandleHostsProperly(t *testing.T) {
|
||||
func TestEchoRoutesHandleAdditionalHosts(t *testing.T) {
|
||||
e := New()
|
||||
h := e.Host("route.com")
|
||||
domain2Router := e.Host("domain2.router.com")
|
||||
routes := []*Route{
|
||||
{http.MethodGet, "/users/:user/events", ""},
|
||||
{http.MethodGet, "/users/:user/events/public", ""},
|
||||
@ -540,24 +540,61 @@ func TestEchoRoutesHandleHostsProperly(t *testing.T) {
|
||||
{http.MethodPost, "/repos/:owner/:repo/git/tags", ""},
|
||||
}
|
||||
for _, r := range routes {
|
||||
h.Add(r.Method, r.Path, func(c Context) error {
|
||||
domain2Router.Add(r.Method, r.Path, func(c Context) error {
|
||||
return c.String(http.StatusOK, "OK")
|
||||
})
|
||||
}
|
||||
e.Add(http.MethodGet, "/api", func(c Context) error {
|
||||
return c.String(http.StatusOK, "OK")
|
||||
})
|
||||
|
||||
if assert.Equal(t, len(routes), len(e.Routes())) {
|
||||
for _, r := range e.Routes() {
|
||||
found := false
|
||||
for _, rr := range routes {
|
||||
if r.Method == rr.Method && r.Path == rr.Path {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
domain2Routes := e.Routers()["domain2.router.com"].Routes()
|
||||
|
||||
assert.Len(t, domain2Routes, len(routes))
|
||||
for _, r := range domain2Routes {
|
||||
found := false
|
||||
for _, rr := range routes {
|
||||
if r.Method == rr.Method && r.Path == rr.Path {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
if !found {
|
||||
t.Errorf("Route %s %s not found", r.Method, r.Path)
|
||||
}
|
||||
if !found {
|
||||
t.Errorf("Route %s %s not found", r.Method, r.Path)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestEchoRoutesHandleDefaultHost(t *testing.T) {
|
||||
e := New()
|
||||
routes := []*Route{
|
||||
{http.MethodGet, "/users/:user/events", ""},
|
||||
{http.MethodGet, "/users/:user/events/public", ""},
|
||||
{http.MethodPost, "/repos/:owner/:repo/git/refs", ""},
|
||||
{http.MethodPost, "/repos/:owner/:repo/git/tags", ""},
|
||||
}
|
||||
for _, r := range routes {
|
||||
e.Add(r.Method, r.Path, func(c Context) error {
|
||||
return c.String(http.StatusOK, "OK")
|
||||
})
|
||||
}
|
||||
e.Host("subdomain.mysite.site").Add(http.MethodGet, "/api", func(c Context) error {
|
||||
return c.String(http.StatusOK, "OK")
|
||||
})
|
||||
|
||||
defaultRouterRoutes := e.Routes()
|
||||
assert.Len(t, defaultRouterRoutes, len(routes))
|
||||
for _, r := range defaultRouterRoutes {
|
||||
found := false
|
||||
for _, rr := range routes {
|
||||
if r.Method == rr.Method && r.Path == rr.Path {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Errorf("Route %s %s not found", r.Method, r.Path)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -1468,14 +1505,27 @@ func TestEchoReverseHandleHostProperly(t *testing.T) {
|
||||
dummyHandler := func(Context) error { return nil }
|
||||
|
||||
e := New()
|
||||
h := e.Host("the_host")
|
||||
h.GET("/static", dummyHandler).Name = "/static"
|
||||
h.GET("/static/*", dummyHandler).Name = "/static/*"
|
||||
|
||||
assert.Equal(t, "/static", e.Reverse("/static"))
|
||||
assert.Equal(t, "/static", e.Reverse("/static", "missing param"))
|
||||
assert.Equal(t, "/static/*", e.Reverse("/static/*"))
|
||||
assert.Equal(t, "/static/foo.txt", e.Reverse("/static/*", "foo.txt"))
|
||||
// routes added to the default router are different form different hosts
|
||||
e.GET("/static", dummyHandler).Name = "default-host /static"
|
||||
e.GET("/static/*", dummyHandler).Name = "xxx"
|
||||
|
||||
// different host
|
||||
h := e.Host("the_host")
|
||||
h.GET("/static", dummyHandler).Name = "host2 /static"
|
||||
h.GET("/static/v2/*", dummyHandler).Name = "xxx"
|
||||
|
||||
assert.Equal(t, "/static", e.Reverse("default-host /static"))
|
||||
// when actual route does not have params and we provide some to Reverse we should get that route url back
|
||||
assert.Equal(t, "/static", e.Reverse("default-host /static", "missing param"))
|
||||
|
||||
host2Router := e.Routers()["the_host"]
|
||||
assert.Equal(t, "/static", host2Router.Reverse("host2 /static"))
|
||||
assert.Equal(t, "/static", host2Router.Reverse("host2 /static", "missing param"))
|
||||
|
||||
assert.Equal(t, "/static/v2/*", host2Router.Reverse("xxx"))
|
||||
assert.Equal(t, "/static/v2/foo.txt", host2Router.Reverse("xxx", "foo.txt"))
|
||||
|
||||
}
|
||||
|
||||
func TestEcho_ListenerAddr(t *testing.T) {
|
||||
|
46
router.go
46
router.go
@ -2,6 +2,7 @@ package echo
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
@ -141,6 +142,51 @@ func NewRouter(e *Echo) *Router {
|
||||
}
|
||||
}
|
||||
|
||||
// Routes returns the registered routes.
|
||||
func (r *Router) Routes() []*Route {
|
||||
routes := make([]*Route, 0, len(r.routes))
|
||||
for _, v := range r.routes {
|
||||
routes = append(routes, v)
|
||||
}
|
||||
return routes
|
||||
}
|
||||
|
||||
// Reverse generates an URL from route name and provided parameters.
|
||||
func (r *Router) Reverse(name string, params ...interface{}) string {
|
||||
uri := new(bytes.Buffer)
|
||||
ln := len(params)
|
||||
n := 0
|
||||
for _, route := range r.routes {
|
||||
if route.Name == name {
|
||||
for i, l := 0, len(route.Path); i < l; i++ {
|
||||
if (route.Path[i] == ':' || route.Path[i] == '*') && n < ln {
|
||||
for ; i < l && route.Path[i] != '/'; i++ {
|
||||
}
|
||||
uri.WriteString(fmt.Sprintf("%v", params[n]))
|
||||
n++
|
||||
}
|
||||
if i < l {
|
||||
uri.WriteByte(route.Path[i])
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
return uri.String()
|
||||
}
|
||||
|
||||
func (r *Router) add(method, path, name string, h HandlerFunc) *Route {
|
||||
r.Add(method, path, h)
|
||||
|
||||
route := &Route{
|
||||
Method: method,
|
||||
Path: path,
|
||||
Name: name,
|
||||
}
|
||||
r.routes[method+path] = route
|
||||
return route
|
||||
}
|
||||
|
||||
// Add registers a new route for method and path with matching handler.
|
||||
func (r *Router) Add(method, path string, h HandlerFunc) {
|
||||
// Validate path
|
||||
|
136
router_test.go
136
router_test.go
@ -914,19 +914,22 @@ func TestRouterParamWithSlash(t *testing.T) {
|
||||
// Searching route for "/a/c/f" should match "/a/*/f"
|
||||
// When route `4) /a/*/f` is not added then request for "/a/c/f" should match "/:e/c/f"
|
||||
//
|
||||
// +----------+
|
||||
// +-----+ "/" root +--------------------+--------------------------+
|
||||
// | +----------+ | |
|
||||
// | | |
|
||||
// +-------v-------+ +---v---------+ +-------v---+
|
||||
// | "a/" (static) +---------------+ | ":" (param) | | "*" (any) |
|
||||
// +-+----------+--+ | +-----------+-+ +-----------+
|
||||
// | | | |
|
||||
// +----------+
|
||||
// +-----+ "/" root +--------------------+--------------------------+
|
||||
// | +----------+ | |
|
||||
// | | |
|
||||
// +-------v-------+ +---v---------+ +-------v---+
|
||||
// | "a/" (static) +---------------+ | ":" (param) | | "*" (any) |
|
||||
// +-+----------+--+ | +-----------+-+ +-----------+
|
||||
// | | | |
|
||||
//
|
||||
// +---------------v+ +-- ---v------+ +------v----+ +-----v-----------+
|
||||
// | "c/d" (static) | | ":" (param) | | "*" (any) | | "/c/f" (static) |
|
||||
// +---------+------+ +--------+----+ +----------++ +-----------------+
|
||||
// | | |
|
||||
// | | |
|
||||
//
|
||||
// | | |
|
||||
// | | |
|
||||
//
|
||||
// +---------v----+ +------v--------+ +------v--------+
|
||||
// | "f" (static) | | "/c" (static) | | "/f" (static) |
|
||||
// +--------------+ +---------------+ +---------------+
|
||||
@ -998,22 +1001,22 @@ func TestRouteMultiLevelBacktracking(t *testing.T) {
|
||||
//
|
||||
// Request for "/a/c/f" should match "/:e/c/f"
|
||||
//
|
||||
// +-0,7--------+
|
||||
// | "/" (root) |----------------------------------+
|
||||
// +------------+ |
|
||||
// | | |
|
||||
// | | |
|
||||
// +-1,6-----------+ | | +-8-----------+ +------v----+
|
||||
// | "a/" (static) +<--+ +--------->+ ":" (param) | | "*" (any) |
|
||||
// +---------------+ +-------------+ +-----------+
|
||||
// | | |
|
||||
// +-2--------v-----+ +v-3,5--------+ +-9------v--------+
|
||||
// | "c/d" (static) | | ":" (param) | | "/c/f" (static) |
|
||||
// +----------------+ +-------------+ +-----------------+
|
||||
// |
|
||||
// +-4--v----------+
|
||||
// | "/c" (static) |
|
||||
// +---------------+
|
||||
// +-0,7--------+
|
||||
// | "/" (root) |----------------------------------+
|
||||
// +------------+ |
|
||||
// | | |
|
||||
// | | |
|
||||
// +-1,6-----------+ | | +-8-----------+ +------v----+
|
||||
// | "a/" (static) +<--+ +--------->+ ":" (param) | | "*" (any) |
|
||||
// +---------------+ +-------------+ +-----------+
|
||||
// | | |
|
||||
// +-2--------v-----+ +v-3,5--------+ +-9------v--------+
|
||||
// | "c/d" (static) | | ":" (param) | | "/c/f" (static) |
|
||||
// +----------------+ +-------------+ +-----------------+
|
||||
// |
|
||||
// +-4--v----------+
|
||||
// | "/c" (static) |
|
||||
// +---------------+
|
||||
func TestRouteMultiLevelBacktracking2(t *testing.T) {
|
||||
e := New()
|
||||
r := e.router
|
||||
@ -2695,6 +2698,87 @@ func TestRouterHandleMethodOptions(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouter_Routes(t *testing.T) {
|
||||
type rr struct {
|
||||
method string
|
||||
path string
|
||||
name string
|
||||
}
|
||||
var testCases = []struct {
|
||||
name string
|
||||
givenRoutes []rr
|
||||
expect []rr
|
||||
}{
|
||||
{
|
||||
name: "ok, multiple",
|
||||
givenRoutes: []rr{
|
||||
{method: http.MethodGet, path: "/static", name: "/static"},
|
||||
{method: http.MethodGet, path: "/static/*", name: "/static/*"},
|
||||
},
|
||||
expect: []rr{
|
||||
{method: http.MethodGet, path: "/static", name: "/static"},
|
||||
{method: http.MethodGet, path: "/static/*", name: "/static/*"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "ok, no routes",
|
||||
givenRoutes: []rr{},
|
||||
expect: []rr{},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
dummyHandler := func(Context) error { return nil }
|
||||
|
||||
e := New()
|
||||
route := e.router
|
||||
|
||||
for _, tmp := range tc.givenRoutes {
|
||||
route.add(tmp.method, tmp.path, tmp.name, dummyHandler)
|
||||
}
|
||||
|
||||
// Add does not add route. because of backwards compatibility we can not change this method signature
|
||||
route.Add("LOCK", "/users", handlerFunc)
|
||||
|
||||
result := route.Routes()
|
||||
assert.Len(t, result, len(tc.expect))
|
||||
for _, r := range result {
|
||||
for _, tmp := range tc.expect {
|
||||
if tmp.name == r.Name {
|
||||
assert.Equal(t, tmp.method, r.Method)
|
||||
assert.Equal(t, tmp.path, r.Path)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouter_Reverse(t *testing.T) {
|
||||
e := New()
|
||||
r := e.router
|
||||
dummyHandler := func(Context) error { return nil }
|
||||
|
||||
r.add(http.MethodGet, "/static", "/static", dummyHandler)
|
||||
r.add(http.MethodGet, "/static/*", "/static/*", dummyHandler)
|
||||
r.add(http.MethodGet, "/params/:foo", "/params/:foo", dummyHandler)
|
||||
r.add(http.MethodGet, "/params/:foo/bar/:qux", "/params/:foo/bar/:qux", dummyHandler)
|
||||
r.add(http.MethodGet, "/params/:foo/bar/:qux/*", "/params/:foo/bar/:qux/*", dummyHandler)
|
||||
|
||||
assert.Equal(t, "/static", r.Reverse("/static"))
|
||||
assert.Equal(t, "/static", r.Reverse("/static", "missing param"))
|
||||
assert.Equal(t, "/static/*", r.Reverse("/static/*"))
|
||||
assert.Equal(t, "/static/foo.txt", r.Reverse("/static/*", "foo.txt"))
|
||||
|
||||
assert.Equal(t, "/params/:foo", r.Reverse("/params/:foo"))
|
||||
assert.Equal(t, "/params/one", r.Reverse("/params/:foo", "one"))
|
||||
assert.Equal(t, "/params/:foo/bar/:qux", r.Reverse("/params/:foo/bar/:qux"))
|
||||
assert.Equal(t, "/params/one/bar/:qux", r.Reverse("/params/:foo/bar/:qux", "one"))
|
||||
assert.Equal(t, "/params/one/bar/two", r.Reverse("/params/:foo/bar/:qux", "one", "two"))
|
||||
assert.Equal(t, "/params/one/bar/two/three", r.Reverse("/params/:foo/bar/:qux/*", "one", "two", "three"))
|
||||
}
|
||||
|
||||
func TestRouterAllowHeaderForAnyOtherMethodType(t *testing.T) {
|
||||
e := New()
|
||||
r := e.router
|
||||
|
Loading…
Reference in New Issue
Block a user