From 0ce73028d0815e0ecec80964cc2da42d98fafa33 Mon Sep 17 00:00:00 2001 From: Hristo Hristov Date: Sat, 29 Oct 2022 21:54:23 +0300 Subject: [PATCH] [suggestion] Add helper interface for ProxyBalancer interface (#2316) * [suggestion] Add helper interface for ProxyBalancer interface * Update proxy_test.go * addressed code review comments * address pr comments * clean up * return error --- middleware/proxy.go | 17 ++++++++++++- middleware/proxy_test.go | 52 ++++++++++++++++++++++++++++++++++++++-- 2 files changed, 66 insertions(+), 3 deletions(-) diff --git a/middleware/proxy.go b/middleware/proxy.go index 6cfd6731..d2cd2aa6 100644 --- a/middleware/proxy.go +++ b/middleware/proxy.go @@ -72,6 +72,11 @@ type ( Next(echo.Context) *ProxyTarget } + // TargetProvider defines an interface that gives the opportunity for balancer to return custom errors when selecting target. + TargetProvider interface { + NextTarget(echo.Context) (*ProxyTarget, error) + } + commonBalancer struct { targets []*ProxyTarget mutex sync.RWMutex @@ -223,6 +228,7 @@ func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc { } } + provider, isTargetProvider := config.Balancer.(TargetProvider) return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) (err error) { if config.Skipper(c) { @@ -231,7 +237,16 @@ func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc { req := c.Request() res := c.Response() - tgt := config.Balancer.Next(c) + + var tgt *ProxyTarget + if isTargetProvider { + tgt, err = provider.NextTarget(c) + if err != nil { + return err + } + } else { + tgt = config.Balancer.Next(c) + } c.Set(config.ContextKey, tgt) if err := rewriteURL(config.RegexRewrite, req); err != nil { diff --git a/middleware/proxy_test.go b/middleware/proxy_test.go index 7939fc5c..0ded50a1 100644 --- a/middleware/proxy_test.go +++ b/middleware/proxy_test.go @@ -18,7 +18,7 @@ import ( "github.com/stretchr/testify/assert" ) -//Assert expected with url.EscapedPath method to obtain the path. +// Assert expected with url.EscapedPath method to obtain the path. func TestProxy(t *testing.T) { // Setup t1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -31,7 +31,6 @@ func TestProxy(t *testing.T) { })) defer t2.Close() url2, _ := url.Parse(t2.URL) - targets := []*ProxyTarget{ { Name: "target 1", @@ -122,6 +121,55 @@ func TestProxy(t *testing.T) { e.ServeHTTP(rec, req) } +type testProvider struct { + *commonBalancer + target *ProxyTarget + err error +} + +func (p *testProvider) Next(c echo.Context) *ProxyTarget { + return &ProxyTarget{} +} + +func (p *testProvider) NextTarget(c echo.Context) (*ProxyTarget, error) { + return p.target, p.err +} + +func TestTargetProvider(t *testing.T) { + t1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprint(w, "target 1") + })) + defer t1.Close() + url1, _ := url.Parse(t1.URL) + + e := echo.New() + tp := &testProvider{commonBalancer: new(commonBalancer)} + tp.target = &ProxyTarget{Name: "target 1", URL: url1} + e.Use(Proxy(tp)) + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/", nil) + e.ServeHTTP(rec, req) + body := rec.Body.String() + assert.Equal(t, "target 1", body) +} + +func TestFailNextTarget(t *testing.T) { + url1, err := url.Parse("http://dummy:8080") + assert.Nil(t, err) + + e := echo.New() + tp := &testProvider{commonBalancer: new(commonBalancer)} + tp.target = &ProxyTarget{Name: "target 1", URL: url1} + tp.err = echo.NewHTTPError(http.StatusInternalServerError, "method could not select target") + + e.Use(Proxy(tp)) + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/", nil) + e.ServeHTTP(rec, req) + body := rec.Body.String() + assert.Equal(t, "{\"message\":\"method could not select target\"}\n", body) +} + func TestProxyRealIPHeader(t *testing.T) { // Setup upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))