1
0
mirror of https://github.com/labstack/echo.git synced 2025-01-12 01:22:21 +02:00

bugfix proxy and rewrite, updated test with actual call settings

This commit is contained in:
Arun Gopalpuri 2020-09-03 00:39:57 -07:00
parent 151ed6b3f1
commit f6dfcbe774
5 changed files with 48 additions and 75 deletions

View File

@ -2,7 +2,6 @@ package middleware
import ( import (
"net/http" "net/http"
"net/url"
"regexp" "regexp"
"strconv" "strconv"
"strings" "strings"
@ -34,15 +33,29 @@ func captureTokens(pattern *regexp.Regexp, input string) *strings.Replacer {
return strings.NewReplacer(replace...) return strings.NewReplacer(replace...)
} }
//rewritePath sets request url path and raw path func rewriteRulesRegex(rewrite map[string]string) map[*regexp.Regexp]string {
func rewritePath(replacer *strings.Replacer, target string, req *http.Request) error { // Initialize
replacerRawPath := replacer.Replace(target) rulesRegex := map[*regexp.Regexp]string{}
replacerPath, err := url.PathUnescape(replacerRawPath) for k, v := range rewrite {
if err != nil { k = regexp.QuoteMeta(k)
return err k = strings.Replace(k, `\*`, "(.*)", -1)
if strings.HasPrefix(k, `\^`) {
k = strings.Replace(k, `\^`, "^", -1)
}
k = k + "$"
rulesRegex[regexp.MustCompile(k)] = v
}
return rulesRegex
}
func rewritePath(rewriteRegex map[*regexp.Regexp]string, req *http.Request) {
for k, v := range rewriteRegex {
replacerRawPath := captureTokens(k, req.URL.EscapedPath())
if replacerRawPath != nil {
replacerPath := captureTokens(k, req.URL.Path)
req.URL.RawPath, req.URL.Path = replacerRawPath.Replace(v), replacerPath.Replace(v)
}
} }
req.URL.Path, req.URL.RawPath = replacerPath, replacerRawPath
return nil
} }
// DefaultSkipper returns false which processes the middleware. // DefaultSkipper returns false which processes the middleware.

View File

@ -8,7 +8,6 @@ import (
"net/http" "net/http"
"net/url" "net/url"
"regexp" "regexp"
"strings"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
@ -206,13 +205,8 @@ func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc {
if config.Balancer == nil { if config.Balancer == nil {
panic("echo: proxy middleware requires balancer") panic("echo: proxy middleware requires balancer")
} }
config.rewriteRegex = map[*regexp.Regexp]string{}
// Initialize config.rewriteRegex = rewriteRulesRegex(config.Rewrite)
for k, v := range config.Rewrite {
k = strings.Replace(k, "*", "(\\S*)", -1)
config.rewriteRegex[regexp.MustCompile(k)] = v
}
return func(next echo.HandlerFunc) echo.HandlerFunc { return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) (err error) { return func(c echo.Context) (err error) {
@ -225,16 +219,8 @@ func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc {
tgt := config.Balancer.Next(c) tgt := config.Balancer.Next(c)
c.Set(config.ContextKey, tgt) c.Set(config.ContextKey, tgt)
// Rewrite // Set rewrite path and raw path
for k, v := range config.rewriteRegex { rewritePath(config.rewriteRegex, req)
//use req.URL.Path here or else we will have double escaping
replacer := captureTokens(k, req.URL.Path)
if replacer != nil {
if err := rewritePath(replacer, v, req); err != nil {
return echo.NewHTTPError(http.StatusBadRequest, "invalid url")
}
}
}
// Fix header // Fix header
// Basically it's not good practice to unconditionally pass incoming x-real-ip header to upstream. // Basically it's not good practice to unconditionally pass incoming x-real-ip header to upstream.
@ -265,3 +251,5 @@ func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc {
} }
} }
} }

View File

@ -94,36 +94,35 @@ func TestProxy(t *testing.T) {
"/users/*/orders/*": "/user/$1/order/$2", "/users/*/orders/*": "/user/$1/order/$2",
}, },
})) }))
req.URL.Path = "/api/users" req.URL, _ = url.Parse("/api/users")
rec = httptest.NewRecorder() rec = httptest.NewRecorder()
e.ServeHTTP(rec, req) e.ServeHTTP(rec, req)
assert.Equal(t, "/users", req.URL.EscapedPath()) assert.Equal(t, "/users", req.URL.EscapedPath())
assert.Equal(t, http.StatusOK, rec.Code) assert.Equal(t, http.StatusOK, rec.Code)
req.URL.Path = "/js/main.js" req.URL, _ = url.Parse( "/js/main.js")
rec = httptest.NewRecorder() rec = httptest.NewRecorder()
e.ServeHTTP(rec, req) e.ServeHTTP(rec, req)
assert.Equal(t, "/public/javascripts/main.js", req.URL.EscapedPath()) assert.Equal(t, "/public/javascripts/main.js", req.URL.EscapedPath())
assert.Equal(t, http.StatusOK, rec.Code) assert.Equal(t, http.StatusOK, rec.Code)
req.URL.Path = "/old" req.URL, _ = url.Parse("/old")
rec = httptest.NewRecorder() rec = httptest.NewRecorder()
e.ServeHTTP(rec, req) e.ServeHTTP(rec, req)
assert.Equal(t, "/new", req.URL.EscapedPath()) assert.Equal(t, "/new", req.URL.EscapedPath())
assert.Equal(t, http.StatusOK, rec.Code) assert.Equal(t, http.StatusOK, rec.Code)
req.URL.Path = "/users/jack/orders/1" req.URL, _ = url.Parse( "/users/jack/orders/1")
rec = httptest.NewRecorder() rec = httptest.NewRecorder()
e.ServeHTTP(rec, req) e.ServeHTTP(rec, req)
assert.Equal(t, "/user/jack/order/1", req.URL.EscapedPath()) assert.Equal(t, "/user/jack/order/1", req.URL.EscapedPath())
assert.Equal(t, http.StatusOK, rec.Code) assert.Equal(t, http.StatusOK, rec.Code)
req.URL.Path = "/users/jill/orders/T%2FcO4lW%2Ft%2FVp%2F" req.URL, _ = url.Parse("/user/jill/order/T%2FcO4lW%2Ft%2FVp%2F")
rec = httptest.NewRecorder() rec = httptest.NewRecorder()
e.ServeHTTP(rec, req) e.ServeHTTP(rec, req)
assert.Equal(t, "/user/jill/order/T%2FcO4lW%2Ft%2FVp%2F", req.URL.EscapedPath()) assert.Equal(t, "/user/jill/order/T%2FcO4lW%2Ft%2FVp%2F", req.URL.EscapedPath())
assert.Equal(t, http.StatusOK, rec.Code) assert.Equal(t, http.StatusOK, rec.Code)
req.URL.Path = "/users/jill/orders/%%%%" req.URL, _ = url.Parse("/api/new users")
rec = httptest.NewRecorder() rec = httptest.NewRecorder()
e.ServeHTTP(rec, req) e.ServeHTTP(rec, req)
assert.Equal(t, http.StatusBadRequest, rec.Code) assert.Equal(t, "/new%20users", req.URL.EscapedPath())
// ModifyResponse // ModifyResponse
e = echo.New() e = echo.New()
e.Use(ProxyWithConfig(ProxyConfig{ e.Use(ProxyWithConfig(ProxyConfig{

View File

@ -1,11 +1,8 @@
package middleware package middleware
import ( import (
"net/http"
"regexp"
"strings"
"github.com/labstack/echo/v4" "github.com/labstack/echo/v4"
"regexp"
) )
type ( type (
@ -54,18 +51,8 @@ func RewriteWithConfig(config RewriteConfig) echo.MiddlewareFunc {
if config.Skipper == nil { if config.Skipper == nil {
config.Skipper = DefaultBodyDumpConfig.Skipper config.Skipper = DefaultBodyDumpConfig.Skipper
} }
config.rulesRegex = map[*regexp.Regexp]string{}
// Initialize config.rulesRegex = rewriteRulesRegex(config.Rules)
for k, v := range config.Rules {
k = regexp.QuoteMeta(k)
k = strings.Replace(k, `\*`, "(.*)", -1)
if strings.HasPrefix(k, `\^`) {
k = strings.Replace(k, `\^`, "^", -1)
}
k = k + "$"
config.rulesRegex[regexp.MustCompile(k)] = v
}
return func(next echo.HandlerFunc) echo.HandlerFunc { return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) (err error) { return func(c echo.Context) (err error) {
@ -74,17 +61,8 @@ func RewriteWithConfig(config RewriteConfig) echo.MiddlewareFunc {
} }
req := c.Request() req := c.Request()
// Rewrite // Set rewrite path and raw path
for k, v := range config.rulesRegex { rewritePath(config.rulesRegex, req)
//use req.URL.Path here or else we will have double escaping
replacer := captureTokens(k, req.URL.Path)
if replacer != nil {
if err := rewritePath(replacer, v, req); err != nil {
return echo.NewHTTPError(http.StatusBadRequest, "invalid url")
}
break
}
}
return next(c) return next(c)
} }
} }

View File

@ -4,6 +4,7 @@ import (
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/url"
"testing" "testing"
"github.com/labstack/echo/v4" "github.com/labstack/echo/v4"
@ -23,33 +24,28 @@ func TestRewrite(t *testing.T) {
})) }))
req := httptest.NewRequest(http.MethodGet, "/", nil) req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
req.URL.Path = "/api/users" req.URL, _ = url.Parse("/api/users")
e.ServeHTTP(rec, req) e.ServeHTTP(rec, req)
assert.Equal(t, "/users", req.URL.EscapedPath()) assert.Equal(t, "/users", req.URL.EscapedPath())
req.URL.Path = "/js/main.js" req.URL, _ = url.Parse("/js/main.js")
rec = httptest.NewRecorder() rec = httptest.NewRecorder()
e.ServeHTTP(rec, req) e.ServeHTTP(rec, req)
assert.Equal(t, "/public/javascripts/main.js", req.URL.EscapedPath()) assert.Equal(t, "/public/javascripts/main.js", req.URL.EscapedPath())
req.URL.Path = "/old" req.URL, _ = url.Parse("/old")
rec = httptest.NewRecorder() rec = httptest.NewRecorder()
e.ServeHTTP(rec, req) e.ServeHTTP(rec, req)
assert.Equal(t, "/new", req.URL.EscapedPath()) assert.Equal(t, "/new", req.URL.EscapedPath())
req.URL.Path = "/users/jack/orders/1" req.URL, _ = url.Parse("/users/jack/orders/1")
rec = httptest.NewRecorder() rec = httptest.NewRecorder()
e.ServeHTTP(rec, req) e.ServeHTTP(rec, req)
assert.Equal(t, "/user/jack/order/1", req.URL.EscapedPath()) assert.Equal(t, "/user/jack/order/1", req.URL.EscapedPath())
req.URL.Path = "/api/new users" req.URL, _ = url.Parse("/user/jill/order/T%2FcO4lW%2Ft%2FVp%2F")
rec = httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(t, "/new%20users", req.URL.EscapedPath())
req.URL.Path = "/users/jill/orders/T%2FcO4lW%2Ft%2FVp%2F"
rec = httptest.NewRecorder() rec = httptest.NewRecorder()
e.ServeHTTP(rec, req) e.ServeHTTP(rec, req)
assert.Equal(t, "/user/jill/order/T%2FcO4lW%2Ft%2FVp%2F", req.URL.EscapedPath()) assert.Equal(t, "/user/jill/order/T%2FcO4lW%2Ft%2FVp%2F", req.URL.EscapedPath())
req.URL.Path = "/users/jill/orders/%%%%" req.URL, _ = url.Parse("/api/new users")
rec = httptest.NewRecorder()
e.ServeHTTP(rec, req) e.ServeHTTP(rec, req)
assert.Equal(t, http.StatusBadRequest, rec.Code) assert.Equal(t, "/new%20users", req.URL.EscapedPath())
} }
// Issue #1086 // Issue #1086
@ -58,11 +54,10 @@ func TestEchoRewritePreMiddleware(t *testing.T) {
r := e.Router() r := e.Router()
// Rewrite old url to new one // Rewrite old url to new one
e.Pre(RewriteWithConfig(RewriteConfig{ e.Pre(Rewrite(map[string]string{
Rules: map[string]string{
"/old": "/new", "/old": "/new",
}, },
})) ))
// Route // Route
r.Add(http.MethodGet, "/new", func(c echo.Context) error { r.Add(http.MethodGet, "/new", func(c echo.Context) error {