diff --git a/middleware/middleware.go b/middleware/middleware.go index 60834b50..8381e3a5 100644 --- a/middleware/middleware.go +++ b/middleware/middleware.go @@ -38,7 +38,7 @@ func rewriteRulesRegex(rewrite map[string]string) map[*regexp.Regexp]string { rulesRegex := map[*regexp.Regexp]string{} for k, v := range rewrite { k = regexp.QuoteMeta(k) - k = strings.Replace(k, `\*`, "(.*)", -1) + k = strings.Replace(k, `\*`, "(.*?)", -1) if strings.HasPrefix(k, `\^`) { k = strings.Replace(k, `\^`, "^", -1) } diff --git a/middleware/proxy.go b/middleware/proxy.go index 1b972eb1..63eec5a2 100644 --- a/middleware/proxy.go +++ b/middleware/proxy.go @@ -36,6 +36,13 @@ type ( // "/users/*/orders/*": "/user/$1/order/$2", Rewrite map[string]string + // RegexRewrite defines rewrite rules using regexp.Rexexp with captures + // Every capture group in the values can be retrieved by index e.g. $1, $2 and so on. + // Example: + // "^/old/[0.9]+/": "/new", + // "^/api/.+?/(.*)": "/v2/$1", + RegexRewrite map[*regexp.Regexp]string + // Context key to store selected ProxyTarget into context. // Optional. Default value "target". ContextKey string @@ -46,8 +53,6 @@ type ( // ModifyResponse defines function to modify response from ProxyTarget. ModifyResponse func(*http.Response) error - - rewriteRegex map[*regexp.Regexp]string } // ProxyTarget defines the upstream target. @@ -206,7 +211,14 @@ func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc { panic("echo: proxy middleware requires balancer") } - config.rewriteRegex = rewriteRulesRegex(config.Rewrite) + if config.Rewrite != nil { + if config.RegexRewrite == nil { + config.RegexRewrite = make(map[*regexp.Regexp]string) + } + for k, v := range rewriteRulesRegex(config.Rewrite) { + config.RegexRewrite[k] = v + } + } return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) (err error) { @@ -220,7 +232,7 @@ func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc { c.Set(config.ContextKey, tgt) // Set rewrite path and raw path - rewritePath(config.rewriteRegex, req) + rewritePath(config.RegexRewrite, req) // Fix header // Basically it's not good practice to unconditionally pass incoming x-real-ip header to upstream. @@ -251,5 +263,3 @@ func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc { } } } - - diff --git a/middleware/proxy_test.go b/middleware/proxy_test.go index 534e45f4..ec6f1925 100644 --- a/middleware/proxy_test.go +++ b/middleware/proxy_test.go @@ -8,6 +8,7 @@ import ( "net/http" "net/http/httptest" "net/url" + "regexp" "testing" "github.com/labstack/echo/v4" @@ -83,46 +84,6 @@ func TestProxy(t *testing.T) { body = rec.Body.String() assert.Equal(t, "target 2", body) - // Rewrite - e = echo.New() - e.Use(ProxyWithConfig(ProxyConfig{ - Balancer: rrb, - Rewrite: map[string]string{ - "/old": "/new", - "/api/*": "/$1", - "/js/*": "/public/javascripts/$1", - "/users/*/orders/*": "/user/$1/order/$2", - }, - })) - req.URL, _ = url.Parse("/api/users") - rec = httptest.NewRecorder() - e.ServeHTTP(rec, req) - assert.Equal(t, "/users", req.URL.EscapedPath()) - assert.Equal(t, http.StatusOK, rec.Code) - req.URL, _ = url.Parse( "/js/main.js") - rec = httptest.NewRecorder() - e.ServeHTTP(rec, req) - assert.Equal(t, "/public/javascripts/main.js", req.URL.EscapedPath()) - assert.Equal(t, http.StatusOK, rec.Code) - req.URL, _ = url.Parse("/old") - rec = httptest.NewRecorder() - e.ServeHTTP(rec, req) - assert.Equal(t, "/new", req.URL.EscapedPath()) - assert.Equal(t, http.StatusOK, rec.Code) - req.URL, _ = url.Parse( "/users/jack/orders/1") - rec = httptest.NewRecorder() - e.ServeHTTP(rec, req) - assert.Equal(t, "/user/jack/order/1", req.URL.EscapedPath()) - assert.Equal(t, http.StatusOK, rec.Code) - req.URL, _ = url.Parse("/user/jill/order/T%2FcO4lW%2Ft%2FVp%2F") - rec = httptest.NewRecorder() - e.ServeHTTP(rec, req) - assert.Equal(t, "/user/jill/order/T%2FcO4lW%2Ft%2FVp%2F", req.URL.EscapedPath()) - assert.Equal(t, http.StatusOK, rec.Code) - req.URL, _ = url.Parse("/api/new users") - rec = httptest.NewRecorder() - e.ServeHTTP(rec, req) - assert.Equal(t, "/new%20users", req.URL.EscapedPath()) // ModifyResponse e = echo.New() e.Use(ProxyWithConfig(ProxyConfig{ @@ -196,3 +157,104 @@ func TestProxyRealIPHeader(t *testing.T) { assert.Equal(t, tt.extectedXRealIP, req.Header.Get(echo.HeaderXRealIP), "hasRealIPheader: %t / hasIPExtractor: %t", tt.hasRealIPheader, tt.hasIPExtractor) } } + +func TestProxyRewrite(t *testing.T) { + // Setup + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) + defer upstream.Close() + url, _ := url.Parse(upstream.URL) + rrb := NewRoundRobinBalancer([]*ProxyTarget{{Name: "upstream", URL: url}}) + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + + // Rewrite + e := echo.New() + e.Use(ProxyWithConfig(ProxyConfig{ + Balancer: rrb, + Rewrite: map[string]string{ + "/old": "/new", + "/api/*": "/$1", + "/js/*": "/public/javascripts/$1", + "/users/*/orders/*": "/user/$1/order/$2", + }, + })) + req.URL, _ = url.Parse("/api/users") + rec = httptest.NewRecorder() + e.ServeHTTP(rec, req) + assert.Equal(t, "/users", req.URL.EscapedPath()) + assert.Equal(t, http.StatusOK, rec.Code) + req.URL, _ = url.Parse("/js/main.js") + rec = httptest.NewRecorder() + e.ServeHTTP(rec, req) + assert.Equal(t, "/public/javascripts/main.js", req.URL.EscapedPath()) + assert.Equal(t, http.StatusOK, rec.Code) + req.URL, _ = url.Parse("/old") + rec = httptest.NewRecorder() + e.ServeHTTP(rec, req) + assert.Equal(t, "/new", req.URL.EscapedPath()) + assert.Equal(t, http.StatusOK, rec.Code) + req.URL, _ = url.Parse("/users/jack/orders/1") + rec = httptest.NewRecorder() + e.ServeHTTP(rec, req) + assert.Equal(t, "/user/jack/order/1", req.URL.EscapedPath()) + assert.Equal(t, http.StatusOK, rec.Code) + req.URL, _ = url.Parse("/user/jill/order/T%2FcO4lW%2Ft%2FVp%2F") + rec = httptest.NewRecorder() + e.ServeHTTP(rec, req) + assert.Equal(t, "/user/jill/order/T%2FcO4lW%2Ft%2FVp%2F", req.URL.EscapedPath()) + assert.Equal(t, http.StatusOK, rec.Code) + req.URL, _ = url.Parse("/api/new users") + rec = httptest.NewRecorder() + e.ServeHTTP(rec, req) + assert.Equal(t, "/new%20users", req.URL.EscapedPath()) +} + +func TestProxyRewriteRegex(t *testing.T) { + // Setup + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) + defer upstream.Close() + url, _ := url.Parse(upstream.URL) + rrb := NewRoundRobinBalancer([]*ProxyTarget{{Name: "upstream", URL: url}}) + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + + // Rewrite + e := echo.New() + e.Use(ProxyWithConfig(ProxyConfig{ + Balancer: rrb, + Rewrite: map[string]string{ + "^/a/*": "/v1/$1", + "^/b/*/c/*": "/v2/$2/$1", + "^/c/*/*": "/v3/$2", + }, + RegexRewrite: map[*regexp.Regexp]string{ + regexp.MustCompile("^/x/.+?/(.*)"): "/v4/$1", + regexp.MustCompile("^/y/(.+?)/(.*)"): "/v5/$2/$1", + }, + })) + + testCases := []struct { + requestPath string + statusCode int + expectPath string + }{ + {"/unmatched", http.StatusOK, "/unmatched"}, + {"/a/test", http.StatusOK, "/v1/test"}, + {"/b/foo/c/bar/baz", http.StatusOK, "/v2/bar/baz/foo"}, + {"/c/ignore/test", http.StatusOK, "/v3/test"}, + {"/c/ignore1/test/this", http.StatusOK, "/v3/test/this"}, + {"/x/ignore/test", http.StatusOK, "/v4/test"}, + {"/y/foo/bar", http.StatusOK, "/v5/bar/foo"}, + } + + + for _, tc := range testCases { + t.Run(tc.requestPath, func(t *testing.T) { + req.URL, _ = url.Parse(tc.requestPath) + rec = httptest.NewRecorder() + e.ServeHTTP(rec, req) + assert.Equal(t, tc.expectPath, req.URL.EscapedPath()) + assert.Equal(t, tc.statusCode, rec.Code) + }) + } +} diff --git a/middleware/rewrite.go b/middleware/rewrite.go index 0965e313..c05d5d84 100644 --- a/middleware/rewrite.go +++ b/middleware/rewrite.go @@ -1,8 +1,9 @@ package middleware import ( - "github.com/labstack/echo/v4" "regexp" + + "github.com/labstack/echo/v4" ) type ( @@ -21,7 +22,12 @@ type ( // Required. Rules map[string]string `yaml:"rules"` - rulesRegex map[*regexp.Regexp]string + // RegexRules defines the URL path rewrite rules using regexp.Rexexp with captures + // Every capture group in the values can be retrieved by index e.g. $1, $2 and so on. + // Example: + // "^/old/[0.9]+/": "/new", + // "^/api/.+?/(.*)": "/v2/$1", + RegexRules map[*regexp.Regexp]string `yaml:"regex_rules"` } ) @@ -45,14 +51,20 @@ func Rewrite(rules map[string]string) echo.MiddlewareFunc { // See: `Rewrite()`. func RewriteWithConfig(config RewriteConfig) echo.MiddlewareFunc { // Defaults - if config.Rules == nil { - panic("echo: rewrite middleware requires url path rewrite rules") + if config.Rules == nil && config.RegexRules == nil { + panic("echo: rewrite middleware requires url path rewrite rules or regex rules") } + if config.Skipper == nil { config.Skipper = DefaultBodyDumpConfig.Skipper } - config.rulesRegex = rewriteRulesRegex(config.Rules) + if config.RegexRules == nil { + config.RegexRules = make(map[*regexp.Regexp]string) + } + for k, v := range rewriteRulesRegex(config.Rules) { + config.RegexRules[k] = v + } return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) (err error) { @@ -62,7 +74,7 @@ func RewriteWithConfig(config RewriteConfig) echo.MiddlewareFunc { req := c.Request() // Set rewrite path and raw path - rewritePath(config.rulesRegex, req) + rewritePath(config.RegexRules, req) return next(c) } } diff --git a/middleware/rewrite_test.go b/middleware/rewrite_test.go index abf11b2f..351b7313 100644 --- a/middleware/rewrite_test.go +++ b/middleware/rewrite_test.go @@ -5,6 +5,7 @@ import ( "net/http" "net/http/httptest" "net/url" + "regexp" "testing" "github.com/labstack/echo/v4" @@ -55,8 +56,8 @@ func TestEchoRewritePreMiddleware(t *testing.T) { // Rewrite old url to new one e.Pre(Rewrite(map[string]string{ - "/old": "/new", - }, + "/old": "/new", + }, )) // Route @@ -129,3 +130,45 @@ func TestEchoRewriteWithCaret(t *testing.T) { e.ServeHTTP(rec, req) assert.Equal(t, "/v2/abc/test", req.URL.Path) } + +// Verify regex used with rewrite +func TestEchoRewriteWithRegexRules(t *testing.T) { + e := echo.New() + + e.Pre(RewriteWithConfig(RewriteConfig{ + Rules: map[string]string{ + "^/a/*": "/v1/$1", + "^/b/*/c/*": "/v2/$2/$1", + "^/c/*/*": "/v3/$2", + }, + RegexRules: map[*regexp.Regexp]string{ + regexp.MustCompile("^/x/.+?/(.*)"): "/v4/$1", + regexp.MustCompile("^/y/(.+?)/(.*)"): "/v5/$2/$1", + }, + })) + + var rec *httptest.ResponseRecorder + var req *http.Request + + testCases := []struct { + requestPath string + expectPath string + }{ + {"/unmatched", "/unmatched"}, + {"/a/test", "/v1/test"}, + {"/b/foo/c/bar/baz", "/v2/bar/baz/foo"}, + {"/c/ignore/test", "/v3/test"}, + {"/c/ignore1/test/this", "/v3/test/this"}, + {"/x/ignore/test", "/v4/test"}, + {"/y/foo/bar", "/v5/bar/foo"}, + } + + for _, tc := range testCases { + t.Run(tc.requestPath, func(t *testing.T) { + req = httptest.NewRequest(http.MethodGet, tc.requestPath, nil) + rec = httptest.NewRecorder() + e.ServeHTTP(rec, req) + assert.Equal(t, tc.expectPath, req.URL.EscapedPath()) + }) + } +}