From 690e3392d984dcbdb9f41a7915ddc0d383311974 Mon Sep 17 00:00:00 2001 From: Martti T Date: Tue, 12 Jul 2022 21:53:41 +0300 Subject: [PATCH] Add support for registering handlers for 404 routes (#2217) --- echo.go | 13 ++++ echo_test.go | 64 +++++++++++++++++ group.go | 7 ++ group_test.go | 65 +++++++++++++++++ router.go | 59 +++++++++++----- router_test.go | 185 +++++++++++++++++++++++++++++++++++++++++++++++++ 6 files changed, 375 insertions(+), 18 deletions(-) diff --git a/echo.go b/echo.go index 8829619c..5b10d586 100644 --- a/echo.go +++ b/echo.go @@ -183,6 +183,8 @@ const ( PROPFIND = "PROPFIND" // REPORT Method can be used to get information about a resource, see rfc 3253 REPORT = "REPORT" + // RouteNotFound is special method type for routes handling "route not found" (404) cases + RouteNotFound = "echo_route_not_found" ) // Headers @@ -480,6 +482,16 @@ func (e *Echo) TRACE(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { return e.Add(http.MethodTrace, path, h, m...) } +// RouteNotFound registers a special-case route which is executed when no other route is found (i.e. HTTP 404 cases) +// for current request URL. +// Path supports static and named/any parameters just like other http method is defined. Generally path is ended with +// wildcard/match-any character (`/*`, `/download/*` etc). +// +// Example: `e.RouteNotFound("/*", func(c echo.Context) error { return c.NoContent(http.StatusNotFound) })` +func (e *Echo) RouteNotFound(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { + return e.Add(RouteNotFound, path, h, m...) +} + // Any registers a new route for all HTTP methods and path with matching handler // in the router with optional route-level middleware. func (e *Echo) Any(path string, handler HandlerFunc, middleware ...MiddlewareFunc) []*Route { @@ -515,6 +527,7 @@ func (e *Echo) File(path, file string, m ...MiddlewareFunc) *Route { func (e *Echo) add(host, method, path string, handler HandlerFunc, middleware ...MiddlewareFunc) *Route { name := handlerName(handler) router := e.findRouter(host) + // FIXME: when handler+middleware are both nil ... make it behave like handler removal router.Add(method, path, func(c Context) error { h := applyMiddleware(handler, middleware...) return h(c) diff --git a/echo_test.go b/echo_test.go index 0e1e42be..64796b3b 100644 --- a/echo_test.go +++ b/echo_test.go @@ -766,6 +766,70 @@ func TestEchoNotFound(t *testing.T) { assert.Equal(t, http.StatusNotFound, rec.Code) } +func TestEcho_RouteNotFound(t *testing.T) { + var testCases = []struct { + name string + whenURL string + expectRoute interface{} + expectCode int + }{ + { + name: "404, route to static not found handler /a/c/xx", + whenURL: "/a/c/xx", + expectRoute: "GET /a/c/xx", + expectCode: http.StatusNotFound, + }, + { + name: "404, route to path param not found handler /a/:file", + whenURL: "/a/echo.exe", + expectRoute: "GET /a/:file", + expectCode: http.StatusNotFound, + }, + { + name: "404, route to any not found handler /*", + whenURL: "/b/echo.exe", + expectRoute: "GET /*", + expectCode: http.StatusNotFound, + }, + { + name: "200, route /a/c/df to /a/c/df", + whenURL: "/a/c/df", + expectRoute: "GET /a/c/df", + expectCode: http.StatusOK, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := New() + + okHandler := func(c Context) error { + return c.String(http.StatusOK, c.Request().Method+" "+c.Path()) + } + notFoundHandler := func(c Context) error { + return c.String(http.StatusNotFound, c.Request().Method+" "+c.Path()) + } + + e.GET("/", okHandler) + e.GET("/a/c/df", okHandler) + e.GET("/a/b*", okHandler) + e.PUT("/*", okHandler) + + e.RouteNotFound("/a/c/xx", notFoundHandler) // static + e.RouteNotFound("/a/:file", notFoundHandler) // param + e.RouteNotFound("/*", notFoundHandler) // any + + req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil) + rec := httptest.NewRecorder() + + e.ServeHTTP(rec, req) + + assert.Equal(t, tc.expectCode, rec.Code) + assert.Equal(t, tc.expectRoute, rec.Body.String()) + }) + } +} + func TestEchoMethodNotAllowed(t *testing.T) { e := New() diff --git a/group.go b/group.go index bba470ce..28ce0dd9 100644 --- a/group.go +++ b/group.go @@ -107,6 +107,13 @@ func (g *Group) File(path, file string) { g.file(path, file, g.GET) } +// RouteNotFound implements `Echo#RouteNotFound()` for sub-routes within the Group. +// +// Example: `g.RouteNotFound("/*", func(c echo.Context) error { return c.NoContent(http.StatusNotFound) })` +func (g *Group) RouteNotFound(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { + return g.Add(RouteNotFound, path, h, m...) +} + // Add implements `Echo#Add()` for sub-routes within the Group. func (g *Group) Add(method, path string, handler HandlerFunc, middleware ...MiddlewareFunc) *Route { // Combine into a new slice to avoid accidentally passing the same slice for diff --git a/group_test.go b/group_test.go index c51fd91e..24f19167 100644 --- a/group_test.go +++ b/group_test.go @@ -119,3 +119,68 @@ func TestGroupRouteMiddlewareWithMatchAny(t *testing.T) { assert.Equal(t, "/*", m) } + +func TestGroup_RouteNotFound(t *testing.T) { + var testCases = []struct { + name string + whenURL string + expectRoute interface{} + expectCode int + }{ + { + name: "404, route to static not found handler /group/a/c/xx", + whenURL: "/group/a/c/xx", + expectRoute: "GET /group/a/c/xx", + expectCode: http.StatusNotFound, + }, + { + name: "404, route to path param not found handler /group/a/:file", + whenURL: "/group/a/echo.exe", + expectRoute: "GET /group/a/:file", + expectCode: http.StatusNotFound, + }, + { + name: "404, route to any not found handler /group/*", + whenURL: "/group/b/echo.exe", + expectRoute: "GET /group/*", + expectCode: http.StatusNotFound, + }, + { + name: "200, route /group/a/c/df to /group/a/c/df", + whenURL: "/group/a/c/df", + expectRoute: "GET /group/a/c/df", + expectCode: http.StatusOK, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := New() + g := e.Group("/group") + + okHandler := func(c Context) error { + return c.String(http.StatusOK, c.Request().Method+" "+c.Path()) + } + notFoundHandler := func(c Context) error { + return c.String(http.StatusNotFound, c.Request().Method+" "+c.Path()) + } + + g.GET("/", okHandler) + g.GET("/a/c/df", okHandler) + g.GET("/a/b*", okHandler) + g.PUT("/*", okHandler) + + g.RouteNotFound("/a/c/xx", notFoundHandler) // static + g.RouteNotFound("/a/:file", notFoundHandler) // param + g.RouteNotFound("/*", notFoundHandler) // any + + req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil) + rec := httptest.NewRecorder() + + e.ServeHTTP(rec, req) + + assert.Equal(t, tc.expectCode, rec.Code) + assert.Equal(t, tc.expectRoute, rec.Body.String()) + }) + } +} diff --git a/router.go b/router.go index 6a8615d8..74bc7659 100644 --- a/router.go +++ b/router.go @@ -28,6 +28,9 @@ type ( isLeaf bool // isHandler indicates that node has at least one handler registered to it isHandler bool + + // notFoundHandler is handler registered with RouteNotFound method and is executed for 404 cases + notFoundHandler *routeMethod } kind uint8 children []*node @@ -73,6 +76,7 @@ func (m *routeMethods) isHandler() bool { m.put != nil || m.trace != nil || m.report != nil + // RouteNotFound/404 is not considered as a handler } func (m *routeMethods) updateAllowHeader() { @@ -382,6 +386,9 @@ func (n *node) addMethod(method string, h *routeMethod) { n.methods.trace = h case REPORT: n.methods.report = h + case RouteNotFound: + n.notFoundHandler = h + return // RouteNotFound/404 is not considered as a handler so no further logic needs to be executed } n.methods.updateAllowHeader() @@ -412,7 +419,7 @@ func (n *node) findMethod(method string) *routeMethod { return n.methods.trace case REPORT: return n.methods.report - default: + default: // RouteNotFound/404 is not considered as a handler return nil } } @@ -515,7 +522,7 @@ func (r *Router) Find(method, path string, c Context) { // No matching prefix, let's backtrack to the first possible alternative node of the decision path nk, ok := backtrackToNextNodeKind(staticKind) if !ok { - return // No other possibilities on the decision path + return // No other possibilities on the decision path, handler will be whatever context is reset to. } else if nk == paramKind { goto Param // NOTE: this case (backtracking from static node to previous any node) can not happen by current any matching logic. Any node is end of search currently @@ -531,15 +538,21 @@ func (r *Router) Find(method, path string, c Context) { search = search[lcpLen:] searchIndex = searchIndex + lcpLen - // Finish routing if no remaining search and we are on a node with handler and matching method type - if search == "" && currentNode.isHandler { - // check if current node has handler registered for http method we are looking for. we store currentNode as - // best matching in case we do no find no more routes matching this path+method - if previousBestMatchNode == nil { - previousBestMatchNode = currentNode - } - if h := currentNode.findMethod(method); h != nil { - matchedRouteMethod = h + // Finish routing if is no request path remaining to search + if search == "" { + // in case of node that is handler we have exact method type match or something for 405 to use + if currentNode.isHandler { + // check if current node has handler registered for http method we are looking for. we store currentNode as + // best matching in case we do no find no more routes matching this path+method + if previousBestMatchNode == nil { + previousBestMatchNode = currentNode + } + if h := currentNode.findMethod(method); h != nil { + matchedRouteMethod = h + break + } + } else if currentNode.notFoundHandler != nil { + matchedRouteMethod = currentNode.notFoundHandler break } } @@ -559,7 +572,8 @@ func (r *Router) Find(method, path string, c Context) { i := 0 l := len(search) if currentNode.isLeaf { - // when param node does not have any children then param node should act similarly to any node - consider all remaining search as match + // when param node does not have any children (path param is last piece of route path) then param node should + // act similarly to any node - consider all remaining search as match i = l } else { for ; i < l && search[i] != '/'; i++ { @@ -585,13 +599,16 @@ func (r *Router) Find(method, path string, c Context) { searchIndex += +len(search) search = "" - // check if current node has handler registered for http method we are looking for. we store currentNode as - // best matching in case we do no find no more routes matching this path+method + if h := currentNode.findMethod(method); h != nil { + matchedRouteMethod = h + break + } + // we store currentNode as best matching in case we do not find more routes matching this path+method. Needed for 405 if previousBestMatchNode == nil { previousBestMatchNode = currentNode } - if h := currentNode.findMethod(method); h != nil { - matchedRouteMethod = h + if currentNode.notFoundHandler != nil { + matchedRouteMethod = currentNode.notFoundHandler break } } @@ -614,12 +631,14 @@ func (r *Router) Find(method, path string, c Context) { return // nothing matched at all } + // matchedHandler could be method+path handler that we matched or notFoundHandler from node with matching path + // user provided not found (404) handler has priority over generic method not found (405) handler or global 404 handler var rPath string var rPNames []string if matchedRouteMethod != nil { - ctx.handler = matchedRouteMethod.handler rPath = matchedRouteMethod.ppath rPNames = matchedRouteMethod.pnames + ctx.handler = matchedRouteMethod.handler } else { // use previous match as basis. although we have no matching handler we have path match. // so we can send http.StatusMethodNotAllowed (405) instead of http.StatusNotFound (404) @@ -628,7 +647,11 @@ func (r *Router) Find(method, path string, c Context) { rPath = currentNode.originalPath rPNames = nil // no params here ctx.handler = NotFoundHandler - if currentNode.isHandler { + if currentNode.notFoundHandler != nil { + rPath = currentNode.notFoundHandler.ppath + rPNames = currentNode.notFoundHandler.pnames + ctx.handler = currentNode.notFoundHandler.handler + } else if currentNode.isHandler { ctx.Set(ContextKeyHeaderAllow, currentNode.methods.allowHeader) ctx.handler = MethodNotAllowedHandler if method == http.MethodOptions { diff --git a/router_test.go b/router_test.go index 8645a26c..34f325d3 100644 --- a/router_test.go +++ b/router_test.go @@ -1101,6 +1101,191 @@ func TestRouterBacktrackingFromMultipleParamKinds(t *testing.T) { } } +func TestNotFoundRouteAnyKind(t *testing.T) { + var testCases = []struct { + name string + whenURL string + expectRoute interface{} + expectID int + expectParam map[string]string + }{ + { + name: "route not existent /xx to not found handler /*", + whenURL: "/xx", + expectRoute: "/*", + expectID: 4, + expectParam: map[string]string{"*": "xx"}, + }, + { + name: "route not existent /a/xx to not found handler /a/*", + whenURL: "/a/xx", + expectRoute: "/a/*", + expectID: 5, + expectParam: map[string]string{"*": "xx"}, + }, + { + name: "route not existent /a/c/dxxx to not found handler /a/c/d*", + whenURL: "/a/c/dxxx", + expectRoute: "/a/c/d*", + expectID: 6, + expectParam: map[string]string{"*": "xxx"}, + }, + { + name: "route /a/c/df to /a/c/df", + whenURL: "/a/c/df", + expectRoute: "/a/c/df", + expectID: 1, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := New() + r := e.router + + r.Add(http.MethodGet, "/", handlerHelper("ID", 0)) + r.Add(http.MethodGet, "/a/c/df", handlerHelper("ID", 1)) + r.Add(http.MethodGet, "/a/b*", handlerHelper("ID", 2)) + r.Add(http.MethodPut, "/*", handlerHelper("ID", 3)) + + r.Add(RouteNotFound, "/a/c/d*", handlerHelper("ID", 6)) + r.Add(RouteNotFound, "/a/*", handlerHelper("ID", 5)) + r.Add(RouteNotFound, "/*", handlerHelper("ID", 4)) + + c := e.NewContext(nil, nil).(*context) + r.Find(http.MethodGet, tc.whenURL, c) + + c.handler(c) + + testValue, _ := c.Get("ID").(int) + assert.Equal(t, tc.expectID, testValue) + assert.Equal(t, tc.expectRoute, c.Path()) + for param, expectedValue := range tc.expectParam { + assert.Equal(t, expectedValue, c.Param(param)) + } + checkUnusedParamValues(t, c, tc.expectParam) + }) + } +} + +func TestNotFoundRouteParamKind(t *testing.T) { + var testCases = []struct { + name string + whenURL string + expectRoute interface{} + expectID int + expectParam map[string]string + }{ + { + name: "route not existent /xx to not found handler /:file", + whenURL: "/xx", + expectRoute: "/:file", + expectID: 4, + expectParam: map[string]string{"file": "xx"}, + }, + { + name: "route not existent /a/xx to not found handler /a/:file", + whenURL: "/a/xx", + expectRoute: "/a/:file", + expectID: 5, + expectParam: map[string]string{"file": "xx"}, + }, + { + name: "route not existent /a/c/dxxx to not found handler /a/c/d:file", + whenURL: "/a/c/dxxx", + expectRoute: "/a/c/d:file", + expectID: 6, + expectParam: map[string]string{"file": "xxx"}, + }, + { + name: "route /a/c/df to /a/c/df", + whenURL: "/a/c/df", + expectRoute: "/a/c/df", + expectID: 1, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := New() + r := e.router + + r.Add(http.MethodGet, "/", handlerHelper("ID", 0)) + r.Add(http.MethodGet, "/a/c/df", handlerHelper("ID", 1)) + r.Add(http.MethodGet, "/a/b*", handlerHelper("ID", 2)) + r.Add(http.MethodPut, "/*", handlerHelper("ID", 3)) + + r.Add(RouteNotFound, "/a/c/d:file", handlerHelper("ID", 6)) + r.Add(RouteNotFound, "/a/:file", handlerHelper("ID", 5)) + r.Add(RouteNotFound, "/:file", handlerHelper("ID", 4)) + + c := e.NewContext(nil, nil).(*context) + r.Find(http.MethodGet, tc.whenURL, c) + + c.handler(c) + + testValue, _ := c.Get("ID").(int) + assert.Equal(t, tc.expectID, testValue) + assert.Equal(t, tc.expectRoute, c.Path()) + for param, expectedValue := range tc.expectParam { + assert.Equal(t, expectedValue, c.Param(param)) + } + checkUnusedParamValues(t, c, tc.expectParam) + }) + } +} + +func TestNotFoundRouteStaticKind(t *testing.T) { + // note: static not found handler is quite silly thing to have but we still support it + var testCases = []struct { + name string + whenURL string + expectRoute interface{} + expectID int + expectParam map[string]string + }{ + { + name: "route not existent / to not found handler /", + whenURL: "/", + expectRoute: "/", + expectID: 3, + expectParam: map[string]string{}, + }, + { + name: "route /a to /a", + whenURL: "/a", + expectRoute: "/a", + expectID: 1, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := New() + r := e.router + + r.Add(http.MethodPut, "/", handlerHelper("ID", 0)) + r.Add(http.MethodGet, "/a", handlerHelper("ID", 1)) + r.Add(http.MethodPut, "/*", handlerHelper("ID", 2)) + + r.Add(RouteNotFound, "/", handlerHelper("ID", 3)) + + c := e.NewContext(nil, nil).(*context) + r.Find(http.MethodGet, tc.whenURL, c) + + c.handler(c) + + testValue, _ := c.Get("ID").(int) + assert.Equal(t, tc.expectID, testValue) + assert.Equal(t, tc.expectRoute, c.Path()) + for param, expectedValue := range tc.expectParam { + assert.Equal(t, expectedValue, c.Param(param)) + } + checkUnusedParamValues(t, c, tc.expectParam) + }) + } +} + // Issue #1509 func TestRouterParamStaticConflict(t *testing.T) { e := New()