diff --git a/router.go b/router.go index 5c32bb1a..c24fc35b 100644 --- a/router.go +++ b/router.go @@ -10,8 +10,9 @@ type ( node struct { label byte prefix string - handler HandlerFunc + parent *node edges edges + handler HandlerFunc echo *Echo } edges []*node @@ -74,7 +75,7 @@ func (r *router) insert(method, path string, h HandlerFunc, echo *Echo) { } } else if l < pl { // Split node - n := newNode(cn.prefix[l:], cn.handler, cn.edges, cn.echo) + n := newNode(cn.prefix[l:], cn, cn.edges, cn.handler, cn.echo) cn.edges = edges{n} // Add to parent // Reset parent node @@ -89,7 +90,7 @@ func (r *router) insert(method, path string, h HandlerFunc, echo *Echo) { cn.echo = echo } else { // Create child node - n = newNode(search[l:], h, edges{}, echo) + n = newNode(search[l:], cn, edges{}, h, echo) cn.edges = append(cn.edges, n) } } else if l < sl { @@ -101,7 +102,7 @@ func (r *router) insert(method, path string, h HandlerFunc, echo *Echo) { continue } // Create child node - n := newNode(search, h, edges{}, echo) + n := newNode(search, cn, edges{}, h, echo) cn.edges = append(cn.edges, n) } else { // Node already exists @@ -114,12 +115,13 @@ func (r *router) insert(method, path string, h HandlerFunc, echo *Echo) { } } -func newNode(pfx string, h HandlerFunc, e edges, echo *Echo) (n *node) { +func newNode(pfx string, p *node, e edges, h HandlerFunc, echo *Echo) (n *node) { n = &node{ label: pfx[0], prefix: pfx, - handler: h, + parent: p, edges: e, + handler: h, echo: echo, } return @@ -152,6 +154,7 @@ func (r *router) Find(method, path string, params Params) (h HandlerFunc, echo * n := 0 // Param count // Search order static > param > catch-all + // TODO: do we need continue??? for { if search == "" || search == cn.prefix { // Fix me // Found @@ -174,15 +177,15 @@ func (r *router) Find(method, path string, params Params) (h HandlerFunc, echo * } // Param node + param: e = cn.findEdge(':') if e != nil { cn = e i, l := 0, len(search) for ; i < l && search[i] != '/'; i++ { } - p := params[:n+1] - p[n].Name = cn.prefix[1:] - p[n].Value = search[:i] + params[n].Name = cn.prefix[1:] + params[n].Value = search[:i] n++ search = search[i:] continue @@ -199,8 +202,13 @@ func (r *router) Find(method, path string, params Params) (h HandlerFunc, echo * continue } - // Not found - return + cn = cn.parent + if cn == nil { + // Not found + return + } + // Search backwards + goto param } } diff --git a/router_test.go b/router_test.go index 09a32be9..aa757558 100644 --- a/router_test.go +++ b/router_test.go @@ -1,6 +1,7 @@ package echo import ( + "bytes" "fmt" "net/http" "net/http/httptest" @@ -280,91 +281,142 @@ var ( func TestRouterStatic(t *testing.T) { r := New().Router - r.Add(GET, "/folders/files/echo.gif", func(*Context) {}, nil) - h, _ := r.Find(GET, "/folders/files/echo.gif", params) + b := new(bytes.Buffer) + path := "/folders/a/files/echo.gif" + r.Add(GET, path, func(*Context) { + b.WriteString(path) + }, nil) + h, _ := r.Find(GET, path, params) if h == nil { - t.Fatal("handle not found") + t.Fatal("handler not found") + } + h(nil) + if b.String() != path { + t.Errorf("buffer should %s", path) } } func TestRouterParam(t *testing.T) { r := New().Router - r.Add(GET, "/users/:id", func(c *Context) { - if c.P(0) != "1" { - t.Error("param id should be 1") - } - }, nil) - h, _ := r.Find(GET, "/users/1", make(Params, 5)) + r.Add(GET, "/users/:id", func(c *Context) {}, nil) + + h, _ := r.Find(GET, "/users/1", params) if h == nil { - t.Fatal("handle not found") + t.Fatal("handler not found") + } + if params[0].Value != "1" { + t.Error("param id should be 1") } } func TestRouterTwoParam(t *testing.T) { r := New().Router - r.Add(GET, "/users/:uid/files/:fid", func(c *Context) { - if c.P(0) != "1" { - t.Error("param uid should be 1") - } - if c.P(1) != "1" { - t.Error("param fid should be 1") - } - }, nil) + r.Add(GET, "/users/:uid/files/:fid", func(*Context) {}, nil) h, _ := r.Find(GET, "/users/1/files/1", params) if h == nil { - t.Fatal("handle not found") + t.Fatal("handler not found") + } + if params[0].Value != "1" { + t.Error("param uid should be 1") + } + if params[1].Value != "1" { + t.Error("param fid should be 1") } } func TestRouterCatchAll(t *testing.T) { r := New().Router r.Add(GET, "/static/*", func(*Context) {}, nil) - h, _ := r.Find(GET, "/static/*", params) + h, _ := r.Find(GET, "/static/echo.gif", params) if h == nil { - t.Fatal("handle not found") + t.Fatal("handler not found") + } + if params[0].Value != "echo.gif" { + t.Error("value should be echo.gif") } } func TestRouterMicroParam(t *testing.T) { r := New().Router - r.Add(GET, "/:a/:b/:c", func(c *Context) { - if c.P(0) != "1" { - t.Error("param a should be 1") - } - if c.P(1) != "2" { - t.Error("param b should be 2") - } - if c.P(2) != "3" { - t.Error("param c should be 3") - } - }, nil) + r.Add(GET, "/:a/:b/:c", func(c *Context) {}, nil) h, _ := r.Find(GET, "/1/2/3", params) if h == nil { - t.Fatal("handle not found") + t.Fatal("handler not found") + } + if params[0].Value != "1" { + t.Error("param a should be 1") + } + if params[1].Value != "2" { + t.Error("param b should be 2") + } + if params[2].Value != "3" { + t.Error("param c should be 3") } } -func TestRouterConflict(t *testing.T) { +func TestRouterConflictingRoute(t *testing.T) { r := New().Router - r.Add(GET, "/new", func(*Context) { - println("/new") + b := new(bytes.Buffer) + path := "/new" + + r.Add(GET, path, func(*Context) { + b.WriteString(path) }, nil) - r.Add(GET, "/new/:id", func(*Context) { - println("/new/:id") + h, _ := r.Find(GET, path, params) + if h == nil { + t.Fatal("handler not found") + } + h(nil) + if b.String() != path { + t.Errorf("buffer should be %s", path) + } + + name := "joe" + r.Add(GET, "/new/:id", func(c *Context) {}, nil) + h, _ = r.Find(GET, "/new/"+name, params) + if h == nil { + t.Fatal("handler not found") + } + if params[0].Value != name { + t.Errorf("param id should be %s", name) + } + + path = "/new/name" + r.Add(GET, path, func(*Context) { + b.Reset() + b.WriteString(path) }, nil) - r.Add(GET, "/new/name", func(*Context) { - println("/new/name") + h, _ = r.Find(GET, path, params) + if h == nil { + t.Fatal("handler not found") + } + h(nil) + if b.String() != path { + t.Errorf("buffer should be %s", path) + } + + r.Add(GET, "/new/name/:id", func(c *Context) {}, nil) + h, _ = r.Find(GET, "/new/name/"+name, params) + if h == nil { + t.Fatal("handler not found") + } + if params[0].Value != name { + t.Errorf("param id should be %s", name) + } + + path = "/new/name/joe" + r.Add(GET, path, func(c *Context) { + b.Reset() + b.WriteString(path) }, nil) - r.Add(GET, "/new/name/joe", func(*Context) { - println("/new/name/joe") - }, nil) - r.Add(GET, "/new/name/:id", func(*Context) { - println("/new/name/:id") - }, nil) - // h, _ := r.Find(GET, "/users/new", params) - // h(&Context{}) - n := r.trees[GET] - n.printTree("", true) + h, _ = r.Find(GET, "/new/name/joe", params) + if h == nil { + t.Fatal("handler not found") + } + h(nil) + if b.String() != path { + t.Errorf("buffer should be %s", path) + } } func TestRouterAPI(t *testing.T) { @@ -379,6 +431,7 @@ func TestRouterAPI(t *testing.T) { } } }, nil) + h, _ := r.Find(route.method, route.path, params) if h == nil { t.Errorf("handler not found, method=%s, path=%s", route.method, route.path) @@ -403,7 +456,7 @@ func TestRouterServeHTTP(t *testing.T) { func (n *node) printTree(pfx string, tail bool) { p := prefix(tail, pfx, "└── ", "├── ") - fmt.Printf("%s%s has=%d, echo=%v\n", p, n.prefix, n.handler, n.echo) + fmt.Printf("%s%s, %p: parent=%p, handler=%v, echo=%v\n", p, n.prefix, n, n.parent, n.handler, n.echo) nodes := n.edges l := len(nodes)