mirror of
https://github.com/labstack/echo.git
synced 2025-01-12 01:22:21 +02:00
bugfix proxy and rewrite, updated test with actual call settings
This commit is contained in:
parent
151ed6b3f1
commit
f6dfcbe774
@ -2,7 +2,6 @@ package middleware
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
|
||||||
"regexp"
|
"regexp"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
@ -34,15 +33,29 @@ func captureTokens(pattern *regexp.Regexp, input string) *strings.Replacer {
|
|||||||
return strings.NewReplacer(replace...)
|
return strings.NewReplacer(replace...)
|
||||||
}
|
}
|
||||||
|
|
||||||
//rewritePath sets request url path and raw path
|
func rewriteRulesRegex(rewrite map[string]string) map[*regexp.Regexp]string {
|
||||||
func rewritePath(replacer *strings.Replacer, target string, req *http.Request) error {
|
// Initialize
|
||||||
replacerRawPath := replacer.Replace(target)
|
rulesRegex := map[*regexp.Regexp]string{}
|
||||||
replacerPath, err := url.PathUnescape(replacerRawPath)
|
for k, v := range rewrite {
|
||||||
if err != nil {
|
k = regexp.QuoteMeta(k)
|
||||||
return err
|
k = strings.Replace(k, `\*`, "(.*)", -1)
|
||||||
|
if strings.HasPrefix(k, `\^`) {
|
||||||
|
k = strings.Replace(k, `\^`, "^", -1)
|
||||||
|
}
|
||||||
|
k = k + "$"
|
||||||
|
rulesRegex[regexp.MustCompile(k)] = v
|
||||||
|
}
|
||||||
|
return rulesRegex
|
||||||
|
}
|
||||||
|
|
||||||
|
func rewritePath(rewriteRegex map[*regexp.Regexp]string, req *http.Request) {
|
||||||
|
for k, v := range rewriteRegex {
|
||||||
|
replacerRawPath := captureTokens(k, req.URL.EscapedPath())
|
||||||
|
if replacerRawPath != nil {
|
||||||
|
replacerPath := captureTokens(k, req.URL.Path)
|
||||||
|
req.URL.RawPath, req.URL.Path = replacerRawPath.Replace(v), replacerPath.Replace(v)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
req.URL.Path, req.URL.RawPath = replacerPath, replacerRawPath
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// DefaultSkipper returns false which processes the middleware.
|
// DefaultSkipper returns false which processes the middleware.
|
||||||
|
@ -8,7 +8,6 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"regexp"
|
"regexp"
|
||||||
"strings"
|
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
@ -206,13 +205,8 @@ func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc {
|
|||||||
if config.Balancer == nil {
|
if config.Balancer == nil {
|
||||||
panic("echo: proxy middleware requires balancer")
|
panic("echo: proxy middleware requires balancer")
|
||||||
}
|
}
|
||||||
config.rewriteRegex = map[*regexp.Regexp]string{}
|
|
||||||
|
|
||||||
// Initialize
|
config.rewriteRegex = rewriteRulesRegex(config.Rewrite)
|
||||||
for k, v := range config.Rewrite {
|
|
||||||
k = strings.Replace(k, "*", "(\\S*)", -1)
|
|
||||||
config.rewriteRegex[regexp.MustCompile(k)] = v
|
|
||||||
}
|
|
||||||
|
|
||||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||||
return func(c echo.Context) (err error) {
|
return func(c echo.Context) (err error) {
|
||||||
@ -225,16 +219,8 @@ func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc {
|
|||||||
tgt := config.Balancer.Next(c)
|
tgt := config.Balancer.Next(c)
|
||||||
c.Set(config.ContextKey, tgt)
|
c.Set(config.ContextKey, tgt)
|
||||||
|
|
||||||
// Rewrite
|
// Set rewrite path and raw path
|
||||||
for k, v := range config.rewriteRegex {
|
rewritePath(config.rewriteRegex, req)
|
||||||
//use req.URL.Path here or else we will have double escaping
|
|
||||||
replacer := captureTokens(k, req.URL.Path)
|
|
||||||
if replacer != nil {
|
|
||||||
if err := rewritePath(replacer, v, req); err != nil {
|
|
||||||
return echo.NewHTTPError(http.StatusBadRequest, "invalid url")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Fix header
|
// Fix header
|
||||||
// Basically it's not good practice to unconditionally pass incoming x-real-ip header to upstream.
|
// Basically it's not good practice to unconditionally pass incoming x-real-ip header to upstream.
|
||||||
@ -265,3 +251,5 @@ func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -94,36 +94,35 @@ func TestProxy(t *testing.T) {
|
|||||||
"/users/*/orders/*": "/user/$1/order/$2",
|
"/users/*/orders/*": "/user/$1/order/$2",
|
||||||
},
|
},
|
||||||
}))
|
}))
|
||||||
req.URL.Path = "/api/users"
|
req.URL, _ = url.Parse("/api/users")
|
||||||
rec = httptest.NewRecorder()
|
rec = httptest.NewRecorder()
|
||||||
e.ServeHTTP(rec, req)
|
e.ServeHTTP(rec, req)
|
||||||
assert.Equal(t, "/users", req.URL.EscapedPath())
|
assert.Equal(t, "/users", req.URL.EscapedPath())
|
||||||
assert.Equal(t, http.StatusOK, rec.Code)
|
assert.Equal(t, http.StatusOK, rec.Code)
|
||||||
req.URL.Path = "/js/main.js"
|
req.URL, _ = url.Parse( "/js/main.js")
|
||||||
rec = httptest.NewRecorder()
|
rec = httptest.NewRecorder()
|
||||||
e.ServeHTTP(rec, req)
|
e.ServeHTTP(rec, req)
|
||||||
assert.Equal(t, "/public/javascripts/main.js", req.URL.EscapedPath())
|
assert.Equal(t, "/public/javascripts/main.js", req.URL.EscapedPath())
|
||||||
assert.Equal(t, http.StatusOK, rec.Code)
|
assert.Equal(t, http.StatusOK, rec.Code)
|
||||||
req.URL.Path = "/old"
|
req.URL, _ = url.Parse("/old")
|
||||||
rec = httptest.NewRecorder()
|
rec = httptest.NewRecorder()
|
||||||
e.ServeHTTP(rec, req)
|
e.ServeHTTP(rec, req)
|
||||||
assert.Equal(t, "/new", req.URL.EscapedPath())
|
assert.Equal(t, "/new", req.URL.EscapedPath())
|
||||||
assert.Equal(t, http.StatusOK, rec.Code)
|
assert.Equal(t, http.StatusOK, rec.Code)
|
||||||
req.URL.Path = "/users/jack/orders/1"
|
req.URL, _ = url.Parse( "/users/jack/orders/1")
|
||||||
rec = httptest.NewRecorder()
|
rec = httptest.NewRecorder()
|
||||||
e.ServeHTTP(rec, req)
|
e.ServeHTTP(rec, req)
|
||||||
assert.Equal(t, "/user/jack/order/1", req.URL.EscapedPath())
|
assert.Equal(t, "/user/jack/order/1", req.URL.EscapedPath())
|
||||||
assert.Equal(t, http.StatusOK, rec.Code)
|
assert.Equal(t, http.StatusOK, rec.Code)
|
||||||
req.URL.Path = "/users/jill/orders/T%2FcO4lW%2Ft%2FVp%2F"
|
req.URL, _ = url.Parse("/user/jill/order/T%2FcO4lW%2Ft%2FVp%2F")
|
||||||
rec = httptest.NewRecorder()
|
rec = httptest.NewRecorder()
|
||||||
e.ServeHTTP(rec, req)
|
e.ServeHTTP(rec, req)
|
||||||
assert.Equal(t, "/user/jill/order/T%2FcO4lW%2Ft%2FVp%2F", req.URL.EscapedPath())
|
assert.Equal(t, "/user/jill/order/T%2FcO4lW%2Ft%2FVp%2F", req.URL.EscapedPath())
|
||||||
assert.Equal(t, http.StatusOK, rec.Code)
|
assert.Equal(t, http.StatusOK, rec.Code)
|
||||||
req.URL.Path = "/users/jill/orders/%%%%"
|
req.URL, _ = url.Parse("/api/new users")
|
||||||
rec = httptest.NewRecorder()
|
rec = httptest.NewRecorder()
|
||||||
e.ServeHTTP(rec, req)
|
e.ServeHTTP(rec, req)
|
||||||
assert.Equal(t, http.StatusBadRequest, rec.Code)
|
assert.Equal(t, "/new%20users", req.URL.EscapedPath())
|
||||||
|
|
||||||
// ModifyResponse
|
// ModifyResponse
|
||||||
e = echo.New()
|
e = echo.New()
|
||||||
e.Use(ProxyWithConfig(ProxyConfig{
|
e.Use(ProxyWithConfig(ProxyConfig{
|
||||||
|
@ -1,11 +1,8 @@
|
|||||||
package middleware
|
package middleware
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net/http"
|
|
||||||
"regexp"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/labstack/echo/v4"
|
"github.com/labstack/echo/v4"
|
||||||
|
"regexp"
|
||||||
)
|
)
|
||||||
|
|
||||||
type (
|
type (
|
||||||
@ -54,18 +51,8 @@ func RewriteWithConfig(config RewriteConfig) echo.MiddlewareFunc {
|
|||||||
if config.Skipper == nil {
|
if config.Skipper == nil {
|
||||||
config.Skipper = DefaultBodyDumpConfig.Skipper
|
config.Skipper = DefaultBodyDumpConfig.Skipper
|
||||||
}
|
}
|
||||||
config.rulesRegex = map[*regexp.Regexp]string{}
|
|
||||||
|
|
||||||
// Initialize
|
config.rulesRegex = rewriteRulesRegex(config.Rules)
|
||||||
for k, v := range config.Rules {
|
|
||||||
k = regexp.QuoteMeta(k)
|
|
||||||
k = strings.Replace(k, `\*`, "(.*)", -1)
|
|
||||||
if strings.HasPrefix(k, `\^`) {
|
|
||||||
k = strings.Replace(k, `\^`, "^", -1)
|
|
||||||
}
|
|
||||||
k = k + "$"
|
|
||||||
config.rulesRegex[regexp.MustCompile(k)] = v
|
|
||||||
}
|
|
||||||
|
|
||||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||||
return func(c echo.Context) (err error) {
|
return func(c echo.Context) (err error) {
|
||||||
@ -74,17 +61,8 @@ func RewriteWithConfig(config RewriteConfig) echo.MiddlewareFunc {
|
|||||||
}
|
}
|
||||||
|
|
||||||
req := c.Request()
|
req := c.Request()
|
||||||
// Rewrite
|
// Set rewrite path and raw path
|
||||||
for k, v := range config.rulesRegex {
|
rewritePath(config.rulesRegex, req)
|
||||||
//use req.URL.Path here or else we will have double escaping
|
|
||||||
replacer := captureTokens(k, req.URL.Path)
|
|
||||||
if replacer != nil {
|
|
||||||
if err := rewritePath(replacer, v, req); err != nil {
|
|
||||||
return echo.NewHTTPError(http.StatusBadRequest, "invalid url")
|
|
||||||
}
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return next(c)
|
return next(c)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -4,6 +4,7 @@ import (
|
|||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
|
"net/url"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/labstack/echo/v4"
|
"github.com/labstack/echo/v4"
|
||||||
@ -23,33 +24,28 @@ func TestRewrite(t *testing.T) {
|
|||||||
}))
|
}))
|
||||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
rec := httptest.NewRecorder()
|
rec := httptest.NewRecorder()
|
||||||
req.URL.Path = "/api/users"
|
req.URL, _ = url.Parse("/api/users")
|
||||||
e.ServeHTTP(rec, req)
|
e.ServeHTTP(rec, req)
|
||||||
assert.Equal(t, "/users", req.URL.EscapedPath())
|
assert.Equal(t, "/users", req.URL.EscapedPath())
|
||||||
req.URL.Path = "/js/main.js"
|
req.URL, _ = url.Parse("/js/main.js")
|
||||||
rec = httptest.NewRecorder()
|
rec = httptest.NewRecorder()
|
||||||
e.ServeHTTP(rec, req)
|
e.ServeHTTP(rec, req)
|
||||||
assert.Equal(t, "/public/javascripts/main.js", req.URL.EscapedPath())
|
assert.Equal(t, "/public/javascripts/main.js", req.URL.EscapedPath())
|
||||||
req.URL.Path = "/old"
|
req.URL, _ = url.Parse("/old")
|
||||||
rec = httptest.NewRecorder()
|
rec = httptest.NewRecorder()
|
||||||
e.ServeHTTP(rec, req)
|
e.ServeHTTP(rec, req)
|
||||||
assert.Equal(t, "/new", req.URL.EscapedPath())
|
assert.Equal(t, "/new", req.URL.EscapedPath())
|
||||||
req.URL.Path = "/users/jack/orders/1"
|
req.URL, _ = url.Parse("/users/jack/orders/1")
|
||||||
rec = httptest.NewRecorder()
|
rec = httptest.NewRecorder()
|
||||||
e.ServeHTTP(rec, req)
|
e.ServeHTTP(rec, req)
|
||||||
assert.Equal(t, "/user/jack/order/1", req.URL.EscapedPath())
|
assert.Equal(t, "/user/jack/order/1", req.URL.EscapedPath())
|
||||||
req.URL.Path = "/api/new users"
|
req.URL, _ = url.Parse("/user/jill/order/T%2FcO4lW%2Ft%2FVp%2F")
|
||||||
rec = httptest.NewRecorder()
|
|
||||||
e.ServeHTTP(rec, req)
|
|
||||||
assert.Equal(t, "/new%20users", req.URL.EscapedPath())
|
|
||||||
req.URL.Path = "/users/jill/orders/T%2FcO4lW%2Ft%2FVp%2F"
|
|
||||||
rec = httptest.NewRecorder()
|
rec = httptest.NewRecorder()
|
||||||
e.ServeHTTP(rec, req)
|
e.ServeHTTP(rec, req)
|
||||||
assert.Equal(t, "/user/jill/order/T%2FcO4lW%2Ft%2FVp%2F", req.URL.EscapedPath())
|
assert.Equal(t, "/user/jill/order/T%2FcO4lW%2Ft%2FVp%2F", req.URL.EscapedPath())
|
||||||
req.URL.Path = "/users/jill/orders/%%%%"
|
req.URL, _ = url.Parse("/api/new users")
|
||||||
rec = httptest.NewRecorder()
|
|
||||||
e.ServeHTTP(rec, req)
|
e.ServeHTTP(rec, req)
|
||||||
assert.Equal(t, http.StatusBadRequest, rec.Code)
|
assert.Equal(t, "/new%20users", req.URL.EscapedPath())
|
||||||
}
|
}
|
||||||
|
|
||||||
// Issue #1086
|
// Issue #1086
|
||||||
@ -58,11 +54,10 @@ func TestEchoRewritePreMiddleware(t *testing.T) {
|
|||||||
r := e.Router()
|
r := e.Router()
|
||||||
|
|
||||||
// Rewrite old url to new one
|
// Rewrite old url to new one
|
||||||
e.Pre(RewriteWithConfig(RewriteConfig{
|
e.Pre(Rewrite(map[string]string{
|
||||||
Rules: map[string]string{
|
|
||||||
"/old": "/new",
|
"/old": "/new",
|
||||||
},
|
},
|
||||||
}))
|
))
|
||||||
|
|
||||||
// Route
|
// Route
|
||||||
r.Add(http.MethodGet, "/new", func(c echo.Context) error {
|
r.Add(http.MethodGet, "/new", func(c echo.Context) error {
|
||||||
|
Loading…
Reference in New Issue
Block a user