diff --git a/router.go b/router.go index 8aa304f9..03a99fc4 100644 --- a/router.go +++ b/router.go @@ -13,11 +13,11 @@ type ( prefix string parent *node children children - // pchild *node // Param child - // cchild *node // Catch-all child - handler HandlerFunc - pnames []string - echo *Echo + pchild *node // Param child + cchild *node // Catch-all child + handler HandlerFunc + pnames []string + echo *Echo } ntype uint8 children []*node @@ -96,11 +96,18 @@ func (r *router) insert(method, path string, h HandlerFunc, t ntype, pnames []st // Split node n := newNode(t, cn.prefix[l:], cn, cn.children, cn.handler, cn.pnames, cn.echo) cn.children = children{n} // Add to parent + // if n.typ == ptype { + // cn.pchild = n + // } else if n.typ == ctype { + // cn.cchild = n + // } // Reset parent node cn.typ = stype cn.label = cn.prefix[0] cn.prefix = cn.prefix[:l] + // cn.pchild = nil + // cn.cchild = nil cn.handler = nil cn.pnames = nil cn.echo = nil @@ -115,6 +122,11 @@ func (r *router) insert(method, path string, h HandlerFunc, t ntype, pnames []st // Create child node n = newNode(t, search[l:], cn, children{}, h, pnames, echo) cn.children = append(cn.children, n) + // if n.typ == ptype { + // cn.pchild = n + // } else if n.typ == ctype { + // cn.cchild = n + // } } } else if l < sl { search = search[l:] @@ -127,6 +139,11 @@ func (r *router) insert(method, path string, h HandlerFunc, t ntype, pnames []st // Create child node n := newNode(t, search, cn, children{}, h, pnames, echo) cn.children = append(cn.children, n) + // if n.typ == ptype { + // cn.pchild = n + // } else if n.typ == ctype { + // cn.cchild = n + // } } else { // Node already exists if h != nil { @@ -139,8 +156,8 @@ func (r *router) insert(method, path string, h HandlerFunc, t ntype, pnames []st } } -func newNode(t ntype, pfx string, p *node, c children, h HandlerFunc, pnames []string, echo *Echo) (n *node) { - n = &node{ +func newNode(t ntype, pfx string, p *node, c children, h HandlerFunc, pnames []string, echo *Echo) *node { + return &node{ typ: t, label: pfx[0], prefix: pfx, @@ -150,7 +167,9 @@ func newNode(t ntype, pfx string, p *node, c children, h HandlerFunc, pnames []s pnames: pnames, echo: echo, } - return +} + +func (n *node) addChild(c *node) { } func (n *node) findChild(l byte) *node { @@ -201,20 +220,22 @@ func lcp(a, b string) (i int) { return } -func (r *router) Find(method, path string, c *Context) (h HandlerFunc, echo *Echo) { +func (r *router) Find(method, path string, ctx *Context) (h HandlerFunc, echo *Echo) { cn := r.trees[method] // Current node as root search := path - chn := new(node) // Child node - n := 0 // Param counter + c := new(node) // Child node + n := 0 // Param counter // Search order static > param > catch-all for { if search == "" || search == cn.prefix { - // Found - h = cn.handler - c.pnames = cn.pnames - echo = cn.echo - return + if cn.handler != nil { + // Found + h = cn.handler + ctx.pnames = cn.pnames + echo = cn.echo + return + } } pl := len(cn.prefix) @@ -226,32 +247,40 @@ func (r *router) Find(method, path string, c *Context) (h HandlerFunc, echo *Ech goto Up } + // Check for catch-all with empty string + if len(search) == 0 { + goto CatchAll + } + // Static node - chn = cn.findSchild(search[0]) - if chn != nil { - cn = chn + c = cn.findSchild(search[0]) + if c != nil { + cn = c continue } // Param node Param: - chn = cn.findPchild() - if chn != nil { - cn = chn + c = cn.findPchild() + // c = cn.pchild + if c != nil { + cn = c i, l := 0, len(search) for ; i < l && search[i] != '/'; i++ { } - c.pvalues[n] = search[:i] + ctx.pvalues[n] = search[:i] n++ search = search[i:] continue } // Catch-all node - chn = cn.findCchild() - if chn != nil { - cn = chn - c.pvalues[n] = search + CatchAll: + // c = cn.cchild + c = cn.findCchild() + if c != nil { + cn = c + ctx.pvalues[n] = search search = "" // End search continue } diff --git a/router_test.go b/router_test.go index d0a32b65..74867202 100644 --- a/router_test.go +++ b/router_test.go @@ -344,7 +344,15 @@ func TestRouterCatchAll(t *testing.T) { return nil }, nil) - h, _ := r.Find(GET, "/users/joe", context) + h, _ := r.Find(GET, "/users/", context) + if h == nil { + t.Fatal("handler not found") + } + if context.pvalues[0] != "" { + t.Error("value should be joe") + } + + h, _ = r.Find(GET, "/users/joe", context) if h == nil { t.Fatal("handler not found") }