1
0
mirror of https://github.com/labstack/echo.git synced 2025-01-07 23:01:56 +02:00

proxy middleware: reuse echo request context (#2537)

This commit is contained in:
Kai Ratzeburg 2023-11-05 17:01:01 +01:00 committed by GitHub
parent 69a0de8415
commit c7d6d4373f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 64 additions and 0 deletions

View File

@ -359,6 +359,10 @@ func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc {
c.Set("_error", nil) c.Set("_error", nil)
} }
// This is needed for ProxyConfig.ModifyResponse and/or ProxyConfig.Transport to be able to process the Request
// that Balancer may have replaced with c.SetRequest.
req = c.Request()
// Proxy // Proxy
switch { switch {
case c.IsWebSocket(): case c.IsWebSocket():

View File

@ -747,3 +747,63 @@ func TestProxyBalancerWithNoTargets(t *testing.T) {
rrb := NewRoundRobinBalancer([]*ProxyTarget{}) rrb := NewRoundRobinBalancer([]*ProxyTarget{})
assert.Nil(t, rrb.Next(nil)) assert.Nil(t, rrb.Next(nil))
} }
type testContextKey string
type customBalancer struct {
target *ProxyTarget
}
func (b *customBalancer) AddTarget(target *ProxyTarget) bool {
return false
}
func (b *customBalancer) RemoveTarget(name string) bool {
return false
}
func (b *customBalancer) Next(c echo.Context) *ProxyTarget {
ctx := context.WithValue(c.Request().Context(), testContextKey("FROM_BALANCER"), "CUSTOM_BALANCER")
c.SetRequest(c.Request().WithContext(ctx))
return b.target
}
func TestModifyResponseUseContext(t *testing.T) {
server := httptest.NewServer(
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("OK"))
}),
)
defer server.Close()
targetURL, _ := url.Parse(server.URL)
e := echo.New()
e.Use(ProxyWithConfig(
ProxyConfig{
Balancer: &customBalancer{
target: &ProxyTarget{
Name: "tst",
URL: targetURL,
},
},
RetryCount: 1,
ModifyResponse: func(res *http.Response) error {
val := res.Request.Context().Value(testContextKey("FROM_BALANCER"))
if valStr, ok := val.(string); ok {
res.Header.Set("FROM_BALANCER", valStr)
}
return nil
},
},
))
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, "OK", rec.Body.String())
assert.Equal(t, "CUSTOM_BALANCER", rec.Header().Get("FROM_BALANCER"))
}