diff --git a/middleware/middleware.go b/middleware/middleware.go index efcbab91..71f95db7 100644 --- a/middleware/middleware.go +++ b/middleware/middleware.go @@ -1,6 +1,12 @@ package middleware -import "github.com/labstack/echo" +import ( + "regexp" + "strconv" + "strings" + + "github.com/labstack/echo" +) type ( // Skipper defines a function to skip middleware. Returning true skips processing @@ -8,6 +14,21 @@ type ( Skipper func(c echo.Context) bool ) +func captureTokens(pattern *regexp.Regexp, input string) *strings.Replacer { + groups := pattern.FindAllStringSubmatch(input, -1) + if groups == nil { + return nil + } + values := groups[0][1:] + replace := make([]string, 2*len(values)) + for i, v := range values { + j := 2 * i + replace[j] = "$" + strconv.Itoa(i+1) + replace[j+1] = v + } + return strings.NewReplacer(replace...) +} + // DefaultSkipper returns false which processes the middleware. func DefaultSkipper(echo.Context) bool { return false diff --git a/middleware/proxy.go b/middleware/proxy.go index 0f8ca07e..ae3ff527 100644 --- a/middleware/proxy.go +++ b/middleware/proxy.go @@ -8,6 +8,8 @@ import ( "net/http" "net/http/httputil" "net/url" + "regexp" + "strings" "sync" "sync/atomic" "time" @@ -26,6 +28,17 @@ type ( // Balancer defines a load balancing technique. // Required. Balancer ProxyBalancer + + // Rewrite defines URL path rewrite rules. The values captured in asterisk can be + // retrieved by index e.g. $1, $2 and so on. + // Examples: + // "/old": "/new", + // "/api/*": "/$1", + // "/js/*": "/public/javascripts/$1", + // "/users/*/orders/*": "/user/$1/order/$2", + Rewrite map[string]string + + rewriteRegex map[*regexp.Regexp]string } // ProxyTarget defines the upstream target. @@ -187,6 +200,13 @@ func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc { if config.Balancer == nil { panic("echo: proxy middleware requires balancer") } + config.rewriteRegex = map[*regexp.Regexp]string{} + + // Initialize + for k, v := range config.Rewrite { + k = strings.Replace(k, "*", "(\\S*)", -1) + config.rewriteRegex[regexp.MustCompile(k)] = v + } return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) (err error) { @@ -198,6 +218,14 @@ func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc { res := c.Response() tgt := config.Balancer.Next() + // Rewrite + for k, v := range config.rewriteRegex { + replacer := captureTokens(k, req.URL.Path) + if replacer != nil { + req.URL.Path = replacer.Replace(v) + } + } + // Fix header if req.Header.Get(echo.HeaderXRealIP) == "" { req.Header.Set(echo.HeaderXRealIP, c.RealIP()) diff --git a/middleware/proxy_test.go b/middleware/proxy_test.go index 52034629..017bb5eb 100644 --- a/middleware/proxy_test.go +++ b/middleware/proxy_test.go @@ -84,4 +84,28 @@ func TestProxy(t *testing.T) { e.ServeHTTP(rec, req) body = rec.Body.String() assert.Equal(t, "target 2", body) + + // Rewrite + e = echo.New() + e.Pre(ProxyWithConfig(ProxyConfig{ + Balancer: rrb, + Rewrite: map[string]string{ + "/old": "/new", + "/api/*": "/$1", + "/js/*": "/public/javascripts/$1", + "/users/*/orders/*": "/user/$1/order/$2", + }, + })) + req.URL.Path = "/api/users" + e.ServeHTTP(rec, req) + assert.Equal(t, "/users", req.URL.Path) + req.URL.Path = "/js/main.js" + e.ServeHTTP(rec, req) + assert.Equal(t, "/public/javascripts/main.js", req.URL.Path) + req.URL.Path = "/old" + e.ServeHTTP(rec, req) + assert.Equal(t, "/new", req.URL.Path) + req.URL.Path = "/users/jack/orders/1" + e.ServeHTTP(rec, req) + assert.Equal(t, "/user/jack/order/1", req.URL.Path) }