diff --git a/router.go b/router.go index 74bc7659..90102a29 100644 --- a/router.go +++ b/router.go @@ -224,7 +224,12 @@ func (r *Router) insert(method, path string, t kind, rm routeMethod) { } currentNode.isLeaf = currentNode.staticChildren == nil && currentNode.paramChild == nil && currentNode.anyChild == nil } else if lcpLen < prefixLen { - // Split node + // Split node into two before we insert new node. + // This happens when we are inserting path that is submatch of any existing inserted paths. + // For example, we have node `/test` and now are about to insert `/te/*`. In that case + // 1. overlapping part is `/te` that is used as parent node + // 2. `st` is part from existing node that is not matching - it gets its own node (child to `/te`) + // 3. `/*` is the new part we are about to insert (child to `/te`) n := newNode( currentNode.kind, currentNode.prefix[lcpLen:], @@ -235,6 +240,7 @@ func (r *Router) insert(method, path string, t kind, rm routeMethod) { currentNode.paramsCount, currentNode.paramChild, currentNode.anyChild, + currentNode.notFoundHandler, ) // Update parent path for all children to new node for _, child := range currentNode.staticChildren { @@ -259,6 +265,7 @@ func (r *Router) insert(method, path string, t kind, rm routeMethod) { currentNode.anyChild = nil currentNode.isLeaf = false currentNode.isHandler = false + currentNode.notFoundHandler = nil // Only Static children could reach here currentNode.addStaticChild(n) @@ -273,7 +280,7 @@ func (r *Router) insert(method, path string, t kind, rm routeMethod) { } } else { // Create child node - n = newNode(t, search[lcpLen:], currentNode, nil, "", new(routeMethods), 0, nil, nil) + n = newNode(t, search[lcpLen:], currentNode, nil, "", new(routeMethods), 0, nil, nil, nil) if rm.handler != nil { n.addMethod(method, &rm) n.paramsCount = len(rm.pnames) @@ -292,7 +299,7 @@ func (r *Router) insert(method, path string, t kind, rm routeMethod) { continue } // Create child node - n := newNode(t, search, currentNode, nil, rm.ppath, new(routeMethods), 0, nil, nil) + n := newNode(t, search, currentNode, nil, rm.ppath, new(routeMethods), 0, nil, nil, nil) if rm.handler != nil { n.addMethod(method, &rm) n.paramsCount = len(rm.pnames) @@ -319,20 +326,32 @@ func (r *Router) insert(method, path string, t kind, rm routeMethod) { } } -func newNode(t kind, pre string, p *node, sc children, originalPath string, mh *routeMethods, paramsCount int, paramChildren, anyChildren *node) *node { +func newNode( + t kind, + pre string, + p *node, + sc children, + originalPath string, + methods *routeMethods, + paramsCount int, + paramChildren, + anyChildren *node, + notFoundHandler *routeMethod, +) *node { return &node{ - kind: t, - label: pre[0], - prefix: pre, - parent: p, - staticChildren: sc, - originalPath: originalPath, - methods: mh, - paramsCount: paramsCount, - paramChild: paramChildren, - anyChild: anyChildren, - isLeaf: sc == nil && paramChildren == nil && anyChildren == nil, - isHandler: mh.isHandler(), + kind: t, + label: pre[0], + prefix: pre, + parent: p, + staticChildren: sc, + originalPath: originalPath, + methods: methods, + paramsCount: paramsCount, + paramChild: paramChildren, + anyChild: anyChildren, + isLeaf: sc == nil && paramChildren == nil && anyChildren == nil, + isHandler: methods.isHandler(), + notFoundHandler: notFoundHandler, } } diff --git a/router_test.go b/router_test.go index 34f325d3..1b0c409b 100644 --- a/router_test.go +++ b/router_test.go @@ -1286,6 +1286,43 @@ func TestNotFoundRouteStaticKind(t *testing.T) { } } +func TestRouter_notFoundRouteWithNodeSplitting(t *testing.T) { + e := New() + r := e.router + + r.Add(http.MethodGet, "/test*", handlerHelper("ID", 0)) + r.Add(RouteNotFound, "/*", handlerHelper("ID", 1)) + r.Add(RouteNotFound, "/test", handlerHelper("ID", 2)) + + // Tree before: + // 1 `/` + // 1.1 `*` (any) ID=1 + // 1.2 `test` (static) ID=2 + // 1.2.1 `*` (any) ID=0 + + // node with path `test` has routeNotFound handler from previous Add call. Now when we insert `/te/st*` into router tree + // This means that node `test` is split into `te` and `st` nodes and new node `/st*` is inserted. + // On that split `/test` routeNotFound handler must not be lost. + r.Add(http.MethodGet, "/te/st*", handlerHelper("ID", 3)) + // Tree after: + // 1 `/` + // 1.1 `*` (any) ID=1 + // 1.2 `te` (static) + // 1.2.1 `st` (static) ID=2 + // 1.2.1.1 `*` (any) ID=0 + // 1.2.2 `/st` (static) + // 1.2.2.1 `*` (any) ID=3 + + c := e.NewContext(nil, nil).(*context) + r.Find(http.MethodPut, "/test", c) + + c.handler(c) + + testValue, _ := c.Get("ID").(int) + assert.Equal(t, 2, testValue) + assert.Equal(t, "/test", c.Path()) +} + // Issue #1509 func TestRouterParamStaticConflict(t *testing.T) { e := New()